In [24]:
import numpy as np
import datasets
from tqdm.auto import tqdm
from datasets import load_dataset

In [3]:
dataset = load_dataset("parquet", data_files="pokemon_embeddings.parquet")["train"]

# embeddings must be numpy arrays
dataset.set_format(
    type="numpy", columns=["text_embedding", "image_embedding"], output_all_columns=True
)
dataset

Dataset({
    features: ['id', 'text_embedding', 'image_embedding', '2d_x', '2d_y'],
    num_rows: 1302
})

Because the embeddings are already unit normalized, the cosine similarity can be calculated just by calculating the dot product, which in Python is easy since you can do `A @ B`!


In [7]:
# Pikachu and Raichu
dataset[24]["text_embedding"] @ dataset[25]["text_embedding"]

0.94904304

In [8]:
dataset[24]["image_embedding"] @ dataset[25]["image_embedding"]

0.9758484

In [15]:
text_embeddings = dataset["text_embedding"]
image_embeddings = dataset["image_embedding"]
poke_ids = dataset["id"]
text_embeddings.shape

(1302, 768)

In [23]:
def cosine_similarities(embed, target):
    assert target in ["text", "image"]
    if target == "text":
        embeddings = text_embeddings
    elif target == "image":
        embeddings = image_embeddings

    return embed @ embeddings.T


# Bulbasaur idx 0, should have high similarity with its evolutions in idx 1 and 2
test_cossims = cosine_similarities(dataset[0]["text_embedding"], "text")
test_cossims[0:50].round(2)

array([1.  , 0.94, 0.9 , 0.85, 0.86, 0.84, 0.84, 0.85, 0.81, 0.82, 0.79,
       0.87, 0.82, 0.82, 0.84, 0.85, 0.84, 0.83, 0.82, 0.85, 0.83, 0.86,
       0.88, 0.87, 0.87, 0.87, 0.85, 0.87, 0.83, 0.87, 0.85, 0.84, 0.83,
       0.83, 0.86, 0.84, 0.89, 0.85, 0.86, 0.86, 0.73, 0.73, 0.9 , 0.84,
       0.87, 0.88, 0.84, 0.88, 0.86, 0.89], dtype=float32)

In [35]:
cossims = []

for poke in tqdm(dataset):
    # text cossims
    text_cossims = cosine_similarities(poke["text_embedding"], "text")
    for i, cossim in enumerate(text_cossims):
        cossims.append(
            {
                "id_1": poke["id"],
                "id_2": poke_ids[i],
                "cossim_type": "text",
                "cossim": cossim,
            }
        )

    # image cossims
    image_cossims = cosine_similarities(poke["image_embedding"], "image")
    for i, cossim in enumerate(image_cossims):
        cossims.append(
            {
                "id_1": poke["id"],
                "id_2": poke_ids[i],
                "cossim_type": "image",
                "cossim": cossim,
            }
        )

len(cossims)

100%|██████████| 1302/1302 [00:01<00:00, 679.24it/s]


3390408

In [36]:
features = datasets.Features(
    {
        "id_1": datasets.Value(dtype="int32"),
        "id_2": datasets.Value(dtype="int32"),
        "cossim_type": datasets.Value(dtype="string"),
        "cossim": datasets.Value(dtype="float32"),
    }
)

features

{'id_1': Value(dtype='int32', id=None),
 'id_2': Value(dtype='int32', id=None),
 'cossim_type': Value(dtype='string', id=None),
 'cossim': Value(dtype='float32', id=None)}

In [37]:
dataset_cossim = datasets.Dataset.from_list(cossims, features=features)
dataset_cossim

Dataset({
    features: ['id_1', 'id_2', 'cossim_type', 'cossim'],
    num_rows: 3390408
})

In [38]:
dataset_cossim.to_parquet("pokemon_cossims.parquet", compression="gzip")

Creating parquet from Arrow format: 100%|██████████| 3391/3391 [00:01<00:00, 2039.18ba/s]


69503364