<a href="https://colab.research.google.com/github/nidharap/Notebooks/blob/master/Word_Embeddings_BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


The objective of this notebook is to - 
1. Extract word embeddings from BERT-like models
3. Visualize these words vectors, stacking them against each other using a similarity metric i.e.
    * Calculate similarity on word vectors
    * Visualize in 2D/3D the similarity matrics using multi-dimensional scaling

In [1]:
#Install libraries
!pip install transformers
!pip install plotly==4.9.0
!pip install wmd

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/27/3c/91ed8f5c4e7ef3227b4119200fc0ed4b4fd965b1f0172021c25701087825/transformers-3.0.2-py3-none-any.whl (769kB)
[K     |████████████████████████████████| 778kB 2.8MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 15.6MB/s 
Collecting sentencepiece!=0.1.92
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 15.0MB/s 
[?25hCollecting tokenizers==0.8.1.rc1
[?25l  Downloading https://files.pythonhosted.org/packages/40/d0/30d5f8d221a0ed981a186c8eb986ce1c94e3a6e87f994eae9f4aa5250217/tokenizers-0.8.1rc1-cp36-cp36m-manylinux1_x86_64.whl (3.0MB

In [2]:
#imports
import torch
from transformers import BertTokenizer, BertModel  #RobertaModel, RobertaTokenizer 
import sys
import re
from collections import defaultdict
from sklearn.metrics.pairwise import euclidean_distances, cosine_distances
from scipy.spatial.distance import euclidean, pdist, squareform
from sklearn import manifold          #use this for MDS computation
import pandas as pd
import numpy as np

#visualization libs
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt
% matplotlib inline

#Used to calculation of word movers distance between sentence
from collections import Counter

#Library to calculate Relaxed-Word Movers distance
from wmd import WMD
from wmd import libwmdrelax

In [3]:
#Define some constants
PRETRAINED_MODEL = 'bert-large-uncased' #'roberta-large'
MAX_LEN = 15

In [4]:
#define some example sentences to look at word vectors
#I picked these sentences to see how if I really get different word vectors for "date" when I use them in different contenxt 
texts = [
"Joe took Alexandria out on a date.",
"What is your date of birth?",
]

#this defines what I would like highlighted when I visualize the word vectors
WORDS_OF_INTEREST = ['date']

In [5]:
#Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL) #RobertaTokenizer.from_pretrained(PRETRAINED_MODEL)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




In [6]:
# Create a function to tokenize a set of texts
def preprocessing_for_bert(data, tokenizer_obj):
    """Perform required preprocessing steps for pretrained BERT.
    @param    data (np.array): Array of texts to be processed.
    @return   input_ids (torch.Tensor): Tensor of token ids to be fed to a model.
    @return   attention_masks (torch.Tensor): Tensor of indices specifying which
                  tokens should be attended to by the model.
    @return   attention_masks_without_special_tok (torch.Tensor): Tensor of indices specifying which
                  tokens should be attended to by the model excluding the special tokens (CLS/SEP)
    """
    # Create empty lists to store outputs
    input_ids = []
    attention_masks = []

    # For every sentence...
    for sent in data:
        # `encode_plus` will:
        #    (1) Tokenize the sentence
        #    (2) Add the `[CLS]` and `[SEP]` token to the start and end
        #    (3) Truncate/Pad sentence to max length
        #    (4) Map tokens to their IDs
        #    (5) Create attention mask
        #    (6) Return a dictionary of outputs
        encoded_sent = tokenizer_obj.encode_plus(
            text=sent,  # Preprocess sentence
            add_special_tokens=True,        # Add `[CLS]` and `[SEP]`
            max_length=MAX_LEN,                  # Max length to truncate/pad
            pad_to_max_length=True,         # Pad sentence to max length
            truncation=True,              #Truncate longer seq to max_len
            return_attention_mask=True      # Return attention mask
            )
        
        # Add the outputs to the lists
        input_ids.append(encoded_sent.get('input_ids'))
        attention_masks.append(encoded_sent.get('attention_mask'))

    # Convert lists to tensors
    input_ids = torch.tensor(input_ids)
    attention_masks = torch.tensor(attention_masks)
    
    #lets create another mask that will be useful when we want to average all word vectors later
    #we would like to average across all word vectors in a sentence, but excluding the CLS and SEP token
    #create a copy
    attention_masks_without_special_tok = attention_masks.clone().detach()
    
    #set the CLS token index to 0 for all sentences 
    attention_masks_without_special_tok[:,0] = 0

    #get sentence lengths and use that to set those indices to 0 for each length
    #essentially, the last index for each sentence, which is the SEP token
    sent_len = attention_masks_without_special_tok.sum(1).tolist()

    #column indices to set to zero
    col_idx = torch.LongTensor(sent_len)
    #row indices for all rows
    row_idx = torch.arange(attention_masks.size(0)).long()
    
    #set the SEP indices for each sentence token to zero
    attention_masks_without_special_tok[row_idx, col_idx] = 0

    return input_ids, attention_masks, attention_masks_without_special_tok

In [7]:
#initialize model
#output_hidden_states = True will give us all hiddent states for all layers
model = BertModel.from_pretrained(PRETRAINED_MODEL, output_hidden_states=True)      #RobertaModel.from_pretrained(PRETRAINED_MODEL ,output_hidden_states = True)

#put this in eval mode so since we do not plan to do backprop and also any other special handling that it needs to do like dropout
model.eval();

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=434.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1344997306.0, style=ProgressStyle(descr…




In [8]:
#run sentences through the tokenizer
input_ids, attention_masks, attention_masks_without_special_tok = preprocessing_for_bert(texts, tokenizer)

In [9]:
#let's take a look at the attention masks. notice that there are less number of 1s in the second case
attention_masks, attention_masks_without_special_tok

(tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]]),
 tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
         [0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]]))

