Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
cvvz committed Dec 27, 2023
1 parent 09e5b67 commit 44c4812
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 9 deletions.
14 changes: 8 additions & 6 deletions pkg/azurefile/nodeserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolu
context := req.GetVolumeContext()
if context != nil {
// token request
if context[serviceAccountTokenField] != "" && !isClientIDEmpty(context) {
if context[serviceAccountTokenField] != "" && hasClientID(context) {
klog.V(2).Infof("NodePublishVolume: volume(%s) mount on %s with service account token, VolumeContext: %v", volumeID, target, context)
_, err := d.NodeStageVolume(ctx, &csi.NodeStageVolumeRequest{
StagingTargetPath: target,
Expand Down Expand Up @@ -169,7 +169,7 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe
volumeID := req.GetVolumeId()
context := req.GetVolumeContext()

if !isClientIDEmpty(context) && context[serviceAccountTokenField] == "" {
if hasClientID(context) && context[serviceAccountTokenField] == "" {
klog.V(2).Infof("Skip NodeStageVolume for volume(%s) since clientID is provided but service account token is empty", volumeID)
return &csi.NodeStageVolumeResponse{}, nil
}
Expand Down Expand Up @@ -613,9 +613,11 @@ func checkGidPresentInMountFlags(mountFlags []string) bool {
return false
}

func isClientIDEmpty(context map[string]string) bool {
if context[clientIDField] != "" || context[strings.ToLower(clientIDField)] != "" {
return false
func hasClientID(context map[string]string) bool {
for k, v := range context {
if strings.EqualFold(k, clientIDField) && v != "" {
return true
}
}
return true
return false
}
57 changes: 54 additions & 3 deletions pkg/azurefile/nodeserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ import (
"syscall"
"testing"

"sigs.k8s.io/azurefile-csi-driver/test/utils/testutil"

azure2 "github.com/Azure/go-autorest/autorest/azure"
"github.com/container-storage-interface/spec/lib/go/csi"
"github.com/stretchr/testify/assert"
Expand All @@ -38,7 +36,7 @@ import (
mount "k8s.io/mount-utils"
"k8s.io/utils/exec"
testingexec "k8s.io/utils/exec/testing"

"sigs.k8s.io/azurefile-csi-driver/test/utils/testutil"
azure "sigs.k8s.io/cloud-provider-azure/pkg/provider"
)

Expand Down Expand Up @@ -1069,3 +1067,56 @@ func makeFakeOutput(output string, err error) testingexec.FakeAction {
return []byte(o), nil, err
}
}

func Test_hasClientID(t *testing.T) {
type args struct {
context map[string]string
}
tests := []struct {
name string
args args
want bool
}{
{
name: "has client id",
args: args{
context: map[string]string{
clientIDField: "test-client-id",
},
},
want: true,
},
{
name: "case not sensitive client id",
args: args{
context: map[string]string{
"ClientId": "test-client-id",
},
},
want: true,
},
{
name: "no client id",
args: args{
context: map[string]string{},
},
want: false,
},
{
name: "client id empty",
args: args{
context: map[string]string{
clientIDField: "",
},
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := hasClientID(tt.args.context); got != tt.want {
t.Errorf("hasClientID() = %v, want %v", got, tt.want)
}
})
}
}

0 comments on commit 44c4812

Please sign in to comment.