Skip to content

Commit

Permalink
nsg select ip version based on lb ip
Browse files Browse the repository at this point in the history
  • Loading branch information
jwtty authored and k8s-infra-cherrypick-robot committed Aug 31, 2022
1 parent fc2064f commit e77c1ff
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 62 deletions.
25 changes: 20 additions & 5 deletions pkg/provider/azure_loadbalancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,8 @@ func (az *Cloud) reconcileService(ctx context.Context, clusterName string, servi
serviceIP = &lbStatus.Ingress[0].IP
}

backendPrivateIPs := az.LoadBalancerBackendPool.GetBackendPrivateIPs(clusterName, service, lb)
klog.V(2).Infof("reconcileService: reconciling security group for service %q with IP %q, wantLb = true", serviceName, logSafe(serviceIP))
if _, err := az.reconcileSecurityGroup(clusterName, service, serviceIP, &backendPrivateIPs, true /* wantLb */); err != nil {
if _, err := az.reconcileSecurityGroup(clusterName, service, serviceIP, lb.Name, true /* wantLb */); err != nil {
klog.Errorf("reconcileSecurityGroup(%s) failed: %#v", serviceName, err)
return nil, err
}
Expand Down Expand Up @@ -221,7 +220,7 @@ func (az *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName stri
}

klog.V(2).Infof("EnsureLoadBalancerDeleted: reconciling security group for service %q with IP %q, wantLb = false", serviceName, serviceIPToCleanup)
_, err = az.reconcileSecurityGroup(clusterName, service, &serviceIPToCleanup, &[]string{}, false /* wantLb */)
_, err = az.reconcileSecurityGroup(clusterName, service, &serviceIPToCleanup, nil, false /* wantLb */)
if err != nil {
return err
}
Expand Down Expand Up @@ -2286,7 +2285,7 @@ func (az *Cloud) getExpectedHAModeLoadBalancingRuleProperties(

// This reconciles the Network Security Group similar to how the LB is reconciled.
// This entails adding required, missing SecurityRules and removing stale rules.
func (az *Cloud) reconcileSecurityGroup(clusterName string, service *v1.Service, lbIP *string, backendIPAddresses *[]string, wantLb bool) (*network.SecurityGroup, error) {
func (az *Cloud) reconcileSecurityGroup(clusterName string, service *v1.Service, lbIP *string, lbName *string, wantLb bool) (*network.SecurityGroup, error) {
serviceName := getServiceName(service)
klog.V(5).Infof("reconcileSecurityGroup(%s): START clusterName=%q", serviceName, clusterName)

Expand All @@ -2309,6 +2308,22 @@ func (az *Cloud) reconcileSecurityGroup(clusterName string, service *v1.Service,
disableFloatingIP = true
}

backendIPAddresses := make([]string, 0)
if disableFloatingIP {
lb, exist, err := az.getAzureLoadBalancer(to.String(lbName), azcache.CacheReadTypeDefault)
if err != nil {
return nil, err
}
if !exist {
return nil, fmt.Errorf("unable to get lb %s", to.String(lbName))
}
backendPrivateIPv4s, backendPrivateIPv6s := az.LoadBalancerBackendPool.GetBackendPrivateIPs(clusterName, service, &lb)
backendIPAddresses = backendPrivateIPv4s
if utilnet.IsIPv6String(*lbIP) {
backendIPAddresses = backendPrivateIPv6s
}
}

destinationIPAddress := ""
if wantLb && lbIP == nil {
return nil, fmt.Errorf("no load balancer IP for setting up security rules for service %s", service.Name)
Expand Down Expand Up @@ -2352,7 +2367,7 @@ func (az *Cloud) reconcileSecurityGroup(clusterName string, service *v1.Service,
sourceAddressPrefixes = append(sourceAddressPrefixes, serviceTags...)
}

expectedSecurityRules, err := az.getExpectedSecurityRules(wantLb, ports, sourceAddressPrefixes, service, destinationIPAddresses, sourceRanges, *backendIPAddresses, disableFloatingIP)
expectedSecurityRules, err := az.getExpectedSecurityRules(wantLb, ports, sourceAddressPrefixes, service, destinationIPAddresses, sourceRanges, backendIPAddresses, disableFloatingIP)
if err != nil {
return nil, err
}
Expand Down
32 changes: 21 additions & 11 deletions pkg/provider/azure_loadbalancer_backendpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ type BackendPool interface {
ReconcileBackendPools(clusterName string, service *v1.Service, lb *network.LoadBalancer) (bool, bool, error)

// GetBackendPrivateIPs returns the private IPs of LoadBalancer's backend pool
GetBackendPrivateIPs(clusterName string, service *v1.Service, lb *network.LoadBalancer) []string
GetBackendPrivateIPs(clusterName string, service *v1.Service, lb *network.LoadBalancer) ([]string, []string)
}

type backendPoolTypeNodeIPConfig struct {
Expand Down Expand Up @@ -232,14 +232,14 @@ func (bc *backendPoolTypeNodeIPConfig) ReconcileBackendPools(clusterName string,
return isBackendPoolPreConfigured, changed, err
}

func (bc *backendPoolTypeNodeIPConfig) GetBackendPrivateIPs(clusterName string, service *v1.Service, lb *network.LoadBalancer) []string {
func (bc *backendPoolTypeNodeIPConfig) GetBackendPrivateIPs(clusterName string, service *v1.Service, lb *network.LoadBalancer) ([]string, []string) {
serviceName := getServiceName(service)
lbBackendPoolName := getBackendPoolName(clusterName, service)
if lb.LoadBalancerPropertiesFormat == nil || lb.LoadBalancerPropertiesFormat.BackendAddressPools == nil {
return nil
return nil, nil
}

backendPrivateIPs := sets.NewString()
backendPrivateIPv4s, backendPrivateIPv6s := sets.NewString(), sets.NewString()
for _, bp := range *lb.BackendAddressPools {
if strings.EqualFold(to.String(bp.Name), lbBackendPoolName) {
klog.V(10).Infof("bc.GetBackendPrivateIPs for service (%s): found wanted backendpool %s", serviceName, to.String(bp.Name))
Expand All @@ -258,15 +258,21 @@ func (bc *backendPoolTypeNodeIPConfig) GetBackendPrivateIPs(clusterName string,
}
if privateIPs != nil {
klog.V(2).Infof("bc.GetBackendPrivateIPs for service (%s): lb backendpool - found private IPs %v of node %s", serviceName, privateIPs, nodeName)
backendPrivateIPs.Insert(privateIPs...)
for _, ip := range privateIPs {
if utilnet.IsIPv4String(ip) {
backendPrivateIPv4s.Insert(ip)
} else {
backendPrivateIPv6s.Insert(ip)
}
}
}
}
}
} else {
klog.V(10).Infof("bc.GetBackendPrivateIPs for service (%s): found unmanaged backendpool %s", serviceName, to.String(bp.Name))
}
}
return backendPrivateIPs.List()
return backendPrivateIPv4s.List(), backendPrivateIPv6s.List()
}

type backendPoolTypeNodeIP struct {
Expand Down Expand Up @@ -495,14 +501,14 @@ func (bi *backendPoolTypeNodeIP) ReconcileBackendPools(clusterName string, servi
return isBackendPoolPreConfigured, changed, nil
}

func (bi *backendPoolTypeNodeIP) GetBackendPrivateIPs(clusterName string, service *v1.Service, lb *network.LoadBalancer) []string {
func (bi *backendPoolTypeNodeIP) GetBackendPrivateIPs(clusterName string, service *v1.Service, lb *network.LoadBalancer) ([]string, []string) {
serviceName := getServiceName(service)
lbBackendPoolName := getBackendPoolName(clusterName, service)
if lb.LoadBalancerPropertiesFormat == nil || lb.LoadBalancerPropertiesFormat.BackendAddressPools == nil {
return nil
return nil, nil
}

backendPrivateIPs := sets.NewString()
backendPrivateIPv4s, backendPrivateIPv6s := sets.NewString(), sets.NewString()
for _, bp := range *lb.BackendAddressPools {
if strings.EqualFold(to.String(bp.Name), lbBackendPoolName) {
klog.V(10).Infof("bi.GetBackendPrivateIPs for service (%s): found wanted backendpool %s", serviceName, to.String(bp.Name))
Expand All @@ -511,7 +517,11 @@ func (bi *backendPoolTypeNodeIP) GetBackendPrivateIPs(clusterName string, servic
ipAddress := backendAddress.IPAddress
if ipAddress != nil {
klog.V(2).Infof("bi.GetBackendPrivateIPs for service (%s): lb backendpool - found private IP %q", serviceName, *ipAddress)
backendPrivateIPs.Insert(*ipAddress)
if utilnet.IsIPv4String(*ipAddress) {
backendPrivateIPv4s.Insert(*ipAddress)
} else {
backendPrivateIPv6s.Insert(*ipAddress)
}
} else {
klog.V(4).Infof("bi.GetBackendPrivateIPs for service (%s): lb backendpool - found null private IP")
}
Expand All @@ -521,7 +531,7 @@ func (bi *backendPoolTypeNodeIP) GetBackendPrivateIPs(clusterName string, servic
klog.V(10).Infof("bi.GetBackendPrivateIPs for service (%s): found unmanaged backendpool %s", serviceName, to.String(bp.Name))
}
}
return backendPrivateIPs.List()
return backendPrivateIPv4s.List(), backendPrivateIPv6s.List()
}

func newBackendPool(lb *network.LoadBalancer, isBackendPoolPreConfigured bool, preConfiguredBackendPoolLoadBalancerTypes, serviceName, lbBackendPoolName string) bool {
Expand Down
47 changes: 43 additions & 4 deletions pkg/provider/azure_loadbalancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ func TestEnsureLoadBalancerDeleted(t *testing.T) {
mockLBBackendPool := az.LoadBalancerBackendPool.(*MockBackendPool)
mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any()).Return(false, false, nil).AnyTimes()
mockLBBackendPool.EXPECT().EnsureHostsInPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
mockLBBackendPool.EXPECT().GetBackendPrivateIPs(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
mockLBBackendPool.EXPECT().GetBackendPrivateIPs(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()

clusterResources, expectedInterfaces, expectedVirtualMachines := getClusterResources(az, vmCount, availabilitySetCount)
setMockEnv(az, ctrl, expectedInterfaces, expectedVirtualMachines, 4)
Expand Down Expand Up @@ -3092,6 +3092,7 @@ func TestReconcileSecurityGroup(t *testing.T) {
testCases := []struct {
desc string
lbIP *string
lbName *string
service v1.Service
existingSgs map[string]network.SecurityGroup
expectedSg *network.SecurityGroup
Expand Down Expand Up @@ -3309,6 +3310,7 @@ func TestReconcileSecurityGroup(t *testing.T) {
SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{},
}},
lbIP: to.StringPtr("1.2.3.4"),
lbName: to.StringPtr("lb"),
wantLb: true,
expectedSg: &network.SecurityGroup{
Name: to.StringPtr("nsg"),
Expand All @@ -3321,7 +3323,38 @@ func TestReconcileSecurityGroup(t *testing.T) {
SourcePortRange: to.StringPtr("*"),
DestinationPortRange: to.StringPtr(strconv.Itoa(int(getBackendPort(80)))),
SourceAddressPrefix: to.StringPtr("Internet"),
DestinationAddressPrefixes: to.StringSlicePtr([]string{}),
DestinationAddressPrefixes: to.StringSlicePtr([]string{"1.2.3.4", "5.6.7.8"}),
Access: network.SecurityRuleAccess("Allow"),
Priority: to.Int32Ptr(500),
Direction: network.SecurityRuleDirection("Inbound"),
},
},
},
},
},
},
{
desc: "reconcileSecurityGroup shall create sgs with only IPv6 destination addresses for IPv6 services with floating IP disabled",
service: getTestService("test1", v1.ProtocolTCP, map[string]string{consts.ServiceAnnotationDisableLoadBalancerFloatingIP: "true"}, false, 80),
existingSgs: map[string]network.SecurityGroup{"nsg": {
Name: to.StringPtr("nsg"),
SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{},
}},
lbIP: to.StringPtr("1234::5"),
lbName: to.StringPtr("lb"),
wantLb: true,
expectedSg: &network.SecurityGroup{
Name: to.StringPtr("nsg"),
SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{
SecurityRules: &[]network.SecurityRule{
{
Name: to.StringPtr("atest1-TCP-80-Internet"),
SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{
Protocol: network.SecurityRuleProtocol("Tcp"),
SourcePortRange: to.StringPtr("*"),
DestinationPortRange: to.StringPtr(strconv.Itoa(int(getBackendPort(80)))),
SourceAddressPrefix: to.StringPtr("Internet"),
DestinationAddressPrefixes: to.StringSlicePtr([]string{"fc00::1", "fc00::2"}),
Access: network.SecurityRuleAccess("Allow"),
Priority: to.Int32Ptr(500),
Direction: network.SecurityRuleDirection("Inbound"),
Expand All @@ -3347,7 +3380,13 @@ func TestReconcileSecurityGroup(t *testing.T) {
t.Fatalf("TestCase[%d] meets unexpected error: %v", i, err)
}
}
sg, err := az.reconcileSecurityGroup("testCluster", &test.service, test.lbIP, &[]string{}, test.wantLb)
mockLBClient := az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface)
mockLBBackendPool := az.LoadBalancerBackendPool.(*MockBackendPool)
if test.lbName != nil {
mockLBBackendPool.EXPECT().GetBackendPrivateIPs(gomock.Any(), gomock.Any(), gomock.Any()).Return([]string{"1.2.3.4", "5.6.7.8"}, []string{"fc00::1", "fc00::2"}).AnyTimes()
mockLBClient.EXPECT().Get(gomock.Any(), "rg", *test.lbName, gomock.Any()).Return(network.LoadBalancer{}, nil)
}
sg, err := az.reconcileSecurityGroup("testCluster", &test.service, test.lbIP, test.lbName, test.wantLb)
assert.Equal(t, test.expectedSg, sg, "TestCase[%d]: %s", i, test.desc)
assert.Equal(t, test.expectedError, err != nil, "TestCase[%d]: %s", i, test.desc)
}
Expand Down Expand Up @@ -3403,7 +3442,7 @@ func TestReconcileSecurityGroupLoadBalancerSourceRanges(t *testing.T) {
mockSGClient := az.SecurityGroupsClient.(*mocksecuritygroupclient.MockInterface)
mockSGClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, gomock.Any(), gomock.Any()).Return(existingSg, nil)
mockSGClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
sg, err := az.reconcileSecurityGroup("testCluster", &service, lbIP, &[]string{}, true)
sg, err := az.reconcileSecurityGroup("testCluster", &service, lbIP, nil, true)
assert.NoError(t, err)
assert.Equal(t, expectedSg, *sg)
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/provider/azure_mock_loadbalancer_backendpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ func (mr *MockBackendPoolMockRecorder) ReconcileBackendPools(clusterName, servic
}

// GetBackendPrivateIPs mocks base method
func (m *MockBackendPool) GetBackendPrivateIPs(clusterName string, service *v1.Service, lb *network.LoadBalancer) []string {
func (m *MockBackendPool) GetBackendPrivateIPs(clusterName string, service *v1.Service, lb *network.LoadBalancer) ([]string, []string) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetBackendPrivateIPs", clusterName, service, lb)
ret0, _ := ret[0].([]string)
return ret0
ret1, _ := ret[1].([]string)
return ret0, ret1
}

// GetBackendPrivateIPs indicates an expected call of GetBackendPrivateIPs
Expand Down

0 comments on commit e77c1ff

Please sign in to comment.