In [1]:
import torch
from transformers import BertTokenizer, BertModel

import json
import io
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load the BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
dataset_path = "../data/pokemon_data.json"

In [4]:
def load_json(path):
    # Opening JSON file
    with open(path, 'r') as openfile:
        # Reading from json file
        json_object = json.load(openfile)
        return json_object
    
def save_json(object_, path):
    # Serializing json
    json_object = json.dumps(object_, indent=4)
    
    # Writing to sample.json
    with open(path, "w") as outfile:
        outfile.write(json_object)

In [5]:
all_pokemon_data = load_json(dataset_path)

In [33]:
# generate embeddings for all pokedex entries
# compare pokedex entry embeddings
all_embeddings = {}
for pid in tqdm(all_pokemon_data):
    # Tokenize and encode the strings
    input_ids = torch.tensor([tokenizer.encode(all_pokemon_data[pid]["flavor_text"], add_special_tokens=True)])
    with torch.no_grad():
        output = model(input_ids)[0]
        string_embedding = output[:, 0, :]  # Use the hidden state of the first token as the embedding
    all_embeddings[pid] = string_embedding
    # print(input_ids.shape)

100%|████████████████████████████████████████████████████████████████████████| 899/899 [00:31<00:00, 28.11it/s]


In [43]:
# compare pokedex entry embeddings
all_similarities = {}
for pid1 in tqdm(all_pokemon_data):
    string1_embedding = all_embeddings[pid1]
    
    similarities = []
    for pid2 in all_pokemon_data:
        if pid1 == pid2:
            continue
        string2_embedding = all_embeddings[pid2]
        
        # Calculate the cosine similarity between the embeddings
        cosine_similarity = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
        with torch.no_grad():
            similarity = cosine_similarity(string1_embedding, string2_embedding)
        similarities.append((pid2, similarity.item()))
    
    similarities.sort(key = lambda x: x[1], reverse=True)
    all_similarities[pid1] = similarities

100%|████████████████████████████████████████████████████████████████████████| 899/899 [00:30<00:00, 29.22it/s]


In [45]:
save_json(all_similarities, "../data/client_data/pokedex_similarities.json")

In [44]:
chosen_pid = "700"
print(all_pokemon_data[chosen_pid]["flavor_text"])
print("========")
for pid, similarity in all_similarities[chosen_pid][:20]:
    print(pid, similarity)
    print(all_pokemon_data[pid]["flavor_text"])
    print()

It emits a soothing aura from its ribbon-shaped organs. It wraps
these appendages around quarrelers to instantly restore calm to
the situation.
783 0.9107884764671326
Before attacking its enemies, it clashes its
scales together and roars. Its sharp claws
shred the opposition.

465 0.908868670463562
Draped with long vines, it resembles a shrub in appearance.
It swings bundles of vines as though they were arms, wrapping
them around prey to ensnare them.

636 0.9084455370903015
Larvesta’s body is warm all over. It spouts fire
from the tips of its horns to intimidate predators
and scare prey.

813 0.908038854598999
It has special pads on the backs of its feet, and
one on its nose. Once it’s raring to fight, these
pads radiate tremendous heat.

613 0.9078810214996338
It sniffles before performing a move, using its
frosty snot to provide an icy element to any
move that needs it.

656 0.907250165939331
It protects its skin by covering its body in
delicate bubbles. Beneath its happy-go-lucky a