TODO:  

0. Prepare the dataset (text, image and metadata) to extract embeddings
1. Extract embeddings with Gecko Multimodal using text and/or image
2. Create search datastore with grounding (citation, summary, snippets)
3. Create API to automatically extract the embedding and upload to Vertex AI Search datastore
4. Create API to search using multimodal (image + text)
5. Verify what are the impacts of this new API to the Front End

# 0. Prepare dataset

In [1]:
import base64
from google.cloud import aiplatform
from google.protobuf import struct_pb2
import typing
import numpy as np

In [2]:
class EmbeddingResponse(typing.NamedTuple):
    text_embedding: typing.Sequence[float]
    image_embedding: typing.Sequence[float]

class EmbeddingPredictionClient:
    """Wrapper around Prediction Service Client."""
    def __init__(self, project : str,
        location : str = "us-central1",
        api_regional_endpoint: str = "us-central1-aiplatform.googleapis.com"):
        client_options = {"api_endpoint": api_regional_endpoint}
        # Initialize client that will be used to create and send requests.
        # This client only needs to be created once, and can be reused for multiple requests.
        self.client = aiplatform.gapic.PredictionServiceClient(client_options=client_options)  
        self.location = location
        self.project = project

    def get_embedding(self, text : str = None, image_bytes : bytes = None):
        if not text and not image_bytes:
            raise ValueError('At least one of text or image_bytes must be specified.')

        instance = struct_pb2.Struct()
        if text:
            instance.fields['text'].string_value = text

        if image_bytes:
            encoded_content = base64.b64encode(image_bytes).decode("utf-8")
            image_struct = instance.fields['image'].struct_value
            image_struct.fields['bytesBase64Encoded'].string_value = encoded_content

        instances = [instance]
        endpoint = (f"projects/{self.project}/locations/{self.location}"
            "/publishers/google/models/multimodalembedding@001")
        response = self.client.predict(endpoint=endpoint, instances=instances)

        text_embedding = None
        if text:    
            text_emb_value = response.predictions[0]['textEmbedding']
            text_embedding = [v for v in text_emb_value]

        image_embedding = None
        if image_bytes:    
            image_emb_value = response.predictions[0]['imageEmbedding']
            image_embedding = [v for v in image_emb_value]

        return EmbeddingResponse(
            text_embedding=text_embedding,
            image_embedding=image_embedding)

In [3]:
def reduce_embedding_dimesion(
        vector_text: list = [],
        vector_image: list = [],
):
    if vector_image and vector_text:
        matrix = np.array([vector_text, vector_image])
        max_pooled_rows = np.max(matrix, axis=0)
    else:
        max_pooled_rows = np.array(vector_text or vector_image)

    max_pooled_columns = [
        max(max_pooled_rows[i], max_pooled_rows[i+1]) 
            for i in range(0, 1408, 2)]
    return max_pooled_columns

In [4]:
embeddings_client = EmbeddingPredictionClient(project="rl-llm-dev")

In [None]:
with open("image_00.jpg", "rb") as f:
    image_contents = f.read()
    response = embeddings_client.get_embedding(
        text="Celular Google Pixel 7 with good camera",
        image_bytes=image_contents)

In [None]:
reduced_vector = reduce_embedding_dimesion(
    vector_image=response.image_embedding,
    vector_text=response.text_embedding
)

In [None]:
import pandas as pd
input_file = pd.read_excel("csm-dataset.xlsx")

In [None]:
metadata = []

for i in range(10):
    html_uri = f"gs://csm-dataset/website-search/{i}.html"
    product_id = str(i)
    title = input_file.title[i]
    description = input_file.description[i]
    with open(f"image_0{i}.jpg", "rb") as f:
        image_contents = f.read()
        response = embeddings_client.get_embedding(
            text=title,
            image_bytes=image_contents)

    reduced_vector = reduce_embedding_dimesion(
        vector_image=response.image_embedding,
        vector_text=response.text_embedding
    )

    metadata.append(
        {
            "id": product_id, 
            "structData": {
                "title":title, 
                "description":description,
                "embedding_vector": list(reduced_vector)},
            "content": {"mimeType": "text/html", "uri": html_uri}
        })

