# BERT's Anatomy Step by Step: Positional Embeddings

In [None]:
import matplotlib
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'svg'

import torch
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoConfig, AutoTokenizer
from transformers import GPT2ForSequenceClassification

In [None]:
model_checkpoint = 'gpt2'

In [None]:
model = GPT2ForSequenceClassification.from_pretrained(model_checkpoint)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
config = AutoConfig.from_pretrained(model_checkpoint)

In [None]:
encoding = tokenizer.encode("let's tokenize something?", return_tensors="pt")

In [None]:
tokens = tokenizer.convert_ids_to_tokens(encoding.flatten())

In [None]:
model.transformer.wpe

In [None]:
model.transformer.wte

In [None]:
config.hidden_size  # size of the embeddings

In [None]:
config.max_position_embeddings  # max seq_len

In [None]:
seq_embedding = model.transformer.wte(encoding)
seq_embedding.shape   # (batch_size, seq_len, hidden_size)

In [None]:
positions = torch.arange(0, encoding.shape[-1])   # seq_len
positions = positions.reshape((1, encoding.shape[-1]))             # make it (batch_size, seq_len)
positions

In [None]:
pos_embedding_511 = model.transformer.wpe(positions)
pos_embedding_511.shape  # (batch_size, hidden_size)

In [None]:
seq_embedding + pos_embedding_511

In [None]:
positions = torch.arange(0, config.max_position_embeddings)   # seq_len
positions = positions.reshape((1, config.max_position_embeddings))             # make it (batch_size, seq_len)
positions
pos_embedding = model.transformer.wpe(positions)[0]
pos_embedding.detach().numpy()  # [batch_size, hidden_size]

In [None]:
similarity_matrix = cosine_similarity(pos_embedding.detach().numpy())
plt.imshow(similarity_matrix, cmap='Blues')  #, aspect='auto', extent=[0, max_len, 0, max_len])
# plt.colorbar()
plt.title('Position-wise Similarity of Positional Embeddings')
plt.xlabel('Position')
plt.ylabel('Position')
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['bottom'].set_visible(False)
plt.gca().spines['left'].set_visible(False)
plt.show()