## Token Embedding Analysis

One of the key innovations of the Transformers architecture is that token embeddings are _affected_ by nearby tokens (e.g., others words in the sentence). This notebook will demonstrate this effect by analyzing single keywords and comparing the output embedding vectors for the keyword from multiple sentences containing the example keyword.

In [98]:
import torch 
import numpy as np 
import torch.nn as F
from typing import List, Tuple
from transformers import DistilBertTokenizer, DistilBertModel

### Utility Functions

In [99]:
def show_decoding(tokenizer, encoding:torch.tensor) -> List[Tuple[int, str]]:
    """Show encoded/decoded pairs for example sentence."""
    return [(_enc.item(), tokenizer.decode(_enc)) for _enc in encoding]

### Load Tokenizer and Model

In [100]:
model_checkpoint = 'distilbert-base-uncased'

tokenizer = DistilBertTokenizer.from_pretrained(model_checkpoint)
model = DistilBertModel.from_pretrained(model_checkpoint)

### Keywords and Example Text

We set the keyword and create example sentences containing the keyword. Notice in the example sentences below the different _ways_ in which the `keyword` is used. At the end of this notebook, we will compare the individual embeddings for the `keyword` from each of the example sentences. We expect to see a more _similar_ embedding if the `keyword` is used in a similar way between any pair of sentences. For instance, keep in mind examples `examples[0]` and `examples[3]`.

In [101]:
# set keyword
keyword = "pilot"
keyword_id = tokenizer.encode(keyword)[1]
print('keyword id:', keyword_id)

keyword id: 4405


In [102]:
# create examples where the above keyword is used in different forms
examples = [
    "Attention passengers, this is the pilot speaking. Please prepare for landing.",             # - *flight-related
    "The tv show was funny but it didn't get approved after the pilot.",                         # - tv show/pilot
    "So, are you happy with your honda pilot? How does it handle the rough roads around here?",  # - honda pilot
    "Even though the flight was bumpy, I trused the pilot had everything under control.",        # - *flight-related
    "She was the best pilot the commander had ever seen."                                        # - *flighted-related (but less so)
]

# note: once you have run through this notebook once, return and iteratively change `trused -> trusted` and `bumpy -> turbulent`
# and see how the word vector cosine similarity scores change. :O

### Text Encoding

Preprocessing step before input to the model. Each token is translated to the integer id maintained in the `tokenizer` vocabulary.

In [103]:
encodings = tokenizer(examples, padding=True, return_tensors='pt')
encodings = encodings['input_ids'] # we aren't interest in attention_masks in this case
print(encodings.size())

torch.Size([5, 22])


In [104]:
# view encoded tensors from example text (padding makes all examples of equal length)
encodings[0]

tensor([ 101, 3086, 5467, 1010, 2023, 2003, 1996, 4405, 4092, 1012, 3531, 7374,
        2005, 4899, 1012,  102,    0,    0,    0,    0,    0,    0])

In [105]:
# notice how the tokenizer handles example `3`
# particularly with the misspelling of "trusted" and the work "bumpy"
show_decoding(tokenizer, encodings[3])

[(101, '[ C L S ]'),
 (2130, 'e v e n'),
 (2295, 't h o u g h'),
 (1996, 't h e'),
 (3462, 'f l i g h t'),
 (2001, 'w a s'),
 (16906, 'b u m p'),
 (2100, '# # y'),
 (1010, ','),
 (1045, 'i'),
 (19817, 't r'),
 (13901, '# # u s e d'),
 (1996, 't h e'),
 (4405, 'p i l o t'),
 (2018, 'h a d'),
 (2673, 'e v e r y t h i n g'),
 (2104, 'u n d e r'),
 (2491, 'c o n t r o l'),
 (1012, '.'),
 (102, '[ S E P ]'),
 (0, '[ P A D ]'),
 (0, '[ P A D ]')]

In [106]:
# get indices for where `keyword` occurs in each encoded vector
keyword_enc_idx = np.where(encodings.numpy() == keyword_id)[1]
keyword_enc_idx

array([ 7, 15,  9, 13,  5])

### Model Outputs

Input the encoded text examples as a forward-pass to the model. The model will output the embeddings for each token in each text example. We will extract the `keyword` embeddings from each of the sentence outputs to compare the `keyword` representations and how they are affected by their _context_ (i.e., neighbor words).

In [107]:
# first dim size should equal len(examples)
outputs = model(encodings)
print('output shape:', outputs[0].size())

