Skip to content

Commit

Permalink
Add vpc-id label to RDS Auto-Discover converter (#35890)
Browse files Browse the repository at this point in the history
This PR adds a new label to RDS Instances and RDS Clusters to identify
its VPC ID.

For Clusters, an additional fetch was necessary because RDS
DescribeDBCluster does not return its VPC. So, a DescribeDBInstance must
be performed to obtain that information from one of its members.

This will allow us to deploy a DatabaseService in a specific VPC and
only proxy the Databases that also belong to that VPC.
  • Loading branch information
marcoandredinis committed Dec 20, 2023
1 parent 4728af3 commit 78c49cd
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 52 deletions.
26 changes: 23 additions & 3 deletions lib/cloud/mocks/aws_rds.go
Expand Up @@ -320,10 +320,15 @@ func applyInstanceFilters(in []*rds.DBInstance, filters []*rds.Filter) ([]*rds.D
}
var out []*rds.DBInstance
efs := engineFilterSet(filters)
clusterIDs := clusterIdentifierFilterSet(filters)
for _, instance := range in {
if instanceEngineMatches(instance, efs) {
out = append(out, instance)
if len(efs) > 0 && !instanceEngineMatches(instance, efs) {
continue
}
if len(clusterIDs) > 0 && !instanceClusterIDMatches(instance, clusterIDs) {
continue
}
out = append(out, instance)
}
return out, nil
}
Expand All @@ -345,9 +350,18 @@ func applyClusterFilters(in []*rds.DBCluster, filters []*rds.Filter) ([]*rds.DBC

// engineFilterSet builds a string set of engine names from a list of RDS filters.
func engineFilterSet(filters []*rds.Filter) map[string]struct{} {
return filterValues(filters, "engine")
}

// clusterIdentifierFilterSet builds a string set of ClusterIDs from a list of RDS filters.
func clusterIdentifierFilterSet(filters []*rds.Filter) map[string]struct{} {
return filterValues(filters, "db-cluster-id")
}

func filterValues(filters []*rds.Filter, filterKey string) map[string]struct{} {
out := make(map[string]struct{})
for _, f := range filters {
if aws.StringValue(f.Name) != "engine" {
if aws.StringValue(f.Name) != filterKey {
continue
}
for _, v := range f.Values {
Expand All @@ -363,6 +377,12 @@ func instanceEngineMatches(instance *rds.DBInstance, filterSet map[string]struct
return ok
}

// instanceClusterIDMatches returns whether an RDS DBInstance ClusterID matches any ClusterID in a filter set.
func instanceClusterIDMatches(instance *rds.DBInstance, filterSet map[string]struct{}) bool {
_, ok := filterSet[aws.StringValue(instance.DBClusterIdentifier)]
return ok
}

// clusterEngineMatches returns whether an RDS DBCluster engine matches any engine name in a filter set.
func clusterEngineMatches(cluster *rds.DBCluster, filterSet map[string]struct{}) bool {
_, ok := filterSet[aws.StringValue(cluster.Engine)]
Expand Down
1 change: 1 addition & 0 deletions lib/integrations/awsoidc/listdatabases_test.go
Expand Up @@ -316,6 +316,7 @@ func TestListDatabases(t *testing.T) {
"engine-version": "",
"region": "",
"status": "available",
"vpc-id": "vpc-999",
"teleport.dev/cloud": "AWS",
},
},
Expand Down
38 changes: 25 additions & 13 deletions lib/services/database.go
Expand Up @@ -698,6 +698,9 @@ func labelsFromRDSV2Instance(rdsInstance *rdsTypesV2.DBInstance, meta *types.AWS
labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(rdsInstance.EngineVersion)
labels[types.DiscoveryLabelEndpointType] = string(RDSEndpointTypeInstance)
labels[types.DiscoveryLabelStatus] = aws.StringValue(rdsInstance.DBInstanceStatus)
if rdsInstance.DBSubnetGroup != nil {
labels[types.DiscoveryLabelVPCID] = aws.StringValue(rdsInstance.DBSubnetGroup.VpcId)
}
return addLabels(labels, libcloudaws.TagsToLabels(rdsInstance.TagList))
}

Expand All @@ -720,7 +723,7 @@ func NewDatabaseFromRDSV2Cluster(cluster *rdsTypesV2.DBCluster, firstInstance *r
return types.NewDatabaseV3(
setAWSDBName(types.Metadata{
Description: fmt.Sprintf("Aurora cluster in %v", metadata.Region),
Labels: labelsFromRDSV2Cluster(cluster, metadata, RDSEndpointTypePrimary),
Labels: labelsFromRDSV2Cluster(cluster, metadata, RDSEndpointTypePrimary, firstInstance),
}, aws.StringValue(cluster.DBClusterIdentifier)),
types.DatabaseSpecV3{
Protocol: protocol,
Expand Down Expand Up @@ -777,17 +780,20 @@ func MetadataFromRDSV2Cluster(rdsCluster *rdsTypesV2.DBCluster, rdsInstance *rds

// labelsFromRDSV2Cluster creates database labels for the provided RDS cluster.
// It uses aws sdk v2.
func labelsFromRDSV2Cluster(rdsCluster *rdsTypesV2.DBCluster, meta *types.AWS, endpointType RDSEndpointType) map[string]string {
func labelsFromRDSV2Cluster(rdsCluster *rdsTypesV2.DBCluster, meta *types.AWS, endpointType RDSEndpointType, memberInstance *rdsTypesV2.DBInstance) map[string]string {
labels := labelsFromAWSMetadata(meta)
labels[types.DiscoveryLabelEngine] = aws.StringValue(rdsCluster.Engine)
labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(rdsCluster.EngineVersion)
labels[types.DiscoveryLabelEndpointType] = string(endpointType)
labels[types.DiscoveryLabelStatus] = aws.StringValue(rdsCluster.Status)
if memberInstance != nil && memberInstance.DBSubnetGroup != nil {
labels[types.DiscoveryLabelVPCID] = aws.StringValue(memberInstance.DBSubnetGroup.VpcId)
}
return addLabels(labels, libcloudaws.TagsToLabels(rdsCluster.TagList))
}

// NewDatabaseFromRDSCluster creates a database resource from an RDS cluster (Aurora).
func NewDatabaseFromRDSCluster(cluster *rds.DBCluster) (types.Database, error) {
func NewDatabaseFromRDSCluster(cluster *rds.DBCluster, memberInstances []*rds.DBInstance) (types.Database, error) {
metadata, err := MetadataFromRDSCluster(cluster)
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -799,7 +805,7 @@ func NewDatabaseFromRDSCluster(cluster *rds.DBCluster) (types.Database, error) {
return types.NewDatabaseV3(
setAWSDBName(types.Metadata{
Description: fmt.Sprintf("Aurora cluster in %v", metadata.Region),
Labels: labelsFromRDSCluster(cluster, metadata, RDSEndpointTypePrimary),
Labels: labelsFromRDSCluster(cluster, metadata, RDSEndpointTypePrimary, memberInstances),
}, aws.StringValue(cluster.DBClusterIdentifier)),
types.DatabaseSpecV3{
Protocol: protocol,
Expand All @@ -809,7 +815,7 @@ func NewDatabaseFromRDSCluster(cluster *rds.DBCluster) (types.Database, error) {
}

// NewDatabaseFromRDSClusterReaderEndpoint creates a database resource from an RDS cluster reader endpoint (Aurora).
func NewDatabaseFromRDSClusterReaderEndpoint(cluster *rds.DBCluster) (types.Database, error) {
func NewDatabaseFromRDSClusterReaderEndpoint(cluster *rds.DBCluster, memberInstances []*rds.DBInstance) (types.Database, error) {
metadata, err := MetadataFromRDSCluster(cluster)
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -821,7 +827,7 @@ func NewDatabaseFromRDSClusterReaderEndpoint(cluster *rds.DBCluster) (types.Data
return types.NewDatabaseV3(
setAWSDBName(types.Metadata{
Description: fmt.Sprintf("Aurora cluster in %v (%v endpoint)", metadata.Region, string(RDSEndpointTypeReader)),
Labels: labelsFromRDSCluster(cluster, metadata, RDSEndpointTypeReader),
Labels: labelsFromRDSCluster(cluster, metadata, RDSEndpointTypeReader, memberInstances),
}, aws.StringValue(cluster.DBClusterIdentifier), string(RDSEndpointTypeReader)),
types.DatabaseSpecV3{
Protocol: protocol,
Expand All @@ -831,7 +837,7 @@ func NewDatabaseFromRDSClusterReaderEndpoint(cluster *rds.DBCluster) (types.Data
}

// NewDatabasesFromRDSClusterCustomEndpoints creates database resources from RDS cluster custom endpoints (Aurora).
func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster) (types.Databases, error) {
func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster, memberInstances []*rds.DBInstance) (types.Databases, error) {
metadata, err := MetadataFromRDSCluster(cluster)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -859,7 +865,7 @@ func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster) (types.Da
database, err := types.NewDatabaseV3(
setAWSDBName(types.Metadata{
Description: fmt.Sprintf("Aurora cluster in %v (%v endpoint)", metadata.Region, string(RDSEndpointTypeCustom)),
Labels: labelsFromRDSCluster(cluster, metadata, RDSEndpointTypeCustom),
Labels: labelsFromRDSCluster(cluster, metadata, RDSEndpointTypeCustom, memberInstances),
}, aws.StringValue(cluster.DBClusterIdentifier), string(RDSEndpointTypeCustom), endpointDetails.ClusterCustomEndpointName),
types.DatabaseSpecV3{
Protocol: protocol,
Expand All @@ -885,7 +891,7 @@ func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster) (types.Da

// NewDatabasesFromRDSCluster creates all database resources from an RDS Aurora
// cluster.
func NewDatabasesFromRDSCluster(cluster *rds.DBCluster) (types.Databases, error) {
func NewDatabasesFromRDSCluster(cluster *rds.DBCluster, memberInstances []*rds.DBInstance) (types.Databases, error) {
var errors []error
var databases types.Databases

Expand All @@ -906,7 +912,7 @@ func NewDatabasesFromRDSCluster(cluster *rds.DBCluster) (types.Databases, error)

// Add a database from primary endpoint, if any writer instances.
if cluster.Endpoint != nil && hasWriterInstance {
database, err := NewDatabaseFromRDSCluster(cluster)
database, err := NewDatabaseFromRDSCluster(cluster, memberInstances)
if err != nil {
errors = append(errors, err)
} else {
Expand All @@ -917,7 +923,7 @@ func NewDatabasesFromRDSCluster(cluster *rds.DBCluster) (types.Databases, error)
// Add a database from reader endpoint, if any reader instances.
// https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/Aurora.Overview.Endpoints.html#Aurora.Endpoints.Reader
if cluster.ReaderEndpoint != nil && hasReaderInstance {
database, err := NewDatabaseFromRDSClusterReaderEndpoint(cluster)
database, err := NewDatabaseFromRDSClusterReaderEndpoint(cluster, memberInstances)
if err != nil {
errors = append(errors, err)
} else {
Expand All @@ -927,7 +933,7 @@ func NewDatabasesFromRDSCluster(cluster *rds.DBCluster) (types.Databases, error)

// Add databases from custom endpoints
if len(cluster.CustomEndpoints) > 0 {
customEndpointDatabases, err := NewDatabasesFromRDSClusterCustomEndpoints(cluster)
customEndpointDatabases, err := NewDatabasesFromRDSClusterCustomEndpoints(cluster, memberInstances)
if err != nil {
errors = append(errors, err)
}
Expand Down Expand Up @@ -1626,15 +1632,21 @@ func labelsFromRDSInstance(rdsInstance *rds.DBInstance, meta *types.AWS) map[str
labels[types.DiscoveryLabelEngine] = aws.StringValue(rdsInstance.Engine)
labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(rdsInstance.EngineVersion)
labels[types.DiscoveryLabelEndpointType] = string(RDSEndpointTypeInstance)
if rdsInstance.DBSubnetGroup != nil {
labels[types.DiscoveryLabelVPCID] = aws.StringValue(rdsInstance.DBSubnetGroup.VpcId)
}
return addLabels(labels, libcloudaws.TagsToLabels(rdsInstance.TagList))
}

// labelsFromRDSCluster creates database labels for the provided RDS cluster.
func labelsFromRDSCluster(rdsCluster *rds.DBCluster, meta *types.AWS, endpointType RDSEndpointType) map[string]string {
func labelsFromRDSCluster(rdsCluster *rds.DBCluster, meta *types.AWS, endpointType RDSEndpointType, memberInstances []*rds.DBInstance) map[string]string {
labels := labelsFromAWSMetadata(meta)
labels[types.DiscoveryLabelEngine] = aws.StringValue(rdsCluster.Engine)
labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(rdsCluster.EngineVersion)
labels[types.DiscoveryLabelEndpointType] = string(endpointType)
if len(memberInstances) > 0 && memberInstances[0].DBSubnetGroup != nil {
labels[types.DiscoveryLabelVPCID] = aws.StringValue(memberInstances[0].DBSubnetGroup.VpcId)
}
return addLabels(labels, libcloudaws.TagsToLabels(rdsCluster.TagList))
}

Expand Down
27 changes: 19 additions & 8 deletions lib/services/database_test.go
Expand Up @@ -789,6 +789,7 @@ func TestDatabaseFromRDSV2Instance(t *testing.T) {
types.DiscoveryLabelEngineVersion: "13.0",
types.DiscoveryLabelEndpointType: "instance",
types.DiscoveryLabelStatus: "available",
types.DiscoveryLabelVPCID: "vpc-asd",
"key": "val",
},
}, types.DatabaseSpecV3{
Expand Down Expand Up @@ -893,6 +894,8 @@ func TestDatabaseFromRDSInstanceNameOverride(t *testing.T) {

// TestDatabaseFromRDSCluster tests converting an RDS cluster to a database resource.
func TestDatabaseFromRDSCluster(t *testing.T) {
vpcid := uuid.NewString()
dbInstanceMembers := []*rds.DBInstance{{DBSubnetGroup: &rds.DBSubnetGroup{VpcId: aws.String(vpcid)}}}
cluster := &rds.DBCluster{
DBClusterArn: aws.String("arn:aws:rds:us-east-1:123456789012:cluster:cluster-1"),
DBClusterIdentifier: aws.String("cluster-1"),
Expand Down Expand Up @@ -934,6 +937,7 @@ func TestDatabaseFromRDSCluster(t *testing.T) {
types.DiscoveryLabelEngine: RDSEngineAuroraMySQL,
types.DiscoveryLabelEngineVersion: "8.0.0",
types.DiscoveryLabelEndpointType: "primary",
types.DiscoveryLabelVPCID: vpcid,
"key": "val",
},
}, types.DatabaseSpecV3{
Expand All @@ -942,7 +946,7 @@ func TestDatabaseFromRDSCluster(t *testing.T) {
AWS: expectedAWS,
})
require.NoError(t, err)
actual, err := NewDatabaseFromRDSCluster(cluster)
actual, err := NewDatabaseFromRDSCluster(cluster, dbInstanceMembers)
require.NoError(t, err)
require.Empty(t, cmp.Diff(expected, actual))
})
Expand All @@ -958,6 +962,7 @@ func TestDatabaseFromRDSCluster(t *testing.T) {
types.DiscoveryLabelEngine: RDSEngineAuroraMySQL,
types.DiscoveryLabelEngineVersion: "8.0.0",
types.DiscoveryLabelEndpointType: "reader",
types.DiscoveryLabelVPCID: vpcid,
"key": "val",
},
}, types.DatabaseSpecV3{
Expand All @@ -966,7 +971,7 @@ func TestDatabaseFromRDSCluster(t *testing.T) {
AWS: expectedAWS,
})
require.NoError(t, err)
actual, err := NewDatabaseFromRDSClusterReaderEndpoint(cluster)
actual, err := NewDatabaseFromRDSClusterReaderEndpoint(cluster, dbInstanceMembers)
require.NoError(t, err)
require.Empty(t, cmp.Diff(expected, actual))
})
Expand All @@ -979,6 +984,7 @@ func TestDatabaseFromRDSCluster(t *testing.T) {
types.DiscoveryLabelEngine: RDSEngineAuroraMySQL,
types.DiscoveryLabelEngineVersion: "8.0.0",
types.DiscoveryLabelEndpointType: "custom",
types.DiscoveryLabelVPCID: vpcid,
"key": "val",
}

Expand Down Expand Up @@ -1010,7 +1016,7 @@ func TestDatabaseFromRDSCluster(t *testing.T) {
})
require.NoError(t, err)

databases, err := NewDatabasesFromRDSClusterCustomEndpoints(cluster)
databases, err := NewDatabasesFromRDSClusterCustomEndpoints(cluster, dbInstanceMembers)
require.NoError(t, err)
require.Equal(t, types.Databases{expectedMyEndpoint1, expectedMyEndpoint2}, databases)
})
Expand All @@ -1021,7 +1027,7 @@ func TestDatabaseFromRDSCluster(t *testing.T) {
aws.String("badendpoint1"),
aws.String("badendpoint2"),
}
_, err := NewDatabasesFromRDSClusterCustomEndpoints(&badCluster)
_, err := NewDatabasesFromRDSClusterCustomEndpoints(&badCluster, dbInstanceMembers)
require.Error(t, err)
})
}
Expand Down Expand Up @@ -1127,6 +1133,7 @@ func TestDatabaseFromRDSV2Cluster(t *testing.T) {
types.DiscoveryLabelEngineVersion: "8.0.0",
types.DiscoveryLabelEndpointType: "primary",
types.DiscoveryLabelStatus: "available",
types.DiscoveryLabelVPCID: "vpc-123",
"key": "val",
},
}, types.DatabaseSpecV3{
Expand All @@ -1153,6 +1160,7 @@ func TestDatabaseFromRDSV2Cluster(t *testing.T) {

// TestDatabaseFromRDSClusterNameOverride tests converting an RDS cluster to a database resource with overridden name.
func TestDatabaseFromRDSClusterNameOverride(t *testing.T) {
dbInstanceMembers := []*rds.DBInstance{{DBSubnetGroup: &rds.DBSubnetGroup{VpcId: aws.String("vpc-123")}}}
for _, overrideLabel := range types.AWSDatabaseNameOverrideLabels {
cluster := &rds.DBCluster{
DBClusterArn: aws.String("arn:aws:rds:us-east-1:123456789012:cluster:cluster-1"),
Expand Down Expand Up @@ -1195,6 +1203,7 @@ func TestDatabaseFromRDSClusterNameOverride(t *testing.T) {
types.DiscoveryLabelEngine: RDSEngineAuroraMySQL,
types.DiscoveryLabelEngineVersion: "8.0.0",
types.DiscoveryLabelEndpointType: "primary",
types.DiscoveryLabelVPCID: "vpc-123",
overrideLabel: "mycluster-2",
"key": "val",
},
Expand All @@ -1204,7 +1213,7 @@ func TestDatabaseFromRDSClusterNameOverride(t *testing.T) {
AWS: expectedAWS,
})
require.NoError(t, err)
actual, err := NewDatabaseFromRDSCluster(cluster)
actual, err := NewDatabaseFromRDSCluster(cluster, dbInstanceMembers)
require.NoError(t, err)
require.Empty(t, cmp.Diff(expected, actual))
})
Expand All @@ -1220,6 +1229,7 @@ func TestDatabaseFromRDSClusterNameOverride(t *testing.T) {
types.DiscoveryLabelEngine: RDSEngineAuroraMySQL,
types.DiscoveryLabelEngineVersion: "8.0.0",
types.DiscoveryLabelEndpointType: "reader",
types.DiscoveryLabelVPCID: "vpc-123",
overrideLabel: "mycluster-2",
"key": "val",
},
Expand All @@ -1229,7 +1239,7 @@ func TestDatabaseFromRDSClusterNameOverride(t *testing.T) {
AWS: expectedAWS,
})
require.NoError(t, err)
actual, err := NewDatabaseFromRDSClusterReaderEndpoint(cluster)
actual, err := NewDatabaseFromRDSClusterReaderEndpoint(cluster, dbInstanceMembers)
require.NoError(t, err)
require.Empty(t, cmp.Diff(expected, actual))
})
Expand All @@ -1242,6 +1252,7 @@ func TestDatabaseFromRDSClusterNameOverride(t *testing.T) {
types.DiscoveryLabelEngine: RDSEngineAuroraMySQL,
types.DiscoveryLabelEngineVersion: "8.0.0",
types.DiscoveryLabelEndpointType: "custom",
types.DiscoveryLabelVPCID: "vpc-123",
overrideLabel: "mycluster-2",
"key": "val",
}
Expand Down Expand Up @@ -1274,7 +1285,7 @@ func TestDatabaseFromRDSClusterNameOverride(t *testing.T) {
})
require.NoError(t, err)

databases, err := NewDatabasesFromRDSClusterCustomEndpoints(cluster)
databases, err := NewDatabasesFromRDSClusterCustomEndpoints(cluster, dbInstanceMembers)
require.NoError(t, err)
require.Equal(t, types.Databases{expectedMyEndpoint1, expectedMyEndpoint2}, databases)
})
Expand All @@ -1285,7 +1296,7 @@ func TestDatabaseFromRDSClusterNameOverride(t *testing.T) {
aws.String("badendpoint1"),
aws.String("badendpoint2"),
}
_, err := NewDatabasesFromRDSClusterCustomEndpoints(&badCluster)
_, err := NewDatabasesFromRDSClusterCustomEndpoints(&badCluster, dbInstanceMembers)
require.Error(t, err)
})
}
Expand Down
3 changes: 2 additions & 1 deletion lib/srv/db/cloud/resource_checker_url_aws.go
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/opensearchservice"
"github.com/aws/aws-sdk-go/service/rds"
"github.com/aws/aws-sdk-go/service/rds/rdsiface"
"github.com/aws/aws-sdk-go/service/redshiftserverless/redshiftserverlessiface"
"github.com/gravitational/trace"
Expand Down Expand Up @@ -98,7 +99,7 @@ func (c *urlChecker) checkRDSCluster(ctx context.Context, database types.Databas
if err != nil {
return trace.Wrap(err)
}
databases, err := services.NewDatabasesFromRDSCluster(rdsCluster)
databases, err := services.NewDatabasesFromRDSCluster(rdsCluster, []*rds.DBInstance{})
if err != nil {
c.log.Warnf("Could not convert RDS cluster %q to database resources: %v.",
aws.StringValue(rdsCluster.DBClusterIdentifier), err)
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/db/cloud/resource_checker_url_aws_test.go
Expand Up @@ -53,7 +53,7 @@ func TestURLChecker_AWS(t *testing.T) {
mocks.WithRDSClusterReader,
mocks.WithRDSClusterCustomEndpoint("my-custom"),
)
rdsClusterDBs, err := services.NewDatabasesFromRDSCluster(rdsCluster)
rdsClusterDBs, err := services.NewDatabasesFromRDSCluster(rdsCluster, []*rds.DBInstance{})
require.NoError(t, err)
require.Len(t, rdsClusterDBs, 3) // Primary, reader, custom.
testCases = append(testCases, append(rdsClusterDBs, rdsInstanceDB)...)
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/discovery/common/renaming_test.go
Expand Up @@ -376,7 +376,7 @@ func makeAuroraPrimaryDB(t *testing.T, name, region, accountID, overrideLabel st
overrideLabel: name,
}),
}
database, err := services.NewDatabaseFromRDSCluster(cluster)
database, err := services.NewDatabaseFromRDSCluster(cluster, []*rds.DBInstance{})
require.NoError(t, err)
return database
}
Expand Down

0 comments on commit 78c49cd

Please sign in to comment.