Notebook for visualizing vocabulary embeddings, position embeddings, and contextualized embeddings using pretrained language representation models like BERT. The notebook uses bert-large-uncased-whole-word-masking.
 
Uses HuggingFace transformers, t-SNE from sklearn, and adjustText (https://github.com/Phlya/adjustText). 
 
When visualizing the vocabulary embeddings, the notebook uses 10,000 embeddings from the vocab (selected with hardcoded indices to avoid unused entries and most single-character subword units) to compute the visualization, then plots a subset of size 4,000.
 
For contextualized embeddings, the notebook computes embeddings from the final layer of BERT when run on sentences containing the same word type. Included with this notebook is a file containing 15,000 instances of the word "values" drawn from Wikipedia and books from Project Gutenberg. After running t-SNE on all 15,000, 750 instances are plotted with their partial sentence contexts. 

Finally, the absolute position embeddings are visualized by running t-SNE on the full set. 
 
Note: we often use more instances when running t-SNE than we do for visualization. This can help t-SNE to  produce a better transformation of the data.

Kevin Gimpel

2020


In [None]:
%matplotlib inline
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import sys
np.set_printoptions(threshold=sys.maxsize)


In [None]:
plt.rcParams['figure.figsize'] = [100, 60]

In [None]:
from adjustText import adjust_text

In [None]:
from transformers import BertTokenizer, BertModel, BertForMaskedLM

In [None]:
import logging
logging.basicConfig(level=logging.INFO)
# Load BERT.
model = BertModel.from_pretrained('bert-large-uncased-whole-word-masking')
# Set the model to eval mode.
model.eval()
# This notebook assumes CPU execution. If you want to use GPUs, put the model on cuda and modify subsequent code blocks.
#model.to('cuda')
# Load tokenizer.
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking')


In [None]:
# Save the BERT vocabulary to a file -- by default it will name this file "vocab.txt".
tokenizer.save_vocabulary(vocab_path='.')

In [None]:
# Get BERT's vocabulary embeddings.
wordembs = model.get_input_embeddings()

In [None]:
model.config.vocab_size

In [None]:
# Convert the vocabulary embeddings to numpy.
allinds = np.arange(0,model.config.vocab_size,1)
inputinds = torch.LongTensor(allinds)
bertwordembs = wordembs(inputinds).detach().numpy()

In [None]:
bertwordembs.shape

In [None]:
def loadLines(filename):
    print("Loading lines from file", filename)
    f = open(filename,'r')
    lines = np.array([])
    for line in f:
        lines = np.append(lines, line.rstrip())
    print("Done. ", len(lines)," lines loaded!")
    return lines

In [None]:
bertwords = loadLines('vocab.txt')

In [None]:
# Determine vocabulary to use for t-SNE/visualization. The indices are hard-coded based partially on inspection:
bert_char_indices_to_use = np.arange(999, 1063, 1)
bert_voc_indices_to_plot = np.append(bert_char_indices_to_use, np.arange(1996, 5932, 1))
bert_voc_indices_to_use = np.append(bert_char_indices_to_use, np.arange(1996, 11932, 1))

In [None]:
print(len(bert_voc_indices_to_plot))
print(len(bert_voc_indices_to_use))

In [None]:
print(bertwords[bert_voc_indices_to_use])

In [None]:
bert_voc_indices_to_use_tensor = torch.LongTensor(bert_voc_indices_to_use)
bert_word_embs_to_use = wordembs(bert_voc_indices_to_use_tensor).detach().numpy()

In [None]:
# Run t-SNE on the BERT vocabulary embeddings we selected:
mytsne_words = TSNE(n_components=2,early_exaggeration=12,verbose=2,metric='cosine',init='pca',n_iter=2500)
bert_word_embs_to_use_tsne = mytsne_words.fit_transform(bert_word_embs_to_use)

In [None]:
bert_words_to_plot = bertwords[bert_voc_indices_to_plot]
print(len(bert_words_to_plot))

In [None]:
# Plot the transformed BERT vocabulary embeddings:
fig = plt.figure()
alltexts = list()
for i, txt in enumerate(bert_words_to_plot):
    plt.scatter(bert_word_embs_to_use_tsne[i,0], bert_word_embs_to_use_tsne[i,1], s=0)
    currtext = plt.text(bert_word_embs_to_use_tsne[i,0], bert_word_embs_to_use_tsne[i,1], txt, family='sans-serif')
    alltexts.append(currtext)
    

# Save the plot before adjusting.
plt.savefig('viz-bert-voc-tsne10k-viz4k-noadj.pdf', format='pdf')
print('now running adjust_text')
# Using autoalign often works better in my experience, but it can be very slow for this case, so it's false by default below:
#numiters = adjust_text(alltexts, autoalign=True, lim=50)
numiters = adjust_text(alltexts, autoalign=False, lim=50)
print('done adjust text, num iterations: ', numiters)
plt.savefig('viz-bert-voc-tsne10k-viz4k-adj50.pdf', format='pdf')

plt.show

Now we will visualize contextualized embeddings.

In [None]:
# This function loads lines from a file, tokenizes them, and processes lines containing keyword, 
# up to a limit of maxLines lines. 
# It returns both the tokenized lines and the integer positions in those tokenized lines of the keyword.
def loadAndTokenizeLinesAndFindKeyword(filename, keyword, maxLines):
    print("Loading lines from file", filename)
    f = open(filename,'r')
    lines = []
    keywordIndices = []
    numSkipped = 0
    for line in f:
        # Tokenize input
        lineForBERT = "[CLS] " + line.rstrip() + " [SEP]"
        tokenized_text = tokenizer.tokenize(lineForBERT)
        if keyword in tokenized_text:
            keywordIndex = tokenized_text.index(keyword)
            lines.append(tokenized_text)
            keywordIndices.append(keywordIndex)
            if len(lines) >= maxLines:
                break
        else:
            print("Keyword \"", keyword, "\" not found in line: ", tokenized_text)
            numSkipped += 1
    print("Done. ", len(lines)," lines loaded, ", numSkipped, " lines skipped.")
    return lines, keywordIndices

In [None]:
keywordLines, keywordIndices = loadAndTokenizeLinesAndFindKeyword("values.books-wiki.15k.txt", "values", 15000)

In [None]:
# Now we will use BERT to encode the sentences we loaded and save the embeddings from the final layer 
# at the position of the keyword.
embs = np.empty((0,model.config.hidden_size), float)
# Go through all tokenized lines and keyword indices:
for tok, ind in zip(keywordLines, keywordIndices):
    #print(tok, ind)
    # Convert token to vocabulary indices
    indexed_tokens = tokenizer.convert_tokens_to_ids(tok)
    # segments_ids will hold indices associated with the first and second sentences in BERT.
    # We just use sentence A indices for all tokens:
    segments_ids = [0] * len(tok)
    # Convert inputs to PyTorch tensors
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])
    # Compute hidden states for each layer:
    with torch.no_grad():
        outputs = model(tokens_tensor, token_type_ids=segments_tensors)
        # The first element of the output holds the hidden states of the last layer of BERT.
        encoded_layers = outputs[0]
        # encoded_layers has shape (batch size, sequence length, model hidden dimension)
        assert tuple(encoded_layers.shape) == (1, len(indexed_tokens), model.config.hidden_size)
        # Get the hidden state for the keyword position, convert it to a numpy array, and add it to the embs matrix.
        embs = np.append(embs, [encoded_layers[0][ind][:].squeeze().numpy()], axis=0)
        

