In [1]:
import os
import csv

import numpy as np
import requests
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoModel, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


## Load the Text Model (do not need Vision Model)


In [2]:
gpu = False

tokenizer = AutoTokenizer.from_pretrained(
    "nomic-ai/nomic-embed-text-v1.5", model_max_length=8192
)
text_model = AutoModel.from_pretrained(
    "nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True, rotary_scaling_factor=2
)
if gpu:
    text_model.to("cuda")
_ = text_model.eval()

<All keys matched successfully>


In [6]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = (
        attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    )
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
        input_mask_expanded.sum(1), min=1e-9
    )


def get_text_embedding(text: str):
    encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors="pt")

    if gpu:
        encoded_input = encoded_input.to("cuda")

    with torch.no_grad():
        model_output = text_model(**encoded_input)

    text_embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
    text_embeddings = F.layer_norm(
        text_embeddings, normalized_shape=(text_embeddings.shape[1],)
    )

    text_embeddings = F.normalize(text_embeddings, p=2, dim=1)
    return text_embeddings[0].cpu().detach().numpy()

In [7]:
sentences = [
    "search_query: What are cute animals to cuddle with?",
]

get_text_embedding(sentences).shape

(768,)

## Load the Previously Generated Image Embeddings


In [8]:
dataset = load_dataset("parquet", data_files="pokemon_embeddings.parquet")["train"]

# embeddings must be numpy arrays
dataset.set_format(
    type="numpy", columns=["text_embedding", "image_embedding"], output_all_columns=True
)
dataset

Dataset({
    features: ['id', 'text_embedding', 'image_embedding', 'umap_2d_x', 'umap_2d_y'],
    num_rows: 1302
})

In [9]:
poke_ids = dataset["id"]
image_embeddings = dataset["image_embedding"]
image_embeddings.shape

(1302, 768)

## Set up Multimodal QA

First, get Pokemon Names/IDs for populating.


In [10]:
graphql_query = """
{
  pokemon_v2_pokemon(where: {id: {_lt: 10000}}, order_by: {id: asc}) {
    id
    name
  }
}
"""

r = requests.post(
    "https://beta.pokeapi.co/graphql/v1beta",
    json={
        "query": graphql_query,
    },
)

pokemon = r.json()["data"]["pokemon_v2_pokemon"]
poke_dict = {x["id"]: x["name"].title() for x in pokemon}

In [30]:
str(poke_dict)[:100]

"{1: 'Bulbasaur', 2: 'Ivysaur', 3: 'Venusaur', 4: 'Charmander', 5: 'Charmeleon', 6: 'Charizard', 7: '"

In [87]:
fieldnames = ["id", "name", "prob"]


def image_lookup_from_text_query(query, n=10):
    embed = get_text_embedding("search_query: " + query)
    cossims_text_image = embed @ image_embeddings.T
    top_idx = np.argsort(cossims_text_image)[::-1]

    # print(top_idx[:n])

    results = []

    count = 0
    for idx in top_idx:
        poke_id = poke_ids[idx]
        if poke_id in poke_dict.keys():
            poke_name = poke_dict[idx + 1]  # one-index
            prob = cossims_text_image[idx]

            results.append({"id": poke_id, "name": poke_name, "prob": prob})
            count += 1
            if count == n:
                break

    return results


def similarity_to_csv(query, csv_index):
    results = image_lookup_from_text_query(query)

    with open(f"q{csv_index}.csv", "w") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        w.writeheader()
        for row in results:
            w.writerow(row)

    return results

In [91]:
query = "What looks like an ice cream cone?"

similarity_to_csv(query, 1)

[{'id': 583, 'name': 'Vanillish', 'prob': 0.08493237},
 {'id': 869, 'name': 'Alcremie', 'prob': 0.075992346},
 {'id': 713, 'name': 'Avalugg', 'prob': 0.06907362},
 {'id': 11, 'name': 'Metapod', 'prob': 0.06854135},
 {'id': 582, 'name': 'Vanillite', 'prob': 0.06713396},
 {'id': 703, 'name': 'Carbink', 'prob': 0.066326044},
 {'id': 220, 'name': 'Swinub', 'prob': 0.06581724},
 {'id': 771, 'name': 'Pyukumuku', 'prob': 0.065558955},
 {'id': 577, 'name': 'Solosis', 'prob': 0.06508828},
 {'id': 712, 'name': 'Bergmite', 'prob': 0.06482702}]

In [92]:
query = "What looks like an orange cat?"

similarity_to_csv(query, 2)

[{'id': 53, 'name': 'Persian', 'prob': 0.07966652},
 {'id': 726, 'name': 'Torracat', 'prob': 0.075021714},
 {'id': 431, 'name': 'Glameow', 'prob': 0.07453124},
 {'id': 844, 'name': 'Sandaconda', 'prob': 0.07432804},
 {'id': 1014, 'name': 'Okidogi', 'prob': 0.07432228},
 {'id': 196, 'name': 'Espeon', 'prob': 0.07034817},
 {'id': 39, 'name': 'Jigglypuff', 'prob': 0.07002451},
 {'id': 432, 'name': 'Purugly', 'prob': 0.06959056},
 {'id': 509, 'name': 'Purrloin', 'prob': 0.069366224},
 {'id': 725, 'name': 'Litten', 'prob': 0.068936065}]

In [93]:
query = "What has only one eye?"

similarity_to_csv(query, 3)

[{'id': 201, 'name': 'Unown', 'prob': 0.0817163},
 {'id': 114, 'name': 'Tangela', 'prob': 0.07049845},
 {'id': 44, 'name': 'Gloom', 'prob': 0.07007043},
 {'id': 808, 'name': 'Meltan', 'prob': 0.06915918},
 {'id': 355, 'name': 'Duskull', 'prob': 0.06906678},
 {'id': 101, 'name': 'Electrode', 'prob': 0.068546444},
 {'id': 205, 'name': 'Forretress', 'prob': 0.06788925},
 {'id': 960, 'name': 'Wiglett', 'prob': 0.06769621},
 {'id': 455, 'name': 'Carnivine', 'prob': 0.066377126},
 {'id': 524, 'name': 'Roggenrola', 'prob': 0.06605223}]

In [95]:
query = "What is a cute bug?"

similarity_to_csv(query, 4)

[{'id': 742, 'name': 'Cutiefly', 'prob': 0.10017799},
 {'id': 165, 'name': 'Ledyba', 'prob': 0.09996718},
 {'id': 840, 'name': 'Applin', 'prob': 0.09831544},
 {'id': 743, 'name': 'Ribombee', 'prob': 0.09511077},
 {'id': 664, 'name': 'Scatterbug', 'prob': 0.0947648},
 {'id': 953, 'name': 'Rellor', 'prob': 0.094578125},
 {'id': 267, 'name': 'Beautifly', 'prob': 0.09351358},
 {'id': 401, 'name': 'Kricketot', 'prob': 0.09324673},
 {'id': 139, 'name': 'Omastar', 'prob': 0.09290971},
 {'id': 48, 'name': 'Venonat', 'prob': 0.091740824}]