In [2]:
import torch
from transformers import BertTokenizer, BertModel

import json
import io
from tqdm import tqdm

In [3]:
# Load the BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

Downloading: 100%|███████████████████████████████████████████████████████████| 232k/232k [00:00<00:00, 466kB/s]
Downloading: 100%|██████████████████████████████████████████████████████████| 28.0/28.0 [00:00<00:00, 14.2kB/s]
Downloading: 100%|█████████████████████████████████████████████████████████████| 570/570 [00:00<00:00, 288kB/s]
Downloading: 100%|██████████████████████████████████████████████████████████| 440M/440M [00:58<00:00, 7.58MB/s]
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a Be

In [4]:
dataset_path = "../data/pokemon_data.json"

In [5]:
def load_json(path):
    # Opening JSON file
    with open(path, 'r') as openfile:
        # Reading from json file
        json_object = json.load(openfile)
        return json_object
    
def save_json(object_, path):
    # Serializing json
    json_object = json.dumps(object_, indent=4)
    
    # Writing to sample.json
    with open(path, "w") as outfile:
        outfile.write(json_object)

In [6]:
all_pokemon_data = load_json(dataset_path)

In [9]:
# compare silhouettes
all_similarities = {}
for pid1 in tqdm(all_pokemon_data):
    # Tokenize and encode the strings
    input_ids = torch.tensor([tokenizer.encode(all_pokemon_data[pid1]["flavor_text"], add_special_tokens=True)])
    output = model(input_ids)[0]
    string1_embedding = output[:, 0, :]  # Use the hidden state of the first token as the embedding
    
    similarities = []
    for pid2 in all_pokemon_data:
        if pid1 == pid2:
            continue
        input_ids = torch.tensor([tokenizer.encode(all_pokemon_data[pid2]["flavor_text"], add_special_tokens=True)])
        output = model(input_ids)[0]
        string2_embedding = output[:, 0, :]
        
        # Calculate the cosine similarity between the embeddings
        cosine_similarity = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
        similarity = cosine_similarity(string1_embedding, string2_embedding)
        similarities.append((pid2, similarity))
    
    similarities.sort(key = lambda x: x[1], reverse=True)
    all_similarities[pid1] = similarities

  0%|                                                                                  | 0/899 [01:10<?, ?it/s]


In [None]:
save_json(all_similarities, "../data/client_data/pokedex_similarities.json")

In [17]:
chosen_pid = "1"
print(all_pokemon_data[chosen_pid]["flavor_text"])
print("========")
for pid, similarity in all_similarities[chosen_pid][:20]:
    print(pid, similarity.item())
    print(all_pokemon_data[pid]["flavor_text"])
    print()

While it is young, it uses the nutrients that are
stored in the seed on its back in order to grow.
290 0.9082520008087158
It can sometimes live underground for more than
10 years. It absorbs nutrients from the roots
of trees.

775 0.9057706594467163
It remains asleep from birth to death as a result
of the sedative properties of the leaves that
form its diet.

194 0.9049892425537109
When walking on land, it covers its body with a
poisonous film that keeps its skin from dehydrating.

2 0.9044811725616455
Exposure to sunlight adds to its strength.
Sunlight also makes the bud on its back
grow larger.

732 0.9043053984642029
From its mouth, it fires the seeds of berries
it has eaten. The scattered seeds give rise
to new plants.

496 0.8899728655815125
When it gets dirty, its leaves can’t be used in
photosynthesis, so it always keeps itself clean.

865 0.8863638043403625
After deflecting attacks with its hard leaf shield,
it strikes back with its sharp leek stalk. The leek
stalk is both weap