# Get Data

In [17]:
from tqdm import tqdm

# PyTorch
import torch
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)

cpu


In [2]:
from datasets import load_dataset, logging
logging.set_verbosity_error()

ds = load_dataset('bookcorpus', split='train[:5000]')
ds

Dataset({
    features: ['text'],
    num_rows: 5000
})

## Indexing the Model State Array with PyTorch

hidden_states is a tuple of tensors/arrays. 

There is one tensor/array for each embedding in the network:
hidden_states[i] == hidden states at i<sup>th</sup> layer of network.

Each of these tensors/arrays has the following shape:
hidden_states[i].shape == [num_examples, sequence_length, embedding_size]  

To get the 13 different embeddings for a single token, we loop over the layers of hidden_states. We collect an example sentence isent, a token within that sentence itok, and all 768 scalar values in the embedding matrix using the colon indexer.

For this example, we have:
 - 1 sentence
 - 7 is the sequence length
 - 13 layers, which is one input embedding *x*_i_ + 12 encoder blocks
 - 768 is the embedding size for BERT embeddings
 
The result is a matrix of shpae [1 x 7 x 13 x 768]

In [16]:
from transformers import logging, AutoModel, AutoTokenizer
logging.set_verbosity_error()
MODEL_NAME = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
model = AutoModel.from_pretrained(MODEL_NAME)

In [4]:
# What is the max sequence length?
max([len(tokenizer(s)['input_ids']) for s in ds['text']])

66

## Process data

In [5]:
embed_array = torch.zeros((5000, 66, 13, 768))
embed_array.shape

torch.Size([5000, 66, 13, 768])

In [160]:
embed_list = []
tok_ids = []

for i, sample in enumerate(tqdm(ds['text'])):
    batch_idx = 0 # one sample at a time, no batching.
    
    inputs = tokenizer(sample, return_tensors='pt')
    
    with torch.no_grad():
        hidden_states = model(**inputs,
                              output_hidden_states=True)['hidden_states']
        
        embed_list.append(torch.stack(hidden_states, dim=2))
        tok_ids.append(inputs.input_ids[batch_idx].tolist())
        # seq_length = hidden_states[batch_idx].size(1)
        # embed_array[i:i+1, :seq_length, : , :] = torch.stack(hidden_states, dim=2)


100%|██████████| 5000/5000 [01:40<00:00, 49.99it/s]


In [None]:
torch.save(embed_list, '../data/bookcorpus_embeddings_0_5000.pt')

In [57]:
embed_array = torch.load('../data/bookcorpus_embeddings_0_5000.pt')

In [143]:
embed_list = torch.load('test.pt')

In [159]:
embed_array.[5099]

IndexError: index 5099 is out of bounds for dimension 0 with size 5000

## BERT lookup embeddings

In [43]:
bert_embeds = torch.Tensor(model.embeddings.word_embeddings.weight)
torch.save(bert_embeds, '../data/bert_lookup_embeddings.pt')

In [48]:
bert_embeds

tensor([[-0.0102, -0.0615, -0.0265,  ..., -0.0199, -0.0372, -0.0098],
        [-0.0117, -0.0600, -0.0323,  ..., -0.0168, -0.0401, -0.0107],
        [-0.0198, -0.0627, -0.0326,  ..., -0.0165, -0.0420, -0.0032],
        ...,
        [-0.0218, -0.0556, -0.0135,  ..., -0.0043, -0.0151, -0.0249],
        [-0.0462, -0.0565, -0.0019,  ...,  0.0157, -0.0139, -0.0095],
        [ 0.0015, -0.0821, -0.0160,  ..., -0.0081, -0.0475,  0.0753]],
       grad_fn=<AliasBackward0>)

In [98]:
def nearest_neighbor_lookup(lookup_embeds, test_vector, topk=3):
    dist = torch.norm(lookup_embeds - test_vector, dim=1, p=None)
    knn = dist.topk(topk, largest=False)
    dist_ind_pairs = list(zip(knn.values.round(decimals=2).tolist(),
                              [tokenizer.decode([x]) for x in knn.indices.tolist()]
                             ))
    return dist_ind_pairs

In [132]:
nearest_neighbor_lookup(bert_embeds, embed_array[0][5][12])

[(14.100000381469727, '[CLS]'),
 (14.119999885559082, 'be'),
 (14.149999618530273, 'is')]

In [133]:
ds['text'][1]

'but just one look at a minion sent him practically catatonic .'

In [72]:
tokenizer.decode([2022, 2042, 2108])

'be been being'

In [100]:
inputs = tokenizer('what to do', return_tensors='pt')

In [109]:
inputs.input_ids[0].tolist()

[101, 2054, 2000, 2079, 102]