output shape: torch.Size([5, 22, 768])


In [108]:
outputs[0].size()

torch.Size([5, 22, 768])

In [109]:
outputs[0][0][:10]

tensor([[-0.2736, -0.1543,  0.0635,  ...,  0.0185,  0.4955,  0.3776],
        [ 0.3556,  0.3899,  0.3734,  ..., -0.0779,  0.2457, -0.2127],
        [-0.0223,  0.0153,  0.3945,  ...,  0.0511,  0.2932,  0.0463],
        ...,
        [ 0.0386, -0.3837,  0.0909,  ..., -0.3653,  0.6322, -0.0575],
        [ 0.0066,  0.2584,  0.1227,  ...,  0.2462,  0.1527,  0.0671],
        [ 0.6849,  0.1757, -0.3934,  ...,  0.1302, -0.3551, -0.3675]],
       grad_fn=<SliceBackward0>)

In [110]:
# extract embeddings for `keyword` in each of the sentence outputs
keyword_embeddings = torch.stack([outputs[0][i][keyword_enc_idx[i]] for i in range(len(examples))])
keyword_embeddings.size()

torch.Size([5, 768])

In [111]:
# example embedding for `keyword` from example text `i` (shortened for print-out)
keyword_embeddings[0][:10]

tensor([ 0.0386, -0.3837,  0.0909,  0.0236,  0.1591,  0.0252, -0.3695,  0.3892,
         0.3110, -0.6666], grad_fn=<SliceBackward0>)

In [112]:
keyword_embeddings[1][:10]

tensor([ 0.3484, -0.5333,  0.0284,  0.2435, -0.2398, -0.2844,  0.2530,  0.2528,
         0.2653, -0.0074], grad_fn=<SliceBackward0>)

### Generate Embedding for Keyword Only

In [113]:
# create single token sentence; this is something of a _control_ vector
base_keyword = f"{keyword}"
base_keyword

'pilot'

In [114]:
# tokenize - first and last `id` will be beginning and end of sentence tokens
base_encoding = tokenizer(base_keyword, return_tensors='pt')
base_encoding = base_encoding['input_ids']

In [115]:
# forward pass - get token embedding
base_output = model(base_encoding)
base_embedding = base_output[0][0][1]

In [116]:
base_embedding[:10]

tensor([ 0.0967, -0.0604, -0.1340, -0.0524,  0.2880,  0.0868, -0.0932, -0.0259,
         0.6041, -0.8347], grad_fn=<SliceBackward0>)

### Cosine Similarity Between Keyword Embeddings 

\*including the `base_embedding`

Based on the examples, we expect the keyword embeddings contained in sentences related to a "pilot" to be more similar to each other than keyword embeddings in other sentences.

In [117]:
cos = F.CosineSimilarity(dim=1)

In [118]:
# compare cosine similarity between example #1 and all others....
sim_scores = cos(keyword_embeddings[0].unsqueeze(dim=0), keyword_embeddings)
print(sim_scores)

tensor([1.0000, 0.6127, 0.7506, 0.8833, 0.8510], grad_fn=<SumBackward1>)


In [119]:
# print out the sim-scores besides text to see clearly...
print(f"target example: {examples[0]}\n")

for example, score in zip(examples, sim_scores.tolist()):
    print(f"example: '{example}'")
    print(f"score: {score:.3f}")
    print( "-" * 50, "\n")

target example: Attention passengers, this is the pilot speaking. Please prepare for landing.

example: 'Attention passengers, this is the pilot speaking. Please prepare for landing.'
score: 1.000
-------------------------------------------------- 

example: 'The tv show was funny but it didn't get approved after the pilot.'
score: 0.613
-------------------------------------------------- 

example: 'So, are you happy with your honda pilot? How does it handle the rough roads around here?'
score: 0.751
-------------------------------------------------- 

example: 'Even though the flight was bumpy, I trused the pilot had everything under control.'
score: 0.883
-------------------------------------------------- 

example: 'She was the best pilot the commander had ever seen.'
score: 0.851
-------------------------------------------------- 



In [120]:
# compare all sentence-based keyword embeddings with the `base_embedding`
# we don't see the same degree of separation as above...
cos(keyword_embeddings, base_embedding.unsqueeze(dim=0))

tensor([0.6686, 0.5095, 0.5919, 0.6764, 0.6742], grad_fn=<SumBackward1>)