Skip to content

Commit

Permalink
Merge pull request #104449 from feiskyer/automated-cherry-pick-of-#10…
Browse files Browse the repository at this point in the history
…4384-#104382-upstream-release-1.19

Automated cherry pick of #104384: fix: skip case sensitivity when checking Azure NSG rules
#104382: fix: ensure InstanceShutdownByProviderID return false for
  • Loading branch information
k8s-ci-robot committed Sep 9, 2021
2 parents 44a3aa9 + 8eb2fbc commit 29af011
Show file tree
Hide file tree
Showing 11 changed files with 444 additions and 60 deletions.
19 changes: 17 additions & 2 deletions staging/src/k8s.io/legacy-cloud-providers/azure/azure_instances.go
Expand Up @@ -20,6 +20,7 @@ package azure

import (
"context"
"errors"
"fmt"
"os"
"strings"
Expand All @@ -29,6 +30,8 @@ import (
cloudprovider "k8s.io/cloud-provider"
"k8s.io/klog/v2"
azcache "k8s.io/legacy-cloud-providers/azure/cache"

"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2019-12-01/compute"
)

const (
Expand Down Expand Up @@ -233,10 +236,22 @@ func (az *Cloud) InstanceShutdownByProviderID(ctx context.Context, providerID st

return false, err
}
klog.V(5).Infof("InstanceShutdownByProviderID gets power status %q for node %q", powerStatus, nodeName)
klog.V(3).Infof("InstanceShutdownByProviderID gets power status %q for node %q", powerStatus, nodeName)

provisioningState, err := az.VMSet.GetProvisioningStateByNodeName(string(nodeName))
if err != nil {
// Returns false, so the controller manager will continue to check InstanceExistsByProviderID().
if errors.Is(err, cloudprovider.InstanceNotFound) {
return false, nil
}

return false, err
}
klog.V(3).Infof("InstanceShutdownByProviderID gets provisioning state %q for node %q", provisioningState, nodeName)

status := strings.ToLower(powerStatus)
return status == vmPowerStateStopped || status == vmPowerStateDeallocated || status == vmPowerStateDeallocating, nil
provisioningSucceeded := strings.EqualFold(strings.ToLower(provisioningState), strings.ToLower(string(compute.ProvisioningStateSucceeded)))
return provisioningSucceeded && (status == vmPowerStateStopped || status == vmPowerStateDeallocated || status == vmPowerStateDeallocating), nil
}

func (az *Cloud) isCurrentInstance(name types.NodeName, metadataVMName string) (bool, error) {
Expand Down
Expand Up @@ -63,6 +63,7 @@ func setTestVirtualMachines(c *Cloud, vmList map[string]string, isDataDisksFull
},
}
vm.VirtualMachineProperties = &compute.VirtualMachineProperties{
ProvisioningState: to.StringPtr(string(compute.ProvisioningStateSucceeded)),
HardwareProfile: &compute.HardwareProfile{
VMSize: compute.VirtualMachineSizeTypesStandardA0,
},
Expand Down Expand Up @@ -252,12 +253,13 @@ func TestInstanceID(t *testing.T) {

func TestInstanceShutdownByProviderID(t *testing.T) {
testcases := []struct {
name string
vmList map[string]string
nodeName string
providerID string
expected bool
expectedErrMsg error
name string
vmList map[string]string
nodeName string
providerID string
provisioningState string
expected bool
expectedErrMsg error
}{
{
name: "InstanceShutdownByProviderID should return false if the vm is in PowerState/Running status",
Expand Down Expand Up @@ -294,6 +296,7 @@ func TestInstanceShutdownByProviderID(t *testing.T) {
providerID: "azure:///subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm5",
expected: true,
},

{
name: "InstanceShutdownByProviderID should return false if the vm is in PowerState/Stopping status",
vmList: map[string]string{"vm6": "PowerState/Stopping"},
Expand All @@ -315,13 +318,23 @@ func TestInstanceShutdownByProviderID(t *testing.T) {
providerID: "azure:///subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm8",
expected: false,
},
{
name: "InstanceShutdownByProviderID should return false if the vm is in PowerState/Stopped state with Creating provisioning state",
vmList: map[string]string{"vm9": "PowerState/Stopped"},
nodeName: "vm9",
provisioningState: "Creating",
providerID: "azure:///subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm9",
expected: false,
},
{
name: "InstanceShutdownByProviderID should report error if providerID is null",
nodeName: "vmm",
expected: false,
},
{
name: "InstanceShutdownByProviderID should report error if providerID is invalid",
providerID: "azure:///subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Compute/VM/vm9",
providerID: "azure:///subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Compute/VM/vm10",
nodeName: "vm10",
expected: false,
expectedErrMsg: fmt.Errorf("error splitting providerID"),
},
Expand All @@ -332,11 +345,14 @@ func TestInstanceShutdownByProviderID(t *testing.T) {
for _, test := range testcases {
cloud := GetTestCloud(ctrl)
expectedVMs := setTestVirtualMachines(cloud, test.vmList, false)
if test.provisioningState != "" {
expectedVMs[0].ProvisioningState = to.StringPtr(test.provisioningState)
}
mockVMsClient := cloud.VirtualMachinesClient.(*mockvmclient.MockInterface)
for _, vm := range expectedVMs {
mockVMsClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, *vm.Name, gomock.Any()).Return(vm, nil).AnyTimes()
}
mockVMsClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, "vm8", gomock.Any()).Return(compute.VirtualMachine{}, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes()
mockVMsClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, test.nodeName, gomock.Any()).Return(compute.VirtualMachine{}, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes()

hasShutdown, err := cloud.InstanceShutdownByProviderID(context.Background(), test.providerID)
assert.Equal(t, test.expectedErrMsg, err, test.name)
Expand Down
Expand Up @@ -1230,18 +1230,18 @@ func (az *Cloud) reconcileSecurityGroup(clusterName string, service *v1.Service,
sharedRuleName := az.getSecurityRuleName(service, port, sourceAddressPrefix)
sharedIndex, sharedRule, sharedRuleFound := findSecurityRuleByName(updatedRules, sharedRuleName)
if !sharedRuleFound {
klog.V(4).Infof("Expected to find shared rule %s for service %s being deleted, but did not", sharedRuleName, service.Name)
return nil, fmt.Errorf("Expected to find shared rule %s for service %s being deleted, but did not", sharedRuleName, service.Name)
klog.V(4).Infof("Didn't find shared rule %s for service %s", sharedRuleName, service.Name)
continue
}
if sharedRule.DestinationAddressPrefixes == nil {
klog.V(4).Infof("Expected to have array of destinations in shared rule for service %s being deleted, but did not", service.Name)
return nil, fmt.Errorf("Expected to have array of destinations in shared rule for service %s being deleted, but did not", service.Name)
klog.V(4).Infof("Didn't find DestinationAddressPrefixes in shared rule for service %s", service.Name)
continue
}
existingPrefixes := *sharedRule.DestinationAddressPrefixes
addressIndex, found := findIndex(existingPrefixes, destinationIPAddress)
if !found {
klog.V(4).Infof("Expected to find destination address %s in shared rule %s for service %s being deleted, but did not", destinationIPAddress, sharedRuleName, service.Name)
return nil, fmt.Errorf("Expected to find destination address %s in shared rule %s for service %s being deleted, but did not", destinationIPAddress, sharedRuleName, service.Name)
klog.V(4).Infof("Didn't find destination address %v in shared rule %s for service %s", destinationIPAddress, sharedRuleName, service.Name)
continue
}
if len(existingPrefixes) == 1 {
updatedRules = append(updatedRules[:sharedIndex], updatedRules[sharedIndex+1:]...)
Expand Down Expand Up @@ -1658,7 +1658,7 @@ func findSecurityRule(rules []network.SecurityRule, rule network.SecurityRule) b
if !strings.EqualFold(to.String(existingRule.Name), to.String(rule.Name)) {
continue
}
if existingRule.Protocol != rule.Protocol {
if !strings.EqualFold(string(existingRule.Protocol), string(rule.Protocol)) {
continue
}
if !strings.EqualFold(to.String(existingRule.SourcePortRange), to.String(rule.SourcePortRange)) {
Expand All @@ -1675,10 +1675,10 @@ func findSecurityRule(rules []network.SecurityRule, rule network.SecurityRule) b
continue
}
}
if existingRule.Access != rule.Access {
if !strings.EqualFold(string(existingRule.Access), string(rule.Access)) {
continue
}
if existingRule.Direction != rule.Direction {
if !strings.EqualFold(string(existingRule.Direction), string(rule.Direction)) {
continue
}
return true
Expand Down
Expand Up @@ -1943,6 +1943,36 @@ func TestReconcileSecurityGroup(t *testing.T) {
},
},
},
{
desc: "reconcileSecurityGroup shall create shared sgs for service with azure-shared-securityrule annotations",
service: getTestService("test1", v1.ProtocolTCP, map[string]string{ServiceAnnotationSharedSecurityRule: "true"}, true, 80),
existingSgs: map[string]network.SecurityGroup{"nsg": {
Name: to.StringPtr("nsg"),
SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{},
}},
lbIP: to.StringPtr("1.2.3.4"),
wantLb: true,
expectedSg: &network.SecurityGroup{
Name: to.StringPtr("nsg"),
SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{
SecurityRules: &[]network.SecurityRule{
{
Name: to.StringPtr("shared-TCP-80-Internet"),
SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{
Protocol: network.SecurityRuleProtocol("Tcp"),
SourcePortRange: to.StringPtr("*"),
DestinationPortRange: to.StringPtr("80"),
SourceAddressPrefix: to.StringPtr("Internet"),
DestinationAddressPrefixes: to.StringSlicePtr([]string{"1.2.3.4"}),
Access: network.SecurityRuleAccess("Allow"),
Priority: to.Int32Ptr(500),
Direction: network.SecurityRuleDirection("Inbound"),
},
},
},
},
},
},
}

for i, test := range testCases {
Expand Down
14 changes: 14 additions & 0 deletions staging/src/k8s.io/legacy-cloud-providers/azure/azure_standard.go
Expand Up @@ -444,6 +444,20 @@ func (as *availabilitySet) GetPowerStatusByNodeName(name string) (powerState str
return vmPowerStateStopped, nil
}

// GetProvisioningStateByNodeName returns the provisioningState for the specified node.
func (as *availabilitySet) GetProvisioningStateByNodeName(name string) (provisioningState string, err error) {
vm, err := as.getVirtualMachine(types.NodeName(name), azcache.CacheReadTypeDefault)
if err != nil {
return provisioningState, err
}

if vm.VirtualMachineProperties == nil || vm.VirtualMachineProperties.ProvisioningState == nil {
return provisioningState, nil
}

return to.String(vm.VirtualMachineProperties.ProvisioningState), nil
}

// GetNodeNameByProviderID gets the node name by provider ID.
func (as *availabilitySet) GetNodeNameByProviderID(providerID string) (types.NodeName, error) {
// NodeName is part of providerID for standard instances.
Expand Down
Expand Up @@ -985,6 +985,69 @@ func TestGetStandardVMPowerStatusByNodeName(t *testing.T) {
}
}

func TestGetStandardVMProvisioningStateByNodeName(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
cloud := GetTestCloud(ctrl)

testcases := []struct {
name string
nodeName string
vm compute.VirtualMachine
expectedProvisioningState string
getErr *retry.Error
expectedErrMsg error
}{
{
name: "GetProvisioningStateByNodeName should report error if node don't exist",
nodeName: "vm1",
vm: compute.VirtualMachine{},
getErr: &retry.Error{
HTTPStatusCode: http.StatusNotFound,
RawError: cloudprovider.InstanceNotFound,
},
expectedErrMsg: fmt.Errorf("instance not found"),
},
{
name: "GetProvisioningStateByNodeName should return Succeeded for running VM",
nodeName: "vm2",
vm: compute.VirtualMachine{
Name: to.StringPtr("vm2"),
VirtualMachineProperties: &compute.VirtualMachineProperties{
ProvisioningState: to.StringPtr("Succeeded"),
InstanceView: &compute.VirtualMachineInstanceView{
Statuses: &[]compute.InstanceViewStatus{
{
Code: to.StringPtr("PowerState/Running"),
},
},
},
},
},
expectedProvisioningState: "Succeeded",
},
{
name: "GetProvisioningStateByNodeName should return empty string when vm.ProvisioningState is nil",
nodeName: "vm3",
vm: compute.VirtualMachine{
Name: to.StringPtr("vm3"),
VirtualMachineProperties: &compute.VirtualMachineProperties{
ProvisioningState: nil,
},
},
expectedProvisioningState: "",
},
}
for _, test := range testcases {
mockVMClient := cloud.VirtualMachinesClient.(*mockvmclient.MockInterface)
mockVMClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, test.nodeName, gomock.Any()).Return(test.vm, test.getErr).AnyTimes()

provisioningState, err := cloud.VMSet.GetProvisioningStateByNodeName(test.nodeName)
assert.Equal(t, test.expectedErrMsg, err, test.name)
assert.Equal(t, test.expectedProvisioningState, provisioningState, test.name)
}
}

func TestGetStandardVMZoneByNodeName(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
Expand Down

0 comments on commit 29af011

Please sign in to comment.