# Contextualized embedding with transformer models illustrated

In this notebook, we begin to peek under the hood of a BERT transformer model to understand how contextualized embedding work.
We then also introduce a couple of potential use cases that leverage contextualized embeddings.

<br>
<a target="_blank" href="https://colab.research.google.com/github/haukelicht/dia_cta_course/blob/main/notebooks/block2/day1/contextualized_embedding_explained.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

### Setup

#### Colab

In [1]:
# check if on colab
COLAB = True
try:
    import google.colab
except:
    COLAB=False

if COLAB:
    # # shallow clone of current state of main branch 
    # !git clone --branch main --single-branch --depth 1 --filter=blob:none https://github.com/haukelicht/dia_cta_course.git

    !pip install -q umap-learn~=0.5.9.post2 bertviz==1.4.1


#### Required packages

In [3]:
from pathlib import Path
import numpy as np
import pandas as pd

import torch
from transformers import BertTokenizer, BertModel
from transformers import (
    BertForMaskedLM, 
    BertModel, BertTokenizer
)

from bertviz import head_view
from bertviz.transformers_neuron_view import BertModel as BertVizModel 
from bertviz.transformers_neuron_view import BertTokenizer as BertVizTokenizer
from bertviz.neuron_view import show

import umap
from sklearn.cluster import KMeans

from sklearn.metrics import accuracy_score
from sklearn.metrics.pairwise import cosine_similarity

import matplotlib.pyplot as plt

### Some terminology

Here are the most important peices of the transformer model we will work wit:

- `model()` -- this _is_ the model (has been pre-trained, we only use it to compute contextualized embeddings)
- `tokenizer()` -- this is the tokenizer associated with the model that converts raw texts into sequences of token IDs that point tot the corresponding tokens' location in the model's input embedding matrix
- `outputs` -- this is what the model _outputs_ when we process an input sentence through it
-  `outputs.hidden_states` -- these are the contextualized embeddings prodcued at each layer
-  `outputs.last_hidden_state` -- this is the contextualized embeddings at the _final_ layer


## Intro to the `transformers` library

In python, the standard library to work with transformer models is `transformers`.
It provides access to pre-trained transformers models through its [model hub]().
The `transformers` library is developed and maintained by Hugging Face Inc.

### pre-trained models and tokenizers

To use a pre-trained model for embedding texts, we need two things:

1. the model's tokenizer
1. and of course the model itself

We use the model to process a text though its **layers** to obtain the text's **embedding**.
But to be able to do this, we need to **tokenize** the text to convert it into number – because deep neural network can only process with numbers, not with raw text.

Below we load a pre-trained BERT model, specifically "bert-base-uncased", which is a smallish version of BERT (hence 'base' instead of 'large') that does not distinguish between upper- and lowercase letters (hence 'uncased'). 

In [5]:
# define the name of the model we want to load
model_id = 'bert-base-uncased'

# load the pre-trained model and tokenizer 
model = BertModel.from_pretrained(model_id)
tokenizer = BertTokenizer.from_pretrained(model_id)
# NOTE: this will trigger downloading the model and tokenizer if you haven't done so before

Let's get some information about the model by looking at its configuration attribute (`config`):

In [6]:
# let's get some important information about the model
print('embedding dimensionality:', model.config.hidden_size)
print('number of layers:', model.config.num_hidden_layers)
print('vocabulary size:', model.config.vocab_size)

embedding dimensionality: 768
number of layers: 12
vocabulary size: 30522


In [7]:
# lets' have a look at the model architecture
print(model)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

- the models first component is a `BertEmbeddings` module that contains
    1. the initial word embedding layer
    2. the positional embedding
- after this we have the `BertEncoder` module that consists of 12 `BertLayer`s

If we just want to get the initial word embeddings, we can access them like this.

In [8]:
model.embeddings.word_embeddings.weight.shape

torch.Size([30522, 768])

In [9]:
print(model.embeddings.word_embeddings)

# let's get the first five values of the first embedding
model.embeddings.word_embeddings.weight[0][:5].detach().numpy()

Embedding(30522, 768, padding_idx=0)


array([-0.01018257, -0.06154883, -0.02649689, -0.0420608 ,  0.00116716],
      dtype=float32)

notes: 

- the layers are attributes of the `model` and they are organized and nested as can be seen when calling `print(model)` 
- we get the actual parameters of the model from a layer's "weigths" (weights is just the machine learning term for parameters)
- weights are $n$-dimensional arrays (called "tensors" in `pytorch` etc.) and we can index them just like numpy arrays
- we use `detach()` because the model and its weights (parameters) are tracked by the optimization algorithm, which we dont need when we only want to see the weight values

But the main reason we use BERT & Co. is to obtain contextualized embeddings.

## Contextualized embedding

To illustrate how contextualized embedding works in transformers, we will first look at how embeddings of the same word differ if their context differs.

Let's take two sentences what contain the word "bank" but use it with different meanings:

In [38]:
sentences = [
    "Today, I will hike along the bank of a river.",
    "Today, I will open a new bank account and deposit some money.",
]

To get the transformer embedding of the word "bank" in these two sentences, we need to follow three steps:

1. tokenize the texts and convert tokens into tokens IDs (to look-up their input embeddings)
2. process these inputs through the model
3. locate the embedding of the focal word in the two sentences.

#### 1) tokenize

The tokenizer converts the text into tokens and maps the tokens to token IDs

