Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Google Cloud Storage support to Storage Spec #3495

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
4 changes: 2 additions & 2 deletions pkg/credentials/service_account_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ const (
)

var (
SupportedStorageSpecTypes = []string{"s3", "hdfs", "webhdfs"}
StorageBucketTypes = []string{"s3"}
SupportedStorageSpecTypes = []string{"s3", "hdfs", "webhdfs", "gs"}
StorageBucketTypes = []string{"s3", "gs"}
)

type CredentialConfig struct {
Expand Down
2 changes: 1 addition & 1 deletion pkg/credentials/service_account_credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1429,7 +1429,7 @@ func TestCredentialBuilder_CreateStorageSpecSecretEnvs(t *testing.T) {
Name: "storage-secret",
Namespace: namespace,
},
StringData: map[string]string{"minio": "{\n \"type\": \"gs\",\n \"access_key_id\": \"minio\",\n \"secret_access_key\": \"minio123\",\n \"endpoint_url\": \"http://minio-service.kubeflow:9000\",\n \"bucket\": \"test-bucket\",\n \"region\": \"us-south\"\n }"},
StringData: map[string]string{"minio": "{\n \"type\": \"gss\",\n \"access_key_id\": \"minio\",\n \"secret_access_key\": \"minio123\",\n \"endpoint_url\": \"http://minio-service.kubeflow:9000\",\n \"bucket\": \"test-bucket\",\n \"region\": \"us-south\"\n }"},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentional? Why wasn't this failing before?

Copy link
Contributor Author

@tjandy98 tjandy98 Mar 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentional? Why wasn't this failing before?

Hi @terrytangyuan , thanks for your feedback.

Yes, this change is intentional. Before this PR, the 'gs' type wasn't supported, the test would fail.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quite related to Dan comment, but I don't think it makes sense to force this to fail, by doing this change. The test should be reworked to properly validate the GS support (or fully remove it, and create a new one from stracth).

},
storageKey: "minio",
storageSecretName: "storage-secret",
Expand Down
47 changes: 40 additions & 7 deletions python/kserve/kserve/storage/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import base64
import binascii
import glob
import gzip
import json
Expand All @@ -36,6 +37,7 @@
from botocore import UNSIGNED
from botocore.client import Config
from google.auth import exceptions
from google.oauth2 import service_account
from google.cloud import storage

MODEL_MOUNT_DIRS = "/mnt/models"
Expand Down Expand Up @@ -147,6 +149,24 @@ def _update_with_storage_spec():
f.write(value)
f.flush()

if storage_secret_json.get("type", "") == "gs":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move gs and other supported types to constants?

if storage_secret_json.get("base64_service_account", "") != "":
try:
base64_service_account = storage_secret_json.get(
"base64_service_account", ""
)
service_account_str = base64.b64decode(
base64_service_account
).decode("utf-8")
service_account_dict = json.loads(service_account_str)
os.environ["GOOGLE_SERVICE_ACCOUNT"] = json.dumps(
service_account_dict
)
except binascii.Error:
raise RuntimeError("Error: Invalid base64 encoding.")
except UnicodeDecodeError:
raise RuntimeError("Error: Cannot decode string.")

@staticmethod
def get_S3_config():
# default s3 config
Expand Down Expand Up @@ -222,8 +242,10 @@ def _download_s3(uri, temp_dir: str):
# Skip where boto3 lists the directory as an object
if obj.key.endswith("/"):
continue
# In the case where bucket_path points to a single object, set the target key to bucket_path
# Otherwise, remove the bucket_path prefix, strip any extra slashes, then prepend the target_dir
# In the case where bucket_path points to a single object,
# set the target key to bucket_path
# Otherwise, remove the bucket_path prefix, strip any extra slashes,
# then prepend the target_dir
# Example:
# s3://test-bucket
# Objects: /a/b/c/model.bin /a/model.bin /model.bin
Expand Down Expand Up @@ -258,7 +280,8 @@ def _download_s3(uri, temp_dir: str):
logging.info("Downloaded object %s to %s" % (obj.key, target))
file_count += 1

# If the exact object is found, then it is sufficient to download that and break the loop
# If the exact object is found, then it is sufficient to download that
# and break the loop
if exact_obj_found:
break
if file_count == 0:
Expand All @@ -275,9 +298,18 @@ def _download_s3(uri, temp_dir: str):
@staticmethod
def _download_gcs(uri, temp_dir: str):
try:
storage_client = storage.Client()
credentials = None
if "GOOGLE_SERVICE_ACCOUNT" in os.environ:
google_service_account = json.loads(
os.environ["GOOGLE_SERVICE_ACCOUNT"]
)
credentials = service_account.Credentials.from_service_account_info(
google_service_account
)
storage_client = storage.Client(credentials=credentials)
except exceptions.DefaultCredentialsError:
storage_client = storage.Client.create_anonymous_client()

bucket_args = uri.replace(_GCS_PREFIX, "", 1).split("/", 1)
bucket_name = bucket_args[0]
bucket_path = bucket_args[1] if len(bucket_args) > 1 else ""
Expand Down Expand Up @@ -359,9 +391,9 @@ def _download_hdfs(uri, out_dir: str):
# Remove hdfs:// or webhdfs:// from the uri to get just the path
# e.g. hdfs://user/me/model -> user/me/model
if uri.startswith(_HDFS_PREFIX):
path = uri[len(_HDFS_PREFIX) :]
path = uri[len(_HDFS_PREFIX) :] # noqa: E203
else:
path = uri[len(_WEBHDFS_PREFIX) :]
path = uri[len(_WEBHDFS_PREFIX) :] # noqa: E203

if not config["HDFS_ROOTPATH"]:
path = "/" + path
Expand Down Expand Up @@ -443,7 +475,8 @@ def _download_azure_blob(uri, out_dir: str): # pylint: disable=too-many-locals
)
if token is None:
logging.warning(
"Azure credentials or shared access signature token not found, retrying anonymous access"
"Azure credentials or shared access signature token not found, \
retrying anonymous access"
)

blob_service_client = BlobServiceClient(account_url, credential=token)
Expand Down
1 change: 0 additions & 1 deletion python/kserve/kserve/storage/test/test_s3_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def test_multikey(mock_storage):

@mock.patch(STORAGE_MODULE + ".boto3")
def test_files_with_no_extension(mock_storage):

# given
bucket_name = "foo"
paths = ["churn-pickle", "churn-pickle-logs", "churn-pickle-report"]
Expand Down
22 changes: 22 additions & 0 deletions python/kserve/kserve/storage/test/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import io
import json
import os
import tempfile
import binascii
Expand Down Expand Up @@ -286,3 +288,23 @@ def test_unpack_zip_file():
Storage._unpack_archive_file(tar_file, mimetype, out_dir)
assert os.path.exists(os.path.join(out_dir, "model.pth"))
os.remove(os.path.join(out_dir, "model.pth"))


@mock.patch.dict(
"os.environ",
{
"STORAGE_CONFIG": json.dumps(
{
"type": "gs",
"base64_service_account": base64.b64encode(
json.dumps({"key": "value"}).encode("utf-8")
).decode("utf-8"),
}
)
},
clear=True,
)
def test_gs_storage_spec():
Storage._update_with_storage_spec()
assert "GOOGLE_SERVICE_ACCOUNT" in os.environ
assert json.loads(os.environ["GOOGLE_SERVICE_ACCOUNT"]) == {"key": "value"}