In [1]:
import pandas as pd
import numpy as np
import torch
from transformers import BertTokenizer, BertModel

df = pd.read_csv('data/cards_clean.csv')

In [2]:
from scripts_and_functions.functions import preprocess_text

df['text'] = preprocess_text(df['oracle_text'])
df = df.dropna(subset=['text'])

In [5]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

model = BertModel.from_pretrained('bert-base-uncased')

model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [7]:
#from scripts_and_functions.functions import bert_embedding

def bert_embedding(text):
    '''
    Takes in a string and returns the BERT embeddings of the string.
    '''
    # Add special tokens takes care of adding [CLS], [SEP], <s>... tokens in the right way for each model.
    input_ids = tokenizer.encode(text, add_special_tokens=True)
    input_ids = torch.tensor(input_ids).unsqueeze(0)  # Batch size 1

    # Get the embeddings
    with torch.no_grad():
        outputs = model(input_ids)

    # outputs[0] contains the hidden states of the last layer
    # We take the embeddings from the first token of the last layer which corresponds to [CLS]
    embeddings = outputs[0][0, 0, :].numpy()

    return embeddings

df['bert'] = df['text'].apply(bert_embedding)

In [8]:
df.to_csv('data/cards_bert.csv', index=False)

In [10]:
from sklearn.neighbors import NearestNeighbors

nn = NearestNeighbors(n_neighbors=5, algorithm='kd_tree')
nn.fit(df['bert'].tolist())

In [11]:
test = df.iloc[10]
test[['name','text']]

name                                         Bronze Horse
text    trample as long as you control another creatur...
Name: 11, dtype: object

In [12]:
distances, indices = nn.kneighbors([test['bert']])

In [13]:
for rank, index in enumerate(indices[0], start=1):
    print(f"Rank: {rank}, Index: {index}, Distance: {distances[0][rank-1]}")

Rank: 1, Index: 10, Distance: 0.0
Rank: 2, Index: 1455, Distance: 5.608117714346787
Rank: 3, Index: 7914, Distance: 5.751406641513333
Rank: 4, Index: 25676, Distance: 5.831063048514878
Rank: 5, Index: 11420, Distance: 5.876155131957004


In [14]:
test['text']

'trample as long as you control another creature, prevent all damage that would be dealt to bronze horse by spells that target it.'

In [17]:
df.iloc[indices[0]]['text']

11       trample as long as you control another creatur...
1510     cast this spell only during the declare attack...
8210     if you control a commander, you may cast this ...
26656    choose any number of target creature and/or pl...
11850    prevent all combat damage that would be dealt ...
Name: text, dtype: object