In [None]:
import json

with open("metadata.jsonl", "w") as f:
    for m in metadata:
        f.write(json.dumps(m))
        f.write("\n")

In [None]:
! gsutil cp metadata.jsonl gs://csm-dataset/embeddings-search/metadata.jsonl

# Update Schema

In [None]:
from google.cloud import discoveryengine_v1beta as discoveryengine
import json

In [None]:
new_schema = {
  "$schema": "https://json-schema.org/draft/2020-12/schema",
  "type": "object",
  "properties": {
    "title": {
      "type": "string",
      "keyPropertyMapping": "title"
    },
    "description": {
      "type": "string",
      "keyPropertyMapping": "description"
    },
    "embedding_vector": {
      "type": "array",
      "keyPropertyMapping": "embedding_vector",
      "dimension": 704,
      "items": {
        "type": "number"
      }
    }
  }
}

In [None]:
client = discoveryengine.SchemaServiceClient()

In [None]:
client.list_schemas(
    discoveryengine.ListSchemasRequest(
        parent="projects/rl-llm-dev/locations/global/collections/default_collection/dataStores/embeddings_1697652904023"
    )
)

In [None]:
# Initialize request argument(s)
request = discoveryengine.UpdateSchemaRequest(
    schema = discoveryengine.Schema(
        json_schema = json.dumps(new_schema),
        name="projects/244831775715/locations/global/collections/default_collection/dataStores/embeddings_1697652904023/schemas/default_schema"
    )
)

# Make the request
operation = client.update_schema(request=request)
print("Waiting for operation to complete...")
response = operation.result()

# Handle the response
print(response)

# Query datastore

In [5]:
from google.cloud import discoveryengine_v1beta as discoveryengine

In [6]:
search_client = discoveryengine.SearchServiceClient()

In [12]:
with open("image_00.jpg", "rb") as f:
    response = embeddings_client.get_embedding(
        image_bytes=f.read())

reduced_vector = reduce_embedding_dimesion(
    vector_image=response.image_embedding
)

In [13]:
embedding_spec = discoveryengine.SearchRequest.EmbeddingSpec(
    embedding_vectors = [
        discoveryengine.SearchRequest.EmbeddingSpec.EmbeddingVector(
            field_path = "embedding_vector",
            vector = reduced_vector)
    ]
)

In [14]:
project_id = "rl-llm-dev"
datastore_location = "global"
datastore_id = "embeddings_1697652904023"

serving_config = search_client.serving_config_path(
    project=project_id,
    location=datastore_location,
    data_store=datastore_id,
    serving_config="default_config")

In [15]:
request = discoveryengine.SearchRequest(
    serving_config=serving_config,
    embedding_spec = embedding_spec,
    ranking_expression="0.5 * relevance_score + 0.3 * dotProduct(embedding_vector)"
)

In [16]:
search_client.search(request)

SearchPager<results {
  id: "4"
  document {
    name: "projects/244831775715/locations/global/collections/default_collection/dataStores/embeddings_1697652904023/branches/0/documents/4"
    id: "4"
    derived_struct_data {
      fields {
        key: "link"
        value {
          string_value: "gs://csm-dataset/website-search/4.html"
        }
      }
    }
  }
}
results {
  id: "8"
  document {
    name: "projects/244831775715/locations/global/collections/default_collection/dataStores/embeddings_1697652904023/branches/0/documents/8"
    id: "8"
    derived_struct_data {
      fields {
        key: "link"
        value {
          string_value: "gs://csm-dataset/website-search/8.html"
        }
      }
    }
  }
}
results {
  id: "6"
  document {
    name: "projects/244831775715/locations/global/collections/default_collection/dataStores/embeddings_1697652904023/branches/0/documents/6"
    id: "6"
    derived_struct_data {
      fields {
        key: "link"
        value {
         