In [1]:
import os

import numpy as np
import requests
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoModel, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


## Load the Text Model (do not need Vision Model)


In [2]:
gpu = False

tokenizer = AutoTokenizer.from_pretrained(
    "nomic-ai/nomic-embed-text-v1.5", model_max_length=8192
)
text_model = AutoModel.from_pretrained(
    "nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True, rotary_scaling_factor=2
)
if gpu:
    text_model.to("cuda")
_ = text_model.eval()

<All keys matched successfully>


In [3]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = (
        attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    )
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
        input_mask_expanded.sum(1), min=1e-9
    )


def get_text_embedding(text: str):
    encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors="pt")

    if gpu:
        encoded_input = encoded_input.to("cuda")

    with torch.no_grad():
        model_output = text_model(**encoded_input)

    text_embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
    text_embeddings = F.normalize(text_embeddings, p=2, dim=1)
    return text_embeddings[0].cpu().detach().numpy()

In [4]:
sentences = [
    "search_query: What are cute animals to cuddle with?",
]

get_text_embedding(sentences).shape

(768,)

## Load the Previously Generated Image Embeddings


In [5]:
dataset = load_dataset("parquet", data_files="pokemon_embeddings.parquet")["train"]

# embeddings must be numpy arrays
dataset.set_format(
    type="numpy", columns=["text_embedding", "image_embedding"], output_all_columns=True
)
dataset

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Dataset({
    features: ['id', 'text_embedding', 'image_embedding', 'umap_2d_x', 'umap_2d_y'],
    num_rows: 1302
})

In [6]:
poke_ids = dataset["id"]
image_embeddings = dataset["image_embedding"]
image_embeddings.shape

(1302, 768)

## Set up Multimodal QA

First, get Pokemon Names/IDs for populating.


In [7]:
graphql_query = """
{
  pokemon_v2_pokemon(where: {id: {_lt: 10000}}, order_by: {id: asc}) {
    id
    name
  }
}
"""

r = requests.post(
    "https://beta.pokeapi.co/graphql/v1beta",
    json={
        "query": graphql_query,
    },
)

pokemon = r.json()["data"]["pokemon_v2_pokemon"]
poke_dict = {x["id"]: x["name"].title() for x in pokemon}

In [8]:
# https://stackoverflow.com/a/38250088
def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)


def image_lookup_from_text_query(query, n=5):
    embed = get_text_embedding("search_query: " + query)
    cossims_text_image = embed @ image_embeddings.T
    # softmaxed_cossims = softmax(cossims_text_image)
    top_idx = np.argsort(cossims_text_image)[::-1]

    count = 0
    for idx in top_idx:
        poke_id = poke_ids[idx]
        if poke_id in poke_dict.keys():
            poke_name = poke_dict[idx]
            prob = cossims_text_image[idx]

            print(poke_name + ": " + str(prob))
            count += 1
            if count == n:
                break

In [10]:
query = "ice cream cone"

image_lookup_from_text_query(query)

Vanillite: 0.086028405
Milcery: 0.07552162
Swanna: 0.06919962
Bergmite: 0.066432506
Magcargo: 0.062358268