Token IDs indicate tokens' locations in the tokenizers vocabulary and hence the model's input embedding. 

In [39]:
inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)

In [40]:
inputs['input_ids']

tensor([[  101,  2651,  1010,  1045,  2097, 21857,  2247,  1996,  2924,  1997,
          1037,  2314,  1012,   102,     0,     0],
        [  101,  2651,  1010,  1045,  2097,  2330,  1037,  2047,  2924,  4070,
          1998, 12816,  2070,  2769,  1012,   102]])

We can "decode" these token IDs into their tokens:

In [41]:
tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

['[CLS]',
 'today',
 ',',
 'i',
 'will',
 'hike',
 'along',
 'the',
 'bank',
 'of',
 'a',
 'river',
 '.',
 '[SEP]',
 '[PAD]',
 '[PAD]']

Notes: 

- the `[CLS]` token is a special token used to summarize the information in a sequence (e.g., for classification tasks)
- the `[SEP]` token is the special "separator" token that indicates sequence boundaries
- the `[PAD]` token is the special "padding" token that is appended to sequences that are shorter than the other sequences in a batch to make the input rectengular (e.g., all rows have an equal number of columns)

In [42]:
# let's use the tokenizer to get the token ID of the focal word 
focal_word_id = tokenizer.convert_tokens_to_ids('bank')
focal_word_id

2924

In [46]:
# create maks that is true where input ID == focal word ID
mask = inputs['input_ids'] == focal_word_id

# show the positions of the focal word in the input IDs
np.where(mask)[1] # NOTE: "bank" has the the same position in the sequence of words in both sentences

array([8, 8])

In [47]:
# put tokens and mask side by side
for i, (iids, msk) in enumerate(zip(inputs['input_ids'], mask)):
    print("Sentence", i+1)
    print("-" * 13)
    for t, m in zip(tokenizer.convert_ids_to_tokens(iids), msk):
        print(t, bool(m), sep='\t')
        if t == '[SEP]':
            break
    print()

Sentence 1
-------------
[CLS]	False
today	False
,	False
i	False
will	False
hike	False
along	False
the	False
bank	True
of	False
a	False
river	False
.	False
[SEP]	False

Sentence 2
-------------
[CLS]	False
today	False
,	False
i	False
will	False
open	False
a	False
new	False
bank	True
account	False
and	False
deposit	False
some	False
money	False
.	False
[SEP]	False



#### 2) embed (process through model)

In [48]:
# get the intial emebdding of the focal word ("bank")
model.embeddings.word_embeddings.weight[focal_word_id].shape

torch.Size([768])

In [49]:
model.embeddings.word_embeddings.weight[focal_word_id][:5]

tensor([-0.0191, -0.0646, -0.0913, -0.0776, -0.0253], grad_fn=<SliceBackward0>)

In [50]:
with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)

**_Note:_** We use `torch.no_grad()` to disable gradient tracking, which is used for "back propagation" – the method used to optimize deep neural networks' parameters  

In [51]:
print(type(outputs))
# list the object's attributes
list(dict(outputs).keys())

<class 'transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions'>


['last_hidden_state', 'pooler_output', 'hidden_states']

In [52]:
outputs.hidden_states[3].shape

torch.Size([2, 16, 768])

In [53]:
# hiden states are the embeddings after each layer
len(outputs.hidden_states)

13

In [54]:
# the final embedding can be accessed like this: 
outputs.hidden_states[-1][0][0].shape

torch.Size([768])

In [None]:
# let's look at the shape:
outputs.hidden_states[-1].shape

torch.Size([2, 16, 768])

#### 3) get the words' contextualized embeddings 

In [56]:
# final transformer embeddings of bank in different contexts
embeddings = outputs.last_hidden_state[mask]

In [57]:
embeddings.shape

torch.Size([2, 768])

In [58]:
embeddings[:3]

tensor([[-0.0908, -0.6298,  0.1233,  ..., -0.5037, -1.1361,  0.4891],
        [ 1.2851,  0.0947,  0.7091,  ..., -0.5602, -0.4731,  0.0220]])

In [59]:
# compute cosine similarity between the two embeddings
cosine_similarity(embeddings[0].reshape(1, -1), embeddings[1].reshape(1, -1))

array([[0.47205228]], dtype=float32)

Below you can see that the similarity of "bank"'s transformer embedding deepends on the model layer we look at.

In [69]:
# iterate over all layers
for i, layer in enumerate(outputs.hidden_states):
    embeddings = layer[mask]
    similarity = cosine_similarity(embeddings[0].reshape(1, -1), embeddings[1].reshape(1, -1))[0][0]
    print(f'    layer {str(i).rjust(2)} :' if i>0 else 'input embed. :', f"{similarity:0.3f}")

input embed. : 1.000
    layer  1 : 0.711
    layer  2 : 0.632
    layer  3 : 0.508
    layer  4 : 0.450
    layer  5 : 0.426
    layer  6 : 0.418
    layer  7 : 0.409
    layer  8 : 0.395
    layer  9 : 0.374
    layer 10 : 0.410
    layer 11 : 0.467
    layer 12 : 0.472


**_Notes_:** 

- If the word "bank" was _not_ in the same position in both sentences, we would already get a similarity < 1.0 at the input step because the word's [positional embedding](https://www.kaggle.com/code/lorentzyeung/positional-embeddings-clearly-explained) would differ.
- If the focal word was _not_ as polysemious (i.e., its meaning would depend less on context), these similarity values would generally be higher. 