In [1]:
import os

import numpy as np
import requests
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoModel, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


## Load the Text Model (do not need Vision Model)


In [2]:
gpu = False

tokenizer = AutoTokenizer.from_pretrained(
    "nomic-ai/nomic-embed-text-v1.5", model_max_length=8192
)
text_model = AutoModel.from_pretrained(
    "nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True, rotary_scaling_factor=2
)
if gpu:
    text_model.to("cuda")
_ = text_model.eval()

<All keys matched successfully>


In [6]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = (
        attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    )
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
        input_mask_expanded.sum(1), min=1e-9
    )


def get_text_embedding(text: str):
    encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors="pt")

    if gpu:
        encoded_input = encoded_input.to("cuda")

    with torch.no_grad():
        model_output = text_model(**encoded_input)

    text_embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
    text_embeddings = F.layer_norm(
        text_embeddings, normalized_shape=(text_embeddings.shape[1],)
    )

    text_embeddings = F.normalize(text_embeddings, p=2, dim=1)
    return text_embeddings[0].cpu().detach().numpy()

In [7]:
sentences = [
    "search_query: What are cute animals to cuddle with?",
]

get_text_embedding(sentences).shape

(768,)

## Load the Previously Generated Image Embeddings


In [8]:
dataset = load_dataset("parquet", data_files="pokemon_embeddings.parquet")["train"]

# embeddings must be numpy arrays
dataset.set_format(
    type="numpy", columns=["text_embedding", "image_embedding"], output_all_columns=True
)
dataset

Dataset({
    features: ['id', 'text_embedding', 'image_embedding', 'umap_2d_x', 'umap_2d_y'],
    num_rows: 1302
})

In [9]:
poke_ids = dataset["id"]
image_embeddings = dataset["image_embedding"]
image_embeddings.shape

(1302, 768)

## Set up Multimodal QA

First, get Pokemon Names/IDs for populating.


In [10]:
graphql_query = """
{
  pokemon_v2_pokemon(where: {id: {_lt: 10000}}, order_by: {id: asc}) {
    id
    name
  }
}
"""

r = requests.post(
    "https://beta.pokeapi.co/graphql/v1beta",
    json={
        "query": graphql_query,
    },
)

pokemon = r.json()["data"]["pokemon_v2_pokemon"]
poke_dict = {x["id"]: x["name"].title() for x in pokemon}

In [15]:
def image_lookup_from_text_query(query, n=10):
    embed = get_text_embedding("search_query: " + query)
    cossims_text_image = embed @ image_embeddings.T
    top_idx = np.argsort(cossims_text_image)[::-1]

    results = []

    count = 0
    for idx in top_idx:
        poke_id = poke_ids[idx]
        if poke_id in poke_dict.keys():
            poke_name = poke_dict[idx]
            prob = cossims_text_image[idx]

            results.append({"id": poke_id, "name": poke_name, "prob": prob})
            count += 1
            if count == n:
                break

    return results

In [18]:
query = "What looks like a ice cream cone?"

results = image_lookup_from_text_query(query)
results

[{'id': 583, 'name': 'Vanillite', 'prob': 0.085194},
 {'id': 869, 'name': 'Milcery', 'prob': 0.076221325},
 {'id': 713, 'name': 'Bergmite', 'prob': 0.06910135},
 {'id': 11, 'name': 'Caterpie', 'prob': 0.06868309},
 {'id': 582, 'name': 'Swanna', 'prob': 0.06752206},
 {'id': 703, 'name': 'Dedenne', 'prob': 0.0665906},
 {'id': 220, 'name': 'Magcargo', 'prob': 0.06607176},
 {'id': 771, 'name': 'Palossand', 'prob': 0.06571421},
 {'id': 712, 'name': 'Gourgeist-Average', 'prob': 0.06521829},
 {'id': 577, 'name': 'Gothitelle', 'prob': 0.06519396}]

In [27]:
import csv

fieldnames = ["id", "name", "prob"]

with open("q1.csv", "w") as f:
    w = csv.DictWriter(f, fieldnames=fieldnames)
    w.writeheader()
    for row in results:
        w.writerow(row)

In [28]:
query = "What monster has flames coming out of it?"

results = image_lookup_from_text_query(query)
results

[{'id': 718, 'name': 'Yveltal', 'prob': 0.095142975},
 {'id': 851, 'name': 'Sizzlipede', 'prob': 0.094870396},
 {'id': 631, 'name': 'Mandibuzz', 'prob': 0.093551904},
 {'id': 979, 'name': 'Tatsugiri', 'prob': 0.09278333},
 {'id': 218, 'name': 'Ursaring', 'prob': 0.09183963},
 {'id': 555, 'name': 'Darumaka', 'prob': 0.09001139},
 {'id': 467, 'name': 'Electivire', 'prob': 0.0898506},
 {'id': 890, 'name': 'Zamazenta', 'prob': 0.0875105},
 {'id': 850, 'name': 'Toxtricity-Amped', 'prob': 0.08745665},
 {'id': 727, 'name': 'Torracat', 'prob': 0.08664321}]

In [29]:
with open("q2.csv", "w") as f:
    w = csv.DictWriter(f, fieldnames=fieldnames)
    w.writeheader()
    for row in results:
        w.writerow(row)