Skip to content

Commit

Permalink
feat: LVM - Added the `MultiModalEmbeddingModel.get_embeddings(dimens…
Browse files Browse the repository at this point in the history
…ion=...)` parameter

PiperOrigin-RevId: 599457605
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed Jan 18, 2024
1 parent 80d5c56 commit 1d9bd23
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 11 deletions.
37 changes: 37 additions & 0 deletions tests/unit/aiplatform/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,43 @@ def test_image_embedding_model_with_only_text(self):
assert not embedding_response.image_embedding
assert embedding_response.text_embedding == test_embeddings

def test_image_embedding_model_with_lower_dimensions(self):
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
),
):
model = preview_vision_models.MultiModalEmbeddingModel.from_pretrained(
"multimodalembedding@001"
)

dimension = 128
test_embeddings = [0] * dimension
gca_predict_response = gca_prediction_service.PredictResponse()
gca_predict_response.predictions.append(
{"imageEmbedding": test_embeddings, "textEmbedding": test_embeddings}
)

image = generate_image_from_file()

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response,
):
embedding_response = model.get_embeddings(
image=image, contextual_text="hello world", dimension=dimension
)

assert embedding_response.image_embedding == test_embeddings
assert embedding_response.text_embedding == test_embeddings


@pytest.mark.usefixtures("google_auth_mock")
class ImageTextModelTests:
Expand Down
37 changes: 26 additions & 11 deletions vertexai/vision_models/_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _generate_images(
seed: Optional[int] = None,
base_image: Optional["Image"] = None,
mask: Optional["Image"] = None,
language:Optional[str] = None,
language: Optional[str] = None,
) -> "ImageGenerationResponse":
"""Generates images from text prompt.
Expand Down Expand Up @@ -641,19 +641,27 @@ class MultiModalEmbeddingModel(_model_garden_models._ModelGardenModel):
)

def get_embeddings(
self, image: Optional[Image] = None, contextual_text: Optional[str] = None
self,
image: Optional[Image] = None,
contextual_text: Optional[str] = None,
dimension: Optional[int] = None,
) -> "MultiModalEmbeddingResponse":
"""Gets embedding vectors from the provided image.
Args:
image (Image):
Optional. The image to generate embeddings for. One of `image` or `contextual_text` is required.
contextual_text (str):
Optional. Contextual text for your input image. If provided, the model will also
generate an embedding vector for the provided contextual text. The returned image
and text embedding vectors are in the same semantic space with the same dimensionality,
and the vectors can be used interchangeably for use cases like searching image by text
or searching text by image. One of `image` or `contextual_text` is required.
image (Image): Optional. The image to generate embeddings for. One of
`image` or `contextual_text` is required.
contextual_text (str): Optional. Contextual text for your input image.
If provided, the model will also generate an embedding vector for the
provided contextual text. The returned image and text embedding
vectors are in the same semantic space with the same dimensionality,
and the vectors can be used interchangeably for use cases like
searching image by text or searching text by image. One of `image` or
`contextual_text` is required.
dimension (int): Optional. The number of embedding dimensions. Lower
values offer decreased latency when using these embeddings for
subsequent tasks, while higher values offer better accuracy. Available
values: `128`, `256`, `512`, and `1408` (default).
Returns:
ImageEmbeddingResponse:
Expand All @@ -671,7 +679,14 @@ def get_embeddings(
if contextual_text:
instance["text"] = contextual_text

response = self._endpoint.predict(instances=[instance])
parameters = {}
if dimension:
parameters["dimension"] = dimension

response = self._endpoint.predict(
instances=[instance],
parameters=parameters,
)
image_embedding = response.predictions[0].get("imageEmbedding")
text_embedding = (
response.predictions[0].get("textEmbedding")
Expand Down

0 comments on commit 1d9bd23

Please sign in to comment.