Skip to content

Commit

Permalink
OCM-6592 | fix: filter out local-zone and wave-length zone associated…
Browse files Browse the repository at this point in the history
… subnets during 'rosa create machinepool (HCP)'
  • Loading branch information
davidleerh committed Apr 16, 2024
1 parent a43c113 commit e0fd2b9
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 28 deletions.
3 changes: 2 additions & 1 deletion cmd/create/machinepool/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ func getSubnetFromUser(cmd *cobra.Command, r *rosa.Runtime, isSubnetSet bool, cl
// getSubnetOptions gets one of the cluster subnets and returns a slice of formatted VPC's private subnets.
func getSubnetOptions(r *rosa.Runtime, cluster *cmv1.Cluster) ([]string, error) {
// Fetch VPC's subnets
privateSubnets, err := r.AWSClient.GetVPCPrivateSubnets(cluster.AWS().SubnetIDs()[0])
privateSubnets, err := r.AWSClient.GetVPCPrivateSubnets(
cluster.Hypershift().Enabled(), cluster.AWS().SubnetIDs()[0])
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/create/machinepool/nodepool.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ func addNodePool(cmd *cobra.Command, clusterKey string, cluster *cmv1.Cluster, r
func getSubnetFromAvailabilityZone(cmd *cobra.Command, r *rosa.Runtime, isAvailabilityZoneSet bool,
cluster *cmv1.Cluster) (string, error) {

privateSubnets, err := r.AWSClient.GetVPCPrivateSubnets(cluster.AWS().SubnetIDs()[0])
privateSubnets, err := r.AWSClient.GetVPCPrivateSubnets(true, cluster.AWS().SubnetIDs()[0])
if err != nil {
return "", err
}
Expand Down
68 changes: 55 additions & 13 deletions pkg/aws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ const (

// Since CloudFormation stacks are region-dependent, we hard-code OCM's default region and
// then use it to ensure that the user always gets the stack from the same region.
DefaultRegion = "us-east-1"
Inline = "inline"
Attached = "attached"

DefaultRegion = "us-east-1"
Inline = "inline"
Attached = "attached"
standardZone = "availability-zone"
LocalZone = "local-zone"
WavelengthZone = "wavelength-zone"

Expand Down Expand Up @@ -111,8 +111,9 @@ type Client interface {
GetSubnetAvailabilityZone(subnetID string) (string, error)
GetAvailabilityZoneType(availabilityZoneName string) (string, error)
GetVPCSubnets(subnetID string) ([]ec2types.Subnet, error)
GetVPCPrivateSubnets(subnetID string) ([]ec2types.Subnet, error)
FilterVPCsPrivateSubnets(subnets []ec2types.Subnet) ([]ec2types.Subnet, error)
GetVPCPrivateSubnets(isHostedCp bool, subnetID string) ([]ec2types.Subnet, error)
FilterVPCsPrivateSubnets(isHostedCp bool, subnets []ec2types.Subnet) ([]ec2types.Subnet, error)
FilterSubnetsWithStandardAvailabilityZones(subnets []ec2types.Subnet) ([]ec2types.Subnet, error)
ValidateQuota() (bool, error)
TagUserRegion(username string, region string) error
GetClusterRegionTagForUser(username string) (string, error)
Expand Down Expand Up @@ -535,13 +536,13 @@ func (c *awsClient) GetSubnetAvailabilityZone(subnetID string) (string, error) {
return *res.Subnets[0].AvailabilityZone, nil
}

func (c *awsClient) GetVPCPrivateSubnets(subnetID string) ([]ec2types.Subnet, error) {
func (c *awsClient) GetVPCPrivateSubnets(isHostedCp bool, subnetID string) ([]ec2types.Subnet, error) {
subnets, err := c.GetVPCSubnets(subnetID)
if err != nil {
return nil, err
}

return c.FilterVPCsPrivateSubnets(subnets)
return c.FilterVPCsPrivateSubnets(isHostedCp, subnets)
}

// getVPCSubnets gets a subnet ID and fetches all the subnets that belong to the same VPC as the provided subnet.
Expand Down Expand Up @@ -584,7 +585,7 @@ func (c *awsClient) GetVPCSubnets(subnetID string) ([]ec2types.Subnet, error) {

// FilterPrivateSubnets gets a slice of subnets that belongs to the same VPC and filters the private subnets.
// Assumption: subnets - non-empty slice.
func (c *awsClient) FilterVPCsPrivateSubnets(subnets []ec2types.Subnet) ([]ec2types.Subnet, error) {
func (c *awsClient) FilterVPCsPrivateSubnets(isHostedCp bool, subnets []ec2types.Subnet) ([]ec2types.Subnet, error) {
// Fetch VPC route tables
vpcID := subnets[0].VpcId
describeRouteTablesOutput, err := c.ec2Client.DescribeRouteTables(context.Background(), &ec2.DescribeRouteTablesInput{
Expand Down Expand Up @@ -613,13 +614,54 @@ func (c *awsClient) FilterVPCsPrivateSubnets(subnets []ec2types.Subnet) ([]ec2ty
}
}

// Temporary filter until local-zone and wave-length zones are supported by HCP
if isHostedCp {
privateSubnets, err = c.FilterSubnetsWithStandardAvailabilityZones(privateSubnets)
if err != nil {
return nil, err
}
}

if len(privateSubnets) < 1 {
return nil, fmt.Errorf("failed to find private subnets associated with VPC '%s'", *subnets[0].VpcId)
}

return privateSubnets, nil
}

func (c *awsClient) FilterSubnetsWithStandardAvailabilityZones(subnets []ec2types.Subnet) ([]ec2types.Subnet, error) {
filteredSubnets := []ec2types.Subnet{}
subnetAvailabilityZonesList := []string{}
standardAvailabilityZonesMap := map[string]bool{}

// list of availability zones to pass as input
for _, subnet := range subnets {
subnetAvailabilityZonesList = append(subnetAvailabilityZonesList, *subnet.AvailabilityZone)
}

describeAvailabilityZonesOutput, err := c.ec2Client.DescribeAvailabilityZones(
context.Background(),
&ec2.DescribeAvailabilityZonesInput{ZoneNames: subnetAvailabilityZonesList})
if err != nil {
return filteredSubnets, err
}

for _, availabilityZone := range describeAvailabilityZonesOutput.AvailabilityZones {
if *availabilityZone.ZoneType == standardZone {
standardAvailabilityZonesMap[*availabilityZone.ZoneName] = true
}
}

for _, subnet := range subnets {
_, exists := standardAvailabilityZonesMap[*subnet.AvailabilityZone]
if exists {
filteredSubnets = append(filteredSubnets, subnet)
}
}

return filteredSubnets, nil
}

// isPublicSubnet a public subnet is a subnet that's associated with a route table that has a route to an
// internet gateway
func (c *awsClient) isPublicSubnet(subnetID *string, routeTables []ec2types.RouteTable) (bool, error) {
Expand All @@ -638,12 +680,12 @@ func (c *awsClient) isPublicSubnet(subnetID *string, routeTables []ec2types.Rout
}

func (c *awsClient) getSubnetRouteTable(subnetID *string,
routeTables []ec2types.RouteTable) (*ec2types.RouteTable, error) {
routeTables []ec2types.RouteTable) (ec2types.RouteTable, error) {
// Subnet route table — A route table that's associated with a subnet
for _, routeTable := range routeTables {
for _, association := range routeTable.Associations {
if aws.ToString(association.SubnetId) == aws.ToString(subnetID) {
return &routeTable, nil
return routeTable, nil
}
}
}
Expand All @@ -653,13 +695,13 @@ func (c *awsClient) getSubnetRouteTable(subnetID *string,
for _, routeTable := range routeTables {
for _, association := range routeTable.Associations {
if aws.ToBool(association.Main) {
return &routeTable, nil
return routeTable, nil
}
}
}

// Each subnet in the VPC must be associated with a route table
return nil, fmt.Errorf("failed to find subnet '%s' route table", *subnetID)
return ec2types.RouteTable{}, fmt.Errorf("failed to find subnet '%s' route table", *subnetID)
}

// getSubnetIDs will return the list of subnetsIDs supported for the region picked.
Expand Down
101 changes: 101 additions & 0 deletions pkg/aws/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -750,4 +750,105 @@ var _ = Describe("Client", func() {
),
)
})

Describe("Filters local-zone and wavelength-zone subnets for HCP node pool", func() {
mockSubnets := []ec2types.Subnet{
{
SubnetId: awsSdk.String("subnet-mockid-1"),
AvailabilityZone: awsSdk.String("us-east-1a"),
VpcId: awsSdk.String("1234"),
},
// wavelength-zone
{
SubnetId: awsSdk.String("subnet-mockid-2"),
AvailabilityZone: awsSdk.String("us-east-1-wl1-atl-wlz-1"),
VpcId: awsSdk.String("1234"),
},
// local-zone
{
SubnetId: awsSdk.String("subnet-mockid-3"),
AvailabilityZone: awsSdk.String("us-east-1-bos-1a"),
VpcId: awsSdk.String("1234"),
},
}

mockDescribeAvailabilityZoneOutput := &ec2.DescribeAvailabilityZonesOutput{
AvailabilityZones: []ec2types.AvailabilityZone{
{
ZoneName: awsSdk.String("us-east-1a"),
ZoneType: awsSdk.String("availability-zone"),
},
{
ZoneName: awsSdk.String("us-east-1-wl1-atl-wlz-1"),
ZoneType: awsSdk.String("wavelength-zone"),
},
{
ZoneName: awsSdk.String("us-east-1-bos-1a"),
ZoneType: awsSdk.String("local-zone"),
},
},
}

mockDescribeRouteTablesOutput := &ec2.DescribeRouteTablesOutput{RouteTables: []ec2types.RouteTable{
{Routes: []ec2types.Route{{GatewayId: awsSdk.String("private-1")}},
Associations: []ec2types.RouteTableAssociation{
{SubnetId: awsSdk.String("subnet-mockid-1")},
{SubnetId: awsSdk.String("subnet-mockid-2")},
{SubnetId: awsSdk.String("subnet-mockid-3")},
},
},
},
}

Context("FilterSubnetsWithStandardAvailabilityZones", func() {
It("Filters out subnets with local-zone, wavelength-zone availability zones", func() {

mockEC2API.EXPECT().DescribeAvailabilityZones(gomock.Any(), gomock.Any()).Return(
mockDescribeAvailabilityZoneOutput, nil)

filteredSubnets, err := client.FilterSubnetsWithStandardAvailabilityZones(mockSubnets)
Expect(err).NotTo(HaveOccurred())

Expect(len(filteredSubnets)).To(Equal(1))
Expect(*filteredSubnets[0].AvailabilityZone).To(Equal("us-east-1a"))
})
})

Context("FilterVPCsPrivateSubnets", func() {
It("Removes subnets with local-zone az, wavelength-zone az for HCP private subnets filter", func() {

mockEC2API.EXPECT().DescribeRouteTables(gomock.Any(), gomock.Any()).Return(
mockDescribeRouteTablesOutput, nil)

mockEC2API.EXPECT().DescribeAvailabilityZones(gomock.Any(), gomock.Any()).Return(
mockDescribeAvailabilityZoneOutput, nil)

// HCP flag set to true
filteredPrivateSubnets, err := client.FilterVPCsPrivateSubnets(true, mockSubnets)
Expect(err).NotTo(HaveOccurred())
Expect(len(filteredPrivateSubnets)).To(Equal(1))
Expect(*filteredPrivateSubnets[0].AvailabilityZone).To(Equal("us-east-1a"))
})

It("Keeps subnets with local-zone az, wavelength-zone az for non-HCP private subnets filter", func() {

mockEC2API.EXPECT().DescribeRouteTables(gomock.Any(), gomock.Any()).Return(
mockDescribeRouteTablesOutput, nil)

// HCP flag set to false
filteredPrivateSubnets, err := client.FilterVPCsPrivateSubnets(false, mockSubnets)
Expect(err).NotTo(HaveOccurred())
Expect(len(filteredPrivateSubnets)).To(Equal(3))

// temp map to verify subnet exists
mapToVerifySubnets := map[string]bool{}
for _, privateSubnet := range filteredPrivateSubnets {
mapToVerifySubnets[*privateSubnet.SubnetId] = true
}
for _, subnet := range mockSubnets {
Expect(mapToVerifySubnets[*subnet.SubnetId]).To(BeTrue())
}
})
})
})
})
35 changes: 25 additions & 10 deletions pkg/aws/mock_client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/ocm/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ func ValidateHostedClusterSubnets(awsClient aws.Client, isPrivate bool, subnetID
}
}

privateSubnets, privateSubnetsErr := awsClient.FilterVPCsPrivateSubnets(subnets)
privateSubnets, privateSubnetsErr := awsClient.FilterVPCsPrivateSubnets(true, subnets)
if privateSubnetsErr != nil {
return 0, privateSubnetsErr
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/ocm/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,12 +417,12 @@ var _ = Describe("ValidateHostedClusterSubnets for Private Cluster", func() {
mockClient.EXPECT().GetVPCSubnets(gomock.Any()).Return(subnets, nil).AnyTimes()
})
It("should not return an error when only private subnets are present for a private cluster", func() {
mockClient.EXPECT().FilterVPCsPrivateSubnets(gomock.Any()).Return([]ec2types.Subnet{subnets[1]}, nil)
mockClient.EXPECT().FilterVPCsPrivateSubnets(gomock.Any(), gomock.Any()).Return([]ec2types.Subnet{subnets[1]}, nil)
_, err := ValidateHostedClusterSubnets(mockClient, true, []string{"subnet-private-2"})
Expect(err).NotTo(HaveOccurred())
})
It("should return an error when public subnets are present for a private cluster", func() {
mockClient.EXPECT().FilterVPCsPrivateSubnets(gomock.Any()).Return([]ec2types.Subnet{}, nil)
mockClient.EXPECT().FilterVPCsPrivateSubnets(gomock.Any(), gomock.Any()).Return([]ec2types.Subnet{}, nil)
_, err := ValidateHostedClusterSubnets(mockClient, true, ids)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("The number of public subnets for a private hosted cluster should be zero"))
Expand Down

0 comments on commit e0fd2b9

Please sign in to comment.