In [10]:
#call the model on the sentences
outputs = model(input_ids, attention_masks) #(tokenized_tensor, sent_tensor)
hidden_states = outputs[2]

print("Total hidden layers:", len(hidden_states))
print("First layer : hidden_states[0].shape ", hidden_states[0].shape)     # [batch_size x seq_length x vector_dim]

Total hidden layers: 25
First layer : hidden_states[0].shape  torch.Size([2, 15, 1024])


### Let's experiment with how to get tensors from different layers and stack them as needed

In [11]:
#get last 4 layers
torch.stack(hidden_states[-4:]).shape

torch.Size([4, 2, 15, 1024])

In [12]:
#concatenate last 4 layer outputs
torch.cat(hidden_states[-4:], dim=2).shape

torch.Size([2, 15, 4096])

In [13]:
#avg last 4 layer outputs
torch.stack(hidden_states[-4:]).mean(0).shape

torch.Size([2, 15, 1024])

In [14]:
#find mean across th 4 layers, and swap the batch_size and seq_len dim to access any token
torch.stack(hidden_states[-4:]).sum(0).permute(1,0,2).shape

torch.Size([15, 2, 1024])

In [15]:
def get_vector(hidden_layers_form_arch, token_index=0, mode='average', top_n_layers=4):
  '''
  retrieve vectors for a token_index from the top n layers and return a concatenated, averaged or summed vector 
  hidden_layers_form_arch: tuple returned by the transformer library
  token_index: index of the token for which a vector is desired
  mode=
        'average' : avg last n layers
        'concat': concatenate last n layers
        'sum' : sum last n layers
        'last': return embeddings only from last layer
        'second_last': return embeddings only from second last layer

  top_n_layers: number of top layers to concatenate/ average / sum
  '''
  if mode == 'concat':
    #concatenate last 4 layer outputs -> returns [batch_size x seq_len x dim]
    #permute(1,0,2) swaps the the batch and seq_len dim , making it easy to return all the vectors for a particular token position
    return torch.cat(hidden_layers_form_arch[-top_n_layers:], dim=2).permute(1,0,2)[token_index]
  
  if mode == 'average':
    #avg last 4 layer outputs -> returns [batch_size x seq_len x dim]
    return torch.stack(hidden_layers_form_arch[-top_n_layers:]).mean(0).permute(1,0,2)[token_index]


  if mode == 'sum':
    #sum last 4 layer outputs -> returns [batch_size x seq_len x dim]
    return torch.stack(hidden_layers_form_arch[-top_n_layers:]).sum(0).permute(1,0,2)[token_index]


  if mode == 'last':
    #last layer output -> returns [batch_size x seq_len x dim]
    return hidden_layers_form_arch[-1:][0].permute(1,0,2)[token_index]

  if mode == 'second_last':
    #last layer output -> returns [batch_size x seq_len x dim]
    return hidden_layers_form_arch[-2:-1][0].permute(1,0,2)[token_index]

  return None

### Let's test our function

In [16]:
get_vector(hidden_states, token_index=0, mode='concat', top_n_layers=4).shape

torch.Size([2, 4096])

In [17]:
get_vector(hidden_states, token_index=0, mode='sum', top_n_layers=4).shape

torch.Size([2, 1024])

In [18]:
#Lengths of each sentence
sent_lengths = attention_masks.sum(1).tolist()
sent_lengths

[10, 9]

In [19]:
#get the tokenized version of each sentence (text form, to label things in the plot)
tokenized_sents = [tokenizer.convert_ids_to_tokens(i) for i in input_ids]
tokenized_sents[0]

['[CLS]',
 'joe',
 'took',
 'alexandria',
 'out',
 'on',
 'a',
 'date',
 '.',
 '[SEP]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]']

