diff --git a/apis/elbv2/v1alpha1/targetgroupbinding_types.go b/apis/elbv2/v1alpha1/targetgroupbinding_types.go index 8d95e28eea..c21009b457 100644 --- a/apis/elbv2/v1alpha1/targetgroupbinding_types.go +++ b/apis/elbv2/v1alpha1/targetgroupbinding_types.go @@ -79,7 +79,7 @@ type NetworkingPort struct { // The port which traffic must match. // If unspecified, defaults to all port. // +optional - Port *intstr.IntOrString `json:"port,omitempty"` + Port *int64 `json:"port,omitempty"` // The protocol which traffic must match. // If unspecified, defaults to all protocol. diff --git a/apis/elbv2/v1alpha1/zz_generated.deepcopy.go b/apis/elbv2/v1alpha1/zz_generated.deepcopy.go index 89e1fc72d3..36ff7e2c0a 100644 --- a/apis/elbv2/v1alpha1/zz_generated.deepcopy.go +++ b/apis/elbv2/v1alpha1/zz_generated.deepcopy.go @@ -22,7 +22,6 @@ package v1alpha1 import ( runtime "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/util/intstr" ) // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. @@ -99,7 +98,7 @@ func (in *NetworkingPort) DeepCopyInto(out *NetworkingPort) { *out = *in if in.Port != nil { in, out := &in.Port, &out.Port - *out = new(intstr.IntOrString) + *out = new(int64) **out = **in } if in.Protocol != nil { diff --git a/config/crd/bases/elbv2.k8s.aws_targetgroupbindings.yaml b/config/crd/bases/elbv2.k8s.aws_targetgroupbindings.yaml index b6fc50e2da..f4777594a0 100644 --- a/config/crd/bases/elbv2.k8s.aws_targetgroupbindings.yaml +++ b/config/crd/bases/elbv2.k8s.aws_targetgroupbindings.yaml @@ -81,12 +81,10 @@ spec: items: properties: port: - anyOf: - - type: integer - - type: string description: The port which traffic must match. If unspecified, defaults to all port. - x-kubernetes-int-or-string: true + format: int64 + type: integer protocol: description: The protocol which traffic must match. If unspecified, defaults to all protocol. diff --git a/controllers/ingress/group_controller.go b/controllers/ingress/group_controller.go index 010b746341..cacf2b09cc 100644 --- a/controllers/ingress/group_controller.go +++ b/controllers/ingress/group_controller.go @@ -20,7 +20,8 @@ import ( const controllerName = "ingress" // NewGroupReconciler constructs new GroupReconciler -func NewGroupReconciler(k8sClient client.Client, eventRecorder record.EventRecorder, ec2Client services.EC2, elbv2Client services.ELBV2, vpcID string, clusterName string, +func NewGroupReconciler(k8sClient client.Client, eventRecorder record.EventRecorder, ec2Client services.EC2, elbv2Client services.ELBV2, + networkingSGManager networkingpkg.SecurityGroupManager, networkingSGReconciler networkingpkg.SecurityGroupReconciler, vpcID string, clusterName string, subnetsResolver networkingpkg.SubnetsResolver, logger logr.Logger) *GroupReconciler { annotationParser := annotations.NewSuffixAnnotationParser("alb.ingress.kubernetes.io") authConfigBuilder := ingress.NewDefaultAuthConfigBuilder(annotationParser) @@ -28,7 +29,7 @@ func NewGroupReconciler(k8sClient client.Client, eventRecorder record.EventRecor modelBuilder := ingress.NewDefaultModelBuilder(k8sClient, eventRecorder, ec2Client, vpcID, clusterName, annotationParser, subnetsResolver, authConfigBuilder, enhancedBackendBuilder) stackMarshaller := deploy.NewDefaultStackMarshaller() - stackDeployer := deploy.NewDefaultStackDeployer(k8sClient, elbv2Client, vpcID, clusterName, "ingress.k8s.aws", logger) + stackDeployer := deploy.NewDefaultStackDeployer(k8sClient, ec2Client, elbv2Client, networkingSGManager, networkingSGReconciler, vpcID, clusterName, "ingress.k8s.aws", logger) groupLoader := ingress.NewDefaultGroupLoader(k8sClient, annotationParser, "alb") finalizerManager := ingress.NewDefaultFinalizerManager(k8sClient) diff --git a/controllers/service/service_controller.go b/controllers/service/service_controller.go index c64cd8ae14..90e05319a1 100644 --- a/controllers/service/service_controller.go +++ b/controllers/service/service_controller.go @@ -28,7 +28,9 @@ const ( controllerName = "service" ) -func NewServiceReconciler(k8sClient client.Client, elbv2Client services.ELBV2, vpcID string, clusterName string, resolver networking.SubnetsResolver, logger logr.Logger) *ServiceReconciler { +func NewServiceReconciler(k8sClient client.Client, ec2Client services.EC2, elbv2Client services.ELBV2, + sgManager networking.SecurityGroupManager, sgReconciler networking.SecurityGroupReconciler, + vpcID string, clusterName string, resolver networking.SubnetsResolver, logger logr.Logger) *ServiceReconciler { return &ServiceReconciler{ k8sClient: k8sClient, logger: logger, @@ -36,7 +38,7 @@ func NewServiceReconciler(k8sClient client.Client, elbv2Client services.ELBV2, v finalizerManager: k8s.NewDefaultFinalizerManager(k8sClient, logger), subnetsResolver: resolver, stackMarshaller: deploy.NewDefaultStackMarshaller(), - stackDeployer: deploy.NewDefaultStackDeployer(k8sClient, elbv2Client, vpcID, clusterName, DefaultTagPrefix, logger), + stackDeployer: deploy.NewDefaultStackDeployer(k8sClient, ec2Client, elbv2Client, sgManager, sgReconciler, vpcID, clusterName, DefaultTagPrefix, logger), } } diff --git a/main.go b/main.go index 07881edd56..9096611915 100644 --- a/main.go +++ b/main.go @@ -101,24 +101,33 @@ func main() { os.Exit(1) } - subnetResolver := networking.NewSubnetsResolver(cloud.EC2(), cloud.VpcID(), k8sClusterName, ctrl.Log.WithName("subnets-resolver")) - ingGroupReconciler := ingress.NewGroupReconciler(mgr.GetClient(), mgr.GetEventRecorderFor("ingress"), cloud.EC2(), cloud.ELBV2(), cloud.VpcID(), k8sClusterName, subnetResolver, ctrl.Log) - if err = ingGroupReconciler.SetupWithManager(mgr); err != nil { - setupLog.Error(err, "unable to create controller", "controller", "Ingress") - os.Exit(1) - } - + podENIResolver := networking.NewDefaultPodENIInfoResolver(cloud.EC2(), cloud.VpcID(), ctrl.Log) + nodeENIResolver := networking.NewDefaultNodeENIInfoResolver(cloud.EC2(), ctrl.Log) + sgManager := networking.NewDefaultSecurityGroupManager(cloud.EC2(), ctrl.Log) + sgReconciler := networking.NewDefaultSecurityGroupReconciler(sgManager, ctrl.Log) finalizerManager := k8s.NewDefaultFinalizerManager(mgr.GetClient(), ctrl.Log) - tgbResManager := targetgroupbinding.NewDefaultResourceManager(mgr.GetClient(), cloud.ELBV2(), ctrl.Log) + tgbResManager := targetgroupbinding.NewDefaultResourceManager(mgr.GetClient(), cloud.ELBV2(), + podENIResolver, nodeENIResolver, sgManager, sgReconciler, cloud.VpcID(), k8sClusterName, ctrl.Log) + + subnetResolver := networking.NewSubnetsResolver(cloud.EC2(), cloud.VpcID(), k8sClusterName, ctrl.Log.WithName("subnets-resolver")) + ingGroupReconciler := ingress.NewGroupReconciler(mgr.GetClient(), mgr.GetEventRecorderFor("ingress"), cloud.EC2(), cloud.ELBV2(), + sgManager, sgReconciler, cloud.VpcID(), k8sClusterName, subnetResolver, ctrl.Log) tgbReconciler := elbv2controller.NewTargetGroupBindingReconciler(mgr.GetClient(), mgr.GetFieldIndexer(), finalizerManager, tgbResManager, ctrl.Log.WithName("controllers").WithName("TargetGroupBinding")) svcReconciler := service.NewServiceReconciler( mgr.GetClient(), + cloud.EC2(), cloud.ELBV2(), + sgManager, + sgReconciler, cloud.VpcID(), k8sClusterName, subnetResolver, ctrl.Log.WithName("controllers").WithName("Service")) + if err = ingGroupReconciler.SetupWithManager(mgr); err != nil { + setupLog.Error(err, "unable to create controller", "controller", "Ingress") + os.Exit(1) + } if err := tgbReconciler.SetupWithManager(context.Background(), mgr); err != nil { setupLog.Error(err, "unable to create controller", "controller", "TargetGroupBinding") os.Exit(1) diff --git a/pkg/deploy/ec2/security_group_manager.go b/pkg/deploy/ec2/security_group_manager.go new file mode 100644 index 0000000000..85b8b37dde --- /dev/null +++ b/pkg/deploy/ec2/security_group_manager.go @@ -0,0 +1,195 @@ +package ec2 + +import ( + "context" + "errors" + awssdk "github.com/aws/aws-sdk-go/aws" + ec2sdk "github.com/aws/aws-sdk-go/service/ec2" + "github.com/go-logr/logr" + "k8s.io/apimachinery/pkg/util/sets" + "sigs.k8s.io/aws-alb-ingress-controller/pkg/algorithm" + "sigs.k8s.io/aws-alb-ingress-controller/pkg/aws/services" + "sigs.k8s.io/aws-alb-ingress-controller/pkg/deploy/tagging" + ec2model "sigs.k8s.io/aws-alb-ingress-controller/pkg/model/ec2" + "sigs.k8s.io/aws-alb-ingress-controller/pkg/networking" +) + +// SecurityGroupManager is responsible for create/update/delete SecurityGroup resources. +type SecurityGroupManager interface { + Create(ctx context.Context, resSG *ec2model.SecurityGroup) (ec2model.SecurityGroupStatus, error) + + Update(ctx context.Context, resSG *ec2model.SecurityGroup, sdkSG networking.SecurityGroupInfo) (ec2model.SecurityGroupStatus, error) + + Delete(ctx context.Context, sdkSG networking.SecurityGroupInfo) error +} + +// NewDefaultSecurityGroupManager constructs new defaultSecurityGroupManager. +func NewDefaultSecurityGroupManager(ec2Client services.EC2, taggingProvider tagging.Provider, networkingSGReconciler networking.SecurityGroupReconciler, vpcID string, logger logr.Logger) *defaultSecurityGroupManager { + return &defaultSecurityGroupManager{ + ec2Client: ec2Client, + taggingProvider: taggingProvider, + networkingSGReconciler: networkingSGReconciler, + vpcID: vpcID, + logger: logger, + } +} + +// default implementation for SecurityGroupManager. +type defaultSecurityGroupManager struct { + ec2Client services.EC2 + taggingProvider tagging.Provider + networkingSGReconciler networking.SecurityGroupReconciler + vpcID string + logger logr.Logger +} + +func (m *defaultSecurityGroupManager) Create(ctx context.Context, resSG *ec2model.SecurityGroup) (ec2model.SecurityGroupStatus, error) { + sgTags := m.taggingProvider.ResourceTags(resSG.Stack(), resSG, resSG.Spec.Tags) + sdkTags := convertTagsToSDKTags(sgTags) + permissionInfos, err := buildIPPermissionInfos(resSG.Spec.Ingress) + if err != nil { + return ec2model.SecurityGroupStatus{}, err + } + + req := &ec2sdk.CreateSecurityGroupInput{ + VpcId: awssdk.String(m.vpcID), + GroupName: awssdk.String(resSG.Spec.GroupName), + Description: awssdk.String(resSG.Spec.Description), + TagSpecifications: []*ec2sdk.TagSpecification{ + { + ResourceType: awssdk.String("security-group"), + Tags: sdkTags, + }, + }, + } + m.logger.Info("creating securityGroup", + "resourceID", resSG.ID()) + resp, err := m.ec2Client.CreateSecurityGroupWithContext(ctx, req) + if err != nil { + return ec2model.SecurityGroupStatus{}, err + } + sgID := awssdk.StringValue(resp.GroupId) + m.logger.Info("created securityGroup", + "resourceID", resSG.ID(), + "securityGroupID", sgID) + + if err := m.networkingSGReconciler.ReconcileIngress(ctx, sgID, permissionInfos); err != nil { + return ec2model.SecurityGroupStatus{}, err + } + + return ec2model.SecurityGroupStatus{ + GroupID: sgID, + }, nil +} + +func (m *defaultSecurityGroupManager) Update(ctx context.Context, resSG *ec2model.SecurityGroup, sdkSG networking.SecurityGroupInfo) (ec2model.SecurityGroupStatus, error) { + permissionInfos, err := buildIPPermissionInfos(resSG.Spec.Ingress) + if err != nil { + return ec2model.SecurityGroupStatus{}, err + } + if err := m.updateSDKSecurityGroupGroupWithTags(ctx, resSG, sdkSG); err != nil { + return ec2model.SecurityGroupStatus{}, err + } + if err := m.networkingSGReconciler.ReconcileIngress(ctx, sdkSG.SecurityGroupID, permissionInfos); err != nil { + return ec2model.SecurityGroupStatus{}, err + } + return ec2model.SecurityGroupStatus{ + GroupID: sdkSG.SecurityGroupID, + }, nil +} + +func (m *defaultSecurityGroupManager) Delete(ctx context.Context, sdkSG networking.SecurityGroupInfo) error { + req := &ec2sdk.DeleteSecurityGroupInput{ + GroupId: awssdk.String(sdkSG.SecurityGroupID), + } + m.logger.Info("deleting securityGroup", + "securityGroupID", sdkSG.SecurityGroupID) + if _, err := m.ec2Client.DeleteSecurityGroupWithContext(ctx, req); err != nil { + return err + } + m.logger.Info("deleted securityGroup", + "securityGroupID", sdkSG.SecurityGroupID) + return nil +} + +func (m *defaultSecurityGroupManager) updateSDKSecurityGroupGroupWithTags(ctx context.Context, resSG *ec2model.SecurityGroup, sdkSG networking.SecurityGroupInfo) error { + desiredTags := m.taggingProvider.ResourceTags(resSG.Stack(), resSG, resSG.Spec.Tags) + tagsToUpdate, tagsToRemove := algorithm.DiffStringMap(desiredTags, sdkSG.Tags) + if len(tagsToUpdate) > 0 { + req := &ec2sdk.CreateTagsInput{ + Resources: []*string{awssdk.String(sdkSG.SecurityGroupID)}, + Tags: convertTagsToSDKTags(tagsToUpdate), + } + + m.logger.Info("adding securityGroup tags", + "securityGroupID", sdkSG.SecurityGroupID, + "change", tagsToUpdate) + if _, err := m.ec2Client.CreateTagsWithContext(ctx, req); err != nil { + return err + } + m.logger.Info("added securityGroup tags", + "securityGroupID", sdkSG.SecurityGroupID) + } + + if len(tagsToRemove) > 0 { + req := &ec2sdk.DeleteTagsInput{ + Resources: []*string{awssdk.String(sdkSG.SecurityGroupID)}, + Tags: convertTagsToSDKTags(tagsToRemove), + } + + m.logger.Info("removing securityGroup tags", + "securityGroupID", sdkSG.SecurityGroupID, + "change", tagsToRemove) + if _, err := m.ec2Client.DeleteTagsWithContext(ctx, req); err != nil { + return err + } + m.logger.Info("removed securityGroup tags", + "securityGroupID", sdkSG.SecurityGroupID) + } + return nil +} + +func buildIPPermissionInfos(permissions []ec2model.IPPermission) ([]networking.IPPermissionInfo, error) { + permissionInfos := make([]networking.IPPermissionInfo, 0, len(permissions)) + for _, permission := range permissions { + permissionInfo, err := buildIPPermissionInfo(permission) + if err != nil { + return nil, err + } + permissionInfos = append(permissionInfos, permissionInfo) + } + return permissionInfos, nil +} + +func buildIPPermissionInfo(permission ec2model.IPPermission) (networking.IPPermissionInfo, error) { + protocol := permission.IPProtocol + if len(permission.IPRanges) == 1 { + labels := networking.NewIPPermissionLabelsForRawDescription(permission.IPRanges[0].Description) + return networking.NewCIDRIPPermission(protocol, permission.FromPort, permission.ToPort, permission.IPRanges[0].CIDRIP, labels), nil + } + if len(permission.IPV6Range) == 1 { + labels := networking.NewIPPermissionLabelsForRawDescription(permission.IPV6Range[0].Description) + return networking.NewCIDRIPPermission(protocol, permission.FromPort, permission.ToPort, permission.IPV6Range[0].CIDRIPv6, labels), nil + } + if len(permission.UserIDGroupPairs) == 1 { + labels := networking.NewIPPermissionLabelsForRawDescription(permission.UserIDGroupPairs[0].Description) + return networking.NewGroupIDIPPermission(protocol, permission.FromPort, permission.ToPort, permission.UserIDGroupPairs[0].GroupID, labels), nil + } + return networking.IPPermissionInfo{}, errors.New("invalid ipPermission") +} + +// convert tags into AWS SDK tag presentation. +func convertTagsToSDKTags(tags map[string]string) []*ec2sdk.Tag { + if len(tags) == 0 { + return nil + } + sdkTags := make([]*ec2sdk.Tag, 0, len(tags)) + + for _, key := range sets.StringKeySet(tags).List() { + sdkTags = append(sdkTags, &ec2sdk.Tag{ + Key: awssdk.String(key), + Value: awssdk.String(tags[key]), + }) + } + return sdkTags +} diff --git a/pkg/deploy/ec2/security_group_synthesizer.go b/pkg/deploy/ec2/security_group_synthesizer.go new file mode 100644 index 0000000000..e3b348a106 --- /dev/null +++ b/pkg/deploy/ec2/security_group_synthesizer.go @@ -0,0 +1,174 @@ +package ec2 + +import ( + "context" + "fmt" + awssdk "github.com/aws/aws-sdk-go/aws" + ec2sdk "github.com/aws/aws-sdk-go/service/ec2" + "github.com/go-logr/logr" + "github.com/pkg/errors" + "k8s.io/apimachinery/pkg/util/sets" + "sigs.k8s.io/aws-alb-ingress-controller/pkg/aws/services" + "sigs.k8s.io/aws-alb-ingress-controller/pkg/deploy/tagging" + "sigs.k8s.io/aws-alb-ingress-controller/pkg/model/core" + ec2model "sigs.k8s.io/aws-alb-ingress-controller/pkg/model/ec2" + "sigs.k8s.io/aws-alb-ingress-controller/pkg/networking" +) + +// NewListenerRuleSynthesizer constructs new listenerRuleSynthesizer. +func NewSecurityGroupSynthesizer(ec2Client services.EC2, taggingProvider tagging.Provider, + networkingSGManager networking.SecurityGroupManager, sgManager SecurityGroupManager, + vpcID string, logger logr.Logger, stack core.Stack) *securityGroupSynthesizer { + return &securityGroupSynthesizer{ + ec2Client: ec2Client, + taggingProvider: taggingProvider, + networkingSGManager: networkingSGManager, + sgManager: sgManager, + vpcID: vpcID, + logger: logger, + stack: stack, + unmatchedSDKSGs: nil, + } +} + +type securityGroupSynthesizer struct { + ec2Client services.EC2 + taggingProvider tagging.Provider + networkingSGManager networking.SecurityGroupManager + sgManager SecurityGroupManager + vpcID string + logger logr.Logger + + stack core.Stack + unmatchedSDKSGs []networking.SecurityGroupInfo +} + +func (s *securityGroupSynthesizer) Synthesize(ctx context.Context) error { + var resSGs []*ec2model.SecurityGroup + s.stack.ListResources(&resSGs) + sdkSGs, err := s.findSDKSecurityGroups(ctx) + if err != nil { + return err + } + matchedResAndSDKSGs, unmatchedResSGs, unmatchedSDKSGs, err := matchResAndSDKSecurityGroups(resSGs, sdkSGs, s.taggingProvider.ResourceIDTagKey()) + if err != nil { + return err + } + + // For SecurityGroup, we delete unmatched ones during post synthesize. + s.unmatchedSDKSGs = unmatchedSDKSGs + + for _, resSG := range unmatchedResSGs { + sgStatus, err := s.sgManager.Create(ctx, resSG) + if err != nil { + return err + } + resSG.SetStatus(sgStatus) + } + for _, resAndSDKSG := range matchedResAndSDKSGs { + sgStatus, err := s.sgManager.Update(ctx, resAndSDKSG.resSG, resAndSDKSG.sdkSG) + if err != nil { + return err + } + resAndSDKSG.resSG.SetStatus(sgStatus) + } + return nil +} + +func (s *securityGroupSynthesizer) PostSynthesize(ctx context.Context) error { + for _, sdkSG := range s.unmatchedSDKSGs { + if err := s.sgManager.Delete(ctx, sdkSG); err != nil { + return err + } + } + return nil +} + +// findSDKSecurityGroups will find all AWS SecurityGroups created for stack. +func (s *securityGroupSynthesizer) findSDKSecurityGroups(ctx context.Context) ([]networking.SecurityGroupInfo, error) { + req := &ec2sdk.DescribeSecurityGroupsInput{ + Filters: []*ec2sdk.Filter{ + { + Name: awssdk.String("vpc-id"), + Values: awssdk.StringSlice([]string{s.vpcID}), + }, + }, + } + stackTags := s.taggingProvider.StackTags(s.stack) + for tagKey, tagValue := range stackTags { + tagFilterName := fmt.Sprintf("tag:%v", tagKey) + req.Filters = append(req.Filters, &ec2sdk.Filter{ + Name: awssdk.String(tagFilterName), + Values: awssdk.StringSlice([]string{tagValue}), + }) + } + sdkSGsByID, err := s.networkingSGManager.FetchSGInfosByRequest(ctx, req) + if err != nil { + return nil, err + } + sdkSGs := make([]networking.SecurityGroupInfo, 0, len(sdkSGsByID)) + for _, sdkSG := range sdkSGsByID { + sdkSGs = append(sdkSGs, sdkSG) + } + return sdkSGs, nil +} + +type resAndSDKSecurityGroupPair struct { + resSG *ec2model.SecurityGroup + sdkSG networking.SecurityGroupInfo +} + +func matchResAndSDKSecurityGroups(resSGs []*ec2model.SecurityGroup, sdkSGs []networking.SecurityGroupInfo, + resourceIDTagKey string) ([]resAndSDKSecurityGroupPair, []*ec2model.SecurityGroup, []networking.SecurityGroupInfo, error) { + var matchedResAndSDKSGs []resAndSDKSecurityGroupPair + var unmatchedResSGs []*ec2model.SecurityGroup + var unmatchedSDKSGs []networking.SecurityGroupInfo + + resSGsByID := mapResSecurityGroupByResourceID(resSGs) + sdkSGsByID, err := mapSDKSecurityGroupByResourceID(sdkSGs, resourceIDTagKey) + if err != nil { + return nil, nil, nil, err + } + + resSGIDs := sets.StringKeySet(resSGsByID) + sdkSGIDs := sets.StringKeySet(sdkSGsByID) + for _, resID := range resSGIDs.Intersection(sdkSGIDs).List() { + resSG := resSGsByID[resID] + sdkSGs := sdkSGsByID[resID] + matchedResAndSDKSGs = append(matchedResAndSDKSGs, resAndSDKSecurityGroupPair{ + resSG: resSG, + sdkSG: sdkSGs[0], + }) + for _, sdkSG := range sdkSGs[1:] { + unmatchedSDKSGs = append(unmatchedSDKSGs, sdkSG) + } + } + for _, resID := range resSGIDs.Difference(sdkSGIDs).List() { + unmatchedResSGs = append(unmatchedResSGs, resSGsByID[resID]) + } + for _, resID := range sdkSGIDs.Difference(resSGIDs).List() { + unmatchedSDKSGs = append(unmatchedSDKSGs, sdkSGsByID[resID]...) + } + + return matchedResAndSDKSGs, unmatchedResSGs, unmatchedSDKSGs, nil +} + +func mapResSecurityGroupByResourceID(resSGs []*ec2model.SecurityGroup) map[string]*ec2model.SecurityGroup { + resSGsByID := make(map[string]*ec2model.SecurityGroup, len(resSGs)) + for _, resSG := range resSGs { + resSGsByID[resSG.ID()] = resSG + } + return resSGsByID +} + +func mapSDKSecurityGroupByResourceID(sdkSGs []networking.SecurityGroupInfo, resourceIDTagKey string) (map[string][]networking.SecurityGroupInfo, error) { + sdkSGsByID := make(map[string][]networking.SecurityGroupInfo, len(sdkSGs)) + for _, sdkSG := range sdkSGs { + resourceID, ok := sdkSG.Tags[resourceIDTagKey] + if !ok { + return nil, errors.Errorf("unexpected securityGroup with no resourceID: %v", sdkSG.SecurityGroupID) + } + sdkSGsByID[resourceID] = append(sdkSGsByID[resourceID], sdkSG) + } + return sdkSGsByID, nil +} diff --git a/pkg/deploy/stack_deployer.go b/pkg/deploy/stack_deployer.go index a4e79c106f..908692fb35 100644 --- a/pkg/deploy/stack_deployer.go +++ b/pkg/deploy/stack_deployer.go @@ -4,9 +4,11 @@ import ( "context" "github.com/go-logr/logr" "sigs.k8s.io/aws-alb-ingress-controller/pkg/aws/services" + "sigs.k8s.io/aws-alb-ingress-controller/pkg/deploy/ec2" "sigs.k8s.io/aws-alb-ingress-controller/pkg/deploy/elbv2" "sigs.k8s.io/aws-alb-ingress-controller/pkg/deploy/tagging" "sigs.k8s.io/aws-alb-ingress-controller/pkg/model/core" + "sigs.k8s.io/aws-alb-ingress-controller/pkg/networking" "sigs.k8s.io/controller-runtime/pkg/client" ) @@ -17,19 +19,26 @@ type StackDeployer interface { } // NewDefaultStackDeployer constructs new defaultStackDeployer. -func NewDefaultStackDeployer(k8sClient client.Client, elbv2Client services.ELBV2, vpcID string, clusterName string, tagPrefix string, logger logr.Logger) *defaultStackDeployer { +func NewDefaultStackDeployer(k8sClient client.Client, ec2Client services.EC2, elbv2Client services.ELBV2, + networkingSGManager networking.SecurityGroupManager, networkingSGReconciler networking.SecurityGroupReconciler, + vpcID string, clusterName string, tagPrefix string, logger logr.Logger) *defaultStackDeployer { taggingProvider := tagging.NewDefaultProvider(tagPrefix, clusterName) elbv2TaggingManager := elbv2.NewDefaultTaggingManager(elbv2Client, logger) + return &defaultStackDeployer{ k8sClient: k8sClient, + ec2Client: ec2Client, elbv2Client: elbv2Client, taggingProvider: taggingProvider, + networkingSGManager: networkingSGManager, + ec2SGManager: ec2.NewDefaultSecurityGroupManager(ec2Client, taggingProvider, networkingSGReconciler, vpcID, logger), elbv2TaggingManager: elbv2TaggingManager, elbv2LBManager: elbv2.NewDefaultLoadBalancerManager(elbv2Client, taggingProvider, elbv2TaggingManager, logger), elbv2LSManager: elbv2.NewDefaultListenerManager(elbv2Client, logger), elbv2LRManager: elbv2.NewDefaultListenerRuleManager(elbv2Client, logger), elbv2TGManager: elbv2.NewDefaultTargetGroupManager(elbv2Client, taggingProvider, elbv2TaggingManager, vpcID, logger), elbv2TGBManager: elbv2.NewDefaultTargetGroupBindingManager(k8sClient, taggingProvider, logger), + vpcID: vpcID, logger: logger, } } @@ -39,14 +48,18 @@ var _ StackDeployer = &defaultStackDeployer{} // defaultStackDeployer is the default implementation for StackDeployer type defaultStackDeployer struct { k8sClient client.Client + ec2Client services.EC2 elbv2Client services.ELBV2 taggingProvider tagging.Provider + networkingSGManager networking.SecurityGroupManager + ec2SGManager ec2.SecurityGroupManager elbv2TaggingManager elbv2.TaggingManager elbv2LBManager elbv2.LoadBalancerManager elbv2LSManager elbv2.ListenerManager elbv2LRManager elbv2.ListenerRuleManager elbv2TGManager elbv2.TargetGroupManager elbv2TGBManager elbv2.TargetGroupBindingManager + vpcID string logger logr.Logger } @@ -59,6 +72,7 @@ type ResourceSynthesizer interface { // Deploy a resource stack. func (d *defaultStackDeployer) Deploy(ctx context.Context, stack core.Stack) error { synthesizers := []ResourceSynthesizer{ + ec2.NewSecurityGroupSynthesizer(d.ec2Client, d.taggingProvider, d.networkingSGManager, d.ec2SGManager, d.vpcID, d.logger, stack), elbv2.NewTargetGroupSynthesizer(d.elbv2Client, d.taggingProvider, d.elbv2TaggingManager, d.elbv2TGManager, d.logger, stack), elbv2.NewTargetGroupBindingSynthesizer(d.k8sClient, d.taggingProvider, d.elbv2TGBManager, d.logger, stack), elbv2.NewLoadBalancerSynthesizer(d.elbv2Client, d.taggingProvider, d.elbv2TaggingManager, d.elbv2LBManager, d.logger, stack), diff --git a/pkg/equality/ec2/compare_option_for_ip_permission.go b/pkg/equality/ec2/compare_option_for_ip_permission.go new file mode 100644 index 0000000000..bae642fdee --- /dev/null +++ b/pkg/equality/ec2/compare_option_for_ip_permission.go @@ -0,0 +1,80 @@ +package ec2 + +import ( + "github.com/aws/aws-sdk-go/aws" + ec2sdk "github.com/aws/aws-sdk-go/service/ec2" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "sigs.k8s.io/aws-alb-ingress-controller/pkg/equality" +) + +// CompareOptionForIPRange returns the compare option for ec2 IPRange +func CompareOptionForIPRange() cmp.Option { + return equality.IgnoreOtherFields(ec2sdk.IpRange{}, "CidrIp") +} + +// CompareOptionForIPRanges returns the compare option for ec2 IPRange slice +func CompareOptionForIPRanges() cmp.Options { + return cmp.Options{ + cmpopts.SortSlices(func(lhs *ec2sdk.IpRange, rhs *ec2sdk.IpRange) bool { + return aws.StringValue(lhs.CidrIp) < aws.StringValue(rhs.CidrIp) + }), + CompareOptionForIPRange(), + } +} + +// CompareOptionForIPv6Range returns the compare option for ec2 IPv6Range +func CompareOptionForIPv6Range() cmp.Option { + return equality.IgnoreOtherFields(ec2sdk.Ipv6Range{}, "CidrIpv6") +} + +// CompareOptionForIPV6Ranges returns the compare option for ec2 IPv6Range slice +func CompareOptionForIPv6Ranges() cmp.Options { + return cmp.Options{ + cmpopts.SortSlices(func(lhs *ec2sdk.Ipv6Range, rhs *ec2sdk.Ipv6Range) bool { + return aws.StringValue(lhs.CidrIpv6) < aws.StringValue(rhs.CidrIpv6) + }), + CompareOptionForIPv6Range(), + } +} + +// CompareOptionForIPV6Ranges returns the compare option for ec2 UserIDGroupPair +func CompareOptionForUserIDGroupPair() cmp.Option { + return equality.IgnoreOtherFields(ec2sdk.UserIdGroupPair{}, "GroupId") +} + +// CompareOptionForIPV6Ranges returns the compare option for ec2 UserIDGroupPair slice +func CompareOptionForUserIDGroupPairs() cmp.Option { + return cmp.Options{ + cmpopts.SortSlices(func(lhs *ec2sdk.UserIdGroupPair, rhs *ec2sdk.UserIdGroupPair) bool { + return aws.StringValue(lhs.GroupId) < aws.StringValue(rhs.GroupId) + }), + CompareOptionForUserIDGroupPair(), + } +} + +// CompareOptionForIPV6Ranges returns the compare option for ec2 prefixListId +func CompareOptionForPrefixListId() cmp.Option { + return equality.IgnoreOtherFields(ec2sdk.PrefixListId{}, "PrefixListId") +} + +// CompareOptionForIPV6Ranges returns the compare option for ec2 prefixListId slice +func CompareOptionForPrefixListIds() cmp.Option { + return cmp.Options{ + cmpopts.SortSlices(func(lhs *ec2sdk.PrefixListId, rhs *ec2sdk.PrefixListId) bool { + return aws.StringValue(lhs.PrefixListId) < aws.StringValue(rhs.PrefixListId) + }), + CompareOptionForPrefixListId(), + } +} + +// CompareOptionForIPPermission returns the compare option for ec2 IPPermission object. +func CompareOptionForIPPermission() cmp.Option { + return cmp.Options{ + cmpopts.EquateEmpty(), + CompareOptionForIPRanges(), + CompareOptionForIPv6Ranges(), + CompareOptionForUserIDGroupPairs(), + CompareOptionForPrefixListIds(), + } +} diff --git a/pkg/equality/ignore_other_fields.go b/pkg/equality/ignore_other_fields.go new file mode 100644 index 0000000000..6196e1e39b --- /dev/null +++ b/pkg/equality/ignore_other_fields.go @@ -0,0 +1,23 @@ +package equality + +import ( + "github.com/google/go-cmp/cmp" + "k8s.io/apimachinery/pkg/util/sets" + "reflect" +) + +// IgnoreOtherFields is option that only compare specific structures fields. +func IgnoreOtherFields(typ interface{}, fields ...string) cmp.Option { + t := reflect.TypeOf(typ) + fieldsSet := sets.NewString(fields...) + return cmp.FilterPath(func(path cmp.Path) bool { + if len(path) < 2 || path.Index(-2).Type() != t { + return false + } + ps, ok := path.Last().(cmp.StructField) + if !ok || fieldsSet.Has(ps.Name()) { + return false + } + return true + }, cmp.Ignore()) +} diff --git a/pkg/equality/ignore_other_fields_test.go b/pkg/equality/ignore_other_fields_test.go new file mode 100644 index 0000000000..9adabaf4c6 --- /dev/null +++ b/pkg/equality/ignore_other_fields_test.go @@ -0,0 +1,125 @@ +package equality + +import ( + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestIgnoreOtherFields(t *testing.T) { + type testStruct struct { + StrFieldA string + StrFieldB string + StrFieldC string + } + + type args struct { + typ interface{} + fields []string + } + tests := []struct { + name string + args args + argLeft testStruct + argRight testStruct + wantEquals bool + }{ + { + name: "only consider strFieldA - all fields equals", + args: args{ + typ: testStruct{}, + fields: []string{"StrFieldA"}, + }, + argLeft: testStruct{ + StrFieldA: "A0", + StrFieldB: "B0", + StrFieldC: "C0", + }, + argRight: testStruct{ + StrFieldA: "A0", + StrFieldB: "B0", + StrFieldC: "C0", + }, + wantEquals: true, + }, + { + name: "only consider strFieldA - equals if StrFieldA equals", + args: args{ + typ: testStruct{}, + fields: []string{"StrFieldA"}, + }, + argLeft: testStruct{ + StrFieldA: "A0", + StrFieldB: "B0", + StrFieldC: "C0", + }, + argRight: testStruct{ + StrFieldA: "A0", + StrFieldB: "B1", + StrFieldC: "C1", + }, + wantEquals: true, + }, + { + name: "only consider strFieldA - not equals if StrFieldA not equals", + args: args{ + typ: testStruct{}, + fields: []string{"StrFieldA"}, + }, + argLeft: testStruct{ + StrFieldA: "A0", + StrFieldB: "B0", + StrFieldC: "C0", + }, + argRight: testStruct{ + StrFieldA: "A1", + StrFieldB: "B0", + StrFieldC: "C0", + }, + wantEquals: false, + }, + { + name: "only consider strFieldA and strFieldC - equals if StrFieldA and strFieldC equals", + args: args{ + typ: testStruct{}, + fields: []string{"StrFieldA", "StrFieldC"}, + }, + argLeft: testStruct{ + StrFieldA: "A0", + StrFieldB: "B0", + StrFieldC: "C0", + }, + argRight: testStruct{ + StrFieldA: "A0", + StrFieldB: "B1", + StrFieldC: "C0", + }, + wantEquals: true, + }, + { + name: "only consider strFieldA and strFieldC - not equals if StrFieldA or strFieldC not equals", + args: args{ + typ: testStruct{}, + fields: []string{"StrFieldA", "StrFieldC"}, + }, + argLeft: testStruct{ + StrFieldA: "A0", + StrFieldB: "B0", + StrFieldC: "C0", + }, + argRight: testStruct{ + StrFieldA: "A0", + StrFieldB: "B0", + StrFieldC: "C1", + }, + wantEquals: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts := IgnoreOtherFields(tt.args.typ, tt.args.fields...) + gotEquals := cmp.Equal(tt.argLeft, tt.argRight, opts) + assert.Equal(t, tt.wantEquals, gotEquals) + }) + } +} diff --git a/pkg/model/ec2/security_group.go b/pkg/model/ec2/security_group.go index 3eaa7e3577..c746e6b5b4 100644 --- a/pkg/model/ec2/security_group.go +++ b/pkg/model/ec2/security_group.go @@ -54,8 +54,7 @@ type SecurityGroupSpec struct { GroupName string `json:"groupName"` // A description for the security group. - // +optional - Description *string `json:"description,omitempty"` + Description string `json:"description"` // +optional Tags map[string]string `json:"tags,omitempty"` diff --git a/pkg/networking/security_group_info.go b/pkg/networking/security_group_info.go index 82e6b16885..94964b7a7d 100644 --- a/pkg/networking/security_group_info.go +++ b/pkg/networking/security_group_info.go @@ -1,9 +1,12 @@ package networking import ( + "fmt" awssdk "github.com/aws/aws-sdk-go/aws" ec2sdk "github.com/aws/aws-sdk-go/service/ec2" + "k8s.io/apimachinery/pkg/util/sets" "regexp" + "strings" ) const ( @@ -13,11 +16,14 @@ const ( // SecurityGroupInfo wraps necessary information about a SecurityGroup. type SecurityGroupInfo struct { - // securityGroup's ID. + // SecurityGroup's ID. SecurityGroupID string // Ingress permission for securityGroup. - IngressPermissions []IPPermissionInfo + Ingress []IPPermissionInfo + + // Tags for securityGroup. + Tags map[string]string } type IPPermissionInfo struct { @@ -29,43 +35,124 @@ type IPPermissionInfo struct { Labels map[string]string } -func buildSecurityGroupInfo(sdkSG *ec2sdk.SecurityGroup) SecurityGroupInfo { +// NewRawSecurityGroupInfo constructs new SecurityGroupInfo with raw ec2SDK's SecurityGroup object. +func NewRawSecurityGroupInfo(sdkSG *ec2sdk.SecurityGroup) SecurityGroupInfo { sgID := awssdk.StringValue(sdkSG.GroupId) - var ingressPermissions []IPPermissionInfo + var ingress []IPPermissionInfo for _, sdkPermission := range sdkSG.IpPermissions { for _, expandedPermission := range expandSDKIPPermission(*sdkPermission) { - ingressPermissions = append(ingressPermissions, buildIPPermissionInfo(expandedPermission)) + ingress = append(ingress, NewRawIPPermission(expandedPermission)) } } + tags := buildSecurityGroupTags(sdkSG) return SecurityGroupInfo{ - SecurityGroupID: sgID, - IngressPermissions: ingressPermissions, + SecurityGroupID: sgID, + Ingress: ingress, + Tags: tags, + } +} + +// NewCIDRIPPermission constructs new IPPermissionInfo with CIDR configuration. +func NewCIDRIPPermission(ipProtocol string, fromPort *int64, toPort *int64, cidr string, labels map[string]string) IPPermissionInfo { + description := buildIPPermissionDescriptionForLabels(labels) + return IPPermissionInfo{ + Permission: ec2sdk.IpPermission{ + IpProtocol: awssdk.String(ipProtocol), + FromPort: fromPort, + ToPort: toPort, + IpRanges: []*ec2sdk.IpRange{ + { + CidrIp: awssdk.String(cidr), + Description: awssdk.String(description), + }, + }, + }, + Labels: labels, + } +} + +// NewCIDRv6IPPermission constructs new IPPermissionInfo with CIDRv6 configuration. +func NewCIDRv6IPPermission(ipProtocol string, fromPort *int64, toPort *int64, cidrV6 string, labels map[string]string) IPPermissionInfo { + description := buildIPPermissionDescriptionForLabels(labels) + return IPPermissionInfo{ + Permission: ec2sdk.IpPermission{ + IpProtocol: awssdk.String(ipProtocol), + FromPort: fromPort, + ToPort: toPort, + Ipv6Ranges: []*ec2sdk.Ipv6Range{ + { + CidrIpv6: awssdk.String(cidrV6), + Description: awssdk.String(description), + }, + }, + }, + Labels: labels, + } +} + +// NewCIDRv6IPPermission constructs new IPPermissionInfo with groupID configuration. +func NewGroupIDIPPermission(ipProtocol string, fromPort *int64, toPort *int64, groupID string, labels map[string]string) IPPermissionInfo { + description := buildIPPermissionDescriptionForLabels(labels) + return IPPermissionInfo{ + Permission: ec2sdk.IpPermission{ + IpProtocol: awssdk.String(ipProtocol), + FromPort: fromPort, + ToPort: toPort, + UserIdGroupPairs: []*ec2sdk.UserIdGroupPair{ + { + GroupId: awssdk.String(groupID), + Description: awssdk.String(description), + }, + }, + }, + Labels: labels, + } +} + +// NewPrefixListIDPermission constructs new IPPermissionInfo with prefixListID configuration +func NewPrefixListIDPermission(ipProtocol string, fromPort *int64, toPort *int64, prefixListID string, labels map[string]string) IPPermissionInfo { + description := buildIPPermissionDescriptionForLabels(labels) + return IPPermissionInfo{ + Permission: ec2sdk.IpPermission{ + IpProtocol: awssdk.String(ipProtocol), + FromPort: fromPort, + ToPort: toPort, + PrefixListIds: []*ec2sdk.PrefixListId{ + { + PrefixListId: awssdk.String(prefixListID), + Description: awssdk.String(description), + }, + }, + }, + Labels: labels, } } -func buildIPPermissionInfo(sdkPermission ec2sdk.IpPermission) IPPermissionInfo { +// NewRawIPPermission constructs new IPPermissionInfo with raw ec2SDK's IpPermission object. +// Note: this IpPermission should be expanded(i.e. only contains one source configuration) +func NewRawIPPermission(sdkPermission ec2sdk.IpPermission) IPPermissionInfo { if len(sdkPermission.IpRanges) == 1 { return IPPermissionInfo{ Permission: sdkPermission, - Labels: buildIPPermissionLabelForDescription(awssdk.StringValue(sdkPermission.IpRanges[0].Description)), + Labels: buildIPPermissionLabelsForDescription(awssdk.StringValue(sdkPermission.IpRanges[0].Description)), } } if len(sdkPermission.Ipv6Ranges) == 1 { return IPPermissionInfo{ Permission: sdkPermission, - Labels: buildIPPermissionLabelForDescription(awssdk.StringValue(sdkPermission.Ipv6Ranges[0].Description)), + Labels: buildIPPermissionLabelsForDescription(awssdk.StringValue(sdkPermission.Ipv6Ranges[0].Description)), } } if len(sdkPermission.PrefixListIds) == 1 { return IPPermissionInfo{ Permission: sdkPermission, - Labels: buildIPPermissionLabelForDescription(awssdk.StringValue(sdkPermission.PrefixListIds[0].Description)), + Labels: buildIPPermissionLabelsForDescription(awssdk.StringValue(sdkPermission.PrefixListIds[0].Description)), } } if len(sdkPermission.UserIdGroupPairs) == 1 { return IPPermissionInfo{ Permission: sdkPermission, - Labels: buildIPPermissionLabelForDescription(awssdk.StringValue(sdkPermission.UserIdGroupPairs[0].Description)), + Labels: buildIPPermissionLabelsForDescription(awssdk.StringValue(sdkPermission.UserIdGroupPairs[0].Description)), } } return IPPermissionInfo{ @@ -74,6 +161,20 @@ func buildIPPermissionInfo(sdkPermission ec2sdk.IpPermission) IPPermissionInfo { } } +// NewIPPermissionLabelsForRawDescription constructs permission labels from description only. +func NewIPPermissionLabelsForRawDescription(description string) map[string]string { + return map[string]string{labelKeyRawDescription: description} +} + +// buildSecurityGroupTags generates the tags for securityGroup. +func buildSecurityGroupTags(sdkSG *ec2sdk.SecurityGroup) map[string]string { + sgTags := make(map[string]string, len(sdkSG.Tags)) + for _, tag := range sdkSG.Tags { + sgTags[awssdk.StringValue(tag.Key)] = awssdk.StringValue(tag.Value) + } + return sgTags +} + // expandSDKIPPermission will expand the IPPermission so that each permission only contain single entry. // EC2 api automatically group IPPermissions, so we need to expand first before further processing. func expandSDKIPPermission(sdkPermission ec2sdk.IpPermission) []ec2sdk.IpPermission { @@ -95,6 +196,7 @@ func expandSDKIPPermission(sdkPermission ec2sdk.IpPermission) []ec2sdk.IpPermiss perm.Ipv6Ranges = []*ec2sdk.Ipv6Range{ipRange} expandedPermissions = append(expandedPermissions, perm) } + for _, prefixListID := range sdkPermission.PrefixListIds { perm := base perm.PrefixListIds = []*ec2sdk.PrefixListId{prefixListID} @@ -115,11 +217,26 @@ func expandSDKIPPermission(sdkPermission ec2sdk.IpPermission) []ec2sdk.IpPermiss var commaSeparatedKVPairPattern = regexp.MustCompile(`(?P[^\s,=]+)=(?P[^\s,=]+)(?:,|$)`) -// buildIPPermissionLabelForDescription constructs a set of labels parsed from IPPermission description -func buildIPPermissionLabelForDescription(description string) map[string]string { +// buildIPPermissionLabelsForDescription computes labels parsed from IPPermission description +func buildIPPermissionLabelsForDescription(description string) map[string]string { labels := map[string]string{labelKeyRawDescription: description} for _, groups := range commaSeparatedKVPairPattern.FindAllStringSubmatch(description, -1) { labels[groups[1]] = groups[2] } return labels } + +// buildIPPermissionDescriptionForLabels compute a description from labels. +func buildIPPermissionDescriptionForLabels(labels map[string]string) string { + if rawDescription, exists := labels[labelKeyRawDescription]; exists { + return rawDescription + } + + kvPairs := make([]string, 0, len(labels)) + sortedLabelKeys := sets.StringKeySet(labels).List() + for _, key := range sortedLabelKeys { + value := labels[key] + kvPairs = append(kvPairs, fmt.Sprintf("%v=%v", key, value)) + } + return strings.Join(kvPairs, ",") +} diff --git a/pkg/networking/security_group_info_test.go b/pkg/networking/security_group_info_test.go deleted file mode 100644 index 2b7e1ae62a..0000000000 --- a/pkg/networking/security_group_info_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package networking - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func Test_buildIPPermissionLabelForDescription(t *testing.T) { - type args struct { - description string - } - tests := []struct { - name string - args args - want map[string]string - }{ - { - name: "empty description", - args: args{ - description: "", - }, - want: map[string]string{ - labelKeyRawDescription: "", - }, - }, - { - name: "non-empty description", - args: args{ - description: "some-description", - }, - want: map[string]string{ - labelKeyRawDescription: "some-description", - }, - }, - { - name: "single k-v pair description", - args: args{ - description: "key1=value1", - }, - want: map[string]string{ - labelKeyRawDescription: "key1=value1", - "key1": "value1", - }, - }, - { - name: "multiple k-v pair description", - args: args{ - description: "key1=value1,key2=value2", - }, - want: map[string]string{ - labelKeyRawDescription: "key1=value1,key2=value2", - "key1": "value1", - "key2": "value2", - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := buildIPPermissionLabelForDescription(tt.args.description) - assert.Equal(t, tt.want, got) - }) - } -} diff --git a/pkg/networking/security_group_manager.go b/pkg/networking/security_group_manager.go new file mode 100644 index 0000000000..6edfc48d6d --- /dev/null +++ b/pkg/networking/security_group_manager.go @@ -0,0 +1,188 @@ +package networking + +import ( + "context" + awssdk "github.com/aws/aws-sdk-go/aws" + ec2sdk "github.com/aws/aws-sdk-go/service/ec2" + "github.com/go-logr/logr" + "github.com/pkg/errors" + "k8s.io/apimachinery/pkg/util/cache" + "k8s.io/apimachinery/pkg/util/sets" + "sigs.k8s.io/aws-alb-ingress-controller/pkg/aws/services" + "sync" + "time" +) + +const ( + // we cache securityGroup's information by 10 minutes. + defaultSGInfoCacheTTL = 10 * time.Minute +) + +// SecurityGroupManager is an abstraction around EC2's SecurityGroup API. +type SecurityGroupManager interface { + // FetchSGInfosByID will fetch SecurityGroupInfo with SecurityGroup IDs. + FetchSGInfosByID(ctx context.Context, sgIDs ...string) (map[string]SecurityGroupInfo, error) + + // FetchSGInfosByRequest will fetch SecurityGroupInfo with raw DescribeSecurityGroupsInput request. + FetchSGInfosByRequest(ctx context.Context, req *ec2sdk.DescribeSecurityGroupsInput) (map[string]SecurityGroupInfo, error) + + // AuthorizeSGIngress will authorize Ingress permissions to SecurityGroup. + AuthorizeSGIngress(ctx context.Context, sgID string, permissions []IPPermissionInfo) error + + // RevokeSGIngress will revoke Ingress permissions from SecurityGroup. + RevokeSGIngress(ctx context.Context, sgID string, permissions []IPPermissionInfo) error +} + +// NewDefaultSecurityGroupManager constructs new defaultSecurityGroupManager. +func NewDefaultSecurityGroupManager(ec2Client services.EC2, logger logr.Logger) *defaultSecurityGroupManager { + return &defaultSecurityGroupManager{ + ec2Client: ec2Client, + logger: logger, + + sgInfoCache: cache.NewExpiring(), + sgInfoCacheMutex: sync.RWMutex{}, + sgInfoCacheTTL: defaultSGInfoCacheTTL, + } +} + +var _ SecurityGroupManager = &defaultSecurityGroupManager{} + +// default implementation for SecurityGroupManager +type defaultSecurityGroupManager struct { + ec2Client services.EC2 + logger logr.Logger + + sgInfoCache *cache.Expiring + sgInfoCacheMutex sync.RWMutex + sgInfoCacheTTL time.Duration +} + +func (m *defaultSecurityGroupManager) FetchSGInfosByID(ctx context.Context, sgIDs ...string) (map[string]SecurityGroupInfo, error) { + sgInfoByID := m.fetchSGInfosFromCache(sgIDs) + sgIDsSet := sets.NewString(sgIDs...) + fetchedSGIDsSet := sets.StringKeySet(sgInfoByID) + unFetchedSGIDs := sgIDsSet.Difference(fetchedSGIDsSet).List() + if len(unFetchedSGIDs) > 0 { + req := &ec2sdk.DescribeSecurityGroupsInput{ + GroupIds: awssdk.StringSlice(unFetchedSGIDs), + } + sgInfoByIDFromAWS, err := m.fetchSGInfosFromAWS(ctx, req) + if err != nil { + return nil, err + } + m.saveSGInfosToCache(sgInfoByIDFromAWS) + for sgID, sgInfo := range sgInfoByIDFromAWS { + sgInfoByID[sgID] = sgInfo + } + } + + fetchedSGIDsSet = sets.StringKeySet(sgInfoByID) + if !sgIDsSet.Equal(fetchedSGIDsSet) { + return nil, errors.Errorf("couldn't fetch SecurityGroupInfos: %v", sgIDsSet.Difference(fetchedSGIDsSet).List()) + } + return sgInfoByID, nil +} + +func (m *defaultSecurityGroupManager) FetchSGInfosByRequest(ctx context.Context, req *ec2sdk.DescribeSecurityGroupsInput) (map[string]SecurityGroupInfo, error) { + sgInfosByID, err := m.fetchSGInfosFromAWS(ctx, req) + if err != nil { + return nil, err + } + m.saveSGInfosToCache(sgInfosByID) + return sgInfosByID, nil +} + +func (m *defaultSecurityGroupManager) AuthorizeSGIngress(ctx context.Context, sgID string, permissions []IPPermissionInfo) error { + sdkIPPermissions := buildSDKIPPermissions(permissions) + req := &ec2sdk.AuthorizeSecurityGroupIngressInput{ + GroupId: awssdk.String(sgID), + IpPermissions: sdkIPPermissions, + } + m.logger.Info("authorizing securityGroup ingress", + "securityGroupID", sgID, + "permission", sdkIPPermissions) + if _, err := m.ec2Client.AuthorizeSecurityGroupIngressWithContext(ctx, req); err != nil { + return err + } + m.logger.Info("authorized securityGroup ingress", + "securityGroupID", sgID) + + // TODO: ideally we can remember the permissions we granted to save DescribeSecurityGroup API calls. + m.clearSGInfosFromCache(sgID) + return nil +} + +func (m *defaultSecurityGroupManager) RevokeSGIngress(ctx context.Context, sgID string, permissions []IPPermissionInfo) error { + sdkIPPermissions := buildSDKIPPermissions(permissions) + req := &ec2sdk.RevokeSecurityGroupIngressInput{ + GroupId: awssdk.String(sgID), + IpPermissions: sdkIPPermissions, + } + m.logger.Info("revoking securityGroup ingress", + "securityGroupID", sgID, + "permission", sdkIPPermissions) + if _, err := m.ec2Client.RevokeSecurityGroupIngressWithContext(ctx, req); err != nil { + return err + } + m.logger.Info("revoked securityGroup ingress", + "securityGroupID", sgID) + + // TODO: ideally we can remember the permissions we revoked to save DescribeSecurityGroup API calls. + m.clearSGInfosFromCache(sgID) + return nil +} + +func (m *defaultSecurityGroupManager) fetchSGInfosFromCache(sgIDs []string) map[string]SecurityGroupInfo { + m.sgInfoCacheMutex.RLock() + defer m.sgInfoCacheMutex.RUnlock() + + sgInfoByID := make(map[string]SecurityGroupInfo, len(sgIDs)) + for _, sgID := range sgIDs { + if rawCacheItem, exists := m.sgInfoCache.Get(sgID); exists { + sgInfo := rawCacheItem.(SecurityGroupInfo) + sgInfoByID[sgID] = sgInfo + } + } + + return sgInfoByID +} + +func (m *defaultSecurityGroupManager) saveSGInfosToCache(sgInfoByID map[string]SecurityGroupInfo) { + m.sgInfoCacheMutex.Lock() + defer m.sgInfoCacheMutex.Unlock() + + for sgID, sgInfo := range sgInfoByID { + m.sgInfoCache.Set(sgID, sgInfo, m.sgInfoCacheTTL) + } +} + +func (m *defaultSecurityGroupManager) clearSGInfosFromCache(sgID string) { + m.sgInfoCache.Delete(sgID) +} + +func (m *defaultSecurityGroupManager) fetchSGInfosFromAWS(ctx context.Context, req *ec2sdk.DescribeSecurityGroupsInput) (map[string]SecurityGroupInfo, error) { + sgs, err := m.ec2Client.DescribeSecurityGroupsAsList(ctx, req) + if err != nil { + return nil, err + } + sgInfoByID := make(map[string]SecurityGroupInfo, len(sgs)) + for _, sg := range sgs { + sgID := awssdk.StringValue(sg.GroupId) + sgInfo := NewRawSecurityGroupInfo(sg) + sgInfoByID[sgID] = sgInfo + } + return sgInfoByID, nil +} + +// buildSDKIPPermissions converts slice of IPPermissionInfo into slice of pointers to IPPermission +// if targets is empty or nil, nil will be returned. +func buildSDKIPPermissions(permissions []IPPermissionInfo) []*ec2sdk.IpPermission { + if len(permissions) == 0 { + return nil + } + sdkPermissions := make([]*ec2sdk.IpPermission, 0, len(permissions)) + for i := range permissions { + sdkPermissions = append(sdkPermissions, &permissions[i].Permission) + } + return sdkPermissions +} diff --git a/pkg/networking/security_group_reconciler.go b/pkg/networking/security_group_reconciler.go index 217721ca2e..24481e1bc1 100644 --- a/pkg/networking/security_group_reconciler.go +++ b/pkg/networking/security_group_reconciler.go @@ -2,42 +2,59 @@ package networking import ( "context" - awssdk "github.com/aws/aws-sdk-go/aws" - ec2sdk "github.com/aws/aws-sdk-go/service/ec2" "github.com/go-logr/logr" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "k8s.io/apimachinery/pkg/labels" - "k8s.io/apimachinery/pkg/util/cache" - "sigs.k8s.io/aws-alb-ingress-controller/pkg/aws/services" - ec2model "sigs.k8s.io/aws-alb-ingress-controller/pkg/model/ec2" - "sync" - "time" -) - -const ( - defaultSGInfoCacheTTL = 10 * time.Minute + ec2equality "sigs.k8s.io/aws-alb-ingress-controller/pkg/equality/ec2" ) +// configuration options for SecurityGroup Reconcile options. type SecurityGroupReconcileOptions struct { // PermissionSelector defines the selector to identify permissions that should be managed. + // Permissions that are not managed shouldn't be altered or deleted. // By default, it selects every permission. PermissionSelector labels.Selector + + // Whether only Authorize permissions. + // By default, it grants and revoke permission. + AuthorizeOnly bool +} + +// Apply SecurityGroupReconcileOption options +func (opts *SecurityGroupReconcileOptions) ApplyOptions(options ...SecurityGroupReconcileOption) { + for _, option := range options { + option(opts) + } } type SecurityGroupReconcileOption func(opts *SecurityGroupReconcileOptions) +// WithPermissionSelector is a option that sets the PermissionSelector. +func WithPermissionSelector(permissionSelector labels.Selector) SecurityGroupReconcileOption { + return func(opts *SecurityGroupReconcileOptions) { + opts.PermissionSelector = permissionSelector + } +} + +// WithAuthorizeOnly is a option that sets the AuthorizeOnly. +func WithAuthorizeOnly(authorizeOnly bool) SecurityGroupReconcileOption { + return func(opts *SecurityGroupReconcileOptions) { + opts.AuthorizeOnly = authorizeOnly + } +} + // SecurityGroupReconciler manages securityGroup rules on securityGroup. type SecurityGroupReconciler interface { - ReconcileIngress(ctx context.Context, sgID string, permissions []ec2model.IPPermission, opts ...SecurityGroupReconcileOption) error + // ReconcileIngress will reconcile Ingress permission on SecurityGroup to be desiredPermission. + ReconcileIngress(ctx context.Context, sgID string, desiredPermissions []IPPermissionInfo, opts ...SecurityGroupReconcileOption) error } // NewDefaultSecurityGroupReconciler constructs new defaultSecurityGroupReconciler. -func NewDefaultSecurityGroupReconciler(ec2Client services.EC2, logger logr.Logger) *defaultSecurityGroupReconciler { +func NewDefaultSecurityGroupReconciler(sgManager SecurityGroupManager, logger logr.Logger) *defaultSecurityGroupReconciler { return &defaultSecurityGroupReconciler{ - ec2Client: ec2Client, - logger: logger, - sgInfoCache: cache.NewExpiring(), - sgInfoCacheMutex: sync.RWMutex{}, - sgInfoCacheTTL: defaultSGInfoCacheTTL, + sgManager: sgManager, + logger: logger, } } @@ -45,31 +62,62 @@ var _ SecurityGroupReconciler = &defaultSecurityGroupReconciler{} // default implementation for SecurityGroupReconciler. type defaultSecurityGroupReconciler struct { - ec2Client services.EC2 + sgManager SecurityGroupManager logger logr.Logger - - sgInfoCache *cache.Expiring - sgInfoCacheMutex sync.RWMutex - sgInfoCacheTTL time.Duration } -func (r *defaultSecurityGroupReconciler) ReconcileIngress(ctx context.Context, sgID string, permissions []ec2model.IPPermission, opts ...SecurityGroupReconcileOption) error { +func (r *defaultSecurityGroupReconciler) ReconcileIngress(ctx context.Context, sgID string, desiredPermissions []IPPermissionInfo, opts ...SecurityGroupReconcileOption) error { + reconcileOpts := SecurityGroupReconcileOptions{ + PermissionSelector: labels.Everything(), + } + reconcileOpts.ApplyOptions(opts...) + + sgInfoByID, err := r.sgManager.FetchSGInfosByID(ctx, sgID) + if err != nil { + return err + } + sgInfo := sgInfoByID[sgID] + + extraPermissions := diffIPPermissionInfos(sgInfo.Ingress, desiredPermissions) + permissionsToRevoke := make([]IPPermissionInfo, 0, len(extraPermissions)) + for _, permission := range extraPermissions { + if reconcileOpts.PermissionSelector.Matches(labels.Set(permission.Labels)) { + permissionsToRevoke = append(permissionsToRevoke, permission) + } + } + permissionsToGrant := diffIPPermissionInfos(desiredPermissions, sgInfo.Ingress) + if len(permissionsToRevoke) > 0 && !reconcileOpts.AuthorizeOnly { + if err := r.sgManager.RevokeSGIngress(ctx, sgID, permissionsToRevoke); err != nil { + return err + } + } + if len(permissionsToGrant) > 0 { + if err := r.sgManager.AuthorizeSGIngress(ctx, sgID, permissionsToGrant); err != nil { + return err + } + } + return nil } -func (r *defaultSecurityGroupReconciler) fetchSGInfosFromAWS(ctx context.Context, sgIDs []string) (map[string]SecurityGroupInfo, error) { - req := &ec2sdk.DescribeSecurityGroupsInput{ - GroupIds: awssdk.StringSlice(sgIDs), - } - sgs, err := r.ec2Client.DescribeSecurityGroupsAsList(ctx, req) - if err != nil { - return nil, err +// diffIPPermissionInfos calculates set_difference as source - target +func diffIPPermissionInfos(source []IPPermissionInfo, target []IPPermissionInfo) []IPPermissionInfo { + opts := cmp.Options{ + ec2equality.CompareOptionForIPPermission(), + cmpopts.IgnoreFields(IPPermissionInfo{}, "Labels"), } - sgInfoByID := make(map[string]SecurityGroupInfo, len(sgs)) - for _, sg := range sgs { - sgID := awssdk.StringValue(sg.GroupId) - sgInfo := buildSecurityGroupInfo(sg) - sgInfoByID[sgID] = sgInfo + var diffs []IPPermissionInfo + for _, sPermission := range source { + containsInTarget := false + for _, tPermission := range target { + if cmp.Equal(sPermission, tPermission, opts) { + containsInTarget = true + break + } + } + if !containsInTarget { + diffs = append(diffs, sPermission) + } } - return sgInfoByID, nil + return diffs } diff --git a/pkg/targetgroupbinding/networking_manager.go b/pkg/targetgroupbinding/networking_manager.go index e07fdc83df..f126674e9b 100644 --- a/pkg/targetgroupbinding/networking_manager.go +++ b/pkg/targetgroupbinding/networking_manager.go @@ -2,36 +2,348 @@ package targetgroupbinding import ( "context" + "fmt" + awssdk "github.com/aws/aws-sdk-go/aws" + ec2sdk "github.com/aws/aws-sdk-go/service/ec2" + "github.com/go-logr/logr" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/pkg/errors" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/sets" + "net" elbv2api "sigs.k8s.io/aws-alb-ingress-controller/apis/elbv2/v1alpha1" "sigs.k8s.io/aws-alb-ingress-controller/pkg/backend" + ec2equality "sigs.k8s.io/aws-alb-ingress-controller/pkg/equality/ec2" + "sigs.k8s.io/aws-alb-ingress-controller/pkg/k8s" + "sigs.k8s.io/aws-alb-ingress-controller/pkg/networking" + "sigs.k8s.io/controller-runtime/pkg/client" + "strings" + "sync" +) + +const ( + tgbNetworkingIPPermissionLabelKey = "elbv2.k8s.aws/targetGroupBinding" + tgbNetworkingIPPermissionLabelValue = "shared" ) // NetworkingManager manages the networking for targetGroupBindings. type NetworkingManager interface { + // ReconcileForPodEndpoints reconcile network settings for TargetGroupBindings with podEndpoints. ReconcileForPodEndpoints(ctx context.Context, tgb *elbv2api.TargetGroupBinding, endpoints []backend.PodEndpoint) error + + // ReconcileForNodePortEndpoints reconcile network settings for TargetGroupBindings with nodePortEndpoints. ReconcileForNodePortEndpoints(ctx context.Context, tgb *elbv2api.TargetGroupBinding, endpoints []backend.NodePortEndpoint) error } +// NewDefaultNetworkingManager constructs defaultNetworkingManager. +func NewDefaultNetworkingManager(k8sClient client.Client, podENIResolver networking.PodENIInfoResolver, nodeENIResolver networking.NodeENIInfoResolver, + sgManager networking.SecurityGroupManager, sgReconciler networking.SecurityGroupReconciler, vpcID string, clusterName string, logger logr.Logger) *defaultNetworkingManager { + return &defaultNetworkingManager{ + k8sClient: k8sClient, + podENIResolver: podENIResolver, + nodeENIResolver: nodeENIResolver, + sgManager: sgManager, + sgReconciler: sgReconciler, + vpcID: vpcID, + clusterName: clusterName, + logger: logger, + + endpointSGsByTGB: make(map[types.NamespacedName][]string), + endpointSGsByTGBMutex: sync.Mutex{}, + trackedEndpointSGs: sets.NewString(), + } +} + // default implementation for NetworkingManager. type defaultNetworkingManager struct { + k8sClient client.Client + podENIResolver networking.PodENIInfoResolver + nodeENIResolver networking.NodeENIInfoResolver + sgManager networking.SecurityGroupManager + sgReconciler networking.SecurityGroupReconciler + vpcID string + clusterName string + logger logr.Logger + + // endpointSGsByTGB are the SecurityGroups for each TargetGroupBinding's endpoints. + endpointSGsByTGB map[types.NamespacedName][]string + // endpointSGsByTGBMutex protects endpointSGsByTGB. + endpointSGsByTGBMutex sync.Mutex + + // trackedEndpointSGs are the securityGroups that we have been managing it's rules. + // we'll garbage collect rules from these trackedEndpointSGs if it's not needed. + trackedEndpointSGs sets.String + // trackedEndpointSGsMutex protects managedEndpointSGs + trackedEndpointSGsMutex sync.Mutex } func (m *defaultNetworkingManager) ReconcileForPodEndpoints(ctx context.Context, tgb *elbv2api.TargetGroupBinding, endpoints []backend.PodEndpoint) error { - return nil + if tgb.Spec.Networking == nil { + return nil + } + + endpointSGs, err := m.resolveEndpointSGsForPodEndpoints(ctx, endpoints) + if err != nil { + return err + } + return m.reconcileForEndpointSGs(ctx, tgb, endpointSGs) } -func (m *defaultResourceManager) ReconcileForNodePortEndpoints(ctx context.Context, tgb *elbv2api.TargetGroupBinding, endpoints []backend.NodePortEndpoint) error { - return nil +func (m *defaultNetworkingManager) ReconcileForNodePortEndpoints(ctx context.Context, tgb *elbv2api.TargetGroupBinding, endpoints []backend.NodePortEndpoint) error { + if tgb.Spec.Networking == nil { + return nil + } + + endpointSGs, err := m.resolveEndpointSGsForNodePortEndpoints(ctx, endpoints) + if err != nil { + return err + } + return m.reconcileForEndpointSGs(ctx, tgb, endpointSGs) } -func (m *defaultResourceManager) reconcileForInstanceSGs(ctx context.Context, tgb *elbv2api.TargetGroupBinding, instanceSGs []string) error { +func (m *defaultNetworkingManager) reconcileForEndpointSGs(ctx context.Context, tgb *elbv2api.TargetGroupBinding, endpointSGs []string) error { + m.endpointSGsByTGBMutex.Lock() + defer m.endpointSGsByTGBMutex.Unlock() + + tgbKey := k8s.NamespacedName(tgb) + m.endpointSGsByTGB[tgbKey] = endpointSGs + tgbsWithNetworking, err := m.fetchTGBsWithNetworking(ctx) + if err != nil { + return err + } + ingressIPPermissionsBySG, computedEndpointSGsForAllTGBs, err := m.computeDesiredIngressIPPermissionsBySG(tgbsWithNetworking) + if err != nil { + return err + } + + permissionSelector := labels.SelectorFromSet(labels.Set{tgbNetworkingIPPermissionLabelKey: tgbNetworkingIPPermissionLabelValue}) + for sgID, ipPermissions := range ingressIPPermissionsBySG { + if err := m.sgReconciler.ReconcileIngress(ctx, sgID, ipPermissions, + networking.WithPermissionSelector(permissionSelector), + networking.WithAuthorizeOnly(!computedEndpointSGsForAllTGBs)) + err != nil { + return err + } + } return nil } -func (m *defaultNetworkingManager) computeInstanceSGsForPodEndpoints(ctx context.Context, endpoints []backend.PodEndpoint) []string { - return nil +// fetchTGBsWithNetworking returns all targetGroupsBindings with networking rules in cluster. +func (m *defaultNetworkingManager) fetchTGBsWithNetworking(ctx context.Context) (map[types.NamespacedName]*elbv2api.TargetGroupBinding, error) { + tgbList := &elbv2api.TargetGroupBindingList{} + if err := m.k8sClient.List(ctx, tgbList); err != nil { + return nil, err + } + tgbWithNetworkingByKey := make(map[types.NamespacedName]*elbv2api.TargetGroupBinding, len(tgbList.Items)) + for i := range tgbList.Items { + tgb := &tgbList.Items[i] + if tgb.Spec.Networking != nil { + tgbWithNetworkingByKey[k8s.NamespacedName(tgb)] = tgb + } + } + return tgbWithNetworkingByKey, nil } -func (m *defaultResourceManager) computeSInstanceSGsForNodePortEndpoints(ctx context.Context, endpoints []backend.NodePortEndpoint) []string { - return nil +// computeDesiredIngressIPPermissionsBySG will compute the desired Ingress IPPermissions per SecurityGroup. +// It will also GC unnecessary entries in endpointSGsByTGB, and return whether all tgb have endpointSGs specified. +func (m *defaultNetworkingManager) computeDesiredIngressIPPermissionsBySG(tgbWithNetworkingByKey map[types.NamespacedName]*elbv2api.TargetGroupBinding) (map[string][]networking.IPPermissionInfo, bool, error) { + tgbNetworkingsBySG := make(map[string][]elbv2api.TargetGroupBindingNetworking) + for tgbKey, endpointSGs := range m.endpointSGsByTGB { + tgb, exists := tgbWithNetworkingByKey[tgbKey] + if !exists { + delete(m.endpointSGsByTGB, tgbKey) + continue + } + for _, endpointSG := range endpointSGs { + tgbNetworkingsBySG[endpointSG] = append(tgbNetworkingsBySG[endpointSG], *tgb.Spec.Networking) + } + } + computedEndpointSGsForAllTGBs := len(tgbWithNetworkingByKey) == len(m.endpointSGsByTGB) + ipPermissionsBySG := make(map[string][]networking.IPPermissionInfo) + for sgID, tgbNetworkings := range tgbNetworkingsBySG { + ipPermissions, err := m.computeDesiredIngressIPPermissions(tgbNetworkings) + if err != nil { + return nil, false, err + } + ipPermissionsBySG[sgID] = ipPermissions + } + return ipPermissionsBySG, computedEndpointSGsForAllTGBs, nil +} + +func (m *defaultNetworkingManager) computeDesiredIngressIPPermissions(tgbNetworkings []elbv2api.TargetGroupBindingNetworking) ([]networking.IPPermissionInfo, error) { + var ipPermissions []networking.IPPermissionInfo + opts := cmp.Options{ + ec2equality.CompareOptionForIPPermission(), + cmpopts.IgnoreFields(networking.IPPermissionInfo{}, "Labels"), + } + + for _, tgbNetworking := range tgbNetworkings { + for _, rule := range tgbNetworking.Ingress { + for _, port := range rule.Ports { + for _, peer := range rule.From { + ipPermission, err := m.computeDesiredIngressIPPermission(port, peer) + if err != nil { + return nil, err + } + containsPermission := false + for _, permission := range ipPermissions { + if cmp.Equal(ipPermission, permission, opts) { + containsPermission = true + break + } + } + if !containsPermission { + ipPermissions = append(ipPermissions, ipPermission) + } + } + } + } + } + return ipPermissions, nil +} + +func (m *defaultNetworkingManager) computeDesiredIngressIPPermission(port elbv2api.NetworkingPort, peer elbv2api.NetworkingPeer) (networking.IPPermissionInfo, error) { + permissionLabels := map[string]string{tgbNetworkingIPPermissionLabelKey: tgbNetworkingIPPermissionLabelValue} + protocol := "-1" + if port.Protocol != nil { + switch *port.Protocol { + case elbv2api.NetworkingProtocolTCP: + protocol = "tcp" + case elbv2api.NetworkingProtocolUDP: + protocol = "udp" + } + } + var fromPort *int64 + var toPort *int64 + if port.Port != nil { + fromPort = port.Port + toPort = port.Port + } + if peer.SecurityGroup != nil { + groupID := peer.SecurityGroup.GroupID + return networking.NewGroupIDIPPermission(protocol, fromPort, toPort, groupID, permissionLabels), nil + } + if peer.IPBlock != nil { + cidr := peer.IPBlock.CIDR + _, _, err := net.ParseCIDR(cidr) + if err != nil { + return networking.IPPermissionInfo{}, err + } + if strings.Contains(cidr, ":") { + return networking.NewCIDRv6IPPermission(protocol, fromPort, toPort, cidr, permissionLabels), nil + } + return networking.NewCIDRIPPermission(protocol, fromPort, toPort, cidr, permissionLabels), nil + } + return networking.IPPermissionInfo{}, errors.New("either SecurityGroup or IPBlock should be specified") +} + +func (m *defaultNetworkingManager) resolveEndpointSGsForPodEndpoints(ctx context.Context, endpoints []backend.PodEndpoint) ([]string, error) { + pods := make([]*corev1.Pod, 0, len(endpoints)) + for _, endpoint := range endpoints { + pods = append(pods, endpoint.Pod) + } + eniInfoByPodKey, err := m.podENIResolver.Resolve(ctx, pods) + if err != nil { + return nil, err + } + sgIDs := sets.NewString() + for _, eniInfo := range eniInfoByPodKey { + sgID, err := m.resolveEndpointSGForENI(ctx, eniInfo) + if err != nil { + return nil, err + } + sgIDs.Insert(sgID) + } + return sgIDs.List(), nil +} + +func (m *defaultNetworkingManager) resolveEndpointSGsForNodePortEndpoints(ctx context.Context, endpoints []backend.NodePortEndpoint) ([]string, error) { + nodes := make([]*corev1.Node, 0, len(endpoints)) + for _, endpoint := range endpoints { + nodes = append(nodes, endpoint.Node) + } + eniInfoByNodeKey, err := m.nodeENIResolver.Resolve(ctx, nodes) + if err != nil { + return nil, err + } + sgIDs := sets.NewString() + for _, eniInfo := range eniInfoByNodeKey { + sgID, err := m.resolveEndpointSGForENI(ctx, eniInfo) + if err != nil { + return nil, err + } + sgIDs.Insert(sgID) + } + return sgIDs.List(), nil +} + +// resolveEndpointSGForENI will resolve the endpoint SecurityGroup for specific ENI. +// If there are only a single securityGroup attached, that one will be the endpoint SecurityGroup. +// If there are multiple securityGroup attached, we expect one and only one securityGroup is tagged with the cluster tag. +func (m *defaultNetworkingManager) resolveEndpointSGForENI(ctx context.Context, eniInfo networking.ENIInfo) (string, error) { + sgIDs := eniInfo.SecurityGroups + if len(sgIDs) == 1 { + return sgIDs[0], nil + } + + sgInfoByID, err := m.sgManager.FetchSGInfosByID(ctx, sgIDs...) + if err != nil { + return "", err + } + clusterResourceTagKey := fmt.Sprintf("kubernetes.io/cluster/%s", m.clusterName) + sgIDsWithClusterTag := sets.NewString() + for sgID, sgInfo := range sgInfoByID { + if _, ok := sgInfo.Tags[clusterResourceTagKey]; ok { + sgIDsWithClusterTag.Insert(sgID) + } + } + if len(sgIDsWithClusterTag) != 1 { + return "", errors.Errorf("expect exactly one securityGroup tagged with %v for eni %v, got: %v", + clusterResourceTagKey, eniInfo.NetworkInterfaceID, sgIDsWithClusterTag.List()) + } + sgID, _ := sgIDsWithClusterTag.PopAny() + return sgID, nil +} + +// trackEndpointSGs will track these endpoint SecurityGroups. +func (m *defaultNetworkingManager) trackEndpointSGs(_ context.Context, sgIDs ...string) { + m.trackedEndpointSGsMutex.Lock() + defer m.trackedEndpointSGsMutex.Unlock() + + m.trackedEndpointSGs.Insert(sgIDs...) +} + +// unTrackEndpointSGs will unTrack these endpoint SecurityGroups. +func (m *defaultNetworkingManager) unTrackEndpointSGs(_ context.Context, sgIDs ...string) { + m.trackedEndpointSGsMutex.Lock() + defer m.trackedEndpointSGsMutex.Unlock() + + m.trackedEndpointSGs.Delete(sgIDs...) +} + +// fetchEndpointSGsFromAWS will return all endpoint securityGroups from AWS API. +// we consider a securityGroup as a endpoint securityGroup if it have the cluster tag. +// note: not all endpoint securityGroup have the cluster Tag(e.g. if a ENI only have a single securityGroup, it will still be used as endpoint securityGroup) +func (m *defaultNetworkingManager) fetchEndpointSGsFromAWS(ctx context.Context) ([]string, error) { + clusterResourceTagKey := fmt.Sprintf("kubernetes.io/cluster/%s", m.clusterName) + req := &ec2sdk.DescribeSecurityGroupsInput{ + Filters: []*ec2sdk.Filter{ + { + Name: awssdk.String("tag:" + clusterResourceTagKey), + Values: awssdk.StringSlice([]string{"owned", "shared"}), + }, + { + Name: awssdk.String("vpc-id"), + Values: awssdk.StringSlice([]string{m.vpcID}), + }, + }, + } + sgInfoByID, err := m.sgManager.FetchSGInfosByRequest(ctx, req) + if err != nil { + return nil, err + } + return sets.StringKeySet(sgInfoByID).List(), nil } diff --git a/pkg/targetgroupbinding/resource_manager.go b/pkg/targetgroupbinding/resource_manager.go index 5890b567b6..e5a6b3fd88 100644 --- a/pkg/targetgroupbinding/resource_manager.go +++ b/pkg/targetgroupbinding/resource_manager.go @@ -12,6 +12,7 @@ import ( "sigs.k8s.io/aws-alb-ingress-controller/pkg/aws/services" "sigs.k8s.io/aws-alb-ingress-controller/pkg/backend" "sigs.k8s.io/aws-alb-ingress-controller/pkg/k8s" + "sigs.k8s.io/aws-alb-ingress-controller/pkg/networking" "sigs.k8s.io/controller-runtime/pkg/client" ) @@ -22,13 +23,18 @@ type ResourceManager interface { } // NewDefaultResourceManager constructs new defaultResourceManager. -func NewDefaultResourceManager(k8sClient client.Client, elbv2Client services.ELBV2, logger logr.Logger) *defaultResourceManager { +func NewDefaultResourceManager(k8sClient client.Client, elbv2Client services.ELBV2, + podENIResolver networking.PodENIInfoResolver, nodeENIResolver networking.NodeENIInfoResolver, + sgManager networking.SecurityGroupManager, sgReconciler networking.SecurityGroupReconciler, + vpcID string, clusterName string, logger logr.Logger) *defaultResourceManager { targetsManager := NewCachedTargetsManager(elbv2Client, logger) endpointResolver := backend.NewDefaultEndpointResolver(k8sClient, logger) + networkingManager := NewDefaultNetworkingManager(k8sClient, podENIResolver, nodeENIResolver, sgManager, sgReconciler, vpcID, clusterName, logger) return &defaultResourceManager{ - targetsManager: targetsManager, - endpointResolver: endpointResolver, - logger: logger, + targetsManager: targetsManager, + endpointResolver: endpointResolver, + networkingManager: networkingManager, + logger: logger, } } @@ -36,9 +42,10 @@ var _ ResourceManager = &defaultResourceManager{} // default implementation for ResourceManager. type defaultResourceManager struct { - targetsManager TargetsManager - endpointResolver backend.EndpointResolver - logger logr.Logger + targetsManager TargetsManager + endpointResolver backend.EndpointResolver + networkingManager NetworkingManager + logger logr.Logger } func (m *defaultResourceManager) Reconcile(ctx context.Context, tgb *elbv2api.TargetGroupBinding) error { @@ -75,6 +82,10 @@ func (m *defaultResourceManager) reconcileWithIPTargetType(ctx context.Context, } notDrainingTargets, drainingTargets := partitionTargetsByDrainingStatus(targets) matchedEndpointAndTargets, unmatchedEndpoints, unmatchedTargets := matchPodEndpointWithTargets(endpoints, notDrainingTargets) + + if err := m.networkingManager.ReconcileForPodEndpoints(ctx, tgb, endpoints); err != nil { + return err + } if err := m.deregisterTargets(ctx, tgARN, unmatchedTargets); err != nil { return err } @@ -83,6 +94,7 @@ func (m *defaultResourceManager) reconcileWithIPTargetType(ctx context.Context, } _ = matchedEndpointAndTargets _ = drainingTargets + return nil } @@ -100,6 +112,10 @@ func (m *defaultResourceManager) reconcileWithInstanceTargetType(ctx context.Con return err } _, unmatchedEndpoints, unmatchedTargets := matchNodePortEndpointWithTargets(endpoints, targets) + + if err := m.networkingManager.ReconcileForNodePortEndpoints(ctx, tgb, endpoints); err != nil { + return err + } if err := m.deregisterTargets(ctx, tgARN, unmatchedTargets); err != nil { return err }