In [None]:
embs.shape

In [None]:
# Run t-SNE on the contextualized embeddings:
mytsne_tokens = TSNE(n_components=2,early_exaggeration=12,verbose=2,metric='cosine',init='pca',n_iter=2500)
embs_tsne = mytsne_tokens.fit_transform(embs)

In [None]:
# Create the list of strings to plot; these will be the keyword with partial context to either side.
keywordWithContext = []
# The window size is the (max) number of subword units on either side of the keyword to display.
windowSize = 5
# The following flag determines whether to merge partial-word units into single words when displaying the context.
mergeSubwordUnits = True
# The following flag determines whether to remove BERT boundary tokens like [CLS] and [SEP] when displaying the context.
removeBoundaryTokens = True
for txt, ind in zip(keywordLines, keywordIndices):
    startInd = ind - windowSize
    if startInd < 0:
        startInd = 0
    currKeywordWithContext = " ".join(txt[startInd:ind+windowSize+1])
    if mergeSubwordUnits:
        currKeywordWithContext = currKeywordWithContext.replace(" ##", "")
        currKeywordWithContext = currKeywordWithContext.replace("##", "")
    if removeBoundaryTokens:
        currKeywordWithContext = currKeywordWithContext.replace("[CLS] ", "")
        currKeywordWithContext = currKeywordWithContext.replace(" [SEP]", "")
    keywordWithContext.append(currKeywordWithContext)
    

