diff --git a/pkg/credentials/service_account_credentials.go b/pkg/credentials/service_account_credentials.go index 086d6a7552..8e55510c2a 100644 --- a/pkg/credentials/service_account_credentials.go +++ b/pkg/credentials/service_account_credentials.go @@ -51,8 +51,8 @@ const ( ) var ( - SupportedStorageSpecTypes = []string{"s3", "hdfs", "webhdfs"} - StorageBucketTypes = []string{"s3"} + SupportedStorageSpecTypes = []string{"s3", "hdfs", "webhdfs", "gs"} + StorageBucketTypes = []string{"s3", "gs"} ) type CredentialConfig struct { diff --git a/pkg/credentials/service_account_credentials_test.go b/pkg/credentials/service_account_credentials_test.go index e305b7920d..1c0b52ee1f 100644 --- a/pkg/credentials/service_account_credentials_test.go +++ b/pkg/credentials/service_account_credentials_test.go @@ -1429,7 +1429,7 @@ func TestCredentialBuilder_CreateStorageSpecSecretEnvs(t *testing.T) { Name: "storage-secret", Namespace: namespace, }, - StringData: map[string]string{"minio": "{\n \"type\": \"gs\",\n \"access_key_id\": \"minio\",\n \"secret_access_key\": \"minio123\",\n \"endpoint_url\": \"http://minio-service.kubeflow:9000\",\n \"bucket\": \"test-bucket\",\n \"region\": \"us-south\"\n }"}, + StringData: map[string]string{"minio": "{\n \"type\": \"gss\",\n \"access_key_id\": \"minio\",\n \"secret_access_key\": \"minio123\",\n \"endpoint_url\": \"http://minio-service.kubeflow:9000\",\n \"bucket\": \"test-bucket\",\n \"region\": \"us-south\"\n }"}, }, storageKey: "minio", storageSecretName: "storage-secret", diff --git a/python/kserve/kserve/storage/storage.py b/python/kserve/kserve/storage/storage.py index 3a46c969b6..4f58e9aa1d 100644 --- a/python/kserve/kserve/storage/storage.py +++ b/python/kserve/kserve/storage/storage.py @@ -13,6 +13,7 @@ # limitations under the License. import base64 +import binascii import glob import gzip import json @@ -36,6 +37,7 @@ from botocore import UNSIGNED from botocore.client import Config from google.auth import exceptions +from google.oauth2 import service_account from google.cloud import storage MODEL_MOUNT_DIRS = "/mnt/models" @@ -147,6 +149,24 @@ def _update_with_storage_spec(): f.write(value) f.flush() + if storage_secret_json.get("type", "") == "gs": + if storage_secret_json.get("base64_service_account", "") != "": + try: + base64_service_account = storage_secret_json.get( + "base64_service_account", "" + ) + service_account_str = base64.b64decode( + base64_service_account + ).decode("utf-8") + service_account_dict = json.loads(service_account_str) + os.environ["GOOGLE_SERVICE_ACCOUNT"] = json.dumps( + service_account_dict + ) + except binascii.Error: + raise RuntimeError("Error: Invalid base64 encoding.") + except UnicodeDecodeError: + raise RuntimeError("Error: Cannot decode string.") + @staticmethod def get_S3_config(): # default s3 config @@ -222,8 +242,10 @@ def _download_s3(uri, temp_dir: str): # Skip where boto3 lists the directory as an object if obj.key.endswith("/"): continue - # In the case where bucket_path points to a single object, set the target key to bucket_path - # Otherwise, remove the bucket_path prefix, strip any extra slashes, then prepend the target_dir + # In the case where bucket_path points to a single object, + # set the target key to bucket_path + # Otherwise, remove the bucket_path prefix, strip any extra slashes, + # then prepend the target_dir # Example: # s3://test-bucket # Objects: /a/b/c/model.bin /a/model.bin /model.bin @@ -258,7 +280,8 @@ def _download_s3(uri, temp_dir: str): logging.info("Downloaded object %s to %s" % (obj.key, target)) file_count += 1 - # If the exact object is found, then it is sufficient to download that and break the loop + # If the exact object is found, then it is sufficient to download that + # and break the loop if exact_obj_found: break if file_count == 0: @@ -275,9 +298,18 @@ def _download_s3(uri, temp_dir: str): @staticmethod def _download_gcs(uri, temp_dir: str): try: - storage_client = storage.Client() + credentials = None + if "GOOGLE_SERVICE_ACCOUNT" in os.environ: + google_service_account = json.loads( + os.environ["GOOGLE_SERVICE_ACCOUNT"] + ) + credentials = service_account.Credentials.from_service_account_info( + google_service_account + ) + storage_client = storage.Client(credentials=credentials) except exceptions.DefaultCredentialsError: storage_client = storage.Client.create_anonymous_client() + bucket_args = uri.replace(_GCS_PREFIX, "", 1).split("/", 1) bucket_name = bucket_args[0] bucket_path = bucket_args[1] if len(bucket_args) > 1 else "" @@ -359,9 +391,9 @@ def _download_hdfs(uri, out_dir: str): # Remove hdfs:// or webhdfs:// from the uri to get just the path # e.g. hdfs://user/me/model -> user/me/model if uri.startswith(_HDFS_PREFIX): - path = uri[len(_HDFS_PREFIX) :] + path = uri[len(_HDFS_PREFIX) :] # noqa: E203 else: - path = uri[len(_WEBHDFS_PREFIX) :] + path = uri[len(_WEBHDFS_PREFIX) :] # noqa: E203 if not config["HDFS_ROOTPATH"]: path = "/" + path @@ -443,7 +475,8 @@ def _download_azure_blob(uri, out_dir: str): # pylint: disable=too-many-locals ) if token is None: logging.warning( - "Azure credentials or shared access signature token not found, retrying anonymous access" + "Azure credentials or shared access signature token not found, \ + retrying anonymous access" ) blob_service_client = BlobServiceClient(account_url, credential=token) diff --git a/python/kserve/kserve/storage/test/test_s3_storage.py b/python/kserve/kserve/storage/test/test_s3_storage.py index 00ef82d0ac..a4f9eecb31 100644 --- a/python/kserve/kserve/storage/test/test_s3_storage.py +++ b/python/kserve/kserve/storage/test/test_s3_storage.py @@ -157,7 +157,6 @@ def test_multikey(mock_storage): @mock.patch(STORAGE_MODULE + ".boto3") def test_files_with_no_extension(mock_storage): - # given bucket_name = "foo" paths = ["churn-pickle", "churn-pickle-logs", "churn-pickle-report"] diff --git a/python/kserve/kserve/storage/test/test_storage.py b/python/kserve/kserve/storage/test/test_storage.py index 63b16893eb..23d2de9f78 100644 --- a/python/kserve/kserve/storage/test/test_storage.py +++ b/python/kserve/kserve/storage/test/test_storage.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import io +import json import os import tempfile import binascii @@ -286,3 +288,23 @@ def test_unpack_zip_file(): Storage._unpack_archive_file(tar_file, mimetype, out_dir) assert os.path.exists(os.path.join(out_dir, "model.pth")) os.remove(os.path.join(out_dir, "model.pth")) + + +@mock.patch.dict( + "os.environ", + { + "STORAGE_CONFIG": json.dumps( + { + "type": "gs", + "base64_service_account": base64.b64encode( + json.dumps({"key": "value"}).encode("utf-8") + ).decode("utf-8"), + } + ) + }, + clear=True, +) +def test_gs_storage_spec(): + Storage._update_with_storage_spec() + assert "GOOGLE_SERVICE_ACCOUNT" in os.environ + assert json.loads(os.environ["GOOGLE_SERVICE_ACCOUNT"]) == {"key": "value"}