From fdfa351bf0ecc188ddc914c5117e2c6faa5dcf30 Mon Sep 17 00:00:00 2001 From: Pengfei Ni Date: Mon, 16 Aug 2021 05:22:22 +0000 Subject: [PATCH 1/2] fix: skip case sensitivity when checking Azure NSG rules --- .../azure/azure_loadbalancer.go | 18 +-- .../azure/azure_loadbalancer_test.go | 30 ++++ .../azure/azure_test.go | 153 ++++++++++++++++++ 3 files changed, 192 insertions(+), 9 deletions(-) diff --git a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_loadbalancer.go b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_loadbalancer.go index 9744507c8cbb..21ff37b4c65b 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_loadbalancer.go +++ b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_loadbalancer.go @@ -1230,18 +1230,18 @@ func (az *Cloud) reconcileSecurityGroup(clusterName string, service *v1.Service, sharedRuleName := az.getSecurityRuleName(service, port, sourceAddressPrefix) sharedIndex, sharedRule, sharedRuleFound := findSecurityRuleByName(updatedRules, sharedRuleName) if !sharedRuleFound { - klog.V(4).Infof("Expected to find shared rule %s for service %s being deleted, but did not", sharedRuleName, service.Name) - return nil, fmt.Errorf("Expected to find shared rule %s for service %s being deleted, but did not", sharedRuleName, service.Name) + klog.V(4).Infof("Didn't find shared rule %s for service %s", sharedRuleName, service.Name) + continue } if sharedRule.DestinationAddressPrefixes == nil { - klog.V(4).Infof("Expected to have array of destinations in shared rule for service %s being deleted, but did not", service.Name) - return nil, fmt.Errorf("Expected to have array of destinations in shared rule for service %s being deleted, but did not", service.Name) + klog.V(4).Infof("Didn't find DestinationAddressPrefixes in shared rule for service %s", service.Name) + continue } existingPrefixes := *sharedRule.DestinationAddressPrefixes addressIndex, found := findIndex(existingPrefixes, destinationIPAddress) if !found { - klog.V(4).Infof("Expected to find destination address %s in shared rule %s for service %s being deleted, but did not", destinationIPAddress, sharedRuleName, service.Name) - return nil, fmt.Errorf("Expected to find destination address %s in shared rule %s for service %s being deleted, but did not", destinationIPAddress, sharedRuleName, service.Name) + klog.V(4).Infof("Didn't find destination address %v in shared rule %s for service %s", destinationIPAddress, sharedRuleName, service.Name) + continue } if len(existingPrefixes) == 1 { updatedRules = append(updatedRules[:sharedIndex], updatedRules[sharedIndex+1:]...) @@ -1658,7 +1658,7 @@ func findSecurityRule(rules []network.SecurityRule, rule network.SecurityRule) b if !strings.EqualFold(to.String(existingRule.Name), to.String(rule.Name)) { continue } - if existingRule.Protocol != rule.Protocol { + if !strings.EqualFold(string(existingRule.Protocol), string(rule.Protocol)) { continue } if !strings.EqualFold(to.String(existingRule.SourcePortRange), to.String(rule.SourcePortRange)) { @@ -1675,10 +1675,10 @@ func findSecurityRule(rules []network.SecurityRule, rule network.SecurityRule) b continue } } - if existingRule.Access != rule.Access { + if !strings.EqualFold(string(existingRule.Access), string(rule.Access)) { continue } - if existingRule.Direction != rule.Direction { + if !strings.EqualFold(string(existingRule.Direction), string(rule.Direction)) { continue } return true diff --git a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_loadbalancer_test.go b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_loadbalancer_test.go index f7eaffd730ee..64e41e7902a5 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_loadbalancer_test.go +++ b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_loadbalancer_test.go @@ -1943,6 +1943,36 @@ func TestReconcileSecurityGroup(t *testing.T) { }, }, }, + { + desc: "reconcileSecurityGroup shall create shared sgs for service with azure-shared-securityrule annotations", + service: getTestService("test1", v1.ProtocolTCP, map[string]string{ServiceAnnotationSharedSecurityRule: "true"}, true, 80), + existingSgs: map[string]network.SecurityGroup{"nsg": { + Name: to.StringPtr("nsg"), + SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{}, + }}, + lbIP: to.StringPtr("1.2.3.4"), + wantLb: true, + expectedSg: &network.SecurityGroup{ + Name: to.StringPtr("nsg"), + SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{ + SecurityRules: &[]network.SecurityRule{ + { + Name: to.StringPtr("shared-TCP-80-Internet"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocol("Tcp"), + SourcePortRange: to.StringPtr("*"), + DestinationPortRange: to.StringPtr("80"), + SourceAddressPrefix: to.StringPtr("Internet"), + DestinationAddressPrefixes: to.StringSlicePtr([]string{"1.2.3.4"}), + Access: network.SecurityRuleAccess("Allow"), + Priority: to.Int32Ptr(500), + Direction: network.SecurityRuleDirection("Inbound"), + }, + }, + }, + }, + }, + }, } for i, test := range testCases { diff --git a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_test.go b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_test.go index 31aecc8c2beb..83ffc11a9f46 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_test.go +++ b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_test.go @@ -3330,3 +3330,156 @@ func TestInitializeCloudFromConfig(t *testing.T) { expectedErr = fmt.Errorf("useInstanceMetadata must be enabled without Azure credentials") assert.Equal(t, expectedErr, err) } + +func TestFindSecurityRule(t *testing.T) { + testRuleName := "test-rule" + testIP1 := "192.168.192.168" + sg := network.SecurityRule{ + Name: to.StringPtr(testRuleName), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: to.StringPtr("*"), + SourceAddressPrefix: to.StringPtr("Internet"), + DestinationPortRange: to.StringPtr("80"), + DestinationAddressPrefix: to.StringPtr(testIP1), + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + }, + } + testCases := []struct { + desc string + testRule network.SecurityRule + expected bool + }{ + { + desc: "false should be returned for an empty rule", + testRule: network.SecurityRule{}, + expected: false, + }, + { + desc: "false should be returned when rule name doesn't match", + testRule: network.SecurityRule{ + Name: to.StringPtr("not-the-right-name"), + }, + expected: false, + }, + { + desc: "false should be returned when protocol doesn't match", + testRule: network.SecurityRule{ + Name: to.StringPtr(testRuleName), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolUDP, + }, + }, + expected: false, + }, + { + desc: "false should be returned when SourcePortRange doesn't match", + testRule: network.SecurityRule{ + Name: to.StringPtr(testRuleName), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolUDP, + SourcePortRange: to.StringPtr("1.2.3.4/32"), + }, + }, + expected: false, + }, + { + desc: "false should be returned when SourceAddressPrefix doesn't match", + testRule: network.SecurityRule{ + Name: to.StringPtr(testRuleName), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolUDP, + SourcePortRange: to.StringPtr("*"), + SourceAddressPrefix: to.StringPtr("2.3.4.0/24"), + }, + }, + expected: false, + }, + { + desc: "false should be returned when DestinationPortRange doesn't match", + testRule: network.SecurityRule{ + Name: to.StringPtr(testRuleName), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolUDP, + SourcePortRange: to.StringPtr("*"), + SourceAddressPrefix: to.StringPtr("Internet"), + DestinationPortRange: to.StringPtr("443"), + }, + }, + expected: false, + }, + { + desc: "false should be returned when DestinationAddressPrefix doesn't match", + testRule: network.SecurityRule{ + Name: to.StringPtr(testRuleName), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolUDP, + SourcePortRange: to.StringPtr("*"), + SourceAddressPrefix: to.StringPtr("Internet"), + DestinationPortRange: to.StringPtr("80"), + DestinationAddressPrefix: to.StringPtr("192.168.0.3"), + }, + }, + expected: false, + }, + { + desc: "false should be returned when Access doesn't match", + testRule: network.SecurityRule{ + Name: to.StringPtr(testRuleName), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolUDP, + SourcePortRange: to.StringPtr("*"), + SourceAddressPrefix: to.StringPtr("Internet"), + DestinationPortRange: to.StringPtr("80"), + DestinationAddressPrefix: to.StringPtr(testIP1), + Access: network.SecurityRuleAccessDeny, + // Direction: network.SecurityRuleDirectionInbound, + }, + }, + expected: false, + }, + { + desc: "false should be returned when Direction doesn't match", + testRule: network.SecurityRule{ + Name: to.StringPtr(testRuleName), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolUDP, + SourcePortRange: to.StringPtr("*"), + SourceAddressPrefix: to.StringPtr("Internet"), + DestinationPortRange: to.StringPtr("80"), + DestinationAddressPrefix: to.StringPtr(testIP1), + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionOutbound, + }, + }, + expected: false, + }, + { + desc: "true should be returned when everything matches but protocol is in different case", + testRule: network.SecurityRule{ + Name: to.StringPtr(testRuleName), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocol("TCP"), + SourcePortRange: to.StringPtr("*"), + SourceAddressPrefix: to.StringPtr("Internet"), + DestinationPortRange: to.StringPtr("80"), + DestinationAddressPrefix: to.StringPtr(testIP1), + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + }, + }, + expected: true, + }, + { + desc: "true should be returned when everything matches", + testRule: sg, + expected: true, + }, + } + + for i := range testCases { + found := findSecurityRule([]network.SecurityRule{sg}, testCases[i].testRule) + assert.Equal(t, testCases[i].expected, found, testCases[i].desc) + } +} From 8eb2fbc1bc5a61ebe1a42047525c675338ee5465 Mon Sep 17 00:00:00 2001 From: Pengfei Ni Date: Mon, 16 Aug 2021 05:05:00 +0000 Subject: [PATCH 2/2] fix: ensure InstanceShutdownByProviderID return false for creating Azure VMs --- .../azure/azure_instances.go | 19 +++- .../azure/azure_instances_test.go | 32 +++++-- .../azure/azure_standard.go | 14 +++ .../azure/azure_standard_test.go | 63 ++++++++++++ .../azure/azure_vmsets.go | 5 +- .../azure/azure_vmss.go | 24 +++++ .../azure/azure_vmss_test.go | 51 ++++++++++ .../azure/mockvmsets/azure_mock_vmsets.go | 95 +++++++++++-------- 8 files changed, 252 insertions(+), 51 deletions(-) diff --git a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_instances.go b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_instances.go index 95f8f49ae733..9cb021709d01 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_instances.go +++ b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_instances.go @@ -20,6 +20,7 @@ package azure import ( "context" + "errors" "fmt" "os" "strings" @@ -29,6 +30,8 @@ import ( cloudprovider "k8s.io/cloud-provider" "k8s.io/klog/v2" azcache "k8s.io/legacy-cloud-providers/azure/cache" + + "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2019-12-01/compute" ) const ( @@ -233,10 +236,22 @@ func (az *Cloud) InstanceShutdownByProviderID(ctx context.Context, providerID st return false, err } - klog.V(5).Infof("InstanceShutdownByProviderID gets power status %q for node %q", powerStatus, nodeName) + klog.V(3).Infof("InstanceShutdownByProviderID gets power status %q for node %q", powerStatus, nodeName) + + provisioningState, err := az.VMSet.GetProvisioningStateByNodeName(string(nodeName)) + if err != nil { + // Returns false, so the controller manager will continue to check InstanceExistsByProviderID(). + if errors.Is(err, cloudprovider.InstanceNotFound) { + return false, nil + } + + return false, err + } + klog.V(3).Infof("InstanceShutdownByProviderID gets provisioning state %q for node %q", provisioningState, nodeName) status := strings.ToLower(powerStatus) - return status == vmPowerStateStopped || status == vmPowerStateDeallocated || status == vmPowerStateDeallocating, nil + provisioningSucceeded := strings.EqualFold(strings.ToLower(provisioningState), strings.ToLower(string(compute.ProvisioningStateSucceeded))) + return provisioningSucceeded && (status == vmPowerStateStopped || status == vmPowerStateDeallocated || status == vmPowerStateDeallocating), nil } func (az *Cloud) isCurrentInstance(name types.NodeName, metadataVMName string) (bool, error) { diff --git a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_instances_test.go b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_instances_test.go index 3005cb7f31b0..2efa0f7fa498 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_instances_test.go +++ b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_instances_test.go @@ -63,6 +63,7 @@ func setTestVirtualMachines(c *Cloud, vmList map[string]string, isDataDisksFull }, } vm.VirtualMachineProperties = &compute.VirtualMachineProperties{ + ProvisioningState: to.StringPtr(string(compute.ProvisioningStateSucceeded)), HardwareProfile: &compute.HardwareProfile{ VMSize: compute.VirtualMachineSizeTypesStandardA0, }, @@ -252,12 +253,13 @@ func TestInstanceID(t *testing.T) { func TestInstanceShutdownByProviderID(t *testing.T) { testcases := []struct { - name string - vmList map[string]string - nodeName string - providerID string - expected bool - expectedErrMsg error + name string + vmList map[string]string + nodeName string + providerID string + provisioningState string + expected bool + expectedErrMsg error }{ { name: "InstanceShutdownByProviderID should return false if the vm is in PowerState/Running status", @@ -294,6 +296,7 @@ func TestInstanceShutdownByProviderID(t *testing.T) { providerID: "azure:///subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm5", expected: true, }, + { name: "InstanceShutdownByProviderID should return false if the vm is in PowerState/Stopping status", vmList: map[string]string{"vm6": "PowerState/Stopping"}, @@ -315,13 +318,23 @@ func TestInstanceShutdownByProviderID(t *testing.T) { providerID: "azure:///subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm8", expected: false, }, + { + name: "InstanceShutdownByProviderID should return false if the vm is in PowerState/Stopped state with Creating provisioning state", + vmList: map[string]string{"vm9": "PowerState/Stopped"}, + nodeName: "vm9", + provisioningState: "Creating", + providerID: "azure:///subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm9", + expected: false, + }, { name: "InstanceShutdownByProviderID should report error if providerID is null", + nodeName: "vmm", expected: false, }, { name: "InstanceShutdownByProviderID should report error if providerID is invalid", - providerID: "azure:///subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Compute/VM/vm9", + providerID: "azure:///subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Compute/VM/vm10", + nodeName: "vm10", expected: false, expectedErrMsg: fmt.Errorf("error splitting providerID"), }, @@ -332,11 +345,14 @@ func TestInstanceShutdownByProviderID(t *testing.T) { for _, test := range testcases { cloud := GetTestCloud(ctrl) expectedVMs := setTestVirtualMachines(cloud, test.vmList, false) + if test.provisioningState != "" { + expectedVMs[0].ProvisioningState = to.StringPtr(test.provisioningState) + } mockVMsClient := cloud.VirtualMachinesClient.(*mockvmclient.MockInterface) for _, vm := range expectedVMs { mockVMsClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, *vm.Name, gomock.Any()).Return(vm, nil).AnyTimes() } - mockVMsClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, "vm8", gomock.Any()).Return(compute.VirtualMachine{}, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() + mockVMsClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, test.nodeName, gomock.Any()).Return(compute.VirtualMachine{}, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() hasShutdown, err := cloud.InstanceShutdownByProviderID(context.Background(), test.providerID) assert.Equal(t, test.expectedErrMsg, err, test.name) diff --git a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_standard.go b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_standard.go index 174ba45e69d9..a35c4ababf89 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_standard.go +++ b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_standard.go @@ -444,6 +444,20 @@ func (as *availabilitySet) GetPowerStatusByNodeName(name string) (powerState str return vmPowerStateStopped, nil } +// GetProvisioningStateByNodeName returns the provisioningState for the specified node. +func (as *availabilitySet) GetProvisioningStateByNodeName(name string) (provisioningState string, err error) { + vm, err := as.getVirtualMachine(types.NodeName(name), azcache.CacheReadTypeDefault) + if err != nil { + return provisioningState, err + } + + if vm.VirtualMachineProperties == nil || vm.VirtualMachineProperties.ProvisioningState == nil { + return provisioningState, nil + } + + return to.String(vm.VirtualMachineProperties.ProvisioningState), nil +} + // GetNodeNameByProviderID gets the node name by provider ID. func (as *availabilitySet) GetNodeNameByProviderID(providerID string) (types.NodeName, error) { // NodeName is part of providerID for standard instances. diff --git a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_standard_test.go b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_standard_test.go index 530d57f19de9..968137f41110 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_standard_test.go +++ b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_standard_test.go @@ -985,6 +985,69 @@ func TestGetStandardVMPowerStatusByNodeName(t *testing.T) { } } +func TestGetStandardVMProvisioningStateByNodeName(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + cloud := GetTestCloud(ctrl) + + testcases := []struct { + name string + nodeName string + vm compute.VirtualMachine + expectedProvisioningState string + getErr *retry.Error + expectedErrMsg error + }{ + { + name: "GetProvisioningStateByNodeName should report error if node don't exist", + nodeName: "vm1", + vm: compute.VirtualMachine{}, + getErr: &retry.Error{ + HTTPStatusCode: http.StatusNotFound, + RawError: cloudprovider.InstanceNotFound, + }, + expectedErrMsg: fmt.Errorf("instance not found"), + }, + { + name: "GetProvisioningStateByNodeName should return Succeeded for running VM", + nodeName: "vm2", + vm: compute.VirtualMachine{ + Name: to.StringPtr("vm2"), + VirtualMachineProperties: &compute.VirtualMachineProperties{ + ProvisioningState: to.StringPtr("Succeeded"), + InstanceView: &compute.VirtualMachineInstanceView{ + Statuses: &[]compute.InstanceViewStatus{ + { + Code: to.StringPtr("PowerState/Running"), + }, + }, + }, + }, + }, + expectedProvisioningState: "Succeeded", + }, + { + name: "GetProvisioningStateByNodeName should return empty string when vm.ProvisioningState is nil", + nodeName: "vm3", + vm: compute.VirtualMachine{ + Name: to.StringPtr("vm3"), + VirtualMachineProperties: &compute.VirtualMachineProperties{ + ProvisioningState: nil, + }, + }, + expectedProvisioningState: "", + }, + } + for _, test := range testcases { + mockVMClient := cloud.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, test.nodeName, gomock.Any()).Return(test.vm, test.getErr).AnyTimes() + + provisioningState, err := cloud.VMSet.GetProvisioningStateByNodeName(test.nodeName) + assert.Equal(t, test.expectedErrMsg, err, test.name) + assert.Equal(t, test.expectedProvisioningState, provisioningState, test.name) + } +} + func TestGetStandardVMZoneByNodeName(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() diff --git a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_vmsets.go b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_vmsets.go index 89ac79cb6d4c..087bdfad1eac 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_vmsets.go +++ b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_vmsets.go @@ -71,11 +71,14 @@ type VMSet interface { // DetachDisk detaches a vhd from host. The vhd can be identified by diskName or diskURI. DetachDisk(diskName, diskURI string, nodeName types.NodeName) error // GetDataDisks gets a list of data disks attached to the node. - GetDataDisks(nodeName types.NodeName, string azcache.AzureCacheReadType) ([]compute.DataDisk, error) + GetDataDisks(nodeName types.NodeName, crt azcache.AzureCacheReadType) ([]compute.DataDisk, error) // GetPowerStatusByNodeName returns the power state of the specified node. GetPowerStatusByNodeName(name string) (string, error) + // GetProvisioningStateByNodeName returns the provisioningState for the specified node. + GetProvisioningStateByNodeName(name string) (string, error) + // GetPrivateIPsByNodeName returns a slice of all private ips assigned to node (ipv6 and ipv4) GetPrivateIPsByNodeName(name string) ([]string, error) } diff --git a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_vmss.go b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_vmss.go index 3c40ebdff320..158fee2da20e 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_vmss.go +++ b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_vmss.go @@ -241,6 +241,30 @@ func (ss *scaleSet) GetPowerStatusByNodeName(name string) (powerState string, er return vmPowerStateStopped, nil } +// GetProvisioningStateByNodeName returns the provisioningState for the specified node. +func (ss *scaleSet) GetProvisioningStateByNodeName(name string) (provisioningState string, err error) { + managedByAS, err := ss.isNodeManagedByAvailabilitySet(name, azcache.CacheReadTypeUnsafe) + if err != nil { + klog.Errorf("Failed to check isNodeManagedByAvailabilitySet: %v", err) + return "", err + } + if managedByAS { + // vm is managed by availability set. + return ss.availabilitySet.GetProvisioningStateByNodeName(name) + } + + _, _, vm, err := ss.getVmssVM(name, azcache.CacheReadTypeDefault) + if err != nil { + return provisioningState, err + } + + if vm.VirtualMachineScaleSetVMProperties == nil || vm.VirtualMachineScaleSetVMProperties.ProvisioningState == nil { + return provisioningState, nil + } + + return to.String(vm.VirtualMachineScaleSetVMProperties.ProvisioningState), nil +} + // getCachedVirtualMachineByInstanceID gets scaleSetVMInfo from cache. // The node must belong to one of scale sets. func (ss *scaleSet) getVmssVMByInstanceID(resourceGroup, scaleSetName, instanceID string, crt azcache.AzureCacheReadType) (*compute.VirtualMachineScaleSetVM, error) { diff --git a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_vmss_test.go b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_vmss_test.go index e2adfe25020f..c06ac07b1015 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_vmss_test.go +++ b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_vmss_test.go @@ -822,6 +822,57 @@ func TestGetPowerStatusByNodeName(t *testing.T) { } } +func TestGetProvisioningStateByNodeName(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + testCases := []struct { + description string + vmList []string + provisioningState string + expectedProvisioningState string + expectedErr error + }{ + { + description: "GetProvisioningStateByNodeName should return empty value when the vm.ProvisioningState is nil", + provisioningState: "", + vmList: []string{"vmss-vm-000001"}, + expectedProvisioningState: "", + }, + { + description: "GetProvisioningStateByNodeName should return Succeeded when the vm is running", + provisioningState: "Succeeded", + vmList: []string{"vmss-vm-000001"}, + expectedProvisioningState: "Succeeded", + }, + } + + for _, test := range testCases { + ss, err := newTestScaleSet(ctrl) + assert.NoError(t, err, "unexpected error when creating test VMSS") + + expectedVMSS := buildTestVMSS(testVMSSName, "vmss-vm-") + mockVMSSClient := ss.cloud.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() + + expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(ss.cloud, testVMSSName, "", 0, test.vmList, "", false) + mockVMSSVMClient := ss.cloud.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) + if test.provisioningState != "" { + expectedVMSSVMs[0].ProvisioningState = to.StringPtr(test.provisioningState) + } else { + expectedVMSSVMs[0].ProvisioningState = nil + } + mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() + + mockVMsClient := ss.cloud.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachine{}, nil).AnyTimes() + + provisioningState, err := ss.GetProvisioningStateByNodeName("vmss-vm-000001") + assert.Equal(t, test.expectedErr, err, test.description+", but an error occurs") + assert.Equal(t, test.expectedProvisioningState, provisioningState, test.description) + } +} + func TestGetVmssVMByInstanceID(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() diff --git a/staging/src/k8s.io/legacy-cloud-providers/azure/mockvmsets/azure_mock_vmsets.go b/staging/src/k8s.io/legacy-cloud-providers/azure/mockvmsets/azure_mock_vmsets.go index 1746b9bbb9f7..6a29ca1f80b4 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/azure/mockvmsets/azure_mock_vmsets.go +++ b/staging/src/k8s.io/legacy-cloud-providers/azure/mockvmsets/azure_mock_vmsets.go @@ -30,30 +30,30 @@ import ( cache "k8s.io/legacy-cloud-providers/azure/cache" ) -// MockVMSet is a mock of VMSet interface +// MockVMSet is a mock of VMSet interface. type MockVMSet struct { ctrl *gomock.Controller recorder *MockVMSetMockRecorder } -// MockVMSetMockRecorder is the mock recorder for MockVMSet +// MockVMSetMockRecorder is the mock recorder for MockVMSet. type MockVMSetMockRecorder struct { mock *MockVMSet } -// NewMockVMSet creates a new mock instance +// NewMockVMSet creates a new mock instance. func NewMockVMSet(ctrl *gomock.Controller) *MockVMSet { mock := &MockVMSet{ctrl: ctrl} mock.recorder = &MockVMSetMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockVMSet) EXPECT() *MockVMSetMockRecorder { return m.recorder } -// GetInstanceIDByNodeName mocks base method +// GetInstanceIDByNodeName mocks base method. func (m *MockVMSet) GetInstanceIDByNodeName(name string) (string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetInstanceIDByNodeName", name) @@ -62,13 +62,13 @@ func (m *MockVMSet) GetInstanceIDByNodeName(name string) (string, error) { return ret0, ret1 } -// GetInstanceIDByNodeName indicates an expected call of GetInstanceIDByNodeName +// GetInstanceIDByNodeName indicates an expected call of GetInstanceIDByNodeName. func (mr *MockVMSetMockRecorder) GetInstanceIDByNodeName(name interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInstanceIDByNodeName", reflect.TypeOf((*MockVMSet)(nil).GetInstanceIDByNodeName), name) } -// GetInstanceTypeByNodeName mocks base method +// GetInstanceTypeByNodeName mocks base method. func (m *MockVMSet) GetInstanceTypeByNodeName(name string) (string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetInstanceTypeByNodeName", name) @@ -77,13 +77,13 @@ func (m *MockVMSet) GetInstanceTypeByNodeName(name string) (string, error) { return ret0, ret1 } -// GetInstanceTypeByNodeName indicates an expected call of GetInstanceTypeByNodeName +// GetInstanceTypeByNodeName indicates an expected call of GetInstanceTypeByNodeName. func (mr *MockVMSetMockRecorder) GetInstanceTypeByNodeName(name interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInstanceTypeByNodeName", reflect.TypeOf((*MockVMSet)(nil).GetInstanceTypeByNodeName), name) } -// GetIPByNodeName mocks base method +// GetIPByNodeName mocks base method. func (m *MockVMSet) GetIPByNodeName(name string) (string, string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetIPByNodeName", name) @@ -93,13 +93,13 @@ func (m *MockVMSet) GetIPByNodeName(name string) (string, string, error) { return ret0, ret1, ret2 } -// GetIPByNodeName indicates an expected call of GetIPByNodeName +// GetIPByNodeName indicates an expected call of GetIPByNodeName. func (mr *MockVMSetMockRecorder) GetIPByNodeName(name interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIPByNodeName", reflect.TypeOf((*MockVMSet)(nil).GetIPByNodeName), name) } -// GetPrimaryInterface mocks base method +// GetPrimaryInterface mocks base method. func (m *MockVMSet) GetPrimaryInterface(nodeName string) (network.Interface, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetPrimaryInterface", nodeName) @@ -108,13 +108,13 @@ func (m *MockVMSet) GetPrimaryInterface(nodeName string) (network.Interface, err return ret0, ret1 } -// GetPrimaryInterface indicates an expected call of GetPrimaryInterface +// GetPrimaryInterface indicates an expected call of GetPrimaryInterface. func (mr *MockVMSetMockRecorder) GetPrimaryInterface(nodeName interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrimaryInterface", reflect.TypeOf((*MockVMSet)(nil).GetPrimaryInterface), nodeName) } -// GetNodeNameByProviderID mocks base method +// GetNodeNameByProviderID mocks base method. func (m *MockVMSet) GetNodeNameByProviderID(providerID string) (types.NodeName, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetNodeNameByProviderID", providerID) @@ -123,13 +123,13 @@ func (m *MockVMSet) GetNodeNameByProviderID(providerID string) (types.NodeName, return ret0, ret1 } -// GetNodeNameByProviderID indicates an expected call of GetNodeNameByProviderID +// GetNodeNameByProviderID indicates an expected call of GetNodeNameByProviderID. func (mr *MockVMSetMockRecorder) GetNodeNameByProviderID(providerID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNodeNameByProviderID", reflect.TypeOf((*MockVMSet)(nil).GetNodeNameByProviderID), providerID) } -// GetZoneByNodeName mocks base method +// GetZoneByNodeName mocks base method. func (m *MockVMSet) GetZoneByNodeName(name string) (cloudprovider.Zone, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetZoneByNodeName", name) @@ -138,13 +138,13 @@ func (m *MockVMSet) GetZoneByNodeName(name string) (cloudprovider.Zone, error) { return ret0, ret1 } -// GetZoneByNodeName indicates an expected call of GetZoneByNodeName +// GetZoneByNodeName indicates an expected call of GetZoneByNodeName. func (mr *MockVMSetMockRecorder) GetZoneByNodeName(name interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetZoneByNodeName", reflect.TypeOf((*MockVMSet)(nil).GetZoneByNodeName), name) } -// GetPrimaryVMSetName mocks base method +// GetPrimaryVMSetName mocks base method. func (m *MockVMSet) GetPrimaryVMSetName() string { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetPrimaryVMSetName") @@ -152,13 +152,13 @@ func (m *MockVMSet) GetPrimaryVMSetName() string { return ret0 } -// GetPrimaryVMSetName indicates an expected call of GetPrimaryVMSetName +// GetPrimaryVMSetName indicates an expected call of GetPrimaryVMSetName. func (mr *MockVMSetMockRecorder) GetPrimaryVMSetName() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrimaryVMSetName", reflect.TypeOf((*MockVMSet)(nil).GetPrimaryVMSetName)) } -// GetVMSetNames mocks base method +// GetVMSetNames mocks base method. func (m *MockVMSet) GetVMSetNames(service *v1.Service, nodes []*v1.Node) (*[]string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetVMSetNames", service, nodes) @@ -167,13 +167,13 @@ func (m *MockVMSet) GetVMSetNames(service *v1.Service, nodes []*v1.Node) (*[]str return ret0, ret1 } -// GetVMSetNames indicates an expected call of GetVMSetNames +// GetVMSetNames indicates an expected call of GetVMSetNames. func (mr *MockVMSetMockRecorder) GetVMSetNames(service, nodes interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVMSetNames", reflect.TypeOf((*MockVMSet)(nil).GetVMSetNames), service, nodes) } -// EnsureHostsInPool mocks base method +// EnsureHostsInPool mocks base method. func (m *MockVMSet) EnsureHostsInPool(service *v1.Service, nodes []*v1.Node, backendPoolID, vmSetName string, isInternal bool) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "EnsureHostsInPool", service, nodes, backendPoolID, vmSetName, isInternal) @@ -181,13 +181,13 @@ func (m *MockVMSet) EnsureHostsInPool(service *v1.Service, nodes []*v1.Node, bac return ret0 } -// EnsureHostsInPool indicates an expected call of EnsureHostsInPool +// EnsureHostsInPool indicates an expected call of EnsureHostsInPool. func (mr *MockVMSetMockRecorder) EnsureHostsInPool(service, nodes, backendPoolID, vmSetName, isInternal interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnsureHostsInPool", reflect.TypeOf((*MockVMSet)(nil).EnsureHostsInPool), service, nodes, backendPoolID, vmSetName, isInternal) } -// EnsureHostInPool mocks base method +// EnsureHostInPool mocks base method. func (m *MockVMSet) EnsureHostInPool(service *v1.Service, nodeName types.NodeName, backendPoolID, vmSetName string, isInternal bool) (string, string, string, *compute.VirtualMachineScaleSetVM, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "EnsureHostInPool", service, nodeName, backendPoolID, vmSetName, isInternal) @@ -199,13 +199,13 @@ func (m *MockVMSet) EnsureHostInPool(service *v1.Service, nodeName types.NodeNam return ret0, ret1, ret2, ret3, ret4 } -// EnsureHostInPool indicates an expected call of EnsureHostInPool +// EnsureHostInPool indicates an expected call of EnsureHostInPool. func (mr *MockVMSetMockRecorder) EnsureHostInPool(service, nodeName, backendPoolID, vmSetName, isInternal interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnsureHostInPool", reflect.TypeOf((*MockVMSet)(nil).EnsureHostInPool), service, nodeName, backendPoolID, vmSetName, isInternal) } -// EnsureBackendPoolDeleted mocks base method +// EnsureBackendPoolDeleted mocks base method. func (m *MockVMSet) EnsureBackendPoolDeleted(service *v1.Service, backendPoolID, vmSetName string, backendAddressPools *[]network.BackendAddressPool) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "EnsureBackendPoolDeleted", service, backendPoolID, vmSetName, backendAddressPools) @@ -213,13 +213,13 @@ func (m *MockVMSet) EnsureBackendPoolDeleted(service *v1.Service, backendPoolID, return ret0 } -// EnsureBackendPoolDeleted indicates an expected call of EnsureBackendPoolDeleted +// EnsureBackendPoolDeleted indicates an expected call of EnsureBackendPoolDeleted. func (mr *MockVMSetMockRecorder) EnsureBackendPoolDeleted(service, backendPoolID, vmSetName, backendAddressPools interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnsureBackendPoolDeleted", reflect.TypeOf((*MockVMSet)(nil).EnsureBackendPoolDeleted), service, backendPoolID, vmSetName, backendAddressPools) } -// AttachDisk mocks base method +// AttachDisk mocks base method. func (m *MockVMSet) AttachDisk(isManagedDisk bool, diskName, diskURI string, nodeName types.NodeName, lun int32, cachingMode compute.CachingTypes, diskEncryptionSetID string, writeAcceleratorEnabled bool) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AttachDisk", isManagedDisk, diskName, diskURI, nodeName, lun, cachingMode, diskEncryptionSetID, writeAcceleratorEnabled) @@ -227,13 +227,13 @@ func (m *MockVMSet) AttachDisk(isManagedDisk bool, diskName, diskURI string, nod return ret0 } -// AttachDisk indicates an expected call of AttachDisk +// AttachDisk indicates an expected call of AttachDisk. func (mr *MockVMSetMockRecorder) AttachDisk(isManagedDisk, diskName, diskURI, nodeName, lun, cachingMode, diskEncryptionSetID, writeAcceleratorEnabled interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AttachDisk", reflect.TypeOf((*MockVMSet)(nil).AttachDisk), isManagedDisk, diskName, diskURI, nodeName, lun, cachingMode, diskEncryptionSetID, writeAcceleratorEnabled) } -// DetachDisk mocks base method +// DetachDisk mocks base method. func (m *MockVMSet) DetachDisk(diskName, diskURI string, nodeName types.NodeName) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DetachDisk", diskName, diskURI, nodeName) @@ -241,28 +241,28 @@ func (m *MockVMSet) DetachDisk(diskName, diskURI string, nodeName types.NodeName return ret0 } -// DetachDisk indicates an expected call of DetachDisk +// DetachDisk indicates an expected call of DetachDisk. func (mr *MockVMSetMockRecorder) DetachDisk(diskName, diskURI, nodeName interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DetachDisk", reflect.TypeOf((*MockVMSet)(nil).DetachDisk), diskName, diskURI, nodeName) } -// GetDataDisks mocks base method -func (m *MockVMSet) GetDataDisks(nodeName types.NodeName, string cache.AzureCacheReadType) ([]compute.DataDisk, error) { +// GetDataDisks mocks base method. +func (m *MockVMSet) GetDataDisks(nodeName types.NodeName, crt cache.AzureCacheReadType) ([]compute.DataDisk, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetDataDisks", nodeName, string) + ret := m.ctrl.Call(m, "GetDataDisks", nodeName, crt) ret0, _ := ret[0].([]compute.DataDisk) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetDataDisks indicates an expected call of GetDataDisks -func (mr *MockVMSetMockRecorder) GetDataDisks(nodeName, string interface{}) *gomock.Call { +// GetDataDisks indicates an expected call of GetDataDisks. +func (mr *MockVMSetMockRecorder) GetDataDisks(nodeName, crt interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDataDisks", reflect.TypeOf((*MockVMSet)(nil).GetDataDisks), nodeName, string) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDataDisks", reflect.TypeOf((*MockVMSet)(nil).GetDataDisks), nodeName, crt) } -// GetPowerStatusByNodeName mocks base method +// GetPowerStatusByNodeName mocks base method. func (m *MockVMSet) GetPowerStatusByNodeName(name string) (string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetPowerStatusByNodeName", name) @@ -271,13 +271,28 @@ func (m *MockVMSet) GetPowerStatusByNodeName(name string) (string, error) { return ret0, ret1 } -// GetPowerStatusByNodeName indicates an expected call of GetPowerStatusByNodeName +// GetPowerStatusByNodeName indicates an expected call of GetPowerStatusByNodeName. func (mr *MockVMSetMockRecorder) GetPowerStatusByNodeName(name interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPowerStatusByNodeName", reflect.TypeOf((*MockVMSet)(nil).GetPowerStatusByNodeName), name) } -// GetPrivateIPsByNodeName mocks base method +// GetProvisioningStateByNodeName mocks base method. +func (m *MockVMSet) GetProvisioningStateByNodeName(name string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProvisioningStateByNodeName", name) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProvisioningStateByNodeName indicates an expected call of GetProvisioningStateByNodeName. +func (mr *MockVMSetMockRecorder) GetProvisioningStateByNodeName(name interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisioningStateByNodeName", reflect.TypeOf((*MockVMSet)(nil).GetProvisioningStateByNodeName), name) +} + +// GetPrivateIPsByNodeName mocks base method. func (m *MockVMSet) GetPrivateIPsByNodeName(name string) ([]string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetPrivateIPsByNodeName", name) @@ -286,7 +301,7 @@ func (m *MockVMSet) GetPrivateIPsByNodeName(name string) ([]string, error) { return ret0, ret1 } -// GetPrivateIPsByNodeName indicates an expected call of GetPrivateIPsByNodeName +// GetPrivateIPsByNodeName indicates an expected call of GetPrivateIPsByNodeName. func (mr *MockVMSetMockRecorder) GetPrivateIPsByNodeName(name interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrivateIPsByNodeName", reflect.TypeOf((*MockVMSet)(nil).GetPrivateIPsByNodeName), name)