diff --git a/pkg/azurefile/azurefile.go b/pkg/azurefile/azurefile.go index e360a3f9fe..e482d27942 100644 --- a/pkg/azurefile/azurefile.go +++ b/pkg/azurefile/azurefile.go @@ -1036,7 +1036,6 @@ func (d *Driver) copyFileShare(ctx context.Context, req *csi.CreateVolumeRequest out, copyErr := cmd.CombinedOutput() if accountSASToken == "" && strings.Contains(string(out), authorizationPermissionMismatch) && copyErr != nil { klog.Warningf("azcopy list failed with AuthorizationPermissionMismatch error, should assign \"Storage File Data SMB Share Elevated Contributor\" role to controller identity, fall back to use sas token, original output: %v", string(out)) - d.azcopySasTokenCache.Set(accountName, "") var sasToken string if sasToken, _, err = d.getAzcopyAuth(ctx, accountName, "", storageEndpointSuffix, accountOptions, secrets, secretName, secretNamespace, true); err != nil { return err diff --git a/pkg/azurefile/controllerserver.go b/pkg/azurefile/controllerserver.go index d717598419..b865f17041 100644 --- a/pkg/azurefile/controllerserver.go +++ b/pkg/azurefile/controllerserver.go @@ -1234,42 +1234,44 @@ func (d *Driver) authorizeAzcopyWithIdentity() ([]string, error) { // 4. parameter useSasToken is true func (d *Driver) getAzcopyAuth(ctx context.Context, accountName, accountKey, storageEndpointSuffix string, accountOptions *azure.AccountOptions, secrets map[string]string, secretName, secretNamespace string, useSasToken bool) (string, []string, error) { var authAzcopyEnv []string + var err error if !useSasToken && len(secrets) == 0 && len(secretName) == 0 { - var err error + // search in cache first + if cache, err := d.azcopySasTokenCache.Get(accountName, azcache.CacheReadTypeDefault); err == nil && cache != nil { + klog.V(2).Infof("use sas token for account(%s) since this account is found in azcopySasTokenCache", accountName) + return cache.(string), nil, nil + } authAzcopyEnv, err = d.authorizeAzcopyWithIdentity() if err != nil { klog.Warningf("failed to authorize azcopy with identity, error: %v", err) - } else { - if len(authAzcopyEnv) > 0 { - // search in cache first - cache, err := d.azcopySasTokenCache.Get(accountName, azcache.CacheReadTypeDefault) - if err != nil { - return "", nil, fmt.Errorf("get(%s) from azcopySasTokenCache failed with error: %v", accountName, err) - } - if cache != nil { - klog.V(2).Infof("use sas token for account(%s) since this account is found in azcopySasTokenCache", accountName) - useSasToken = true - } - } } } if len(secrets) > 0 || len(secretName) > 0 || len(authAzcopyEnv) == 0 || useSasToken { - var err error if accountKey == "" { if accountKey, err = d.GetStorageAccesskey(ctx, accountOptions, secrets, secretName, secretNamespace); err != nil { return "", nil, err } } klog.V(2).Infof("generate sas token for account(%s)", accountName) - sasToken, err := generateSASToken(accountName, accountKey, storageEndpointSuffix, d.sasTokenExpirationMinutes) + sasToken, err := d.generateSASToken(accountName, accountKey, storageEndpointSuffix, d.sasTokenExpirationMinutes) return sasToken, nil, err } return "", authAzcopyEnv, nil } // generateSASToken generate a sas token for storage account -func generateSASToken(accountName, accountKey, storageEndpointSuffix string, expiryTime int) (string, error) { +func (d *Driver) generateSASToken(accountName, accountKey, storageEndpointSuffix string, expiryTime int) (string, error) { + // search in cache first + cache, err := d.azcopySasTokenCache.Get(accountName, azcache.CacheReadTypeDefault) + if err != nil { + return "", fmt.Errorf("get(%s) from azcopySasTokenCache failed with error: %v", accountName, err) + } + if cache != nil { + klog.V(2).Infof("use sas token for account(%s) since this account is found in azcopySasTokenCache", accountName) + return cache.(string), nil + } + credential, err := service.NewSharedKeyCredential(accountName, accountKey) if err != nil { return "", status.Errorf(codes.Internal, fmt.Sprintf("failed to generate sas token in creating new shared key credential, accountName: %s, err: %s", accountName, err.Error())) @@ -1290,5 +1292,7 @@ func generateSASToken(accountName, accountKey, storageEndpointSuffix string, exp if err != nil { return "", err } - return "?" + u.RawQuery, nil + sasToken := "?" + u.RawQuery + d.azcopySasTokenCache.Set(accountName, sasToken) + return sasToken, nil } diff --git a/pkg/azurefile/controllerserver_test.go b/pkg/azurefile/controllerserver_test.go index 03df137064..d67e4fddc1 100644 --- a/pkg/azurefile/controllerserver_test.go +++ b/pkg/azurefile/controllerserver_test.go @@ -2717,6 +2717,7 @@ func TestSetAzureCredentials(t *testing.T) { } func TestGenerateSASToken(t *testing.T) { + d := NewFakeDriver() storageEndpointSuffix := "core.windows.net" tests := []struct { name string @@ -2742,7 +2743,7 @@ func TestGenerateSASToken(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sas, err := generateSASToken(tt.accountName, tt.accountKey, storageEndpointSuffix, 30) + sas, err := d.generateSASToken(tt.accountName, tt.accountKey, storageEndpointSuffix, 30) if !reflect.DeepEqual(err, tt.expectedErr) { t.Errorf("generateSASToken error = %v, expectedErr %v, sas token = %v, want %v", err, tt.expectedErr, sas, tt.want) return