In [20]:
def plt_dists(dists, labels, dims=2, words_of_interest=[], title=""):
  '''
  Plot distances using MDS in 2D/3D 
  dists: precomputed distance matrix
  labels: labels to display on the plot
  dims: 2/3 for 2 or 3 dimensional plot, defaults to 2 for any other value passed
  words_of_interest: list of words to highlight with a different color
  title: title for the plot
  '''
  cnt_dict = dict()
  color = list()

  #separate colors for words that are in words_of_interest vs other
  #each word will have a _SentenceNumber at the end to differentiate the words coming in from different sentences
  for v in labels:
    found = False
    for wrd_int in words_of_interest:
      if wrd_int in v:
        found = True
        break
      
    if found:
      color.append(1)
    else:
      color.append(0)

  #https://community.plotly.com/t/plotly-colours-list/11730/6
  colorscale = [[0, 'darkcyan'], [1, 'white']]

  #dists is precomputed using cosine similarity and passed
  #calculate MDS with number of dims passed
  mds = manifold.MDS(n_components=dims, dissimilarity="precomputed", random_state=60, max_iter=90000)
  results = mds.fit(dists)

  #get coodinates for each point
  coords = results.embedding_

  #plot
  if dims == 3:
    fig = go.Figure(data=[go.Scatter3d(
        x=coords[:, 0],
        y=coords[:, 1],
        z=coords[:, 2],
        mode='markers+text',
        textposition="top center",
        text=labels,
        marker=dict(
            size=10,
            color=color,
            colorscale=colorscale,
            opacity=0.8,
            
        )
    )])
  else:
    fig = go.Figure(data=[go.Scatter(
        x=coords[:, 0],
        y=coords[:, 1],
        mode='markers+text',
        text=labels,
        textposition="top center",
        marker=dict(
            size=12,
            color=color,
            colorscale=colorscale,
            opacity=0.8,
            
        )
    )])

  fig.update_layout(template="plotly_dark")
  if title!="":
    fig.update_layout(title_text=title)
  fig.show()

In [21]:
def eval_vecs(input_hidden_states, input_tokenized_sents, mode='concat', top_n_layers=4, viz_dims=2, words_with_diff_color=WORDS_OF_INTEREST):
  '''
  function to get a vectors for each word in each sentence, add the sentence number to the end of each word
  calculate cosine distance between each pair of words and then pass it to the visualization function

  inputs:
  input_hidden_states: hiddent states retrieved from a BERT-like model
  input_tokenized_sents: tokenized sentences, used to assign labels for each point on the plot
  model:  'average' : avg last n layers
          'concat': concatenate last n layers
          'sum' : sum last n layers
          'last':  embeddings only from last layer
          'second_last':  embeddings only from second last layer
  top_n_layers: top n layers to use for concat/sum etc.
  viz_dims: 2/3 for 2D or 3D plot
  words_with_diff_color: words that should be highlighed with different color on the plot
  '''
  vecs = list()
  labels = list()
  for token_ind in range(MAX_LEN):
    if token_ind == 0:
      #ignore CLS
      continue
    vectors = get_vector(input_hidden_states, token_index=token_ind, mode=mode, top_n_layers=top_n_layers)
    for sent_ind, sent_len in enumerate(sent_lengths):
      if token_ind < sent_len-1:
        #ignore SEP which will be at the last index of each sentence
        vecs.append(vectors[sent_ind])
        labels.append(input_tokenized_sents[sent_ind][token_ind]+"_"+str(sent_ind))
    
  #create a numpy matrix to pass to cosine distance
  mat = torch.stack(vecs).detach().numpy()
  #call the plot function on the cosine distance matrix
  plt_dists(cosine_distances(mat), labels=labels, dims=viz_dims, words_of_interest=words_with_diff_color, title='Method: {}'.format(mode))

In [22]:
#check if sum and average are the same
sm = get_vector(hidden_states, token_index=0, mode='sum', top_n_layers=4)
av = get_vector(hidden_states, token_index=0, mode='average', top_n_layers=4)

torch.eq(sm, av)

tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])

### Let's look at the word vectors 

In [23]:
MODE = 'concat'
eval_vecs(hidden_states, tokenized_sents, mode='concat')

In [24]:
#we can look at this using a 3D plot too
eval_vecs(hidden_states, tokenized_sents, mode='concat', viz_dims=3)

In [25]:
MODE = 'sum'
eval_vecs(hidden_states, tokenized_sents, mode=MODE)

In [26]:
# MODE = 'sum'
# eval_vecs(hidden_states, tokenized_sents, mode=MODE, viz_dims=3)

In [27]:
MODE = 'average'
eval_vecs(hidden_states, tokenized_sents, mode=MODE)

In [28]:
MODE = 'last'
eval_vecs(hidden_states, tokenized_sents, mode=MODE)

In [29]:
MODE = 'second_last'
eval_vecs(hidden_states, tokenized_sents, mode=MODE)