Skip to content

Commit

Permalink
Fix possible nil VMSS VM in VMSS cache If VMSS VM in cache is nil, it…
Browse files Browse the repository at this point in the history
… leads to invalid memory panic. Fix the bug and add UT.

Signed-off-by: Zhecheng Li <zhechengli@microsoft.com>
  • Loading branch information
lzhecheng authored and k8s-infra-cherrypick-robot committed Apr 28, 2022
1 parent 111ca7c commit f57bd1b
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 2 deletions.
8 changes: 6 additions & 2 deletions pkg/provider/azure_vmss.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,12 @@ func (ss *ScaleSet) getVmssVMByNodeIdentity(node *nodeIdentity, crt azcache.Azur
}

virtualMachines := cached.(*sync.Map)
if vm, ok := virtualMachines.Load(nodeName); ok {
result := vm.(*vmssVirtualMachinesEntry)
if entry, ok := virtualMachines.Load(nodeName); ok {
result := entry.(*vmssVirtualMachinesEntry)
if result.virtualMachine == nil {
klog.Warningf("failed to get VM with vmssVirtualMachinesEntry on Node %q", nodeName)
return nil, false, nil
}
found = true
return result.vmssName, result.instanceID, result.virtualMachine, found, nil
}
Expand Down
83 changes: 83 additions & 0 deletions pkg/provider/azure_vmss_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package provider
import (
"fmt"
"strings"
"sync"
"testing"

"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2021-07-01/compute"
Expand Down Expand Up @@ -905,6 +906,88 @@ func TestGetVmssVMByInstanceID(t *testing.T) {
}
}

func TestGetVmssVMByNodeIdentity(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

testCases := []struct {
description string
instanceID string
vmList []string
goneVMList []string
expectedErr error
goneVMExpectedErr error
}{
{
description: "getVmssVMByNodeIdentity should return the correct VMSS VM",
vmList: []string{"vmss-vm-000000"},
goneVMList: []string{},
},
{
description: "getVmssVMByNodeIdentity should not panic with a gone VMSS VM but cache for the VMSS VM entry still exists",
vmList: []string{"vmss-vm-000000"},
goneVMList: []string{"vmss-vm-000001"},
goneVMExpectedErr: cloudprovider.InstanceNotFound,
},
}

for _, test := range testCases {
t.Run(test.description, func(t *testing.T) {
ss, err := NewTestScaleSet(ctrl)
assert.NoError(t, err, "unexpected error when creating test VMSS")

expectedVMSS := compute.VirtualMachineScaleSet{
Name: to.StringPtr(testVMSSName),
VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{
VirtualMachineProfile: &compute.VirtualMachineScaleSetVMProfile{},
},
}
mockVMSSClient := ss.cloud.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface)
mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes()

expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(ss.cloud, testVMSSName, "", 0, test.vmList, "", false)
mockVMSSVMClient := ss.cloud.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface)
mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes()

// Make some nil VMSS VM in cache.
cacheKey, cache, err := ss.getVMSSVMCache(ss.ResourceGroup, testVMSSName)
assert.Nil(t, err)
cached, err := cache.Get(cacheKey, azcache.CacheReadTypeDefault)
assert.Nil(t, err)
virtualMachines := cached.(*sync.Map)
for _, vm := range test.goneVMList {
entry := vmssVirtualMachinesEntry{
resourceGroup: ss.ResourceGroup,
vmssName: testVMSSName,
}
virtualMachines.Store(vm, &entry)
}

for i := 0; i < len(test.vmList); i++ {
node := nodeIdentity{ss.ResourceGroup, testVMSSName, test.vmList[i]}
vm, err := ss.getVmssVMByNodeIdentity(&node, azcache.CacheReadTypeDefault)
assert.Equal(t, test.expectedErr, err)
assert.Equal(t, *virtualmachine.FromVirtualMachineScaleSetVM(&expectedVMSSVMs[i], virtualmachine.ByVMSS(testVMSSName)), *vm)
}
for i := 0; i < len(test.goneVMList); i++ {
node := nodeIdentity{ss.ResourceGroup, testVMSSName, test.goneVMList[i]}
_, err := ss.getVmssVMByNodeIdentity(&node, azcache.CacheReadTypeDefault)
assert.Equal(t, test.goneVMExpectedErr, err)
}

cacheKey, cache, err = ss.getVMSSVMCache(ss.ResourceGroup, testVMSSName)
assert.Nil(t, err)
cached, err = cache.Get(cacheKey, azcache.CacheReadTypeDefault)
assert.Nil(t, err)
virtualMachines = cached.(*sync.Map)
for _, vm := range test.goneVMList {
_, ok := virtualMachines.Load(vm)
assert.False(t, ok)
}
})
}
}

func TestGetInstanceTypeByNodeName(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
Expand Down

0 comments on commit f57bd1b

Please sign in to comment.