diff --git a/lib/cloud/clients.go b/lib/cloud/clients.go index ac0deddb52b67..d11a77caa22ab 100644 --- a/lib/cloud/clients.go +++ b/lib/cloud/clients.go @@ -610,10 +610,10 @@ func (c *cloudClients) getAWSSessionForRegion(region string) (*awssession.Sessio // getAWSSessionForRole returns AWS session for the specified region and role. func (c *cloudClients) getAWSSessionForRole(ctx context.Context, region string, options awsAssumeRoleOpts) (*awssession.Session, error) { - assumeRoler := sts.New(options.baseSession) cacheKey := fmt.Sprintf("Region[%s]:RoleARN[%s]:ExternalID[%s]", region, options.assumeRoleARN, options.assumeRoleExternalID) return utils.FnCacheGet(ctx, c.awsSessionsCache, cacheKey, func(ctx context.Context) (*awssession.Session, error) { - return newSessionWithRole(ctx, assumeRoler, region, options.assumeRoleARN, options.assumeRoleExternalID) + stsClient := sts.New(options.baseSession) + return newSessionWithRole(ctx, stsClient, region, options.assumeRoleARN, options.assumeRoleExternalID) }) } diff --git a/lib/integrations/awsoidc/listdatabases_test.go b/lib/integrations/awsoidc/listdatabases_test.go index 696538e5a5cd4..abec450741af8 100644 --- a/lib/integrations/awsoidc/listdatabases_test.go +++ b/lib/integrations/awsoidc/listdatabases_test.go @@ -24,6 +24,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/rds" rdsTypes "github.com/aws/aws-sdk-go-v2/service/rds/types" + "github.com/google/go-cmp/cmp" "github.com/gravitational/trace" "github.com/stretchr/testify/require" @@ -210,7 +211,7 @@ func TestListDatabases(t *testing.T) { }, ) require.NoError(t, err) - require.Equal(t, expectedDB, ldr.Databases[0]) + require.Empty(t, cmp.Diff(expectedDB, ldr.Databases[0])) }, errCheck: noErrorFunc, }, @@ -275,7 +276,7 @@ func TestListDatabases(t *testing.T) { }, ) require.NoError(t, err) - require.Equal(t, expectedDB, ldr.Databases[0]) + require.Empty(t, cmp.Diff(expectedDB, ldr.Databases[0])) }, errCheck: noErrorFunc, }, @@ -328,7 +329,7 @@ func TestListDatabases(t *testing.T) { }, ) require.NoError(t, err) - require.Equal(t, expectedDB, ldr.Databases[0]) + require.Empty(t, cmp.Diff(expectedDB, ldr.Databases[0])) }, errCheck: noErrorFunc, }, diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go index f96160dd0821c..44ad72e16fb71 100644 --- a/lib/srv/discovery/discovery.go +++ b/lib/srv/discovery/discovery.go @@ -739,10 +739,8 @@ func (s *Server) Start() error { if s.gcpWatcher != nil { go s.handleGCPDiscovery() } - if len(s.kubeFetchers) > 0 { - if err := s.startKubeWatchers(); err != nil { - return trace.Wrap(err) - } + if err := s.startKubeWatchers(); err != nil { + return trace.Wrap(err) } if err := s.startDatabaseWatchers(); err != nil { return trace.Wrap(err) diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index e3e66116c3eb2..2e093e69330bb 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -506,7 +506,7 @@ func TestDiscoveryKube(t *testing.T) { mustConvertEKSToKubeCluster(t, eksMockClusters[0], mainDiscoveryGroup), mustConvertEKSToKubeCluster(t, eksMockClusters[1], mainDiscoveryGroup), }, - clustersNotUpdated: []string{"eks-cluster1"}, + clustersNotUpdated: []string{mustConvertEKSToKubeCluster(t, eksMockClusters[0], mainDiscoveryGroup).GetName()}, }, { name: "1 cluster in auth that belongs the same discovery group but has unmatched labels + import 2 prod clusters from EKS", @@ -593,7 +593,7 @@ func TestDiscoveryKube(t *testing.T) { mustConvertAKSToKubeCluster(t, aksMockClusters["group1"][0], mainDiscoveryGroup), mustConvertAKSToKubeCluster(t, aksMockClusters["group1"][1], mainDiscoveryGroup), }, - clustersNotUpdated: []string{"aks-cluster1"}, + clustersNotUpdated: []string{mustConvertAKSToKubeCluster(t, aksMockClusters["group1"][0], mainDiscoveryGroup).GetName()}, }, { name: "no clusters in auth server, import 2 prod clusters from GKE", @@ -1061,7 +1061,7 @@ func TestDiscoveryDatabase(t *testing.T) { name: "update existing database", existingDatabases: []types.Database{ mustNewDatabase(t, types.Metadata{ - Name: "aws-redshift", + Name: awsRedshiftDB.GetName(), Description: "should be updated", Labels: map[string]string{types.OriginLabel: types.OriginCloud, types.TeleportInternalDiscoveryGroupName: mainDiscoveryGroup}, }, types.DatabaseSpecV3{ @@ -1085,7 +1085,7 @@ func TestDiscoveryDatabase(t *testing.T) { name: "update existing database with assumed role", existingDatabases: []types.Database{ mustNewDatabase(t, types.Metadata{ - Name: "aws-rds", + Name: awsRDSDBWithRole.GetName(), Description: "should be updated", Labels: map[string]string{types.OriginLabel: types.OriginCloud, types.TeleportInternalDiscoveryGroupName: mainDiscoveryGroup}, }, types.DatabaseSpecV3{ @@ -1105,7 +1105,7 @@ func TestDiscoveryDatabase(t *testing.T) { name: "delete existing database", existingDatabases: []types.Database{ mustNewDatabase(t, types.Metadata{ - Name: "aws-redshift", + Name: awsRedshiftDB.GetName(), Description: "should not be deleted", Labels: map[string]string{types.OriginLabel: types.OriginCloud}, }, types.DatabaseSpecV3{ @@ -1120,7 +1120,7 @@ func TestDiscoveryDatabase(t *testing.T) { }}, expectDatabases: []types.Database{ mustNewDatabase(t, types.Metadata{ - Name: "aws-redshift", + Name: awsRedshiftDB.GetName(), Description: "should not be deleted", Labels: map[string]string{types.OriginLabel: types.OriginCloud}, }, types.DatabaseSpecV3{ diff --git a/lib/srv/discovery/fetchers/aks.go b/lib/srv/discovery/fetchers/aks.go index a9eb2af088c4c..91bc688dbef3d 100644 --- a/lib/srv/discovery/fetchers/aks.go +++ b/lib/srv/discovery/fetchers/aks.go @@ -103,9 +103,16 @@ func (a *aksFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) kubeClusters = append(kubeClusters, kubeCluster) } + + a.rewriteKubeClusters(kubeClusters) return kubeClusters.AsResources(), nil } +// rewriteKubeClusters rewrites the discovered kube clusters. +func (a *aksFetcher) rewriteKubeClusters(clusters types.KubeClusters) { + // no-op +} + func (a *aksFetcher) getAKSClusters(ctx context.Context) ([]*azure.AKSCluster, error) { var ( clusters []*azure.AKSCluster diff --git a/lib/srv/discovery/fetchers/db/aws.go b/lib/srv/discovery/fetchers/db/aws.go index f3e900358c258..0493462dd2623 100644 --- a/lib/srv/discovery/fetchers/db/aws.go +++ b/lib/srv/discovery/fetchers/db/aws.go @@ -17,11 +17,121 @@ limitations under the License. package db import ( + "context" + "fmt" + + "github.com/gravitational/trace" + "github.com/sirupsen/logrus" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/srv/discovery/common" ) +// awsFetcherPlugin defines an interface that provides database type specific +// functions for use by the common AWS database fetcher. +type awsFetcherPlugin interface { + // GetDatabases fetches databases from AWS API and converts the results to + // Teleport types.Databases. + GetDatabases(context.Context, *awsFetcherConfig) (types.Databases, error) + // ComponentShortName provides the plugin's short component name for + // logging purposes. + ComponentShortName() string +} + +// awsFetcherConfig is the AWS database fetcher configuration. +type awsFetcherConfig struct { + // AWSClients are the AWS API clients. + AWSClients cloud.AWSClients + // Type is the type of DB matcher, for example "rds", "redshift", etc. + Type string + // AssumeRole provides a role ARN and ExternalID to assume an AWS role + // when fetching databases. + AssumeRole types.AssumeRole + // Labels is a selector to match cloud database tags. + Labels types.Labels + // Region is the AWS region selector to match cloud databases. + Region string + // Log is a field logger to provide structured logging for each matcher, + // based on its config settings by default. + Log logrus.FieldLogger +} + +// CheckAndSetDefaults validates the config and sets defaults. +func (cfg *awsFetcherConfig) CheckAndSetDefaults(component string) error { + if cfg.AWSClients == nil { + return trace.BadParameter("missing parameter AWSClients") + } + if cfg.Type == "" { + return trace.BadParameter("missing parameter Type") + } + if len(cfg.Labels) == 0 { + return trace.BadParameter("missing parameter Labels") + } + if cfg.Region == "" { + return trace.BadParameter("missing parameter Region") + } + if cfg.Log == nil { + cfg.Log = logrus.WithFields(logrus.Fields{ + trace.Component: "watch:" + component, + "labels": cfg.Labels, + "region": cfg.Region, + "role": cfg.AssumeRole, + }) + } + return nil +} + +// newAWSFetcher returns a AWS database fetcher for the provided selectors +// and AWS database-type specific fetcher plugin. +func newAWSFetcher(cfg awsFetcherConfig, plugin awsFetcherPlugin) (*awsFetcher, error) { + if err := cfg.CheckAndSetDefaults(plugin.ComponentShortName()); err != nil { + return nil, trace.Wrap(err) + } + return &awsFetcher{cfg: cfg, plugin: plugin}, nil +} + // awsFetcher is the common base for AWS database fetchers. type awsFetcher struct { + // cfg is the awsFetcher configuration. + cfg awsFetcherConfig + // plugin does AWS database type specific API calls fetch databases. + plugin awsFetcherPlugin +} + +// awsFetcher implements common.Fetcher. +var _ common.Fetcher = (*awsFetcher)(nil) + +// Get returns AWS databases matching the fetcher's selectors. +func (f *awsFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) { + databases, err := f.getDatabases(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + f.rewriteDatabases(databases) + return databases.AsResources(), nil +} + +func (f *awsFetcher) getDatabases(ctx context.Context) (types.Databases, error) { + databases, err := f.plugin.GetDatabases(ctx, &f.cfg) + if err != nil { + return nil, trace.Wrap(err) + } + return filterDatabasesByLabels(databases, f.cfg.Labels, f.cfg.Log), nil +} + +// rewriteDatabases rewrites the discovered databases. +func (f *awsFetcher) rewriteDatabases(databases types.Databases) { + for _, db := range databases { + f.applyAssumeRole(db) + } +} + +// applyAssumeRole sets the database AWS AssumeRole metadata to match the +// fetcher's setting. +func (f *awsFetcher) applyAssumeRole(db types.Database) { + db.SetAWSAssumeRole(f.cfg.AssumeRole.RoleARN) + db.SetAWSExternalID(f.cfg.AssumeRole.ExternalID) } // Cloud returns the cloud the fetcher is operating. @@ -34,6 +144,12 @@ func (f *awsFetcher) ResourceType() string { return types.KindDatabase } +// String returns the fetcher's string description. +func (f *awsFetcher) String() string { + return fmt.Sprintf("awsFetcher(Type: %v, Region=%v, Labels=%v)", + f.cfg.Type, f.cfg.Region, f.cfg.Labels) +} + // maxAWSPages is the maximum number of pages to iterate over when fetching aws // databases. const maxAWSPages = 10 diff --git a/lib/srv/discovery/fetchers/db/aws_elasticache.go b/lib/srv/discovery/fetchers/db/aws_elasticache.go index d16fa61d05127..3851eeab6a8b9 100644 --- a/lib/srv/discovery/fetchers/db/aws_elasticache.go +++ b/lib/srv/discovery/fetchers/db/aws_elasticache.go @@ -17,75 +17,41 @@ package db import ( "context" - "fmt" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/elasticache/elasticacheiface" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" ) -// elastiCacheFetcherConfig is the ElastiCache databases fetcher configuration. -type elastiCacheFetcherConfig struct { - // Labels is a selector to match cloud databases. - Labels types.Labels - // ElastiCache is the ElastiCache API client. - ElastiCache elasticacheiface.ElastiCacheAPI - // Region is the AWS region to query databases in. - Region string - // AssumeRole is the AWS IAM role to assume before discovering databases. - AssumeRole types.AssumeRole +// newElastiCacheFetcher returns a new AWS fetcher for ElastiCache databases. +func newElastiCacheFetcher(cfg awsFetcherConfig) (common.Fetcher, error) { + return newAWSFetcher(cfg, &elastiCachePlugin{}) } -// CheckAndSetDefaults validates the config and sets defaults. -func (c *elastiCacheFetcherConfig) CheckAndSetDefaults() error { - if len(c.Labels) == 0 { - return trace.BadParameter("missing parameter Labels") - } - if c.ElastiCache == nil { - return trace.BadParameter("missing parameter ElastiCache") - } - if c.Region == "" { - return trace.BadParameter("missing parameter Region") - } - return nil -} +// elastiCachePlugin retrieves ElastiCache Redis databases. +type elastiCachePlugin struct{} -// elastiCacheFetcher retrieves ElastiCache Redis databases. -type elastiCacheFetcher struct { - awsFetcher - - cfg elastiCacheFetcherConfig - log logrus.FieldLogger +func (f *elastiCachePlugin) ComponentShortName() string { + return "elasticache" } -// newElastiCacheFetcher returns a new ElastiCache databases fetcher instance. -func newElastiCacheFetcher(config elastiCacheFetcherConfig) (common.Fetcher, error) { - if err := config.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - return &elastiCacheFetcher{ - cfg: config, - log: logrus.WithFields(logrus.Fields{ - trace.Component: "watch:elasticache", - "labels": config.Labels, - "region": config.Region, - "role": config.AssumeRole, - }), - }, nil -} - -// Get returns ElastiCache Redis databases matching the watcher's selectors. +// GetDatabases returns ElastiCache Redis databases matching the watcher's selectors. // // TODO(greedy52) support ElastiCache global datastore. -func (f *elastiCacheFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) { - clusters, err := getElastiCacheClusters(ctx, f.cfg.ElastiCache) +func (f *elastiCachePlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { + ecClient, err := cfg.AWSClients.GetAWSElastiCacheClient(ctx, cfg.Region, + cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID)) + if err != nil { + return nil, trace.Wrap(err) + } + clusters, err := getElastiCacheClusters(ctx, ecClient) if err != nil { return nil, trace.Wrap(err) } @@ -93,12 +59,12 @@ func (f *elastiCacheFetcher) Get(ctx context.Context) (types.ResourcesWithLabels var eligibleClusters []*elasticache.ReplicationGroup for _, cluster := range clusters { if !services.IsElastiCacheClusterSupported(cluster) { - f.log.Debugf("ElastiCache cluster %q is not supported. Skipping.", aws.StringValue(cluster.ReplicationGroupId)) + cfg.Log.Debugf("ElastiCache cluster %q is not supported. Skipping.", aws.StringValue(cluster.ReplicationGroupId)) continue } if !services.IsElastiCacheClusterAvailable(cluster) { - f.log.Debugf("The current status of ElastiCache cluster %q is %q. Skipping.", + cfg.Log.Debugf("The current status of ElastiCache cluster %q is %q. Skipping.", aws.StringValue(cluster.ReplicationGroupId), aws.StringValue(cluster.Status)) continue @@ -108,25 +74,25 @@ func (f *elastiCacheFetcher) Get(ctx context.Context) (types.ResourcesWithLabels } if len(eligibleClusters) == 0 { - return types.ResourcesWithLabels{}, nil + return nil, nil } // Fetch more information to provide extra labels. Do not fail because some // of these labels are missing. - allNodes, err := getElastiCacheNodes(ctx, f.cfg.ElastiCache) + allNodes, err := getElastiCacheNodes(ctx, ecClient) if err != nil { if trace.IsAccessDenied(err) { - f.log.WithError(err).Debug("No permissions to describe nodes") + cfg.Log.WithError(err).Debug("No permissions to describe nodes") } else { - f.log.WithError(err).Info("Failed to describe nodes.") + cfg.Log.WithError(err).Info("Failed to describe nodes.") } } - allSubnetGroups, err := getElastiCacheSubnetGroups(ctx, f.cfg.ElastiCache) + allSubnetGroups, err := getElastiCacheSubnetGroups(ctx, ecClient) if err != nil { if trace.IsAccessDenied(err) { - f.log.WithError(err).Debug("No permissions to describe subnet groups") + cfg.Log.WithError(err).Debug("No permissions to describe subnet groups") } else { - f.log.WithError(err).Info("Failed to describe subnet groups.") + cfg.Log.WithError(err).Info("Failed to describe subnet groups.") } } @@ -135,33 +101,25 @@ func (f *elastiCacheFetcher) Get(ctx context.Context) (types.ResourcesWithLabels // Resource tags are not found in elasticache.ReplicationGroup but can // be on obtained by elasticache.ListTagsForResource (one call per // resource). - tags, err := getElastiCacheResourceTags(ctx, f.cfg.ElastiCache, cluster.ARN) + tags, err := getElastiCacheResourceTags(ctx, ecClient, cluster.ARN) if err != nil { if trace.IsAccessDenied(err) { - f.log.WithError(err).Debug("No permissions to list resource tags") + cfg.Log.WithError(err).Debug("No permissions to list resource tags") } else { - f.log.WithError(err).Infof("Failed to list resource tags for ElastiCache cluster %q.", aws.StringValue(cluster.ReplicationGroupId)) + cfg.Log.WithError(err).Infof("Failed to list resource tags for ElastiCache cluster %q.", aws.StringValue(cluster.ReplicationGroupId)) } } extraLabels := services.ExtraElastiCacheLabels(cluster, tags, allNodes, allSubnetGroups) if dbs, err := services.NewDatabasesFromElastiCacheReplicationGroup(cluster, extraLabels); err != nil { - f.log.Infof("Could not convert ElastiCache cluster %q to database resources: %v.", + cfg.Log.Infof("Could not convert ElastiCache cluster %q to database resources: %v.", aws.StringValue(cluster.ReplicationGroupId), err) } else { databases = append(databases, dbs...) } } - - applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole) - return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil -} - -// String returns the fetcher's string description. -func (f *elastiCacheFetcher) String() string { - return fmt.Sprintf("elastiCacheFetcher(Region=%v, Labels=%v)", - f.cfg.Region, f.cfg.Labels) + return databases, nil } // getElastiCacheClusters fetches all ElastiCache replication groups. diff --git a/lib/srv/discovery/fetchers/db/aws_memorydb.go b/lib/srv/discovery/fetchers/db/aws_memorydb.go index f6732aeb9403f..11efb3f51662f 100644 --- a/lib/srv/discovery/fetchers/db/aws_memorydb.go +++ b/lib/srv/discovery/fetchers/db/aws_memorydb.go @@ -17,73 +17,39 @@ package db import ( "context" - "fmt" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/memorydb/memorydbiface" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" ) -// memoryDBFetcherConfig is the MemoryDB databases fetcher configuration. -type memoryDBFetcherConfig struct { - // Labels is a selector to match cloud databases. - Labels types.Labels - // MemoryDB is the MemoryDB API client. - MemoryDB memorydbiface.MemoryDBAPI - // Region is the AWS region to query databases in. - Region string - // AssumeRole is the AWS IAM role to assume before discovering databases. - AssumeRole types.AssumeRole -} +// memoryDBPlugin retrieves MemoryDB Redis databases. +type memoryDBPlugin struct{} -// CheckAndSetDefaults validates the config and sets defaults. -func (c *memoryDBFetcherConfig) CheckAndSetDefaults() error { - if len(c.Labels) == 0 { - return trace.BadParameter("missing parameter Labels") - } - if c.MemoryDB == nil { - return trace.BadParameter("missing parameter MemoryDB") - } - if c.Region == "" { - return trace.BadParameter("missing parameter Region") - } - return nil +// newMemoryDBFetcher returns a new AWS fetcher for MemoryDB databases. +func newMemoryDBFetcher(cfg awsFetcherConfig) (common.Fetcher, error) { + return newAWSFetcher(cfg, &memoryDBPlugin{}) } -// memoryDBFetcher retrieves MemoryDB Redis databases. -type memoryDBFetcher struct { - awsFetcher - - cfg memoryDBFetcherConfig - log logrus.FieldLogger +func (f *memoryDBPlugin) ComponentShortName() string { + return "memorydb" } -// newMemoryDBFetcher returns a new MemoryDB databases fetcher instance. -func newMemoryDBFetcher(config memoryDBFetcherConfig) (common.Fetcher, error) { - if err := config.CheckAndSetDefaults(); err != nil { +// GetDatabases returns MemoryDB databases matching the watcher's selectors. +func (f *memoryDBPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { + memDBClient, err := cfg.AWSClients.GetAWSMemoryDBClient(ctx, cfg.Region, + cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID)) + if err != nil { return nil, trace.Wrap(err) } - return &memoryDBFetcher{ - cfg: config, - log: logrus.WithFields(logrus.Fields{ - trace.Component: "watch:memorydb", - "labels": config.Labels, - "region": config.Region, - "role": config.AssumeRole, - }), - }, nil -} - -// Get returns MemoryDB databases matching the watcher's selectors. -func (f *memoryDBFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) { - clusters, err := getMemoryDBClusters(ctx, f.cfg.MemoryDB) + clusters, err := getMemoryDBClusters(ctx, memDBClient) if err != nil { return nil, trace.Wrap(err) } @@ -91,12 +57,12 @@ func (f *memoryDBFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, e var eligibleClusters []*memorydb.Cluster for _, cluster := range clusters { if !services.IsMemoryDBClusterSupported(cluster) { - f.log.Debugf("MemoryDB cluster %q is not supported. Skipping.", aws.StringValue(cluster.Name)) + cfg.Log.Debugf("MemoryDB cluster %q is not supported. Skipping.", aws.StringValue(cluster.Name)) continue } if !services.IsMemoryDBClusterAvailable(cluster) { - f.log.Debugf("The current status of MemoryDB cluster %q is %q. Skipping.", + cfg.Log.Debugf("The current status of MemoryDB cluster %q is %q. Skipping.", aws.StringValue(cluster.Name), aws.StringValue(cluster.Status)) continue @@ -106,47 +72,40 @@ func (f *memoryDBFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, e } if len(eligibleClusters) == 0 { - return types.ResourcesWithLabels{}, nil + return nil, nil } // Fetch more information to provide extra labels. Do not fail because some // of these labels are missing. - allSubnetGroups, err := getMemoryDBSubnetGroups(ctx, f.cfg.MemoryDB) + allSubnetGroups, err := getMemoryDBSubnetGroups(ctx, memDBClient) if err != nil { if trace.IsAccessDenied(err) { - f.log.WithError(err).Debug("No permissions to describe subnet groups") + cfg.Log.WithError(err).Debug("No permissions to describe subnet groups") } else { - f.log.WithError(err).Info("Failed to describe subnet groups.") + cfg.Log.WithError(err).Info("Failed to describe subnet groups.") } } var databases types.Databases for _, cluster := range eligibleClusters { - tags, err := getMemoryDBResourceTags(ctx, f.cfg.MemoryDB, cluster.ARN) + tags, err := getMemoryDBResourceTags(ctx, memDBClient, cluster.ARN) if err != nil { if trace.IsAccessDenied(err) { - f.log.WithError(err).Debug("No permissions to list resource tags") + cfg.Log.WithError(err).Debug("No permissions to list resource tags") } else { - f.log.WithError(err).Infof("Failed to list resource tags for MemoryDB cluster %q.", aws.StringValue(cluster.Name)) + cfg.Log.WithError(err).Infof("Failed to list resource tags for MemoryDB cluster %q.", aws.StringValue(cluster.Name)) } } extraLabels := services.ExtraMemoryDBLabels(cluster, tags, allSubnetGroups) database, err := services.NewDatabaseFromMemoryDBCluster(cluster, extraLabels) if err != nil { - f.log.WithError(err).Infof("Could not convert memorydb cluster %q configuration endpoint to database resource.", aws.StringValue(cluster.Name)) + cfg.Log.WithError(err).Infof("Could not convert memorydb cluster %q configuration endpoint to database resource.", aws.StringValue(cluster.Name)) } else { databases = append(databases, database) } } - applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole) - return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil -} - -// String returns the fetcher's string description. -func (f *memoryDBFetcher) String() string { - return fmt.Sprintf("memorydbFetcher(Region=%v, Labels=%v)", - f.cfg.Region, f.cfg.Labels) + return databases, nil } // getMemoryDBClusters fetches all MemoryDB clusters. diff --git a/lib/srv/discovery/fetchers/db/aws_opensearch.go b/lib/srv/discovery/fetchers/db/aws_opensearch.go index f46065af6d511..f47720029d3f0 100644 --- a/lib/srv/discovery/fetchers/db/aws_opensearch.go +++ b/lib/srv/discovery/fetchers/db/aws_opensearch.go @@ -16,80 +16,46 @@ package db import ( "context" - "fmt" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/opensearchservice" "github.com/aws/aws-sdk-go/service/opensearchservice/opensearchserviceiface" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" ) -// openSearchFetcherConfig is the OpenSearch databases fetcher configuration. -type openSearchFetcherConfig struct { - // Labels is a selector to match cloud databases. - Labels types.Labels - // openSearch is the OpenSearch API client. - openSearch opensearchserviceiface.OpenSearchServiceAPI - // Region is the AWS region to query databases in. - Region string - // AssumeRole is the AWS IAM role to assume before discovering databases. - AssumeRole types.AssumeRole +// newOpenSearchFetcher returns a new AWS fetcher for OpenSearch databases. +func newOpenSearchFetcher(cfg awsFetcherConfig) (common.Fetcher, error) { + return newAWSFetcher(cfg, &openSearchPlugin{}) } -// CheckAndSetDefaults validates the config and sets defaults. -func (c *openSearchFetcherConfig) CheckAndSetDefaults() error { - if len(c.Labels) == 0 { - return trace.BadParameter("missing parameter Labels") - } - if c.openSearch == nil { - return trace.BadParameter("missing parameter openSearch") - } - if c.Region == "" { - return trace.BadParameter("missing parameter Region") - } - return nil -} - -// openSearchFetcher retrieves OpenSearch databases. -type openSearchFetcher struct { - awsFetcher +// openSearchPlugin retrieves OpenSearch databases. +type openSearchPlugin struct{} - cfg openSearchFetcherConfig - log logrus.FieldLogger +func (f *openSearchPlugin) ComponentShortName() string { + return "opensearch" } -// newOpenSearchFetcher returns a new OpenSearch databases fetcher instance. -func newOpenSearchFetcher(config openSearchFetcherConfig) (common.Fetcher, error) { - if err := config.CheckAndSetDefaults(); err != nil { +// GetDatabases returns OpenSearch databases. +func (f *openSearchPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { + opensearchClient, err := cfg.AWSClients.GetAWSOpenSearchClient(ctx, + cfg.Region, cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID)) + if err != nil { return nil, trace.Wrap(err) } - return &openSearchFetcher{ - cfg: config, - log: logrus.WithFields(logrus.Fields{ - trace.Component: "watch:opensearch", - "labels": config.Labels, - "region": config.Region, - "role": config.AssumeRole, - }), - }, nil -} - -// Get returns OpenSearch databases matching the watcher's selectors. -func (f *openSearchFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) { - domains, err := getOpenSearchDomains(ctx, f.cfg.openSearch) + domains, err := getOpenSearchDomains(ctx, opensearchClient) if err != nil { return nil, trace.Wrap(err) } var eligibleDomains []*opensearchservice.DomainStatus for _, domain := range domains { if !services.IsOpenSearchDomainAvailable(domain) { - f.log.Debugf("OpenSearch domain %q is unavailable. Skipping.", aws.StringValue(domain.DomainName)) + cfg.Log.Debugf("OpenSearch domain %q is unavailable. Skipping.", aws.StringValue(domain.DomainName)) continue } @@ -97,37 +63,29 @@ func (f *openSearchFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, } if len(eligibleDomains) == 0 { - return types.ResourcesWithLabels{}, nil + return nil, nil } var databases types.Databases for _, domain := range eligibleDomains { - tags, err := getOpenSearchResourceTags(ctx, f.cfg.openSearch, domain.ARN) + tags, err := getOpenSearchResourceTags(ctx, opensearchClient, domain.ARN) if err != nil { if trace.IsAccessDenied(err) { - f.log.WithError(err).Debug("No permissions to list resource tags") + cfg.Log.WithError(err).Debug("No permissions to list resource tags") } else { - f.log.WithError(err).Infof("Failed to list resource tags for OpenSearch domain %q.", aws.StringValue(domain.DomainName)) + cfg.Log.WithError(err).Infof("Failed to list resource tags for OpenSearch domain %q.", aws.StringValue(domain.DomainName)) } } dbs, err := services.NewDatabasesFromOpenSearchDomain(domain, tags) if err != nil { - f.log.WithError(err).Infof("Could not convert OpenSearch domain %q configuration to database resource.", aws.StringValue(domain.DomainName)) + cfg.Log.WithError(err).Infof("Could not convert OpenSearch domain %q configuration to database resource.", aws.StringValue(domain.DomainName)) } else { databases = append(databases, dbs...) } } - - applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole) - return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil -} - -// String returns the fetcher's string description. -func (f *openSearchFetcher) String() string { - return fmt.Sprintf("openSearchFetcher(Region=%v, Labels=%v)", - f.cfg.Region, f.cfg.Labels) + return databases, nil } // getOpenSearchDomains fetches all OpenSearch domains. diff --git a/lib/srv/discovery/fetchers/db/aws_opensearch_test.go b/lib/srv/discovery/fetchers/db/aws_opensearch_test.go index ea80ffbb67aa0..f3f2d6d517139 100644 --- a/lib/srv/discovery/fetchers/db/aws_opensearch_test.go +++ b/lib/srv/discovery/fetchers/db/aws_opensearch_test.go @@ -123,8 +123,7 @@ func makeOpenSearchDomain(t *testing.T, tagMap map[string][]*opensearchservice.T tagMap[aws.StringValue(domain.ARN)] = tags - database, err := services.NewDatabasesFromOpenSearchDomain(domain, tags) + databases, err := services.NewDatabasesFromOpenSearchDomain(domain, tags) require.NoError(t, err) - - return domain, database + return domain, databases } diff --git a/lib/srv/discovery/fetchers/db/aws_rds.go b/lib/srv/discovery/fetchers/db/aws_rds.go index 18bc4e51c552f..492ef2b2374cc 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds.go +++ b/lib/srv/discovery/fetchers/db/aws_rds.go @@ -18,7 +18,6 @@ package db import ( "context" - "fmt" "strings" "github.com/aws/aws-sdk-go/aws" @@ -28,82 +27,39 @@ import ( "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" ) -// rdsFetcherConfig is the RDS databases fetcher configuration. -type rdsFetcherConfig struct { - // Labels is a selector to match cloud databases. - Labels types.Labels - // RDS is the RDS API client. - RDS rdsiface.RDSAPI - // Region is the AWS region to query databases in. - Region string - // AssumeRole is the AWS IAM role to assume before discovering databases. - AssumeRole types.AssumeRole +// newRDSDBInstancesFetcher returns a new AWS fetcher for RDS databases. +func newRDSDBInstancesFetcher(cfg awsFetcherConfig) (common.Fetcher, error) { + return newAWSFetcher(cfg, &rdsDBInstancesPlugin{}) } -// CheckAndSetDefaults validates the config and sets defaults. -func (c *rdsFetcherConfig) CheckAndSetDefaults() error { - if len(c.Labels) == 0 { - return trace.BadParameter("missing parameter Labels") - } - if c.RDS == nil { - return trace.BadParameter("missing parameter RDS") - } - if c.Region == "" { - return trace.BadParameter("missing parameter Region") - } - return nil -} - -// rdsDBInstancesFetcher retrieves RDS DB instances. -type rdsDBInstancesFetcher struct { - awsFetcher +// rdsDBInstancesPlugin retrieves RDS DB instances. +type rdsDBInstancesPlugin struct{} - cfg rdsFetcherConfig - log logrus.FieldLogger +func (f *rdsDBInstancesPlugin) ComponentShortName() string { + return "rds" } -// newRDSDBInstancesFetcher returns a new RDS DB instances fetcher instance. -func newRDSDBInstancesFetcher(config rdsFetcherConfig) (common.Fetcher, error) { - if err := config.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - return &rdsDBInstancesFetcher{ - cfg: config, - log: logrus.WithFields(logrus.Fields{ - trace.Component: "watch:rds", - "labels": config.Labels, - "region": config.Region, - "role": config.AssumeRole, - }), - }, nil -} - -// Get returns RDS DB instances matching the watcher's selectors. -func (f *rdsDBInstancesFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) { - rdsDatabases, err := f.getRDSDatabases(ctx) +// GetDatabases returns a list of database resources representing RDS instances. +func (f *rdsDBInstancesPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { + rdsClient, err := cfg.AWSClients.GetAWSRDSClient(ctx, cfg.Region, + cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID)) if err != nil { return nil, trace.Wrap(err) } - - applyAssumeRoleToDatabases(rdsDatabases, f.cfg.AssumeRole) - return filterDatabasesByLabels(rdsDatabases, f.cfg.Labels, f.log).AsResources(), nil -} - -// getRDSDatabases returns a list of database resources representing RDS instances. -func (f *rdsDBInstancesFetcher) getRDSDatabases(ctx context.Context) (types.Databases, error) { - instances, err := getAllDBInstances(ctx, f.cfg.RDS, maxAWSPages, f.log) + instances, err := getAllDBInstances(ctx, rdsClient, maxAWSPages, cfg.Log) if err != nil { return nil, trace.Wrap(libcloudaws.ConvertRequestFailureError(err)) } databases := make(types.Databases, 0, len(instances)) for _, instance := range instances { if !services.IsRDSInstanceSupported(instance) { - f.log.Debugf("RDS instance %q (engine mode %v, engine version %v) doesn't support IAM authentication. Skipping.", + cfg.Log.Debugf("RDS instance %q (engine mode %v, engine version %v) doesn't support IAM authentication. Skipping.", aws.StringValue(instance.DBInstanceIdentifier), aws.StringValue(instance.Engine), aws.StringValue(instance.EngineVersion)) @@ -111,7 +67,7 @@ func (f *rdsDBInstancesFetcher) getRDSDatabases(ctx context.Context) (types.Data } if !services.IsRDSInstanceAvailable(instance.DBInstanceStatus, instance.DBInstanceIdentifier) { - f.log.Debugf("The current status of RDS instance %q is %q. Skipping.", + cfg.Log.Debugf("The current status of RDS instance %q is %q. Skipping.", aws.StringValue(instance.DBInstanceIdentifier), aws.StringValue(instance.DBInstanceStatus)) continue @@ -119,7 +75,7 @@ func (f *rdsDBInstancesFetcher) getRDSDatabases(ctx context.Context) (types.Data database, err := services.NewDatabaseFromRDSInstance(instance) if err != nil { - f.log.Warnf("Could not convert RDS instance %q to database resource: %v.", + cfg.Log.Warnf("Could not convert RDS instance %q to database resource: %v.", aws.StringValue(instance.DBInstanceIdentifier), err) } else { databases = append(databases, database) @@ -151,57 +107,34 @@ func getAllDBInstances(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages return instances, trace.Wrap(err) } -// String returns the fetcher's string description. -func (f *rdsDBInstancesFetcher) String() string { - return fmt.Sprintf("rdsDBInstancesFetcher(Region=%v, Labels=%v)", - f.cfg.Region, f.cfg.Labels) +// newRDSAuroraClustersFetcher returns a new AWS fetcher for RDS Aurora +// databases. +func newRDSAuroraClustersFetcher(cfg awsFetcherConfig) (common.Fetcher, error) { + return newAWSFetcher(cfg, &rdsAuroraClustersPlugin{}) } -// rdsAuroraClustersFetcher retrieves RDS Aurora clusters. -type rdsAuroraClustersFetcher struct { - awsFetcher - - cfg rdsFetcherConfig - log logrus.FieldLogger -} +// rdsAuroraClustersPlugin retrieves RDS Aurora clusters. +type rdsAuroraClustersPlugin struct{} -// newRDSAuroraClustersFetcher returns a new RDS Aurora fetcher instance. -func newRDSAuroraClustersFetcher(config rdsFetcherConfig) (common.Fetcher, error) { - if err := config.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - return &rdsAuroraClustersFetcher{ - cfg: config, - log: logrus.WithFields(logrus.Fields{ - trace.Component: "watch:aurora", - "labels": config.Labels, - "region": config.Region, - "role": config.AssumeRole, - }), - }, nil +func (f *rdsAuroraClustersPlugin) ComponentShortName() string { + return "aurora" } -// Get returns Aurora clusters matching the watcher's selectors. -func (f *rdsAuroraClustersFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) { - auroraDatabases, err := f.getAuroraDatabases(ctx) +// GetDatabases returns a list of database resources representing RDS clusters. +func (f *rdsAuroraClustersPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { + rdsClient, err := cfg.AWSClients.GetAWSRDSClient(ctx, cfg.Region, + cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID)) if err != nil { return nil, trace.Wrap(err) } - - applyAssumeRoleToDatabases(auroraDatabases, f.cfg.AssumeRole) - return filterDatabasesByLabels(auroraDatabases, f.cfg.Labels, f.log).AsResources(), nil -} - -// getAuroraDatabases returns a list of database resources representing RDS clusters. -func (f *rdsAuroraClustersFetcher) getAuroraDatabases(ctx context.Context) (types.Databases, error) { - clusters, err := getAllDBClusters(ctx, f.cfg.RDS, maxAWSPages, f.log) + clusters, err := getAllDBClusters(ctx, rdsClient, maxAWSPages, cfg.Log) if err != nil { return nil, trace.Wrap(libcloudaws.ConvertRequestFailureError(err)) } databases := make(types.Databases, 0, len(clusters)) for _, cluster := range clusters { if !services.IsRDSClusterSupported(cluster) { - f.log.Debugf("Aurora cluster %q (engine mode %v, engine version %v) doesn't support IAM authentication. Skipping.", + cfg.Log.Debugf("Aurora cluster %q (engine mode %v, engine version %v) doesn't support IAM authentication. Skipping.", aws.StringValue(cluster.DBClusterIdentifier), aws.StringValue(cluster.EngineMode), aws.StringValue(cluster.EngineVersion)) @@ -209,7 +142,7 @@ func (f *rdsAuroraClustersFetcher) getAuroraDatabases(ctx context.Context) (type } if !services.IsRDSClusterAvailable(cluster.Status, cluster.DBClusterIdentifier) { - f.log.Debugf("The current status of Aurora cluster %q is %q. Skipping.", + cfg.Log.Debugf("The current status of Aurora cluster %q is %q. Skipping.", aws.StringValue(cluster.DBClusterIdentifier), aws.StringValue(cluster.Status)) continue @@ -217,7 +150,7 @@ func (f *rdsAuroraClustersFetcher) getAuroraDatabases(ctx context.Context) (type dbs, err := services.NewDatabasesFromRDSCluster(cluster) if err != nil { - f.log.Warnf("Could not convert RDS cluster %q to database resources: %v.", + cfg.Log.Warnf("Could not convert RDS cluster %q to database resources: %v.", aws.StringValue(cluster.DBClusterIdentifier), err) } databases = append(databases, dbs...) @@ -248,12 +181,6 @@ func getAllDBClusters(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages i return clusters, trace.Wrap(err) } -// String returns the fetcher's string description. -func (f *rdsAuroraClustersFetcher) String() string { - return fmt.Sprintf("rdsAuroraClustersFetcher(Region=%v, Labels=%v)", - f.cfg.Region, f.cfg.Labels) -} - // rdsInstanceEngines returns engines to make sure DescribeDBInstances call returns // only databases with engines Teleport supports. func rdsInstanceEngines() []string { diff --git a/lib/srv/discovery/fetchers/db/aws_rds_proxy.go b/lib/srv/discovery/fetchers/db/aws_rds_proxy.go index 8286a1c6c02bb..719924ed2d1d9 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds_proxy.go +++ b/lib/srv/discovery/fetchers/db/aws_rds_proxy.go @@ -15,82 +15,62 @@ package db import ( "context" - "fmt" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/rds" "github.com/aws/aws-sdk-go/service/rds/rdsiface" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" ) -// rdsDBProxyFetcher retrieves RDS Proxies and their custom endpoints. -type rdsDBProxyFetcher struct { - awsFetcher - - cfg rdsFetcherConfig - log logrus.FieldLogger +// newRDSDBProxyFetcher returns a new AWS fetcher for RDS Proxy databases. +func newRDSDBProxyFetcher(cfg awsFetcherConfig) (common.Fetcher, error) { + return newAWSFetcher(cfg, &rdsDBProxyPlugin{}) } -// newRDSDBProxyFetcher returns a new RDS Proxy fetcher instance. -func newRDSDBProxyFetcher(config rdsFetcherConfig) (common.Fetcher, error) { - if err := config.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - return &rdsDBProxyFetcher{ - cfg: config, - log: logrus.WithFields(logrus.Fields{ - trace.Component: "watch:rdsproxy", - "labels": config.Labels, - "region": config.Region, - "role": config.AssumeRole, - }), - }, nil +// rdsDBProxyPlugin retrieves RDS Proxies and their custom endpoints. +type rdsDBProxyPlugin struct{} + +func (f *rdsDBProxyPlugin) ComponentShortName() string { + return "rdsproxy" } -// Get returns RDS Proxies and proxy endpoints matching the watcher's -// selectors. -func (f *rdsDBProxyFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) { - databases, err := f.getRDSProxyDatabases(ctx) +// GetDatabases returns a list of database resources representing RDS +// Proxies and custom endpoints. +func (f *rdsDBProxyPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { + rdsClient, err := cfg.AWSClients.GetAWSRDSClient(ctx, cfg.Region, + cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID)) if err != nil { return nil, trace.Wrap(err) } - - applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole) - return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil -} - -// getRDSProxyDatabases returns a list of database resources representing RDS -// Proxies and custom endpoints. -func (f *rdsDBProxyFetcher) getRDSProxyDatabases(ctx context.Context) (types.Databases, error) { // Get a list of all RDS Proxies. Each RDS Proxy has one "default" // endpoint. - rdsProxies, err := getRDSProxies(ctx, f.cfg.RDS, maxAWSPages) + rdsProxies, err := getRDSProxies(ctx, rdsClient, maxAWSPages) if err != nil { return nil, trace.Wrap(err) } // Get all RDS Proxy custom endpoints sorted by the name of the RDS Proxy // that owns the custom endpoints. - customEndpointsByProxyName, err := getRDSProxyCustomEndpoints(ctx, f.cfg.RDS, maxAWSPages) + customEndpointsByProxyName, err := getRDSProxyCustomEndpoints(ctx, rdsClient, maxAWSPages) if err != nil { - f.log.Debugf("Failed to get RDS Proxy endpoints: %v.", err) + cfg.Log.Debugf("Failed to get RDS Proxy endpoints: %v.", err) } var databases types.Databases for _, dbProxy := range rdsProxies { if !aws.BoolValue(dbProxy.RequireTLS) { - f.log.Debugf("RDS Proxy %q doesn't support TLS. Skipping.", aws.StringValue(dbProxy.DBProxyName)) + cfg.Log.Debugf("RDS Proxy %q doesn't support TLS. Skipping.", aws.StringValue(dbProxy.DBProxyName)) continue } if !services.IsRDSProxyAvailable(dbProxy) { - f.log.Debugf("The current status of RDS Proxy %q is %q. Skipping.", + cfg.Log.Debugf("The current status of RDS Proxy %q is %q. Skipping.", aws.StringValue(dbProxy.DBProxyName), aws.StringValue(dbProxy.Status)) continue @@ -98,23 +78,23 @@ func (f *rdsDBProxyFetcher) getRDSProxyDatabases(ctx context.Context) (types.Dat // rds.DBProxy has no port information. An extra SDK call is made to // find the port from its targets. - port, err := getRDSProxyTargetPort(ctx, f.cfg.RDS, dbProxy.DBProxyName) + port, err := getRDSProxyTargetPort(ctx, rdsClient, dbProxy.DBProxyName) if err != nil { - f.log.Debugf("Failed to get port for RDS Proxy %v: %v.", aws.StringValue(dbProxy.DBProxyName), err) + cfg.Log.Debugf("Failed to get port for RDS Proxy %v: %v.", aws.StringValue(dbProxy.DBProxyName), err) continue } // rds.DBProxy has no tags information. An extra SDK call is made to // fetch the tags. If failed, keep going without the tags. - tags, err := listRDSResourceTags(ctx, f.cfg.RDS, dbProxy.DBProxyArn) + tags, err := listRDSResourceTags(ctx, rdsClient, dbProxy.DBProxyArn) if err != nil { - f.log.Debugf("Failed to get tags for RDS Proxy %v: %v.", aws.StringValue(dbProxy.DBProxyName), err) + cfg.Log.Debugf("Failed to get tags for RDS Proxy %v: %v.", aws.StringValue(dbProxy.DBProxyName), err) } // Add a database from RDS Proxy (default endpoint). database, err := services.NewDatabaseFromRDSProxy(dbProxy, port, tags) if err != nil { - f.log.Debugf("Could not convert RDS Proxy %q to database resource: %v.", + cfg.Log.Debugf("Could not convert RDS Proxy %q to database resource: %v.", aws.StringValue(dbProxy.DBProxyName), err) } else { databases = append(databases, database) @@ -123,7 +103,7 @@ func (f *rdsDBProxyFetcher) getRDSProxyDatabases(ctx context.Context) (types.Dat // Add custom endpoints. for _, customEndpoint := range customEndpointsByProxyName[aws.StringValue(dbProxy.DBProxyName)] { if !services.IsRDSProxyCustomEndpointAvailable(customEndpoint) { - f.log.Debugf("The current status of custom endpoint %q of RDS Proxy %q is %q. Skipping.", + cfg.Log.Debugf("The current status of custom endpoint %q of RDS Proxy %q is %q. Skipping.", aws.StringValue(customEndpoint.DBProxyEndpointName), aws.StringValue(customEndpoint.DBProxyName), aws.StringValue(customEndpoint.Status)) @@ -132,7 +112,7 @@ func (f *rdsDBProxyFetcher) getRDSProxyDatabases(ctx context.Context) (types.Dat database, err = services.NewDatabaseFromRDSProxyCustomEndpoint(dbProxy, customEndpoint, port, tags) if err != nil { - f.log.Debugf("Could not convert custom endpoint %q of RDS Proxy %q to database resource: %v.", + cfg.Log.Debugf("Could not convert custom endpoint %q of RDS Proxy %q to database resource: %v.", aws.StringValue(customEndpoint.DBProxyEndpointName), aws.StringValue(customEndpoint.DBProxyName), err) @@ -145,12 +125,6 @@ func (f *rdsDBProxyFetcher) getRDSProxyDatabases(ctx context.Context) (types.Dat return databases, nil } -// String returns the fetcher's string description. -func (f *rdsDBProxyFetcher) String() string { - return fmt.Sprintf("rdsDBProxyFetcher(Region=%v, Labels=%v)", - f.cfg.Region, f.cfg.Labels) -} - // getRDSProxies fetches all RDS Proxies using the provided client, up to the // specified max number of pages. func getRDSProxies(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int) (rdsProxies []*rds.DBProxy, err error) { diff --git a/lib/srv/discovery/fetchers/db/aws_rds_test.go b/lib/srv/discovery/fetchers/db/aws_rds_test.go index 5bfbb37743735..61b5e4ce5fe47 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds_test.go +++ b/lib/srv/discovery/fetchers/db/aws_rds_test.go @@ -245,7 +245,6 @@ func makeRDSClusterWithExtraEndpoints(t *testing.T, name, region string, labels customDatabases, err := services.NewDatabasesFromRDSClusterCustomEndpoints(cluster) require.NoError(t, err) databases = append(databases, customDatabases...) - return cluster, databases } diff --git a/lib/srv/discovery/fetchers/db/aws_redshift.go b/lib/srv/discovery/fetchers/db/aws_redshift.go index 771e183644dfb..e211b08121a3c 100644 --- a/lib/srv/discovery/fetchers/db/aws_redshift.go +++ b/lib/srv/discovery/fetchers/db/aws_redshift.go @@ -18,73 +18,35 @@ package db import ( "context" - "fmt" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/redshift" "github.com/aws/aws-sdk-go/service/redshift/redshiftiface" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" ) -// redshiftFetcherConfig is the Redshift databases fetcher configuration. -type redshiftFetcherConfig struct { - // Labels is a selector to match cloud databases. - Labels types.Labels - // Redshift is the Redshift API client. - Redshift redshiftiface.RedshiftAPI - // Region is the AWS region to query databases in. - Region string - // AssumeRole is the AWS IAM role to assume before discovering databases. - AssumeRole types.AssumeRole +// newRedshiftFetcher returns a new AWS fetcher for Redshift databases. +func newRedshiftFetcher(cfg awsFetcherConfig) (common.Fetcher, error) { + return newAWSFetcher(cfg, &redshiftPlugin{}) } -// CheckAndSetDefaults validates the config and sets defaults. -func (c *redshiftFetcherConfig) CheckAndSetDefaults() error { - if len(c.Labels) == 0 { - return trace.BadParameter("missing parameter Labels") - } - if c.Redshift == nil { - return trace.BadParameter("missing parameter Redshift") - } - if c.Region == "" { - return trace.BadParameter("missing parameter Region") - } - return nil -} - -// redshiftFetcher retrieves Redshift databases. -type redshiftFetcher struct { - awsFetcher +// redshiftPlugin retrieves Redshift databases. +type redshiftPlugin struct{} - cfg redshiftFetcherConfig - log logrus.FieldLogger -} - -// newRedshiftFetcher returns a new Redshift databases fetcher instance. -func newRedshiftFetcher(config redshiftFetcherConfig) (common.Fetcher, error) { - if err := config.CheckAndSetDefaults(); err != nil { +// GetDatabases returns Redshift databases matching the watcher's selectors. +func (f *redshiftPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { + redshiftClient, err := cfg.AWSClients.GetAWSRedshiftClient(ctx, cfg.Region, + cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID)) + if err != nil { return nil, trace.Wrap(err) } - return &redshiftFetcher{ - cfg: config, - log: logrus.WithFields(logrus.Fields{ - trace.Component: "watch:redshift", - "labels": config.Labels, - "region": config.Region, - "role": config.AssumeRole, - }), - }, nil -} - -// Get returns Redshift and Aurora databases matching the watcher's selectors. -func (f *redshiftFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) { - clusters, err := getRedshiftClusters(ctx, f.cfg.Redshift) + clusters, err := getRedshiftClusters(ctx, redshiftClient) if err != nil { return nil, trace.Wrap(err) } @@ -92,7 +54,7 @@ func (f *redshiftFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, e var databases types.Databases for _, cluster := range clusters { if !services.IsRedshiftClusterAvailable(cluster) { - f.log.Debugf("The current status of Redshift cluster %q is %q. Skipping.", + cfg.Log.Debugf("The current status of Redshift cluster %q is %q. Skipping.", aws.StringValue(cluster.ClusterIdentifier), aws.StringValue(cluster.ClusterStatus)) continue @@ -100,21 +62,18 @@ func (f *redshiftFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, e database, err := services.NewDatabaseFromRedshiftCluster(cluster) if err != nil { - f.log.Infof("Could not convert Redshift cluster %q to database resource: %v.", + cfg.Log.Infof("Could not convert Redshift cluster %q to database resource: %v.", aws.StringValue(cluster.ClusterIdentifier), err) continue } databases = append(databases, database) } - applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole) - return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil + return databases, nil } -// String returns the fetcher's string description. -func (f *redshiftFetcher) String() string { - return fmt.Sprintf("redshiftFetcher(Region=%v, Labels=%v)", - f.cfg.Region, f.cfg.Labels) +func (f *redshiftPlugin) ComponentShortName() string { + return "redshift" } // getRedshiftClusters fetches all Reshift clusters using the provided client, diff --git a/lib/srv/discovery/fetchers/db/aws_redshift_serverless.go b/lib/srv/discovery/fetchers/db/aws_redshift_serverless.go index 6dc3535330671..c8f719d593e43 100644 --- a/lib/srv/discovery/fetchers/db/aws_redshift_serverless.go +++ b/lib/srv/discovery/fetchers/db/aws_redshift_serverless.go @@ -18,7 +18,6 @@ package db import ( "context" - "fmt" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/redshiftserverless" @@ -27,120 +26,85 @@ import ( "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" ) -// redshiftServerlessFetcherConfig is the Redshift Serverless databases fetcher -// configuration. -type redshiftServerlessFetcherConfig struct { - // Labels is a selector to match cloud databases. - Labels types.Labels - // Region is the AWS region to query databases in. - Region string - // Client is the Redshift Serverless API client. - Client redshiftserverlessiface.RedshiftServerlessAPI - // AssumeRole is the AWS IAM role to assume before discovering databases. - AssumeRole types.AssumeRole +// newRedshiftServerlessFetcher returns a new AWS fetcher for Redshift +// Serverless databases. +func newRedshiftServerlessFetcher(cfg awsFetcherConfig) (common.Fetcher, error) { + return newAWSFetcher(cfg, &redshiftServerlessPlugin{}) } -// CheckAndSetDefaults validates the config and sets defaults. -func (c *redshiftServerlessFetcherConfig) CheckAndSetDefaults() error { - if len(c.Labels) == 0 { - return trace.BadParameter("missing parameter Labels") - } - if c.Region == "" { - return trace.BadParameter("missing parameter Region") - } - if c.Client == nil { - return trace.BadParameter("missing parameter Client") - } - return nil -} - -type redshiftServerlessWorkgroupWithTags struct { +type workgroupWithTags struct { *redshiftserverless.Workgroup Tags []*redshiftserverless.Tag } -// redshiftServerlessFetcher retrieves Redshift Serverless databases. -type redshiftServerlessFetcher struct { - awsFetcher +// redshiftServerlessPlugin retrieves Redshift Serverless databases. +type redshiftServerlessPlugin struct{} - cfg redshiftServerlessFetcherConfig - log logrus.FieldLogger +func (f *redshiftServerlessPlugin) ComponentShortName() string { + // (r)ed(s)hift (s)erver(<)less + return "rss<" } -// newRedshiftServerlessFetcher returns a new Redshift Serverless databases -// fetcher instance. -func newRedshiftServerlessFetcher(config redshiftServerlessFetcherConfig) (common.Fetcher, error) { - if err := config.CheckAndSetDefaults(); err != nil { +// rssAPI is a type alias for brevity alone. +type rssAPI = redshiftserverlessiface.RedshiftServerlessAPI + +// GetDatabases returns Redshift Serverless databases matching the watcher's selectors. +func (f *redshiftServerlessPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { + client, err := cfg.AWSClients.GetAWSRedshiftServerlessClient(ctx, cfg.Region, + cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID)) + if err != nil { return nil, trace.Wrap(err) } - return &redshiftServerlessFetcher{ - cfg: config, - log: logrus.WithFields(logrus.Fields{ - trace.Component: "watch:rss<", // (r)ed(s)hift (s)erver(<)less - "labels": config.Labels, - "region": config.Region, - "role": config.AssumeRole, - }), - }, nil -} - -// Get returns Redshift Serverless databases matching the watcher's selectors. -func (f *redshiftServerlessFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) { - databases, workgroups, err := f.getDatabasesFromWorkgroups(ctx) + databases, workgroups, err := getDatabasesFromWorkgroups(ctx, client, cfg.Log) if err != nil { return nil, trace.Wrap(err) } if len(workgroups) > 0 { - vpcEndpointDatabases, err := f.getDatabasesFromVPCEndpoints(ctx, workgroups) + vpcEndpointDatabases, err := getDatabasesFromVPCEndpoints(ctx, workgroups, client, cfg.Log) if err != nil { if trace.IsAccessDenied(err) { - f.log.Debugf("No permission to get Redshift Serverless VPC endpoints: %v.", err) + cfg.Log.Debugf("No permission to get Redshift Serverless VPC endpoints: %v.", err) } else { - f.log.Warnf("Failed to get Redshift Serverless VPC endpoints: %v.", err) + cfg.Log.Warnf("Failed to get Redshift Serverless VPC endpoints: %v.", err) } } databases = append(databases, vpcEndpointDatabases...) } - applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole) - return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil -} - -// String returns the fetcher's string description. -func (f *redshiftServerlessFetcher) String() string { - return fmt.Sprintf("redshiftServerlessFetcher(Region=%v, Labels=%v)", f.cfg.Region, f.cfg.Labels) + return databases, nil } -func (f *redshiftServerlessFetcher) getDatabasesFromWorkgroups(ctx context.Context) (types.Databases, []*redshiftServerlessWorkgroupWithTags, error) { - workgroups, err := f.getWorkgroups(ctx) +func getDatabasesFromWorkgroups(ctx context.Context, client rssAPI, log logrus.FieldLogger) (types.Databases, []*workgroupWithTags, error) { + workgroups, err := getRSSWorkgroups(ctx, client) if err != nil { return nil, nil, trace.Wrap(err) } var databases types.Databases - var workgroupsWithTags []*redshiftServerlessWorkgroupWithTags + var workgroupsWithTags []*workgroupWithTags for _, workgroup := range workgroups { if !services.IsAWSResourceAvailable(workgroup, workgroup.Status) { - f.log.Debugf("The current status of Redshift Serverless workgroup %v is %v. Skipping.", aws.StringValue(workgroup.WorkgroupName), aws.StringValue(workgroup.Status)) + log.Debugf("The current status of Redshift Serverless workgroup %v is %v. Skipping.", aws.StringValue(workgroup.WorkgroupName), aws.StringValue(workgroup.Status)) continue } - tags := f.getResourceTags(ctx, workgroup.WorkgroupArn) + tags := getRSSResourceTags(ctx, workgroup.WorkgroupArn, client, log) database, err := services.NewDatabaseFromRedshiftServerlessWorkgroup(workgroup, tags) if err != nil { - f.log.WithError(err).Infof("Could not convert Redshift Serverless workgroup %q to database resource.", aws.StringValue(workgroup.WorkgroupName)) + log.WithError(err).Infof("Could not convert Redshift Serverless workgroup %q to database resource.", aws.StringValue(workgroup.WorkgroupName)) continue } databases = append(databases, database) - workgroupsWithTags = append(workgroupsWithTags, &redshiftServerlessWorkgroupWithTags{ + workgroupsWithTags = append(workgroupsWithTags, &workgroupWithTags{ Workgroup: workgroup, Tags: tags, }) @@ -148,8 +112,8 @@ func (f *redshiftServerlessFetcher) getDatabasesFromWorkgroups(ctx context.Conte return databases, workgroupsWithTags, nil } -func (f *redshiftServerlessFetcher) getDatabasesFromVPCEndpoints(ctx context.Context, workgroups []*redshiftServerlessWorkgroupWithTags) (types.Databases, error) { - endpoints, err := f.getVPCEndpoints(ctx) +func getDatabasesFromVPCEndpoints(ctx context.Context, workgroups []*workgroupWithTags, client rssAPI, log logrus.FieldLogger) (types.Databases, error) { + endpoints, err := getRSSVPCEndpoints(ctx, client) if err != nil { return nil, trace.Wrap(err) } @@ -158,12 +122,12 @@ func (f *redshiftServerlessFetcher) getDatabasesFromVPCEndpoints(ctx context.Con for _, endpoint := range endpoints { workgroup, found := findWorkgroupWithName(workgroups, aws.StringValue(endpoint.WorkgroupName)) if !found { - f.log.Debugf("Could not find matching workgroup for Redshift Serverless endpoint %v. Skipping.", aws.StringValue(endpoint.EndpointName)) + log.Debugf("Could not find matching workgroup for Redshift Serverless endpoint %v. Skipping.", aws.StringValue(endpoint.EndpointName)) continue } if !services.IsAWSResourceAvailable(endpoint, endpoint.EndpointStatus) { - f.log.Debugf("The current status of Redshift Serverless endpoint %v is %v. Skipping.", aws.StringValue(endpoint.EndpointName), aws.StringValue(endpoint.EndpointStatus)) + log.Debugf("The current status of Redshift Serverless endpoint %v is %v. Skipping.", aws.StringValue(endpoint.EndpointName), aws.StringValue(endpoint.EndpointStatus)) continue } @@ -171,7 +135,7 @@ func (f *redshiftServerlessFetcher) getDatabasesFromVPCEndpoints(ctx context.Con // tags from the workgroups instead. database, err := services.NewDatabaseFromRedshiftServerlessVPCEndpoint(endpoint, workgroup.Workgroup, workgroup.Tags) if err != nil { - f.log.WithError(err).Infof("Could not convert Redshift Serverless endpoint %q to database resource.", aws.StringValue(endpoint.EndpointName)) + log.WithError(err).Infof("Could not convert Redshift Serverless endpoint %q to database resource.", aws.StringValue(endpoint.EndpointName)) continue } databases = append(databases, database) @@ -179,41 +143,41 @@ func (f *redshiftServerlessFetcher) getDatabasesFromVPCEndpoints(ctx context.Con return databases, nil } -func (f *redshiftServerlessFetcher) getResourceTags(ctx context.Context, arn *string) []*redshiftserverless.Tag { - output, err := f.cfg.Client.ListTagsForResourceWithContext(ctx, &redshiftserverless.ListTagsForResourceInput{ +func getRSSResourceTags(ctx context.Context, arn *string, client rssAPI, log logrus.FieldLogger) []*redshiftserverless.Tag { + output, err := client.ListTagsForResourceWithContext(ctx, &redshiftserverless.ListTagsForResourceInput{ ResourceArn: arn, }) if err != nil { // Log errors here and return nil. if trace.IsAccessDenied(err) { - f.log.WithError(err).Debugf("No Permission to get tags for %q.", aws.StringValue(arn)) + log.WithError(err).Debugf("No Permission to get tags for %q.", aws.StringValue(arn)) } else { - f.log.WithError(err).Warnf("Failed to get tags for %q.", aws.StringValue(arn)) + log.WithError(err).Warnf("Failed to get tags for %q.", aws.StringValue(arn)) } return nil } return output.Tags } -func (f *redshiftServerlessFetcher) getWorkgroups(ctx context.Context) ([]*redshiftserverless.Workgroup, error) { +func getRSSWorkgroups(ctx context.Context, client rssAPI) ([]*redshiftserverless.Workgroup, error) { var pages [][]*redshiftserverless.Workgroup - err := f.cfg.Client.ListWorkgroupsPagesWithContext(ctx, nil, func(page *redshiftserverless.ListWorkgroupsOutput, lastPage bool) bool { + err := client.ListWorkgroupsPagesWithContext(ctx, nil, func(page *redshiftserverless.ListWorkgroupsOutput, lastPage bool) bool { pages = append(pages, page.Workgroups) return len(pages) <= maxAWSPages }) return flatten(pages), libcloudaws.ConvertRequestFailureError(err) } -func (f *redshiftServerlessFetcher) getVPCEndpoints(ctx context.Context) ([]*redshiftserverless.EndpointAccess, error) { +func getRSSVPCEndpoints(ctx context.Context, client rssAPI) ([]*redshiftserverless.EndpointAccess, error) { var pages [][]*redshiftserverless.EndpointAccess - err := f.cfg.Client.ListEndpointAccessPagesWithContext(ctx, nil, func(page *redshiftserverless.ListEndpointAccessOutput, lastPage bool) bool { + err := client.ListEndpointAccessPagesWithContext(ctx, nil, func(page *redshiftserverless.ListEndpointAccessOutput, lastPage bool) bool { pages = append(pages, page.Endpoints) return len(pages) <= maxAWSPages }) return flatten(pages), libcloudaws.ConvertRequestFailureError(err) } -func findWorkgroupWithName(workgroups []*redshiftServerlessWorkgroupWithTags, name string) (*redshiftServerlessWorkgroupWithTags, bool) { +func findWorkgroupWithName(workgroups []*workgroupWithTags, name string) (*workgroupWithTags, bool) { for _, workgroup := range workgroups { if aws.StringValue(workgroup.WorkgroupName) == name { return workgroup, true diff --git a/lib/srv/discovery/fetchers/db/azure.go b/lib/srv/discovery/fetchers/db/azure.go index 127611294cf33..9bfbc7d8f6800 100644 --- a/lib/srv/discovery/fetchers/db/azure.go +++ b/lib/srv/discovery/fetchers/db/azure.go @@ -148,8 +148,13 @@ func (f *azureFetcher[DBType, ListClient]) Get(ctx context.Context) (types.Resou if err != nil { return nil, trace.Wrap(err) } + f.rewriteDatabases(databases) + return databases.AsResources(), nil +} - return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil +// rewriteDatabases rewrites the discovered databases. +func (f *azureFetcher[DBType, ListClient]) rewriteDatabases(databases types.Databases) { + // no-op } // getSubscriptions returns the subscriptions that this fetcher is configured to query. @@ -225,7 +230,7 @@ func (f *azureFetcher[DBType, ListClient]) getDatabases(ctx context.Context) (ty databases = append(databases, database) } } - return databases, nil + return filterDatabasesByLabels(databases, f.cfg.Labels, f.log), nil } // String returns the fetcher's string description. diff --git a/lib/srv/discovery/fetchers/db/db.go b/lib/srv/discovery/fetchers/db/db.go index 1d6a07a8c3cee..2a756db943be3 100644 --- a/lib/srv/discovery/fetchers/db/db.go +++ b/lib/srv/discovery/fetchers/db/db.go @@ -29,18 +29,18 @@ import ( "github.com/gravitational/teleport/lib/srv/discovery/common" ) -type makeAWSFetcherFunc func(context.Context, cloud.AWSClients, string, types.Labels, types.AssumeRole) (common.Fetcher, error) +type makeAWSFetcherFunc func(awsFetcherConfig) (common.Fetcher, error) type makeAzureFetcherFunc func(azureFetcherConfig) (common.Fetcher, error) var ( makeAWSFetcherFuncs = map[string][]makeAWSFetcherFunc{ - services.AWSMatcherRDS: {makeRDSInstanceFetcher, makeRDSAuroraFetcher}, - services.AWSMatcherRDSProxy: {makeRDSProxyFetcher}, - services.AWSMatcherRedshift: {makeRedshiftFetcher}, - services.AWSMatcherRedshiftServerless: {makeRedshiftServerlessFetcher}, - services.AWSMatcherElastiCache: {makeElastiCacheFetcher}, - services.AWSMatcherMemoryDB: {makeMemoryDBFetcher}, - services.AWSMatcherOpenSearch: {makeOpenSearchFetcher}, + services.AWSMatcherRDS: {newRDSDBInstancesFetcher, newRDSAuroraClustersFetcher}, + services.AWSMatcherRDSProxy: {newRDSDBProxyFetcher}, + services.AWSMatcherRedshift: {newRedshiftFetcher}, + services.AWSMatcherRedshiftServerless: {newRedshiftServerlessFetcher}, + services.AWSMatcherElastiCache: {newElastiCacheFetcher}, + services.AWSMatcherMemoryDB: {newMemoryDBFetcher}, + services.AWSMatcherOpenSearch: {newOpenSearchFetcher}, } makeAzureFetcherFuncs = map[string][]makeAzureFetcherFunc{ @@ -76,7 +76,13 @@ func MakeAWSFetchers(ctx context.Context, clients cloud.AWSClients, matchers []t for _, makeFetcher := range makeFetchers { for _, region := range matcher.Regions { - fetcher, err := makeFetcher(ctx, clients, region, matcher.Tags, assumeRole) + fetcher, err := makeFetcher(awsFetcherConfig{ + AWSClients: clients, + Type: matcherType, + AssumeRole: assumeRole, + Labels: matcher.Tags, + Region: region, + }) if err != nil { return nil, trace.Wrap(err) } @@ -120,125 +126,6 @@ func MakeAzureFetchers(clients cloud.AzureClients, matchers []types.AzureMatcher return result, nil } -// makeRDSInstanceFetcher returns RDS instance fetcher for the provided region and tags. -func makeRDSInstanceFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole types.AssumeRole) (common.Fetcher, error) { - rds, err := clients.GetAWSRDSClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) - if err != nil { - return nil, trace.Wrap(err) - } - - fetcher, err := newRDSDBInstancesFetcher(rdsFetcherConfig{ - Region: region, - Labels: tags, - RDS: rds, - AssumeRole: assumeRole, - }) - return fetcher, trace.Wrap(err) -} - -// makeRDSAuroraFetcher returns RDS Aurora fetcher for the provided region and tags. -func makeRDSAuroraFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole types.AssumeRole) (common.Fetcher, error) { - rds, err := clients.GetAWSRDSClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) - if err != nil { - return nil, trace.Wrap(err) - } - - fetcher, err := newRDSAuroraClustersFetcher(rdsFetcherConfig{ - Region: region, - Labels: tags, - RDS: rds, - AssumeRole: assumeRole, - }) - return fetcher, trace.Wrap(err) -} - -// makeRDSProxyFetcher returns RDS proxy fetcher for the provided region and tags. -func makeRDSProxyFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole types.AssumeRole) (common.Fetcher, error) { - rds, err := clients.GetAWSRDSClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) - if err != nil { - return nil, trace.Wrap(err) - } - - return newRDSDBProxyFetcher(rdsFetcherConfig{ - Region: region, - Labels: tags, - RDS: rds, - AssumeRole: assumeRole, - }) -} - -// makeRedshiftFetcher returns Redshift fetcher for the provided region and tags. -func makeRedshiftFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole types.AssumeRole) (common.Fetcher, error) { - redshift, err := clients.GetAWSRedshiftClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) - if err != nil { - return nil, trace.Wrap(err) - } - return newRedshiftFetcher(redshiftFetcherConfig{ - Region: region, - Labels: tags, - Redshift: redshift, - AssumeRole: assumeRole, - }) -} - -// makeElastiCacheFetcher returns ElastiCache fetcher for the provided region and tags. -func makeElastiCacheFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole types.AssumeRole) (common.Fetcher, error) { - elastiCache, err := clients.GetAWSElastiCacheClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) - if err != nil { - return nil, trace.Wrap(err) - } - return newElastiCacheFetcher(elastiCacheFetcherConfig{ - Region: region, - Labels: tags, - ElastiCache: elastiCache, - AssumeRole: assumeRole, - }) -} - -// makeMemoryDBFetcher returns MemoryDB fetcher for the provided region and tags. -func makeMemoryDBFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole types.AssumeRole) (common.Fetcher, error) { - memorydb, err := clients.GetAWSMemoryDBClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) - if err != nil { - return nil, trace.Wrap(err) - } - return newMemoryDBFetcher(memoryDBFetcherConfig{ - Region: region, - Labels: tags, - MemoryDB: memorydb, - AssumeRole: assumeRole, - }) -} - -// makeOpenSearchFetcher returns OpenSearch fetcher for the provided region and tags. -func makeOpenSearchFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole types.AssumeRole) (common.Fetcher, error) { - opensearch, err := clients.GetAWSOpenSearchClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) - if err != nil { - return nil, trace.Wrap(err) - } - - return newOpenSearchFetcher(openSearchFetcherConfig{ - Region: region, - Labels: tags, - openSearch: opensearch, - AssumeRole: assumeRole, - }) -} - -// makeRedshiftServerlessFetcher returns Redshift Serverless fetcher for the -// provided region and tags. -func makeRedshiftServerlessFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole types.AssumeRole) (common.Fetcher, error) { - client, err := clients.GetAWSRedshiftServerlessClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) - if err != nil { - return nil, trace.Wrap(err) - } - return newRedshiftServerlessFetcher(redshiftServerlessFetcherConfig{ - Region: region, - Labels: tags, - Client: client, - AssumeRole: assumeRole, - }) -} - // filterDatabasesByLabels filters input databases with provided labels. func filterDatabasesByLabels(databases types.Databases, labels types.Labels, log logrus.FieldLogger) types.Databases { var matchedDatabases types.Databases @@ -255,14 +142,6 @@ func filterDatabasesByLabels(databases types.Databases, labels types.Labels, log return matchedDatabases } -// applyAssumeRoleToDatabases applies assume role settings from fetcher to databases. -func applyAssumeRoleToDatabases(databases types.Databases, assumeRole types.AssumeRole) { - for _, db := range databases { - db.SetAWSAssumeRole(assumeRole.RoleARN) - db.SetAWSExternalID(assumeRole.ExternalID) - } -} - // flatten flattens a nested slice [][]T to []T. func flatten[T any](s [][]T) (result []T) { for i := range s { diff --git a/lib/srv/discovery/fetchers/db/helpers_test.go b/lib/srv/discovery/fetchers/db/helpers_test.go index a2a81e18c3f2d..a2f089836e7bc 100644 --- a/lib/srv/discovery/fetchers/db/helpers_test.go +++ b/lib/srv/discovery/fetchers/db/helpers_test.go @@ -146,9 +146,11 @@ func copyDatabasesWithAWSAssumeRole(role types.AssumeRole, databases ...types.Da } out := make(types.Databases, 0, len(databases)) for _, db := range databases { - out = append(out, db.Copy()) + dbCopy := db.Copy() + dbCopy.SetAWSAssumeRole(role.RoleARN) + dbCopy.SetAWSExternalID(role.ExternalID) + out = append(out, dbCopy) } - applyAssumeRoleToDatabases(out, role) return out } diff --git a/lib/srv/discovery/fetchers/eks.go b/lib/srv/discovery/fetchers/eks.go index c0c0c0c42bcc5..8e458369bdd5f 100644 --- a/lib/srv/discovery/fetchers/eks.go +++ b/lib/srv/discovery/fetchers/eks.go @@ -87,9 +87,15 @@ func (a *eksFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) return nil, trace.Wrap(err) } + a.rewriteKubeClusters(clusters) return clusters.AsResources(), nil } +// rewriteKubeClusters rewrites the discovered kube clusters. +func (a *eksFetcher) rewriteKubeClusters(clusters types.KubeClusters) { + // no-op +} + func (a *eksFetcher) getEKSClusters(ctx context.Context) (types.KubeClusters, error) { var ( clusters types.KubeClusters diff --git a/lib/srv/discovery/fetchers/gke.go b/lib/srv/discovery/fetchers/gke.go index acccc17835347..67da404a4849d 100644 --- a/lib/srv/discovery/fetchers/gke.go +++ b/lib/srv/discovery/fetchers/gke.go @@ -84,6 +84,7 @@ func (a *gkeFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) return nil, trace.Wrap(err) } + a.rewriteKubeClusters(clusters) return clusters.AsResources(), nil } @@ -108,6 +109,11 @@ func (a *gkeFetcher) getGKEClusters(ctx context.Context) (types.KubeClusters, er return clusters, trace.Wrap(err) } +// rewriteKubeClusters rewrites the discovered kube clusters. +func (a *gkeFetcher) rewriteKubeClusters(clusters types.KubeClusters) { + // no-op +} + func (a *gkeFetcher) ResourceType() string { return types.KindKubernetesCluster }