From ee9f271c0cb6bb5da9e23bc4ed7485ae809662be Mon Sep 17 00:00:00 2001 From: tjandy98 <3953059+tjandy98@users.noreply.github.com> Date: Sat, 2 Mar 2024 16:12:53 +0800 Subject: [PATCH 1/7] Add gcs storage spec support Signed-off-by: tjandy98 <3953059+tjandy98@users.noreply.github.com> --- pkg/credentials/service_account_credentials.go | 4 ++-- .../service_account_credentials_test.go | 2 +- python/kserve/kserve/storage/storage.py | 15 +++++++++++++++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/pkg/credentials/service_account_credentials.go b/pkg/credentials/service_account_credentials.go index 13f9eecd2e..214b4787b1 100644 --- a/pkg/credentials/service_account_credentials.go +++ b/pkg/credentials/service_account_credentials.go @@ -52,8 +52,8 @@ const ( ) var ( - SupportedStorageSpecTypes = []string{"s3", "hdfs", "webhdfs"} - StorageBucketTypes = []string{"s3"} + SupportedStorageSpecTypes = []string{"s3", "hdfs", "webhdfs", "gs"} + StorageBucketTypes = []string{"s3", "gs"} ) type CredentialConfig struct { diff --git a/pkg/credentials/service_account_credentials_test.go b/pkg/credentials/service_account_credentials_test.go index 866ef8a831..b6e564fe87 100644 --- a/pkg/credentials/service_account_credentials_test.go +++ b/pkg/credentials/service_account_credentials_test.go @@ -1429,7 +1429,7 @@ func TestCredentialBuilder_CreateStorageSpecSecretEnvs(t *testing.T) { Name: "storage-secret", Namespace: namespace, }, - StringData: map[string]string{"minio": "{\n \"type\": \"gs\",\n \"access_key_id\": \"minio\",\n \"secret_access_key\": \"minio123\",\n \"endpoint_url\": \"http://minio-service.kubeflow:9000\",\n \"bucket\": \"test-bucket\",\n \"region\": \"us-south\"\n }"}, + StringData: map[string]string{"minio": "{\n \"type\": \"gss\",\n \"access_key_id\": \"minio\",\n \"secret_access_key\": \"minio123\",\n \"endpoint_url\": \"http://minio-service.kubeflow:9000\",\n \"bucket\": \"test-bucket\",\n \"region\": \"us-south\"\n }"}, }, storageKey: "minio", storageSecretName: "storage-secret", diff --git a/python/kserve/kserve/storage/storage.py b/python/kserve/kserve/storage/storage.py index 32777962ba..21a5e68a45 100644 --- a/python/kserve/kserve/storage/storage.py +++ b/python/kserve/kserve/storage/storage.py @@ -139,6 +139,21 @@ def _update_with_storage_spec(): f.write(value) f.flush() + if storage_secret_json.get("type", "") == "gs": + temp_dir = tempfile.mkdtemp() + credential_dir = temp_dir+ "/" + "google_application_credentials.json" + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credential_dir + if storage_secret_json.get("base64_service_account", "") != "": + try: + with open(credential_dir, "w") as f: + base64_service_account = storage_secret_json.get("base64_service_account", "") + service_account= base64.b64decode(base64_service_account).decode('utf-8') + f.write(service_account) + f.flush() + except binascii.Error: + raise RuntimeError("Error: Invalid base64 encoding.") + except UnicodeDecodeError: + raise RuntimeError("Error: Cannot decode string.") @staticmethod def get_S3_config(): # default s3 config From 98c58bfcf61457dfd278f7e1e82cae45270fa15e Mon Sep 17 00:00:00 2001 From: tjandy98 <3953059+tjandy98@users.noreply.github.com> Date: Thu, 11 Apr 2024 21:27:30 +0800 Subject: [PATCH 2/7] Add test for gs storage type Signed-off-by: tjandy98 <3953059+tjandy98@users.noreply.github.com> --- .../kserve/storage/test/test_storage.py | 252 +++++++++++++----- 1 file changed, 179 insertions(+), 73 deletions(-) diff --git a/python/kserve/kserve/storage/test/test_storage.py b/python/kserve/kserve/storage/test/test_storage.py index 500a11b438..e320132c7a 100644 --- a/python/kserve/kserve/storage/test/test_storage.py +++ b/python/kserve/kserve/storage/test/test_storage.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import io +import json import os import tempfile import binascii @@ -25,52 +27,51 @@ from kserve.storage import Storage -STORAGE_MODULE = 'kserve.storage.storage' -HTTPS_URI_TARGZ = 'https://foo.bar/model.tar.gz' -HTTPS_URI_TARGZ_WITH_QUERY = HTTPS_URI_TARGZ + '?foo=bar' +STORAGE_MODULE = "kserve.storage.storage" +HTTPS_URI_TARGZ = "https://foo.bar/model.tar.gz" +HTTPS_URI_TARGZ_WITH_QUERY = HTTPS_URI_TARGZ + "?foo=bar" # *.tar.gz contains a single empty file model.pth -FILE_TAR_GZ_RAW = binascii.unhexlify('1f8b0800bac550600003cbcd4f49cdd12b28c960a01d3030303033315100d1e666a660dac008c28' - '701054313a090a189919981998281a1b1b1a1118382010ddd0407a5c525894540a754656466e464e' - '2560754969686c71ca83fe0f4281805a360140c7200009f7e1bb400060000') +FILE_TAR_GZ_RAW = binascii.unhexlify( + "1f8b0800bac550600003cbcd4f49cdd12b28c960a01d3030303033315100d1e666a660dac008c28" + "701054313a090a189919981998281a1b1b1a1118382010ddd0407a5c525894540a754656466e464e" + "2560754969686c71ca83fe0f4281805a360140c7200009f7e1bb400060000" +) # *.zip contains a single empty file model.pth -FILE_ZIP_RAW = binascii.unhexlify('504b030414000800080035b67052000000000000000000000000090020006d6f64656c2e70746855540' - 'd000786c5506086c5506086c5506075780b000104f501000004140000000300504b07080000000002000' - '00000000000504b0102140314000800080035b6705200000000020000000000000009002000000000000' - '0000000a481000000006d6f64656c2e70746855540d000786c5506086c5506086c5506075780b000104f' - '50100000414000000504b0506000000000100010057000000590000000000') +FILE_ZIP_RAW = binascii.unhexlify( + "504b030414000800080035b67052000000000000000000000000090020006d6f64656c2e70746855540" + "d000786c5506086c5506086c5506075780b000104f501000004140000000300504b07080000000002000" + "00000000000504b0102140314000800080035b6705200000000020000000000000009002000000000000" + "0000000a481000000006d6f64656c2e70746855540d000786c5506086c5506086c5506075780b000104f" + "50100000414000000504b0506000000000100010057000000590000000000" +) def test_storage_local_path(): - abs_path = 'file:///' - relative_path = 'file://.' + abs_path = "file:///" + relative_path = "file://." assert Storage.download(abs_path) == abs_path.replace("file://", "", 1) assert Storage.download(relative_path) == relative_path.replace("file://", "", 1) def test_storage_local_path_exception(): - not_exist_path = 'file:///some/random/path' + not_exist_path = "file:///some/random/path" with pytest.raises(Exception): Storage.download(not_exist_path) def test_no_prefix_local_path(): - abs_path = '/' - relative_path = '.' + abs_path = "/" + relative_path = "." assert Storage.download(abs_path) == abs_path assert Storage.download(relative_path) == relative_path class MockHttpResponse(object): - def __init__( - self, - status_code=404, - raw=b'', - content_type='' - ): + def __init__(self, status_code=404, raw=b"", content_type=""): self.status_code = status_code self.raw = io.BytesIO(raw) - self.headers = {'Content-Type': content_type} + self.headers = {"Content-Type": content_type} def __enter__(self): return self @@ -79,85 +80,169 @@ def __exit__(self, ex_type, ex_val, traceback): pass -@mock.patch('requests.get', return_value=MockHttpResponse(status_code=200, content_type='application/octet-stream')) +@mock.patch( + "requests.get", + return_value=MockHttpResponse( + status_code=200, content_type="application/octet-stream" + ), +) def test_http_uri_path(_): - http_uri = 'http://foo.bar/model.joblib' - http_with_query_uri = 'http://foo.bar/model.joblib?foo=bar' - out_dir = '.' + http_uri = "http://foo.bar/model.joblib" + http_with_query_uri = "http://foo.bar/model.joblib?foo=bar" + out_dir = "." assert Storage.download(http_uri, out_dir=out_dir) == out_dir assert Storage.download(http_with_query_uri, out_dir=out_dir) == out_dir - os.remove('./model.joblib') + os.remove("./model.joblib") -@mock.patch('requests.get', return_value=MockHttpResponse(status_code=200, content_type='application/octet-stream')) +@mock.patch( + "requests.get", + return_value=MockHttpResponse( + status_code=200, content_type="application/octet-stream" + ), +) def test_https_uri_path(_): - https_uri = 'https://foo.bar/model.joblib' - https_with_query_uri = 'https://foo.bar/model.joblib?foo=bar' - out_dir = '.' + https_uri = "https://foo.bar/model.joblib" + https_with_query_uri = "https://foo.bar/model.joblib?foo=bar" + out_dir = "." assert Storage.download(https_uri, out_dir=out_dir) == out_dir assert Storage.download(https_with_query_uri, out_dir=out_dir) == out_dir - os.remove('./model.joblib') + os.remove("./model.joblib") http_uri_path_testparams = [ - (HTTPS_URI_TARGZ, MockHttpResponse(200, FILE_TAR_GZ_RAW, 'application/x-tar'), None), - (HTTPS_URI_TARGZ, MockHttpResponse(200, FILE_TAR_GZ_RAW, 'application/x-gtar'), None), - (HTTPS_URI_TARGZ, MockHttpResponse(200, FILE_TAR_GZ_RAW, 'application/x-gzip'), None), - (HTTPS_URI_TARGZ, MockHttpResponse(200, FILE_TAR_GZ_RAW, 'application/gzip'), None), - (HTTPS_URI_TARGZ, MockHttpResponse(200, FILE_TAR_GZ_RAW, 'application/zip'), RuntimeError), - (HTTPS_URI_TARGZ_WITH_QUERY, MockHttpResponse(200, FILE_TAR_GZ_RAW, 'application/x-tar'), None), - (HTTPS_URI_TARGZ_WITH_QUERY, MockHttpResponse(200, FILE_TAR_GZ_RAW, 'application/x-gtar'), None), - (HTTPS_URI_TARGZ_WITH_QUERY, MockHttpResponse(200, FILE_TAR_GZ_RAW, 'application/x-gzip'), None), - (HTTPS_URI_TARGZ_WITH_QUERY, MockHttpResponse(200, FILE_TAR_GZ_RAW, 'application/gzip'), None), - ('https://foo.bar/model.zip', MockHttpResponse(200, FILE_ZIP_RAW, 'application/zip'), None), - ('https://foo.bar/model.zip', MockHttpResponse(200, FILE_ZIP_RAW, 'application/x-zip-compressed'), None), - ('https://foo.bar/model.zip', MockHttpResponse(200, FILE_ZIP_RAW, 'application/zip-compressed'), None), - ('https://foo.bar/model.zip?foo=bar', MockHttpResponse(200, FILE_ZIP_RAW, 'application/zip'), None), - ('https://foo.bar/model.zip?foo=bar', MockHttpResponse(200, FILE_ZIP_RAW, 'application/x-zip-compressed'), None), - ('https://foo.bar/model.zip?foo=bar', MockHttpResponse(200, FILE_ZIP_RAW, 'application/zip-compressed'), None), - ('https://theabyss.net/model.joblib', MockHttpResponse(404), RuntimeError), - ('https://some.site.com/test.model', MockHttpResponse(status_code=200, content_type='text/html'), RuntimeError), - ('https://foo.bar/test/', MockHttpResponse(200), ValueError), - ('https://foo.bar/download?path=/20210530/model.zip', MockHttpResponse(200, FILE_ZIP_RAW, 'application/zip'), None), - ('https://foo.bar/download?path=/20210530/model.zip', MockHttpResponse(200, FILE_ZIP_RAW, 'application/x-zip' - '-compressed'), None), - ('https://foo.bar/download?path=/20210530/model.zip', MockHttpResponse(200, FILE_ZIP_RAW, 'application/zip' - '-compressed'), None), + ( + HTTPS_URI_TARGZ, + MockHttpResponse(200, FILE_TAR_GZ_RAW, "application/x-tar"), + None, + ), + ( + HTTPS_URI_TARGZ, + MockHttpResponse(200, FILE_TAR_GZ_RAW, "application/x-gtar"), + None, + ), + ( + HTTPS_URI_TARGZ, + MockHttpResponse(200, FILE_TAR_GZ_RAW, "application/x-gzip"), + None, + ), + (HTTPS_URI_TARGZ, MockHttpResponse(200, FILE_TAR_GZ_RAW, "application/gzip"), None), + ( + HTTPS_URI_TARGZ, + MockHttpResponse(200, FILE_TAR_GZ_RAW, "application/zip"), + RuntimeError, + ), + ( + HTTPS_URI_TARGZ_WITH_QUERY, + MockHttpResponse(200, FILE_TAR_GZ_RAW, "application/x-tar"), + None, + ), + ( + HTTPS_URI_TARGZ_WITH_QUERY, + MockHttpResponse(200, FILE_TAR_GZ_RAW, "application/x-gtar"), + None, + ), + ( + HTTPS_URI_TARGZ_WITH_QUERY, + MockHttpResponse(200, FILE_TAR_GZ_RAW, "application/x-gzip"), + None, + ), + ( + HTTPS_URI_TARGZ_WITH_QUERY, + MockHttpResponse(200, FILE_TAR_GZ_RAW, "application/gzip"), + None, + ), + ( + "https://foo.bar/model.zip", + MockHttpResponse(200, FILE_ZIP_RAW, "application/zip"), + None, + ), + ( + "https://foo.bar/model.zip", + MockHttpResponse(200, FILE_ZIP_RAW, "application/x-zip-compressed"), + None, + ), + ( + "https://foo.bar/model.zip", + MockHttpResponse(200, FILE_ZIP_RAW, "application/zip-compressed"), + None, + ), + ( + "https://foo.bar/model.zip?foo=bar", + MockHttpResponse(200, FILE_ZIP_RAW, "application/zip"), + None, + ), + ( + "https://foo.bar/model.zip?foo=bar", + MockHttpResponse(200, FILE_ZIP_RAW, "application/x-zip-compressed"), + None, + ), + ( + "https://foo.bar/model.zip?foo=bar", + MockHttpResponse(200, FILE_ZIP_RAW, "application/zip-compressed"), + None, + ), + ("https://theabyss.net/model.joblib", MockHttpResponse(404), RuntimeError), + ( + "https://some.site.com/test.model", + MockHttpResponse(status_code=200, content_type="text/html"), + RuntimeError, + ), + ("https://foo.bar/test/", MockHttpResponse(200), ValueError), + ( + "https://foo.bar/download?path=/20210530/model.zip", + MockHttpResponse(200, FILE_ZIP_RAW, "application/zip"), + None, + ), + ( + "https://foo.bar/download?path=/20210530/model.zip", + MockHttpResponse(200, FILE_ZIP_RAW, "application/x-zip" "-compressed"), + None, + ), + ( + "https://foo.bar/download?path=/20210530/model.zip", + MockHttpResponse(200, FILE_ZIP_RAW, "application/zip" "-compressed"), + None, + ), ] -@pytest.mark.parametrize('uri,response,expected_error', http_uri_path_testparams) +@pytest.mark.parametrize("uri,response,expected_error", http_uri_path_testparams) def test_http_uri_paths(uri, response, expected_error): if expected_error: + def test(_): with pytest.raises(expected_error): Storage.download(uri) + else: + def test(_): with tempfile.TemporaryDirectory() as out_dir: assert Storage.download(uri, out_dir=out_dir) == out_dir - assert os.path.exists(os.path.join(out_dir, 'model.pth')) - mock.patch('requests.get', return_value=response)(test)() + assert os.path.exists(os.path.join(out_dir, "model.pth")) + + mock.patch("requests.get", return_value=response)(test)() -@mock.patch(STORAGE_MODULE + '.storage') +@mock.patch(STORAGE_MODULE + ".storage") def test_mock_gcs(mock_storage): - gcs_path = 'gs://foo/bar' + gcs_path = "gs://foo/bar" mock_obj = mock.MagicMock() - mock_obj.name = 'mock.object' + mock_obj.name = "mock.object" mock_storage.Client().bucket().list_blobs().__iter__.return_value = [mock_obj] assert Storage.download(gcs_path) def test_storage_blob_exception(): - blob_path = 'https://accountname.blob.core.windows.net/container/some/blob/' + blob_path = "https://accountname.blob.core.windows.net/container/some/blob/" with pytest.raises(Exception): Storage.download(blob_path) -@mock.patch(STORAGE_MODULE + '.boto3') +@mock.patch(STORAGE_MODULE + ".boto3") def test_storage_s3_exception(mock_boto3): - path = 's3://foo/bar' + path = "s3://foo/bar" # Create mock client mock_s3_resource = mock.MagicMock() mock_s3_resource.Bucket.side_effect = Exception() @@ -167,8 +252,8 @@ def test_storage_s3_exception(mock_boto3): Storage.download(path) -@mock.patch(STORAGE_MODULE + '.boto3') -@mock.patch('urllib3.PoolManager') +@mock.patch(STORAGE_MODULE + ".boto3") +@mock.patch("urllib3.PoolManager") def test_no_permission_buckets(mock_connection, mock_boto3): bad_s3_path = "s3://random/path" # Access private buckets without credentials @@ -186,20 +271,41 @@ def test_no_permission_buckets(mock_connection, mock_boto3): def test_unpack_tar_file(): - out_dir = '.' + out_dir = "." tar_file = os.path.join(out_dir, "model.tgz") Path(tar_file).write_bytes(FILE_TAR_GZ_RAW) mimetype, _ = mimetypes.guess_type(tar_file) Storage._unpack_archive_file(tar_file, mimetype, out_dir) - assert os.path.exists(os.path.join(out_dir, 'model.pth')) - os.remove(os.path.join(out_dir, 'model.pth')) + assert os.path.exists(os.path.join(out_dir, "model.pth")) + os.remove(os.path.join(out_dir, "model.pth")) def test_unpack_zip_file(): - out_dir = '.' + out_dir = "." tar_file = os.path.join(out_dir, "model.zip") Path(tar_file).write_bytes(FILE_ZIP_RAW) mimetype, _ = mimetypes.guess_type(tar_file) Storage._unpack_archive_file(tar_file, mimetype, out_dir) - assert os.path.exists(os.path.join(out_dir, 'model.pth')) - os.remove(os.path.join(out_dir, 'model.pth')) + assert os.path.exists(os.path.join(out_dir, "model.pth")) + os.remove(os.path.join(out_dir, "model.pth")) + + +@mock.patch("os.environ") +def test_gs_storage(mock_os): + def side_effect(key, default=None): + if key == "STORAGE_CONFIG": + return json.dumps( + { + "type": "gs", + "base64_service_account": base64.b64encode( + b"service_account_content" + ).decode(), + } + ) + return default + + mock_os.get.side_effect = side_effect + Storage._update_with_storage_spec() + credential_dir = mock_os.__setitem__.call_args_list[0][0][1] + with open(credential_dir, "r") as f: + assert f.read() == "service_account_content" From 1b99800a9243a1666fce04b070dad00b7dfbeadf Mon Sep 17 00:00:00 2001 From: tjandy98 <3953059+tjandy98@users.noreply.github.com> Date: Thu, 11 Apr 2024 21:37:58 +0800 Subject: [PATCH 3/7] Fix formatting Signed-off-by: tjandy98 <3953059+tjandy98@users.noreply.github.com> --- python/kserve/kserve/storage/storage.py | 260 +++++++++++------- .../kserve/storage/test/test_azure_storage.py | 228 ++++++++------- .../kserve/storage/test/test_s3_storage.py | 119 ++++---- 3 files changed, 354 insertions(+), 253 deletions(-) diff --git a/python/kserve/kserve/storage/storage.py b/python/kserve/kserve/storage/storage.py index 21a5e68a45..c7c2a7bd0c 100644 --- a/python/kserve/kserve/storage/storage.py +++ b/python/kserve/kserve/storage/storage.py @@ -13,6 +13,7 @@ # limitations under the License. import base64 +import binascii import glob import gzip import json @@ -96,9 +97,12 @@ def download(uri: str, out_dir: str = None) -> str: # serving mode. The model agent will download models. return out_dir else: - raise Exception("Cannot recognize storage type for " + uri + - "\n'%s', '%s', '%s', and '%s' are the current available storage type." % - (_GCS_PREFIX, _S3_PREFIX, _LOCAL_PREFIX, _HTTP_PREFIX)) + raise Exception( + "Cannot recognize storage type for " + + uri + + "\n'%s', '%s', '%s', and '%s' are the current available storage type." + % (_GCS_PREFIX, _S3_PREFIX, _LOCAL_PREFIX, _HTTP_PREFIX) + ) logging.info("Successfully copied %s to %s", uri, out_dir) return out_dir @@ -106,7 +110,9 @@ def download(uri: str, out_dir: str = None) -> str: @staticmethod def _update_with_storage_spec(): storage_secret_json = json.loads(os.environ.get("STORAGE_CONFIG", "{}")) - storage_secret_override_params = json.loads(os.environ.get("STORAGE_OVERRIDE_CONFIG", "{}")) + storage_secret_override_params = json.loads( + os.environ.get("STORAGE_OVERRIDE_CONFIG", "{}") + ) if storage_secret_override_params: for key, value in storage_secret_override_params.items(): storage_secret_json[key] = value @@ -124,7 +130,10 @@ def _update_with_storage_spec(): if key in storage_secret_json: os.environ[env_var] = storage_secret_json.get(key) - if storage_secret_json.get("type", "") == "hdfs" or storage_secret_json.get("type", "") == "webhdfs": + if ( + storage_secret_json.get("type", "") == "hdfs" + or storage_secret_json.get("type", "") == "webhdfs" + ): temp_dir = tempfile.mkdtemp() os.environ["HDFS_SECRET_DIR"] = temp_dir for key, value in storage_secret_json.items(): @@ -141,32 +150,37 @@ def _update_with_storage_spec(): if storage_secret_json.get("type", "") == "gs": temp_dir = tempfile.mkdtemp() - credential_dir = temp_dir+ "/" + "google_application_credentials.json" + credential_dir = temp_dir + "/" + "google_application_credentials.json" os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credential_dir if storage_secret_json.get("base64_service_account", "") != "": try: with open(credential_dir, "w") as f: - base64_service_account = storage_secret_json.get("base64_service_account", "") - service_account= base64.b64decode(base64_service_account).decode('utf-8') + base64_service_account = storage_secret_json.get( + "base64_service_account", "" + ) + service_account = base64.b64decode( + base64_service_account + ).decode("utf-8") f.write(service_account) f.flush() except binascii.Error: raise RuntimeError("Error: Invalid base64 encoding.") except UnicodeDecodeError: raise RuntimeError("Error: Cannot decode string.") + @staticmethod def get_S3_config(): # default s3 config c = Config() # anon environment variable defined in s3_secret.go - anon = ("true" == os.getenv("awsAnonymousCredential", "false").lower()) + anon = "true" == os.getenv("awsAnonymousCredential", "false").lower() # S3UseVirtualBucket environment variable defined in s3_secret.go # use virtual hosted-style URLs if enabled - virtual = ("true" == os.getenv("S3_USER_VIRTUAL_BUCKET", "false").lower()) + virtual = "true" == os.getenv("S3_USER_VIRTUAL_BUCKET", "false").lower() # S3UseAccelerate environment variable defined in s3_secret.go # use transfer acceleration if enabled - accelerate = ("true" == os.getenv("S3_USE_ACCELERATE", "false").lower()) + accelerate = "true" == os.getenv("S3_USE_ACCELERATE", "false").lower() if anon: c = c.merge(Config(signature_version=UNSIGNED)) @@ -185,9 +199,7 @@ def _download_s3(uri, temp_dir: str): # if awsAnonymousCredential env var true, passed in via config # 2. Environment variables # 3. ~/.aws/config file - kwargs = { - "config": Storage.get_S3_config() - } + kwargs = {"config": Storage.get_S3_config()} endpoint_url = os.getenv("AWS_ENDPOINT_URL") if endpoint_url: kwargs.update({"endpoint_url": endpoint_url}) @@ -206,18 +218,23 @@ def _download_s3(uri, temp_dir: str): if isvc_aws_ca_bundle_path and isvc_aws_ca_bundle_path != "": ca_bundle_full_path = isvc_aws_ca_bundle_path else: - global_ca_bundle_volume_mount_path = os.getenv("CA_BUNDLE_VOLUME_MOUNT_POINT") - ca_bundle_full_path = global_ca_bundle_volume_mount_path + "/cabundle.crt" + global_ca_bundle_volume_mount_path = os.getenv( + "CA_BUNDLE_VOLUME_MOUNT_POINT" + ) + ca_bundle_full_path = ( + global_ca_bundle_volume_mount_path + "/cabundle.crt" + ) if os.path.exists(ca_bundle_full_path): - logging.info('ca bundle file(%s) exists.' % (ca_bundle_full_path)) + logging.info("ca bundle file(%s) exists." % (ca_bundle_full_path)) kwargs.update({"verify": ca_bundle_full_path}) else: raise RuntimeError( - "Failed to find ca bundle file(%s)." % ca_bundle_full_path) + "Failed to find ca bundle file(%s)." % ca_bundle_full_path + ) s3 = boto3.resource("s3", **kwargs) - parsed = urlparse(uri, scheme='s3') + parsed = urlparse(uri, scheme="s3") bucket_name = parsed.netloc - bucket_path = parsed.path.lstrip('/') + bucket_path = parsed.path.lstrip("/") file_count = 0 exact_obj_found = False @@ -226,8 +243,10 @@ def _download_s3(uri, temp_dir: str): # Skip where boto3 lists the directory as an object if obj.key.endswith("/"): continue - # In the case where bucket_path points to a single object, set the target key to bucket_path - # Otherwise, remove the bucket_path prefix, strip any extra slashes, then prepend the target_dir + # In the case where bucket_path points to a single object, + # set the target key to bucket_path + # Otherwise, remove the bucket_path prefix, strip any extra slashes, + # then prepend the target_dir # Example: # s3://test-bucket # Objects: /a/b/c/model.bin /a/model.bin /model.bin @@ -248,7 +267,9 @@ def _download_s3(uri, temp_dir: str): if bucket_path == obj.key: target_key = obj.key.rsplit("/", 1)[-1] exact_obj_found = True - elif bucket_path_last_part and object_last_path.startswith(bucket_path_last_part): + elif bucket_path_last_part and object_last_path.startswith( + bucket_path_last_part + ): target_key = object_last_path else: target_key = obj.key.replace(bucket_path, "").lstrip("/") @@ -257,15 +278,17 @@ def _download_s3(uri, temp_dir: str): if not os.path.exists(os.path.dirname(target)): os.makedirs(os.path.dirname(target), exist_ok=True) bucket.download_file(obj.key, target) - logging.info('Downloaded object %s to %s' % (obj.key, target)) + logging.info("Downloaded object %s to %s" % (obj.key, target)) file_count += 1 - # If the exact object is found, then it is sufficient to download that and break the loop + # If the exact object is found, then it is sufficient to download that + # and break the loop if exact_obj_found: break if file_count == 0: raise RuntimeError( - "Failed to fetch model. No model found in %s." % bucket_path) + "Failed to fetch model. No model found in %s." % bucket_path + ) # Unpack compressed file, supports .tgz, tar.gz and zip file formats. if file_count == 1: @@ -294,7 +317,9 @@ def _download_gcs(uri, temp_dir: str): # Create necessary subdirectory to store the object locally if "/" in subdir_object_key: - local_object_dir = os.path.join(temp_dir, subdir_object_key.rsplit("/", 1)[0]) + local_object_dir = os.path.join( + temp_dir, subdir_object_key.rsplit("/", 1)[0] + ) if not os.path.isdir(local_object_dir): os.makedirs(local_object_dir, exist_ok=True) if subdir_object_key.strip() != "" and not subdir_object_key.endswith("/"): @@ -303,8 +328,7 @@ def _download_gcs(uri, temp_dir: str): blob.download_to_filename(dest_path) file_count += 1 if file_count == 0: - raise RuntimeError( - "Failed to fetch model. No model found in %s." % uri) + raise RuntimeError("Failed to fetch model. No model found in %s." % uri) # Unpack compressed file, supports .tgz, tar.gz and zip file formats. if file_count == 1: @@ -359,9 +383,9 @@ def _download_hdfs(uri, out_dir: str): # Remove hdfs:// or webhdfs:// from the uri to get just the path # e.g. hdfs://user/me/model -> user/me/model if uri.startswith(_HDFS_PREFIX): - path = uri[len(_HDFS_PREFIX):] + path = uri[len(_HDFS_PREFIX) :] # noqa: E203 else: - path = uri[len(_WEBHDFS_PREFIX):] + path = uri[len(_WEBHDFS_PREFIX) :] # noqa: E203 if not config["HDFS_ROOTPATH"]: path = "/" + path @@ -384,21 +408,21 @@ def _download_hdfs(uri, out_dir: str): context = krbContext( using_keytab=True, principal=config["KERBEROS_PRINCIPAL"], - keytab_file=config["KERBEROS_KEYTAB"] + keytab_file=config["KERBEROS_KEYTAB"], ) context.init_with_keytab() client = KerberosClient( config["HDFS_NAMENODE"], proxy=config["USER_PROXY"], root=config["HDFS_ROOTPATH"], - session=s + session=s, ) else: client = Client( config["HDFS_NAMENODE"], proxy=config["USER_PROXY"], root=config["HDFS_ROOTPATH"], - session=s + session=s, ) file_count = 0 dest_file_path = "" @@ -416,7 +440,9 @@ def _download_hdfs(uri, out_dir: str): files = client.list(path) file_count += len(files) for f in files: - client.download(f"{path}/{f}", out_dir, n_threads=int(config["N_THREADS"])) + client.download( + f"{path}/{f}", out_dir, n_threads=int(config["N_THREADS"]) + ) dest_file_path = f"{out_dir}/{f}" if file_count == 1: @@ -426,14 +452,24 @@ def _download_hdfs(uri, out_dir: str): @staticmethod def _download_azure_blob(uri, out_dir: str): # pylint: disable=too-many-locals - account_name, account_url, container_name, prefix = Storage._parse_azure_uri(uri) - logging.info("Connecting to BLOB account: [%s], container: [%s], prefix: [%s]", - account_name, - container_name, - prefix) - token = Storage._get_azure_storage_token() or Storage._get_azure_storage_access_key() + account_name, account_url, container_name, prefix = Storage._parse_azure_uri( + uri + ) + logging.info( + "Connecting to BLOB account: [%s], container: [%s], prefix: [%s]", + account_name, + container_name, + prefix, + ) + token = ( + Storage._get_azure_storage_token() + or Storage._get_azure_storage_access_key() + ) if token is None: - logging.warning("Azure credentials or shared access signature token not found, retrying anonymous access") + logging.warning( + "Azure credentials or shared access signature token not found, \ + retrying anonymous access" + ) blob_service_client = BlobServiceClient(account_url, credential=token) container_client = blob_service_client.get_container_client(container_name) @@ -445,15 +481,17 @@ def _download_azure_blob(uri, out_dir: str): # pylint: disable=too-many-locals curr_prefix, depth = stack.pop() if depth < 0: continue - for item in container_client.walk_blobs( - name_starts_with=curr_prefix): + for item in container_client.walk_blobs(name_starts_with=curr_prefix): if isinstance(item, BlobPrefix): stack.append((item.name, depth - 1)) else: - blobs += container_client.list_blobs(name_starts_with=item.name, - include=['snapshots']) + blobs += container_client.list_blobs( + name_starts_with=item.name, include=["snapshots"] + ) for blob in blobs: - dest_path = os.path.join(out_dir, blob.name.replace(prefix, "", 1).lstrip("/")) + dest_path = os.path.join( + out_dir, blob.name.replace(prefix, "", 1).lstrip("/") + ) Path(os.path.dirname(dest_path)).mkdir(parents=True, exist_ok=True) logging.info("Downloading: %s to %s", blob.name, dest_path) downloader = container_client.download_blob(blob.name) @@ -461,8 +499,7 @@ def _download_azure_blob(uri, out_dir: str): # pylint: disable=too-many-locals f.write(downloader.readall()) file_count += 1 if file_count == 0: - raise RuntimeError( - "Failed to fetch model. No model found in %s." % (uri)) + raise RuntimeError("Failed to fetch model. No model found in %s." % (uri)) # Unpack compressed file, supports .tgz, tar.gz and zip file formats. if file_count == 1: @@ -471,15 +508,21 @@ def _download_azure_blob(uri, out_dir: str): # pylint: disable=too-many-locals Storage._unpack_archive_file(dest_path, mimetype, out_dir) @staticmethod - def _download_azure_file_share(uri, out_dir: str): # pylint: disable=too-many-locals + def _download_azure_file_share( + uri, out_dir: str + ): # pylint: disable=too-many-locals account_name, account_url, share_name, prefix = Storage._parse_azure_uri(uri) - logging.info("Connecting to file share account: [%s], container: [%s], prefix: [%s]", - account_name, - share_name, - prefix) + logging.info( + "Connecting to file share account: [%s], container: [%s], prefix: [%s]", + account_name, + share_name, + prefix, + ) access_key = Storage._get_azure_storage_access_key() if access_key is None: - logging.warning("Azure storage access key not found, retrying anonymous access") + logging.warning( + "Azure storage access key not found, retrying anonymous access" + ) share_service_client = ShareServiceClient(account_url, credential=access_key) share_client = share_service_client.get_share_client(share_name) @@ -492,15 +535,18 @@ def _download_azure_file_share(uri, out_dir: str): # pylint: disable=too-many-l if depth < 0: continue for item in share_client.list_directories_and_files( - directory_name=curr_prefix): + directory_name=curr_prefix + ): if item.is_directory: - stack.append(('/'.join([curr_prefix, item.name]).strip('/'), depth - 1)) + stack.append( + ("/".join([curr_prefix, item.name]).strip("/"), depth - 1) + ) else: share_files.append((curr_prefix, item)) for prefix, file_item in share_files: parts = [prefix] if prefix else [] parts.append(file_item.name) - file_path = '/'.join(parts).lstrip('/') + file_path = "/".join(parts).lstrip("/") dest_path = os.path.join(out_dir, file_path) Path(os.path.dirname(dest_path)).mkdir(parents=True, exist_ok=True) logging.info("Downloading: %s to %s", file_item.name, dest_path) @@ -510,8 +556,7 @@ def _download_azure_file_share(uri, out_dir: str): # pylint: disable=too-many-l data.readinto(f) file_count += 1 if file_count == 0: - raise RuntimeError( - "Failed to fetch model. No model found in %s." % (uri)) + raise RuntimeError("Failed to fetch model. No model found in %s." % (uri)) # Unpack compressed file, supports .tgz, tar.gz and zip file formats. if file_count == 1: @@ -522,10 +567,12 @@ def _download_azure_file_share(uri, out_dir: str): # pylint: disable=too-many-l @staticmethod def _parse_azure_uri(uri): # pylint: disable=too-many-locals parsed = urlparse(uri) - account_name = parsed.netloc.split('.')[0] - account_url = 'https://{}{}'.format(parsed.netloc, '?' + parsed.query if parsed.query else '') - object_name, prefix = parsed.path.lstrip('/').split("/", 1) - prefix = prefix.strip('/') + account_name = parsed.netloc.split(".")[0] + account_url = "https://{}{}".format( + parsed.netloc, "?" + parsed.query if parsed.query else "" + ) + object_name, prefix = parsed.path.lstrip("/").split("/", 1) + prefix = prefix.strip("/") return account_name, account_url, object_name, prefix @staticmethod @@ -549,10 +596,10 @@ def _get_azure_storage_token(): # note the SP must have "Storage Blob Data Owner" perms for this to work from azure.identity import DefaultAzureCredential + token_credential = DefaultAzureCredential() - logging.info("Retrieved SP token credential for client_id: %s", - client_id) + logging.info("Retrieved SP token credential for client_id: %s", client_id) return token_credential @staticmethod @@ -584,8 +631,7 @@ def _download_local(uri, out_dir=None): logging.info("File %s already exist", dest_path) file_count += 1 if file_count == 0: - raise RuntimeError( - "Failed to fetch model. No model found in %s." % (uri)) + raise RuntimeError("Failed to fetch model. No model found in %s." % (uri)) # Unpack compressed file, supports .tgz, tar.gz and zip file formats. if file_count == 1: mimetype, _ = mimetypes.guess_type(dest_path) @@ -599,14 +645,14 @@ def _download_from_uri(uri, out_dir=None): url = urlparse(uri) filename = os.path.basename(url.path) # Determine if the symbol '?' exists in the path - if mimetypes.guess_type(url.path)[0] is None and url.query != '': + if mimetypes.guess_type(url.path)[0] is None and url.query != "": mimetype, encoding = mimetypes.guess_type(url.query) else: mimetype, encoding = mimetypes.guess_type(url.path) local_path = os.path.join(out_dir, filename) - if filename == '': - raise ValueError('No filename contained in URI: %s' % (uri)) + if filename == "": + raise ValueError("No filename contained in URI: %s" % (uri)) # Get header information from host url headers = {} @@ -617,28 +663,52 @@ def _download_from_uri(uri, out_dir=None): with requests.get(uri, stream=True, headers=headers) as response: if response.status_code != 200: - raise RuntimeError("URI: %s returned a %s response code." % (uri, response.status_code)) - zip_content_types = ('application/x-zip-compressed', 'application/zip', 'application/zip-compressed') - if mimetype == 'application/zip' and not response.headers.get('Content-Type', '') \ - .startswith(zip_content_types): - raise RuntimeError("URI: %s did not respond with any of following \'Content-Type\': " % uri + - ", ".join(zip_content_types)) - tar_content_types = ('application/x-tar', 'application/x-gtar', 'application/x-gzip', 'application/gzip') - if mimetype == 'application/x-tar' and not response.headers.get('Content-Type', '') \ - .startswith(tar_content_types): - raise RuntimeError("URI: %s did not respond with any of following \'Content-Type\': " % uri + - ", ".join(tar_content_types)) - if (mimetype != 'application/zip' and mimetype != 'application/x-tar') and \ - not response.headers.get('Content-Type', '').startswith('application/octet-stream'): - raise RuntimeError("URI: %s did not respond with \'Content-Type\': \'application/octet-stream\'" - % uri) - - if encoding == 'gzip': + raise RuntimeError( + "URI: %s returned a %s response code." % (uri, response.status_code) + ) + zip_content_types = ( + "application/x-zip-compressed", + "application/zip", + "application/zip-compressed", + ) + if mimetype == "application/zip" and not response.headers.get( + "Content-Type", "" + ).startswith(zip_content_types): + raise RuntimeError( + "URI: %s did not respond with any of following 'Content-Type': " + % uri + + ", ".join(zip_content_types) + ) + tar_content_types = ( + "application/x-tar", + "application/x-gtar", + "application/x-gzip", + "application/gzip", + ) + if mimetype == "application/x-tar" and not response.headers.get( + "Content-Type", "" + ).startswith(tar_content_types): + raise RuntimeError( + "URI: %s did not respond with any of following 'Content-Type': " + % uri + + ", ".join(tar_content_types) + ) + if ( + mimetype != "application/zip" and mimetype != "application/x-tar" + ) and not response.headers.get("Content-Type", "").startswith( + "application/octet-stream" + ): + raise RuntimeError( + "URI: %s did not respond with 'Content-Type': 'application/octet-stream'" + % uri + ) + + if encoding == "gzip": stream = gzip.GzipFile(fileobj=response.raw) - local_path = os.path.join(out_dir, f'{filename}.tar') + local_path = os.path.join(out_dir, f"{filename}.tar") else: stream = response.raw - with open(local_path, 'wb') as out: + with open(local_path, "wb") as out: shutil.copyfileobj(stream, out) if mimetype in ["application/x-tar", "application/zip"]: @@ -654,12 +724,14 @@ def _unpack_archive_file(file_path, mimetype, target_dir=None): try: logging.info("Unpacking: %s", file_path) if mimetype == "application/x-tar": - archive = tarfile.open(file_path, 'r', encoding='utf-8') + archive = tarfile.open(file_path, "r", encoding="utf-8") else: - archive = zipfile.ZipFile(file_path, 'r') + archive = zipfile.ZipFile(file_path, "r") archive.extractall(target_dir) archive.close() except (tarfile.TarError, zipfile.BadZipfile): - raise RuntimeError("Failed to unpack archive file. \ -The file format is not valid.") + raise RuntimeError( + "Failed to unpack archive file. \ +The file format is not valid." + ) os.remove(file_path) diff --git a/python/kserve/kserve/storage/test/test_azure_storage.py b/python/kserve/kserve/storage/test/test_azure_storage.py index ff837fd4ba..bb134b1c02 100644 --- a/python/kserve/kserve/storage/test/test_azure_storage.py +++ b/python/kserve/kserve/storage/test/test_azure_storage.py @@ -18,7 +18,7 @@ from kserve.storage import Storage -STORAGE_MODULE = 'kserve.storage.storage' +STORAGE_MODULE = "kserve.storage.storage" def create_mock_item(path): @@ -76,21 +76,20 @@ def get_call_args(call_args_list): # pylint: disable=protected-access -@pytest.fixture(scope='session', autouse=True) +@pytest.fixture(scope="session", autouse=True) def test_cleanup(): yield None # Will be executed after the last test - shutil.rmtree('some', ignore_errors=True) - shutil.rmtree('dest_path', ignore_errors=True) + shutil.rmtree("some", ignore_errors=True) + shutil.rmtree("dest_path", ignore_errors=True) -@mock.patch(STORAGE_MODULE + '.os.makedirs') -@mock.patch(STORAGE_MODULE + '.BlobServiceClient') +@mock.patch(STORAGE_MODULE + ".os.makedirs") +@mock.patch(STORAGE_MODULE + ".BlobServiceClient") def test_blob(mock_storage, mock_makedirs): # pylint: disable=unused-argument - # given - blob_path = 'https://kfserving.blob.core.windows.net/triton/simple_string/' - paths = ['simple_string/1/model.graphdef', 'simple_string/config.pbtxt'] + blob_path = "https://kfserving.blob.core.windows.net/triton/simple_string/" + paths = ["simple_string/1/model.graphdef", "simple_string/config.pbtxt"] mock_blob, mock_container = create_mock_blob(mock_storage, paths) # when @@ -98,20 +97,23 @@ def test_blob(mock_storage, mock_makedirs): # pylint: disable=unused-argument # then arg_list = get_call_args(mock_container.download_blob.call_args_list) - assert set(arg_list) == set([('simple_string/1/model.graphdef',), - ('simple_string/config.pbtxt',)]) - - mock_storage.assert_called_with('https://kfserving.blob.core.windows.net', - credential=None) + assert set(arg_list) == set( + [("simple_string/1/model.graphdef",), ("simple_string/config.pbtxt",)] + ) + mock_storage.assert_called_with( + "https://kfserving.blob.core.windows.net", credential=None + ) -@mock.patch(STORAGE_MODULE + '.os.makedirs') -@mock.patch(STORAGE_MODULE + '.Storage._get_azure_storage_token') -@mock.patch(STORAGE_MODULE + '.BlobServiceClient') -def test_secure_blob(mock_storage, mock_get_token, mock_makedirs): # pylint: disable=unused-argument +@mock.patch(STORAGE_MODULE + ".os.makedirs") +@mock.patch(STORAGE_MODULE + ".Storage._get_azure_storage_token") +@mock.patch(STORAGE_MODULE + ".BlobServiceClient") +def test_secure_blob( + mock_storage, mock_get_token, mock_makedirs +): # pylint: disable=unused-argument # given - blob_path = 'https://kfsecured.blob.core.windows.net/triton/simple_string/' + blob_path = "https://kfsecured.blob.core.windows.net/triton/simple_string/" mock_get_token.return_value = "some_token" # when @@ -124,17 +126,18 @@ def test_secure_blob(mock_storage, mock_get_token, mock_makedirs): # pylint: di for call in mock_storage.call_args_list: _, kwargs = call arg_list.append(kwargs) - assert arg_list == [{'credential': 'some_token'}] + assert arg_list == [{"credential": "some_token"}] -@mock.patch(STORAGE_MODULE + '.os.makedirs') -@mock.patch(STORAGE_MODULE + '.BlobServiceClient') +@mock.patch(STORAGE_MODULE + ".os.makedirs") +@mock.patch(STORAGE_MODULE + ".BlobServiceClient") def test_deep_blob(mock_storage, mock_makedirs): # pylint: disable=unused-argument - # given - blob_path = 'https://accountname.blob.core.windows.net/container/some/deep/blob/path' - paths = ['f1', 'f2', 'd1/f11', 'd1/d2/f21', 'd1/d2/d3/f1231', 'd4/f41'] - fq_item_paths = ['some/deep/blob/path/' + p for p in paths] + blob_path = ( + "https://accountname.blob.core.windows.net/container/some/deep/blob/path" + ) + paths = ["f1", "f2", "d1/f11", "d1/d2/f21", "d1/d2/d3/f1231", "d4/f41"] + fq_item_paths = ["some/deep/blob/path/" + p for p in paths] expected_calls = [(f,) for f in fq_item_paths] # when @@ -149,13 +152,12 @@ def test_deep_blob(mock_storage, mock_makedirs): # pylint: disable=unused-argum assert set(actual_calls) == set(expected_calls) -@mock.patch(STORAGE_MODULE + '.os.makedirs') -@mock.patch(STORAGE_MODULE + '.BlobServiceClient') +@mock.patch(STORAGE_MODULE + ".os.makedirs") +@mock.patch(STORAGE_MODULE + ".BlobServiceClient") def test_blob_file(mock_storage, mock_makedirs): # pylint: disable=unused-argument - # given - blob_path = 'https://accountname.blob.core.windows.net/container/somefile.text' - paths = ['somefile'] + blob_path = "https://accountname.blob.core.windows.net/container/somefile.text" + paths = ["somefile"] fq_item_paths = paths expected_calls = [(f,) for f in fq_item_paths] @@ -168,14 +170,15 @@ def test_blob_file(mock_storage, mock_makedirs): # pylint: disable=unused-argum assert actual_calls == expected_calls -@mock.patch(STORAGE_MODULE + '.os.makedirs') -@mock.patch(STORAGE_MODULE + '.BlobServiceClient') +@mock.patch(STORAGE_MODULE + ".os.makedirs") +@mock.patch(STORAGE_MODULE + ".BlobServiceClient") def test_blob_fq_file(mock_storage, mock_makedirs): # pylint: disable=unused-argument - # given - blob_path = 'https://accountname.blob.core.windows.net/container/folder/somefile.text' - paths = ['somefile'] - fq_item_paths = ['folder/' + p for p in paths] + blob_path = ( + "https://accountname.blob.core.windows.net/container/folder/somefile.text" + ) + paths = ["somefile"] + fq_item_paths = ["folder/" + p for p in paths] expected_calls = [(f,) for f in fq_item_paths] # when @@ -187,14 +190,13 @@ def test_blob_fq_file(mock_storage, mock_makedirs): # pylint: disable=unused-ar assert actual_calls == expected_calls -@mock.patch(STORAGE_MODULE + '.os.makedirs') -@mock.patch(STORAGE_MODULE + '.BlobServiceClient') +@mock.patch(STORAGE_MODULE + ".os.makedirs") +@mock.patch(STORAGE_MODULE + ".BlobServiceClient") def test_blob_no_prefix(mock_storage, mock_makedirs): # pylint: disable=unused-argument - # given - blob_path = 'https://accountname.blob.core.windows.net/container/' - paths = ['somefile.text', 'somefolder/somefile.text'] - fq_item_paths = ['' + p for p in paths] + blob_path = "https://accountname.blob.core.windows.net/container/" + paths = ["somefile.text", "somefolder/somefile.text"] + fq_item_paths = ["" + p for p in paths] expected_calls = [(f,) for f in fq_item_paths] # when @@ -206,53 +208,72 @@ def test_blob_no_prefix(mock_storage, mock_makedirs): # pylint: disable=unused- assert set(actual_calls) == set(expected_calls) -@mock.patch(STORAGE_MODULE + '.os.makedirs') -@mock.patch(STORAGE_MODULE + '.Storage._get_azure_storage_access_key') -@mock.patch(STORAGE_MODULE + '.ShareServiceClient') -def test_file_share(mock_storage, mock_get_access_key, mock_makedirs): # pylint: disable=unused-argument - +@mock.patch(STORAGE_MODULE + ".os.makedirs") +@mock.patch(STORAGE_MODULE + ".Storage._get_azure_storage_access_key") +@mock.patch(STORAGE_MODULE + ".ShareServiceClient") +def test_file_share( + mock_storage, mock_get_access_key, mock_makedirs +): # pylint: disable=unused-argument # given - file_share_path = 'https://kfserving.file.core.windows.net/triton/simple_string/' + file_share_path = "https://kfserving.file.core.windows.net/triton/simple_string/" mock_get_access_key.return_value = "some_token" mock_file_share, mock_file, mock_data = create_mock_objects_for_file_share( - mock_storage, [[create_mock_dir('1'), create_mock_file('config.pbtxt')], - [create_mock_file('model.graphdef')], - []]) + mock_storage, + [ + [create_mock_dir("1"), create_mock_file("config.pbtxt")], + [create_mock_file("model.graphdef")], + [], + ], + ) # when Storage._download_azure_file_share(file_share_path, "dest_path") # then arg_list = get_call_args(mock_file.get_file_client.call_args_list) - assert set(arg_list) == set([('simple_string/1/model.graphdef',), - ('simple_string/config.pbtxt',)]) + assert set(arg_list) == set( + [("simple_string/1/model.graphdef",), ("simple_string/config.pbtxt",)] + ) # then mock_get_access_key.assert_called() - mock_storage.assert_called_with('https://kfserving.file.core.windows.net', - credential='some_token') - - -@mock.patch(STORAGE_MODULE + '.os.makedirs') -@mock.patch(STORAGE_MODULE + '.Storage._get_azure_storage_access_key') -@mock.patch(STORAGE_MODULE + '.ShareServiceClient') -def test_deep_file_share(mock_storage, mock_get_access_key, mock_makedirs): # pylint: disable=unused-argument - - file_share_path = 'https://accountname.file.core.windows.net/container/some/deep/blob/path' - paths = ['f1', 'f2', 'd1/f11', 'd1/d2/f21', 'd1/d2/d3/f1231', 'd4/f41'] - fq_item_paths = ['some/deep/blob/path/' + p for p in paths] + mock_storage.assert_called_with( + "https://kfserving.file.core.windows.net", credential="some_token" + ) + + +@mock.patch(STORAGE_MODULE + ".os.makedirs") +@mock.patch(STORAGE_MODULE + ".Storage._get_azure_storage_access_key") +@mock.patch(STORAGE_MODULE + ".ShareServiceClient") +def test_deep_file_share( + mock_storage, mock_get_access_key, mock_makedirs +): # pylint: disable=unused-argument + file_share_path = ( + "https://accountname.file.core.windows.net/container/some/deep/blob/path" + ) + paths = ["f1", "f2", "d1/f11", "d1/d2/f21", "d1/d2/d3/f1231", "d4/f41"] + fq_item_paths = ["some/deep/blob/path/" + p for p in paths] expected_calls = [(f,) for f in fq_item_paths] mock_get_access_key.return_value = "some_token" # when mock_file_share, mock_file, mock_data = create_mock_objects_for_file_share( - mock_storage, [[create_mock_dir('d1'), create_mock_dir('d4'), create_mock_file('f1'), create_mock_file('f2')], - [create_mock_file('f41')], - [create_mock_dir('d2'), create_mock_file('f11')], - [create_mock_dir('d3'), create_mock_file('f21')], - [create_mock_file('f1231')], - []]) + mock_storage, + [ + [ + create_mock_dir("d1"), + create_mock_dir("d4"), + create_mock_file("f1"), + create_mock_file("f2"), + ], + [create_mock_file("f41")], + [create_mock_dir("d2"), create_mock_file("f11")], + [create_mock_dir("d3"), create_mock_file("f21")], + [create_mock_file("f1231")], + [], + ], + ) try: Storage._download_azure_file_share(file_share_path, "some/dest/path") except OSError: # Permissions Error Handling @@ -264,26 +285,28 @@ def test_deep_file_share(mock_storage, mock_get_access_key, mock_makedirs): # p # then mock_get_access_key.assert_called() - mock_storage.assert_called_with('https://accountname.file.core.windows.net', - credential='some_token') - + mock_storage.assert_called_with( + "https://accountname.file.core.windows.net", credential="some_token" + ) -@mock.patch(STORAGE_MODULE + '.os.makedirs') -@mock.patch(STORAGE_MODULE + '.Storage._get_azure_storage_access_key') -@mock.patch(STORAGE_MODULE + '.ShareServiceClient') -def test_file_share_fq_file(mock_storage, mock_get_access_key, mock_makedirs): # pylint: disable=unused-argument +@mock.patch(STORAGE_MODULE + ".os.makedirs") +@mock.patch(STORAGE_MODULE + ".Storage._get_azure_storage_access_key") +@mock.patch(STORAGE_MODULE + ".ShareServiceClient") +def test_file_share_fq_file( + mock_storage, mock_get_access_key, mock_makedirs +): # pylint: disable=unused-argument # given - file_share_path = 'https://accountname.file.core.windows.net/container/folder/' - paths = ['somefile.text'] - fq_item_paths = ['folder/' + p for p in paths] + file_share_path = "https://accountname.file.core.windows.net/container/folder/" + paths = ["somefile.text"] + fq_item_paths = ["folder/" + p for p in paths] expected_calls = [(f,) for f in fq_item_paths] # when mock_get_access_key.return_value = "some_token" mock_file_share, mock_file, mock_data = create_mock_objects_for_file_share( - mock_storage, [[create_mock_file('somefile.text')], - []]) + mock_storage, [[create_mock_file("somefile.text")], []] + ) Storage._download_azure_file_share(file_share_path, "some/dest/path") # then @@ -292,27 +315,33 @@ def test_file_share_fq_file(mock_storage, mock_get_access_key, mock_makedirs): # then mock_get_access_key.assert_called() - mock_storage.assert_called_with('https://accountname.file.core.windows.net', - credential='some_token') - + mock_storage.assert_called_with( + "https://accountname.file.core.windows.net", credential="some_token" + ) -@mock.patch(STORAGE_MODULE + '.os.makedirs') -@mock.patch(STORAGE_MODULE + '.Storage._get_azure_storage_access_key') -@mock.patch(STORAGE_MODULE + '.ShareServiceClient') -def test_file_share_no_prefix(mock_storage, mock_get_access_key, mock_makedirs): # pylint: disable=unused-argument +@mock.patch(STORAGE_MODULE + ".os.makedirs") +@mock.patch(STORAGE_MODULE + ".Storage._get_azure_storage_access_key") +@mock.patch(STORAGE_MODULE + ".ShareServiceClient") +def test_file_share_no_prefix( + mock_storage, mock_get_access_key, mock_makedirs +): # pylint: disable=unused-argument # given - file_share_path = 'https://accountname.file.core.windows.net/container/' - paths = ['somefile.text', 'somefolder/somefile.text'] - fq_item_paths = ['' + p for p in paths] + file_share_path = "https://accountname.file.core.windows.net/container/" + paths = ["somefile.text", "somefolder/somefile.text"] + fq_item_paths = ["" + p for p in paths] expected_calls = [(f,) for f in fq_item_paths] # when mock_get_access_key.return_value = "some_token" mock_file_share, mock_file, mock_data = create_mock_objects_for_file_share( - mock_storage, [[create_mock_dir('somefolder'), create_mock_file('somefile.text')], - [create_mock_file('somefile.text')], - []]) + mock_storage, + [ + [create_mock_dir("somefolder"), create_mock_file("somefile.text")], + [create_mock_file("somefile.text")], + [], + ], + ) Storage._download_azure_file_share(file_share_path, "some/dest/path") # then @@ -321,5 +350,6 @@ def test_file_share_no_prefix(mock_storage, mock_get_access_key, mock_makedirs): # then mock_get_access_key.assert_called() - mock_storage.assert_called_with('https://accountname.file.core.windows.net', - credential='some_token') + mock_storage.assert_called_with( + "https://accountname.file.core.windows.net", credential="some_token" + ) diff --git a/python/kserve/kserve/storage/test/test_s3_storage.py b/python/kserve/kserve/storage/test/test_s3_storage.py index e621d7b3f4..a4f9eecb31 100644 --- a/python/kserve/kserve/storage/test/test_s3_storage.py +++ b/python/kserve/kserve/storage/test/test_s3_storage.py @@ -20,7 +20,7 @@ from botocore import UNSIGNED from kserve.storage import Storage -STORAGE_MODULE = 'kserve.storage.storage' +STORAGE_MODULE = "kserve.storage.storage" def create_mock_obj(path): @@ -50,133 +50,129 @@ def get_call_args(call_args_list): def expected_call_args_list_single_obj(dest, path): - return [( - f'{path}'.strip('/'), - f'{dest}/{path.rsplit("/", 1)[-1]}'.strip('/'))] + return [(f"{path}".strip("/"), f'{dest}/{path.rsplit("/", 1)[-1]}'.strip("/"))] def expected_call_args_list(parent_key, dest, paths): - return [(f'{parent_key}/{p}'.strip('/'), f'{dest}/{p}'.strip('/')) - for p in paths] + return [(f"{parent_key}/{p}".strip("/"), f"{dest}/{p}".strip("/")) for p in paths] # pylint: disable=protected-access -@mock.patch(STORAGE_MODULE + '.boto3') +@mock.patch(STORAGE_MODULE + ".boto3") def test_parent_key(mock_storage): # given - bucket_name = 'foo' - paths = ['models/weights.pt', '0002.h5', 'a/very/long/path/config.json'] - object_paths = ['bar/' + p for p in paths] + bucket_name = "foo" + paths = ["models/weights.pt", "0002.h5", "a/very/long/path/config.json"] + object_paths = ["bar/" + p for p in paths] # when mock_boto3_bucket = create_mock_boto3_bucket(mock_storage, object_paths) - Storage._download_s3(f's3://{bucket_name}/bar', 'dest_path') + Storage._download_s3(f"s3://{bucket_name}/bar", "dest_path") # then arg_list = get_call_args(mock_boto3_bucket.download_file.call_args_list) - assert arg_list == expected_call_args_list('bar', 'dest_path', paths) + assert arg_list == expected_call_args_list("bar", "dest_path", paths) - mock_boto3_bucket.objects.filter.assert_called_with(Prefix='bar') + mock_boto3_bucket.objects.filter.assert_called_with(Prefix="bar") -@mock.patch(STORAGE_MODULE + '.boto3') +@mock.patch(STORAGE_MODULE + ".boto3") def test_no_key(mock_storage): # given - bucket_name = 'foo' - object_paths = ['models/weights.pt', '0002.h5', 'a/very/long/path/config.json'] + bucket_name = "foo" + object_paths = ["models/weights.pt", "0002.h5", "a/very/long/path/config.json"] # when mock_boto3_bucket = create_mock_boto3_bucket(mock_storage, object_paths) - Storage._download_s3(f's3://{bucket_name}/', 'dest_path') + Storage._download_s3(f"s3://{bucket_name}/", "dest_path") # then arg_list = get_call_args(mock_boto3_bucket.download_file.call_args_list) - assert arg_list == expected_call_args_list('', 'dest_path', object_paths) + assert arg_list == expected_call_args_list("", "dest_path", object_paths) - mock_boto3_bucket.objects.filter.assert_called_with(Prefix='') + mock_boto3_bucket.objects.filter.assert_called_with(Prefix="") -@mock.patch(STORAGE_MODULE + '.boto3') +@mock.patch(STORAGE_MODULE + ".boto3") def test_full_name_key(mock_storage): # given - bucket_name = 'foo' - object_key = 'path/to/model/name.pt' + bucket_name = "foo" + object_key = "path/to/model/name.pt" # when mock_boto3_bucket = create_mock_boto3_bucket(mock_storage, [object_key]) - Storage._download_s3(f's3://{bucket_name}/{object_key}', 'dest_path') + Storage._download_s3(f"s3://{bucket_name}/{object_key}", "dest_path") # then arg_list = get_call_args(mock_boto3_bucket.download_file.call_args_list) - assert arg_list == expected_call_args_list_single_obj('dest_path', - object_key) + assert arg_list == expected_call_args_list_single_obj("dest_path", object_key) mock_boto3_bucket.objects.filter.assert_called_with(Prefix=object_key) -@mock.patch(STORAGE_MODULE + '.boto3') +@mock.patch(STORAGE_MODULE + ".boto3") def test_full_name_key_root_bucket_dir(mock_storage): # given - bucket_name = 'foo' - object_key = 'name.pt' + bucket_name = "foo" + object_key = "name.pt" # when mock_boto3_bucket = create_mock_boto3_bucket(mock_storage, [object_key]) - Storage._download_s3(f's3://{bucket_name}/{object_key}', 'dest_path') + Storage._download_s3(f"s3://{bucket_name}/{object_key}", "dest_path") # then arg_list = get_call_args(mock_boto3_bucket.download_file.call_args_list) - assert arg_list == expected_call_args_list_single_obj('dest_path', - object_key) + assert arg_list == expected_call_args_list_single_obj("dest_path", object_key) mock_boto3_bucket.objects.filter.assert_called_with(Prefix=object_key) -AWS_TEST_CREDENTIALS = {"AWS_ACCESS_KEY_ID": "testing", - "AWS_SECRET_ACCESS_KEY": "testing", - "AWS_SECURITY_TOKEN": "testing", - "AWS_SESSION_TOKEN": "testing"} +AWS_TEST_CREDENTIALS = { + "AWS_ACCESS_KEY_ID": "testing", + "AWS_SECRET_ACCESS_KEY": "testing", + "AWS_SECURITY_TOKEN": "testing", + "AWS_SESSION_TOKEN": "testing", +} -@mock.patch(STORAGE_MODULE + '.boto3') +@mock.patch(STORAGE_MODULE + ".boto3") def test_multikey(mock_storage): # given - bucket_name = 'foo' - paths = ['b/model.bin'] - object_paths = ['test/a/' + p for p in paths] + bucket_name = "foo" + paths = ["b/model.bin"] + object_paths = ["test/a/" + p for p in paths] # when mock_boto3_bucket = create_mock_boto3_bucket(mock_storage, object_paths) - Storage._download_s3(f's3://{bucket_name}/test/a', 'dest_path') + Storage._download_s3(f"s3://{bucket_name}/test/a", "dest_path") # then arg_list = get_call_args(mock_boto3_bucket.download_file.call_args_list) - assert arg_list == expected_call_args_list('test/a', 'dest_path', paths) + assert arg_list == expected_call_args_list("test/a", "dest_path", paths) - mock_boto3_bucket.objects.filter.assert_called_with(Prefix='test/a') + mock_boto3_bucket.objects.filter.assert_called_with(Prefix="test/a") -@mock.patch(STORAGE_MODULE + '.boto3') +@mock.patch(STORAGE_MODULE + ".boto3") def test_files_with_no_extension(mock_storage): - # given - bucket_name = 'foo' - paths = ['churn-pickle', 'churn-pickle-logs', 'churn-pickle-report'] - object_paths = ['test/' + p for p in paths] + bucket_name = "foo" + paths = ["churn-pickle", "churn-pickle-logs", "churn-pickle-report"] + object_paths = ["test/" + p for p in paths] # when mock_boto3_bucket = create_mock_boto3_bucket(mock_storage, object_paths) - Storage._download_s3(f's3://{bucket_name}/test/churn-pickle', 'dest_path') + Storage._download_s3(f"s3://{bucket_name}/test/churn-pickle", "dest_path") # then arg_list = get_call_args(mock_boto3_bucket.download_file.call_args_list) # Download only the exact file if found; otherwise, download all files with the given prefix - assert arg_list[0] == expected_call_args_list('test', 'dest_path', paths)[0] + assert arg_list[0] == expected_call_args_list("test", "dest_path", paths)[0] - mock_boto3_bucket.objects.filter.assert_called_with(Prefix='test/churn-pickle') + mock_boto3_bucket.objects.filter.assert_called_with(Prefix="test/churn-pickle") def test_get_S3_config(): @@ -221,7 +217,10 @@ def test_get_S3_config(): with mock.patch.dict(os.environ, {"S3_USE_ACCELERATE": "True"}): config7 = Storage.get_S3_config() - assert config7.s3["use_accelerate_endpoint"] == USE_ACCELERATE_CONFIG.s3["use_accelerate_endpoint"] + assert ( + config7.s3["use_accelerate_endpoint"] + == USE_ACCELERATE_CONFIG.s3["use_accelerate_endpoint"] + ) def test_update_with_storage_spec_s3(monkeypatch): @@ -271,20 +270,20 @@ def test_update_with_storage_spec_s3(monkeypatch): os.environ = previous_env -@mock.patch(STORAGE_MODULE + '.boto3') +@mock.patch(STORAGE_MODULE + ".boto3") def test_target_startswith_parent_folder_name(mock_storage): - bucket_name = 'foo' + bucket_name = "foo" paths = ["model.pkl", "a/model.pkl", "conda.yaml"] - object_paths = ['test/artifacts/model/' + p for p in paths] + object_paths = ["test/artifacts/model/" + p for p in paths] # when mock_boto3_bucket = create_mock_boto3_bucket(mock_storage, object_paths) - Storage._download_s3( - f's3://{bucket_name}/test/artifacts/model', 'dest_path') + Storage._download_s3(f"s3://{bucket_name}/test/artifacts/model", "dest_path") # then arg_list = get_call_args(mock_boto3_bucket.download_file.call_args_list) - assert arg_list[0] == expected_call_args_list( - 'test/artifacts/model', 'dest_path', paths)[0] - mock_boto3_bucket.objects.filter.assert_called_with( - Prefix='test/artifacts/model') + assert ( + arg_list[0] + == expected_call_args_list("test/artifacts/model", "dest_path", paths)[0] + ) + mock_boto3_bucket.objects.filter.assert_called_with(Prefix="test/artifacts/model") From f917fa18fc9746d48e6c5ff3e15e156b1f2cd5bc Mon Sep 17 00:00:00 2001 From: tjandy98 <3953059+tjandy98@users.noreply.github.com> Date: Thu, 11 Apr 2024 21:43:38 +0800 Subject: [PATCH 4/7] Flake8 Signed-off-by: tjandy98 <3953059+tjandy98@users.noreply.github.com> --- python/kserve/kserve/storage/test/test_azure_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/kserve/kserve/storage/test/test_azure_storage.py b/python/kserve/kserve/storage/test/test_azure_storage.py index 5f19912e5c..05b8280255 100644 --- a/python/kserve/kserve/storage/test/test_azure_storage.py +++ b/python/kserve/kserve/storage/test/test_azure_storage.py @@ -384,4 +384,4 @@ def test_file_share_no_prefix( mock_get_access_key.assert_called() mock_storage.assert_called_with( "https://accountname.file.core.windows.net", credential="some_token" - ) \ No newline at end of file + ) From 36a0b8035eadf5e8714ae915aec4b4ec5b064cc0 Mon Sep 17 00:00:00 2001 From: tjandy98 <3953059+tjandy98@users.noreply.github.com> Date: Sun, 5 May 2024 15:12:03 +0800 Subject: [PATCH 5/7] Use `from_service_account_info` for loading service account Signed-off-by: tjandy98 <3953059+tjandy98@users.noreply.github.com> --- python/kserve/kserve/storage/storage.py | 34 +++++++++++------- .../kserve/storage/test/test_storage.py | 35 +++++++++---------- 2 files changed, 38 insertions(+), 31 deletions(-) diff --git a/python/kserve/kserve/storage/storage.py b/python/kserve/kserve/storage/storage.py index c6b023d61d..4f58e9aa1d 100644 --- a/python/kserve/kserve/storage/storage.py +++ b/python/kserve/kserve/storage/storage.py @@ -37,6 +37,7 @@ from botocore import UNSIGNED from botocore.client import Config from google.auth import exceptions +from google.oauth2 import service_account from google.cloud import storage MODEL_MOUNT_DIRS = "/mnt/models" @@ -149,20 +150,18 @@ def _update_with_storage_spec(): f.flush() if storage_secret_json.get("type", "") == "gs": - temp_dir = tempfile.mkdtemp() - credential_dir = temp_dir + "/" + "google_application_credentials.json" - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credential_dir if storage_secret_json.get("base64_service_account", "") != "": try: - with open(credential_dir, "w") as f: - base64_service_account = storage_secret_json.get( - "base64_service_account", "" - ) - service_account = base64.b64decode( - base64_service_account - ).decode("utf-8") - f.write(service_account) - f.flush() + base64_service_account = storage_secret_json.get( + "base64_service_account", "" + ) + service_account_str = base64.b64decode( + base64_service_account + ).decode("utf-8") + service_account_dict = json.loads(service_account_str) + os.environ["GOOGLE_SERVICE_ACCOUNT"] = json.dumps( + service_account_dict + ) except binascii.Error: raise RuntimeError("Error: Invalid base64 encoding.") except UnicodeDecodeError: @@ -299,9 +298,18 @@ def _download_s3(uri, temp_dir: str): @staticmethod def _download_gcs(uri, temp_dir: str): try: - storage_client = storage.Client() + credentials = None + if "GOOGLE_SERVICE_ACCOUNT" in os.environ: + google_service_account = json.loads( + os.environ["GOOGLE_SERVICE_ACCOUNT"] + ) + credentials = service_account.Credentials.from_service_account_info( + google_service_account + ) + storage_client = storage.Client(credentials=credentials) except exceptions.DefaultCredentialsError: storage_client = storage.Client.create_anonymous_client() + bucket_args = uri.replace(_GCS_PREFIX, "", 1).split("/", 1) bucket_name = bucket_args[0] bucket_path = bucket_args[1] if len(bucket_args) > 1 else "" diff --git a/python/kserve/kserve/storage/test/test_storage.py b/python/kserve/kserve/storage/test/test_storage.py index e320132c7a..23d2de9f78 100644 --- a/python/kserve/kserve/storage/test/test_storage.py +++ b/python/kserve/kserve/storage/test/test_storage.py @@ -290,22 +290,21 @@ def test_unpack_zip_file(): os.remove(os.path.join(out_dir, "model.pth")) -@mock.patch("os.environ") -def test_gs_storage(mock_os): - def side_effect(key, default=None): - if key == "STORAGE_CONFIG": - return json.dumps( - { - "type": "gs", - "base64_service_account": base64.b64encode( - b"service_account_content" - ).decode(), - } - ) - return default - - mock_os.get.side_effect = side_effect +@mock.patch.dict( + "os.environ", + { + "STORAGE_CONFIG": json.dumps( + { + "type": "gs", + "base64_service_account": base64.b64encode( + json.dumps({"key": "value"}).encode("utf-8") + ).decode("utf-8"), + } + ) + }, + clear=True, +) +def test_gs_storage_spec(): Storage._update_with_storage_spec() - credential_dir = mock_os.__setitem__.call_args_list[0][0][1] - with open(credential_dir, "r") as f: - assert f.read() == "service_account_content" + assert "GOOGLE_SERVICE_ACCOUNT" in os.environ + assert json.loads(os.environ["GOOGLE_SERVICE_ACCOUNT"]) == {"key": "value"} From 7041a31ed6b548f07c1f5475e8ff2bc8aed8abd3 Mon Sep 17 00:00:00 2001 From: tjandy98 <3953059+tjandy98@users.noreply.github.com> Date: Tue, 7 May 2024 22:46:59 +0800 Subject: [PATCH 6/7] Trigger CI Signed-off-by: tjandy98 <3953059+tjandy98@users.noreply.github.com> From e6e48b7fd982011c7f5b74665f01462283e035bd Mon Sep 17 00:00:00 2001 From: tjandy98 <3953059+tjandy98@users.noreply.github.com> Date: Tue, 27 Aug 2024 14:34:04 +0800 Subject: [PATCH 7/7] Rename base64_service_account to base64_service_account_key_file Signed-off-by: tjandy98 <3953059+tjandy98@users.noreply.github.com> --- python/kserve/kserve/storage/storage.py | 8 ++++---- python/kserve/kserve/storage/test/test_storage.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/kserve/kserve/storage/storage.py b/python/kserve/kserve/storage/storage.py index 8179f1286d..5934cb0fbb 100644 --- a/python/kserve/kserve/storage/storage.py +++ b/python/kserve/kserve/storage/storage.py @@ -155,13 +155,13 @@ def _update_with_storage_spec(): f.flush() if storage_secret_json.get("type", "") == "gs": - if storage_secret_json.get("base64_service_account", "") != "": + if storage_secret_json.get("base64_service_account_key_file", "") != "": try: - base64_service_account = storage_secret_json.get( - "base64_service_account", "" + base64_service_account_key_file = storage_secret_json.get( + "base64_service_account_key_file", "" ) service_account_str = base64.b64decode( - base64_service_account + base64_service_account_key_file ).decode("utf-8") service_account_dict = json.loads(service_account_str) os.environ["GOOGLE_SERVICE_ACCOUNT"] = json.dumps( diff --git a/python/kserve/kserve/storage/test/test_storage.py b/python/kserve/kserve/storage/test/test_storage.py index 535b8af176..09ca2e9e15 100644 --- a/python/kserve/kserve/storage/test/test_storage.py +++ b/python/kserve/kserve/storage/test/test_storage.py @@ -313,7 +313,7 @@ def test_unpack_zip_file(): "STORAGE_CONFIG": json.dumps( { "type": "gs", - "base64_service_account": base64.b64encode( + "base64_service_account_key_file": base64.b64encode( json.dumps({"key": "value"}).encode("utf-8") ).decode("utf-8"), }