In [None]:
# Print some sample keyword + context strings
keywordWithContext[49:58]

In [None]:
# For visualization, we will use only the first 750 instances.
keywordWithContextToPlot = keywordWithContext[0:750]
print(len(keywordWithContextToPlot))
print(keywordWithContextToPlot[0:3])

In [None]:
# Plot the keyword+context strings.
fig = plt.figure()
alltexts = list()
for i, txt in enumerate(keywordWithContextToPlot):
    plt.scatter(embs_tsne[i,0], embs_tsne[i,1], s=0)
    currtext = plt.text(embs_tsne[i,0], embs_tsne[i,1], txt, family='sans-serif')
    alltexts.append(currtext)
    
plt.savefig('viz-bert-ctx-values-viz750-noadj.pdf', format='pdf')
print('now running adjust_text')
#numiters = adjust_text(alltexts, autoalign=True, lim=50)
numiters = adjust_text(alltexts, autoalign=False, lim=50)
print('done adjust text, num iterations: ', numiters)
plt.savefig('viz-bert-ctx-values-viz750-adj.pdf', format='pdf')

plt.show

Next we will visualize the position embeddings.

In [None]:
# Get the position embedding module from the model.
posembs = 0
for name, module in model.named_modules():
    if name == "embeddings.position_embeddings":
        posembs = module

In [None]:
posembs

In [None]:
# Convert the position embeddings to numpy.
pos_allinds = np.arange(0,512,1)
pos_inputinds = torch.LongTensor(pos_allinds)
bertposembs = posembs(pos_inputinds).detach().numpy()

In [None]:
bertposembs.shape

In [None]:
# Run t-SNE on the position embeddings.
mytsne_pos = TSNE(n_components=2,early_exaggeration=12,verbose=2,metric='cosine',init='pca',n_iter=2500)
bertposembs_tsne = mytsne_pos.fit_transform(bertposembs)

In [None]:
# Generate strings corresponding to the positions.
bertpos_strings = (['{}'.format(i) for i in range(0, 512)])

In [None]:
# Using a smaller figure size, plot the position embeddings.
plt.rcParams['figure.figsize'] = [25, 15]
fig = plt.figure()
alltexts = list()
for i, txt in enumerate(bertpos_strings):
    plt.scatter(bertposembs_tsne[i,0], bertposembs_tsne[i,1], s=0)
    currtext = plt.text(bertposembs_tsne[i,0], bertposembs_tsne[i,1], txt, family='sans-serif')
    alltexts.append(currtext)
    
# We don't really need to use adjustText here since the position embeddings are well-separated and there are not too many of them.
plt.savefig('viz-bert-pos.pdf', format='pdf')
plt.show