Skip to content

Commit

Permalink
fix: LVM - Added support for GCS storage.googleapis.com URL import …
Browse files Browse the repository at this point in the history
…in `vision_models.Image`

PiperOrigin-RevId: 612948528
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed Mar 5, 2024
1 parent 9eb5a52 commit 2690e72
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
46 changes: 46 additions & 0 deletions tests/unit/aiplatform/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ def generate_image_from_gcs_uri(
return ga_vision_models.Image.load_from_file(gcs_uri)


def generate_image_from_storage_url(
gcs_uri: str = "https://storage.googleapis.com/cloud-samples-data/vertex-ai/llm/prompts/landmark1.png",
) -> ga_vision_models.Image:
return ga_vision_models.Image.load_from_file(gcs_uri)


def generate_video_from_gcs_uri(
gcs_uri: str = "gs://cloud-samples-data/vertex-ai-vision/highway_vehicles.mp4",
) -> ga_vision_models.Video:
Expand Down Expand Up @@ -894,6 +900,46 @@ def test_image_embedding_model_with_gcs_uri(self):
assert embedding_response.image_embedding == test_embeddings
assert embedding_response.text_embedding == test_embeddings

def test_image_embedding_model_with_storage_url(self):
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
),
):
model = preview_vision_models.MultiModalEmbeddingModel.from_pretrained(
"multimodalembedding@001"
)

test_embeddings = [0, 0]
gca_predict_response = gca_prediction_service.PredictResponse()
gca_predict_response.predictions.append(
{"imageEmbedding": test_embeddings, "textEmbedding": test_embeddings}
)

image = generate_image_from_storage_url()
assert (
image._gcs_uri
== "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
)

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response,
):
embedding_response = model.get_embeddings(
image=image, contextual_text="hello world"
)

assert embedding_response.image_embedding == test_embeddings
assert embedding_response.text_embedding == test_embeddings

def test_video_embedding_model_with_only_video(self):
aiplatform.init(
project=_TEST_PROJECT,
Expand Down
14 changes: 13 additions & 1 deletion vertexai/vision_models/_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pathlib
import typing
from typing import Any, Dict, List, Optional, Union
import urllib

from google.cloud import storage

Expand Down Expand Up @@ -80,9 +81,20 @@ def load_from_file(location: str) -> "Image":
Returns:
Loaded image as an `Image` object.
"""
if location.startswith("gs://"):
parsed_url = urllib.parse.urlparse(location)
if (
parsed_url.scheme == "https"
and parsed_url.netloc == "storage.googleapis.com"
):
parsed_url = parsed_url._replace(
scheme="gs", netloc="", path=f"/{urllib.parse.unquote(parsed_url.path)}"
)
location = urllib.parse.urlunparse(parsed_url)

if parsed_url.scheme == "gs":
return Image(gcs_uri=location)

# Load image from local path
image_bytes = pathlib.Path(location).read_bytes()
image = Image(image_bytes=image_bytes)
return image
Expand Down

0 comments on commit 2690e72

Please sign in to comment.