
## The following is what will be covered in this notebook (You do not need a GPU to run this notebook)
1. Extracting sentence embeddings from pretrained BERT-like models
2. Visualize these sentence emebddings, stacking them against each other using a distance metric i.e.
    * Calculate distance between sentence vectors
    * Visualize in 2D/3D the distance matrics using multi-dimensional scaling
3. Extract word vectors from BERT-like models, and use Word Movers Distance (WMD: http://proceedings.mlr.press/v37/kusnerb15.pdf) to calculate distance between sentences
    * Calculate distance between sentences by applying WMD
    * Visualize in 2D/3D the distance matrics using multi-dimensional scaling
4. Load a context specific finetuned model to understand how sentence similarity changes based on the corpus used to trian

**Note: Similarity of two sentences is very subjective. Two sentences could be very similar in one context, and could mean something very different in other contexts. Let us pick sentiment as a way to evaluate these vectors. In other words, lets see how close these sentneces land up in the contexts of their sentiment.**

In [None]:
# Install libraries
# ! pip install transformers
# ! pip install plotly==4.9.0
# ! pip install wmd

In [None]:
import re
import sys
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# visualization libs
import plotly.express as px
import plotly.graph_objects as go

# imports
import torch
from scipy.spatial.distance import euclidean, pdist, squareform
from sklearn import manifold  # use this for MDS computation
from sklearn.metrics.pairwise import cosine_distances, euclidean_distances
from transformers import RobertaModel, RobertaTokenizer

% matplotlib inline

# Used to calculation of word movers distance between sentence
from collections import Counter

# Library to calculate Relaxed-Word Movers distance
from wmd import WMD, libwmdrelax

In [None]:
# Define some constants
PRETRAINED_MODEL = "roberta-base"  #'bert-large-uncased'
MAX_LEN = 512

In [None]:
# Initialize tokenizer
tokenizer = RobertaTokenizer.from_pretrained(
    PRETRAINED_MODEL
)  # BertTokenizer.from_pretrained(PRETRAINED_MODEL)

In [None]:
# Create a function to tokenize a set of texts
def preprocessing_for_bert(data, tokenizer_obj, max_len=MAX_LEN):
    """Perform required preprocessing steps for pretrained BERT.
    @param    data (np.array): Array of texts to be processed.
    @return   input_ids (torch.Tensor): Tensor of token ids to be fed to a model.
    @return   attention_masks (torch.Tensor): Tensor of indices specifying which
                  tokens should be attended to by the model.
    @return   attention_masks_without_special_tok (torch.Tensor): Tensor of indices specifying which
                  tokens should be attended to by the model excluding the special tokens (CLS/SEP)
    """
    # Create empty lists to store outputs
    input_ids = []
    attention_masks = []

    # For every sentence...
    for sent in data:
        # `encode_plus` will:
        #    (1) Tokenize the sentence
        #    (2) Add the `[CLS]` and `[SEP]` token to the start and end
        #    (3) Truncate/Pad sentence to max length
        #    (4) Map tokens to their IDs
        #    (5) Create attention mask
        #    (6) Return a dictionary of outputs
        encoded_sent = tokenizer_obj.encode_plus(
            text=sent,  # Preprocess sentence
            add_special_tokens=True,  # Add `[CLS]` and `[SEP]`
            max_length=max_len,  # Max length to truncate/pad
            pad_to_max_length=True,  # Pad sentence to max length
            truncation=True,  # Truncate longer seq to max_len
            return_attention_mask=True,  # Return attention mask
        )

        # Add the outputs to the lists
        input_ids.append(encoded_sent.get("input_ids"))
        attention_masks.append(encoded_sent.get("attention_mask"))

    # Convert lists to tensors
    input_ids = torch.tensor(input_ids)
    attention_masks = torch.tensor(attention_masks)

    # lets create another mask that will be useful when we want to average all word vectors later
    # we would like to average across all word vectors in a sentence, but excluding the CLS and SEP token
    # create a copy
    attention_masks_without_special_tok = attention_masks.clone().detach()

    # set the CLS token index to 0 for all sentences
    attention_masks_without_special_tok[:, 0] = 0

    # get sentence lengths and use that to set those indices to 0 for each length
    # essentially, the last index for each sentence, which is the SEP token
    sent_len = attention_masks_without_special_tok.sum(1).tolist()

    # column indices to set to zero
    col_idx = torch.LongTensor(sent_len)
    # row indices for all rows
    row_idx = torch.arange(attention_masks.size(0)).long()

    # set the SEP indices for each sentence token to zero
    attention_masks_without_special_tok[row_idx, col_idx] = 0

    return input_ids, attention_masks, attention_masks_without_special_tok

In [None]:
# initialize model
# output_hidden_states = True will give us all hiddenn states for all layers
pretrained_model = RobertaModel.from_pretrained(
    PRETRAINED_MODEL, output_hidden_states=True
)
# put this in eval mode so since we do not plan to do backprop
pretrained_model.eval();

# The data
We pick Four sentences -
  * Two from the IMDB 50k movie reviews dataset (ensuring we dont pick from the training set we eventually use to finetune)
  * Two from a a dataset from a completely different domain, Amazon fine food reviews dataset

The idea is to stack these 4 random sentences against each other, both from the base pretrained models, and from a finetuned model. This will allow us to evaluate if the model that is finetuned becomes biased towards the sentiment, shedding some of its understanding of other forms of those words/sentences 

You can download these datasets from here:

https://www.kaggle.com/snap/amazon-fine-food-reviews

https://www.kaggle.com/atulanandjha/imdb-50k-movie-reviews-test-your-bert


In [None]:
### Lets pick the sentences that we would run through and visualize distance/similarity
# List of tupes :
# (sentence, label_id)
# label_id == 0 == negative
# label_id == 1 == positive

sents_and_labs = [
    (
        "This taffy is so good.  It is very soft and chewy.  The flavors are amazing.  I would definitely recommend you buying it.  Very satisfying!!",
        1,
    ),
    # ('This is a good film. This is very funny. Yet after this film there were no good Ernest films!', 1),
    (
        "Just love the interplay between two great characters of stage & screen - Veidt & Barrymore",
        1,
    ),
    (
        "Hated it with all my being. Worst movie ever. Mentally- scarred. Help me. It was that bad.TRUST ME!!!",
        0,
    ),
    (
        "This oatmeal is not good. Its mushy, soft, I don't like it. Quaker Oats is the way to go.",
        0,
    ),
]

sents = [s for s, l in sents_and_labs]
sents

In [None]:
def get_preds(sentences, tokenizer_obj, model_obj):
    """
    Quick function to extract hidden states and masks from the sentences and model passed
    """
    # Run the sentences through tokenizer
    input_ids, att_msks, attention_masks_wo_special_tok = preprocessing_for_bert(
        sentences, tokenizer_obj
    )
    # Run the sentences through the model
    outputs = model_obj(input_ids, att_msks)

    # Lengths of each sentence
    sent_lens = att_msks.sum(1).tolist()

    # calculate unique vocab
    # #get the tokenized version of each sentence (text form, to label things in the plot)
    tokenized_sents = [tokenizer_obj.convert_ids_to_tokens(i) for i in input_ids]
    return {
        "hidden_states": outputs[2],
        "pooled_output": outputs[1],
        "attention_masks": att_msks,
        "attention_masks_without_special_tok": attention_masks_wo_special_tok,
        "tokenized_sents": tokenized_sents,
        "sentences": sentences,
        "sent_lengths": sent_lens,
    }

In [None]:
pretrained_preds = get_preds(sents, tokenizer, pretrained_model)

## Let's get sentence embedding and visualize using cosine distance

###### Left the below code in there which helped me understand how to apply a 2D mast to a 3d tensor

In [None]:
# https://stackoverflow.com/questions/61956893/how-to-mask-a-3d-tensor-with-2d-mask-and-keep-the-dimensions-of-original-vector
# example to apply a 2d mask to the 3d tensor
# X = torch.arange(24).view(4, 3, 2)
# print(X)

# mask = torch.zeros((4, 3), dtype=torch.int64)  # or dtype=torch.ByteTensor
# mask[0, 0] = 1
# mask[1, 1] = 1
# mask[3, 0] = 1
# print('Mask: ', mask)

# # Add a dimension to the mask tensor and expand it to the size of original tensor
# mask_ = mask.unsqueeze(-1).expand(X.size())
# print(mask_)

# # Select based on the new expanded mask
# Y = X * mask_
# print(Y)

In [None]:
def plt_dists(
    dists,
    sentences_and_labels,
    dims=2,
    title="",
    xrange=[-0.5, 0.5],
    yrange=[-0.5, 0.5],
    zrange=[-0.5, 0.5],
):
    """
    Plot distances using MDS in 2D/3D
    dists: precomputed distance matrix
    sentences_and_labels: tuples of sentence and label_ids
    dims: 2/3 for 2 or 3 dimensional plot, defaults to 2 for any other value passed
    words_of_interest: list of words to highlight with a different color
    title: title for the plot
    """
    # get the sentence text and labels to pass to the plot
    sents, color = zip(*sentences_and_labels)

    # https://community.plotly.com/t/plotly-colours-list/11730/6
    colorscale = [
        [0, "deeppink"],
        [1, "yellow"],
    ]  # , [2, 'greens'], [3, 'reds'], [4, 'blues']]

    # dists is precomputed using cosine similarity/other other metric and passed
    # calculate MDS with number of dims passed
    mds = manifold.MDS(
        n_components=dims, dissimilarity="precomputed", random_state=60, max_iter=90000
    )
    results = mds.fit(dists)

    # get coodinates for each point
    coords = results.embedding_

    # plot 3d/2d
    if dims == 3:
        fig = go.Figure(
            data=[
                go.Scatter3d(
                    x=coords[:, 0],
                    y=coords[:, 1],
                    z=coords[:, 2],
                    mode="markers+text",
                    textposition="top center",
                    text=sents,
                    marker=dict(
                        size=12, color=color, colorscale=colorscale, opacity=0.8
                    ),
                )
            ]
        )
    else:
        fig = go.Figure(
            data=[
                go.Scatter(
                    x=coords[:, 0],
                    y=coords[:, 1],
                    text=sents,
                    textposition="top center",
                    mode="markers+text",
                    marker=dict(
                        size=12, color=color, colorscale=colorscale, opacity=0.8
                    ),
                )
            ]
        )

    fig.update_layout(template="plotly_dark")
    if title != "":
        fig.update_layout(title_text=title)
        fig.update_layout(
            titlefont=dict(
                family="Courier New, monospace", size=14, color="cornflowerblue"
            )
        )

    # update the axes ranges
    fig.update_layout(yaxis=dict(range=yrange))
    fig.update_layout(xaxis=dict(range=xrange))
    fig.update_traces(textfont_size=10)

    # TO DO: fix this. I could not get this to work. somehow the library does not like the zaxis.
    # if dims==3:
    # fig.update_layout(zaxis=dict(range=zrange))
    fig.show()

In [None]:
def get_word_vectors(
    hidden_layers_form_arch, token_index=None, mode="average", top_n_layers=4
):
    """
    retrieve vectors for all tokens from the top n layers and return a concatenated, averaged or summed vector
    hidden_layers_form_arch: tuple returned by the transformer library
    token_index: None/Index:
      If None: Returns all the tokens
      If Index: Returns vectors for that index in each sentence

    mode=
          'average' : avg last n layers
          'concat': concatenate last n layers
          'sum' : sum last n layers
          'last': return embeddings only from last layer
          'second_last': return embeddings only from second last layer

    top_n_layers: number of top layers to concatenate/ average / sum
    """

    vecs = None
    if mode == "concat":
        vecs = torch.cat(hidden_layers_form_arch[-top_n_layers:], dim=2)

    if mode == "average":
        vecs = torch.stack(hidden_layers_form_arch[-top_n_layers:]).mean(0)

    if mode == "sum":
        vecs = torch.stack(hidden_layers_form_arch[-top_n_layers:]).sum(0)

    if mode == "last":
        vecs = hidden_layers_form_arch[-1:][0]

    if mode == "second_last":
        vecs = hidden_layers_form_arch[-2:-1][0]

    if vecs is not None and token_index:
        # if a token index is passed, return values for a particular index in the sequence instead of vectors for all
        return vecs.permute(1, 0, 2)[token_index]
    return vecs

In [None]:
def get_sent_vectors(input_states, att_mask):
    """
    get a sentence vector by averaging over all word vectors -> this could come from any layers or averaged themselves (see get_all_token_vectors function)
    input_states: [batch_size x seq_len x vector_dims] -> e.g. output from  hidden stats from a particular layer
    att_mask: attention mask passed should have already maseked the special tokens too i.e. CLS/SEP/<s>/special tokens masked out with 0 -> [batch_size x max_seq_length]
    ref: https://stackoverflow.com/questions/61956893/how-to-mask-a-3d-tensor-with-2d-mask-and-keep-the-dimensions-of-original-vector
    """

    # print(input_states.shape) #-> [batch_size x seq_len x vector_dim]

    # Let's get sentence lengths for each sentence
    sent_lengths = att_mask.sum(
        1
    )  # att_mask has a 1 against each valid token and 0 otherwise

    # create a new 3rd dim and broadcast the attention mask across it -> this will allow us to use this mask with the 3d tensor input_hidden_states
    att_mask_ = att_mask.unsqueeze(-1).expand(input_states.size())

    # use mask to 0 out all the values against special tokens like CLS, SEP , <s> using mask
    masked_states = input_states * att_mask_

    # calculate average
    sums = masked_states.sum(1)
    avg = sums / sent_lengths[:, None]
    return avg

In [None]:
def eval_vectors(
    model_output,
    sentences_and_labels,
    wrd_vec_mode="concat",
    wrd_vec_top_n_layers=4,
    viz_dims=2,
    sentence_emb_mode="average_word_vectors",
    title_prefix=None,
    plt_xrange=[-0.05, 0.05],
    plt_yrange=[-0.05, 0.05],
    plt_zrange=[-0.05, 0.05],
):
    """
    Get vectors for all sentences and visualize them based on cosine distance between them

    model_output: model output extracted as a dictionary from get_preds function
    sentences_and_labels: tuple of sentence and labels_ids
    att_msk: attention mask that also marks the special tokens (CLS/SEP etc.) as 0
    mode=
          'average' : avg last n layers
          'concat': concatenate last n layers
          'sum' : sum last n layers
          'last': return embeddings only from last layer
          'second_last': return embeddings only from second last layer
    viz_dims:2/3 for 2D/3D plot
    title_prefix: String to add before the descriptive title. Can be used to add model name etc.
    """
    title_wrd_emv = "{} across {} layers".format(wrd_vec_mode, wrd_vec_top_n_layers)

    # get word vectors for all words in the sentence
    if sentence_emb_mode == "average_word_vectors":
        title_sent_emb = (
            "average(word vectors in the sentence); Sentence Distance: Cosine"
        )
        word_vecs_across_sent = get_word_vectors(
            model_output["hidden_states"],
            mode=wrd_vec_mode,
            token_index=None,
            top_n_layers=wrd_vec_top_n_layers,
        )  # returns [batch_size x seq_len x vector_dim]
        sent_vecs = get_sent_vectors(
            word_vecs_across_sent, model_output["attention_masks_without_special_tok"]
        )
    else:
        title_sent_emb = "First tok (CLS) vector; Sentence Distance: Cosine"
        # Get the pooled output from the first token (e.g. CLS token in case of BERT)

        # Note from https://huggingface.co/transformers/model_doc/bert.html#bertmodel
        # This output is usually not a good summary of the semantic content of the
        # input, you’re often better with averaging or
        # pooling the sequence of hidden-states for the whole input sequence.
        print("inside")
        sent_vecs = model_output["pooled_output"]  # vector

    if title_prefix:
        final_title = "{} Word Vec: {}; Sentence Vector: {}".format(
            title_prefix, title_wrd_emv, title_sent_emb
        )
    else:
        final_title = "Word Vec: {}; Sentence Vector: {}".format(
            title_wrd_emv, title_sent_emb
        )
    mat = sent_vecs.detach().numpy()
    plt_dists(
        cosine_distances(mat),
        sentences_and_labels=sentences_and_labels,
        dims=viz_dims,
        title=final_title,
        xrange=plt_xrange,
        yrange=plt_yrange,
        zrange=plt_zrange,
    )

In [None]:
eval_vectors(
    pretrained_preds,
    sents_and_labs,
    wrd_vec_mode="concat",
    sentence_emb_mode="average_word_vectors",
    plt_xrange=[-0.03, 0.03],
    plt_yrange=[-0.03, 0.03],
    title_prefix="Pretrained model:",
)

In [None]:
eval_vectors(
    pretrained_preds,
    sents_and_labs,
    wrd_vec_mode="concat",
    sentence_emb_mode="pooled_output",
    plt_xrange=[-0.03, 0.03],
    plt_yrange=[-0.03, 0.03],
    title_prefix="Pretrained model:",
)

In [None]:
eval_vectors(
    pretrained_preds,
    sents_and_labs,
    wrd_vec_mode="average",
    sentence_emb_mode="average_word_vectors",
    title_prefix="Pretrained model:",
    plt_xrange=[-0.03, 0.03],
    plt_yrange=[-0.03, 0.03],
)

In [None]:
eval_vectors(
    pretrained_preds,
    sents_and_labs,
    wrd_vec_mode="second_last",
    sentence_emb_mode="average_word_vectors",
    plt_xrange=[-0.03, 0.04],
    plt_yrange=[-0.03, 0.03],
    title_prefix="Pretrained model:",
)

## Let's do the same, but this time with Word Movers Distance
1. Link to the paper: http://www.cs.cornell.edu/~kilian/papers/wmd_metric.pdf
2. Implementation being modified from https://github.com/src-d/wmd-relax

In [None]:
def get_vector_for_each_token_position(
    hidden_layers_form_arch, token_index=0, mode="average", top_n_layers=4
):
    """
    retrieve vectors for a token_index from the top n layers and return a concatenated, averaged or summed vector
    hidden_layers_form_arch: tuple returned by the transformer library
    token_index: index of the token for which a vector is desired
    mode=
          'average' : avg last n layers
          'concat': concatenate last n layers
          'sum' : sum last n layers
          'last': return embeddings only from last layer
          'second_last': return embeddings only from second last layer

    top_n_layers: number of top layers to concatenate/ average / sum
    """
    if mode == "concat":
        # concatenate last 4 layer outputs -> returns [batch_size x seq_len x dim]
        # permute(1,0,2) swaps the the batch and seq_len dim , making it easy to return all the vectors for a particular token position
        return torch.cat(hidden_layers_form_arch[-top_n_layers:], dim=2).permute(
            1, 0, 2
        )[token_index]

    if mode == "average":
        # avg last 4 layer outputs -> returns [batch_size x seq_len x dim]
        return (
            torch.stack(hidden_layers_form_arch[-top_n_layers:])
            .mean(0)
            .permute(1, 0, 2)[token_index]
        )

    if mode == "sum":
        # sum last 4 layer outputs -> returns [batch_size x seq_len x dim]
        return (
            torch.stack(hidden_layers_form_arch[-top_n_layers:])
            .sum(0)
            .permute(1, 0, 2)[token_index]
        )

    if mode == "last":
        # last layer output -> returns [batch_size x seq_len x dim]
        return hidden_layers_form_arch[-1:][0].permute(1, 0, 2)[token_index]

    if mode == "second_last":
        # last layer output -> returns [batch_size x seq_len x dim]
        return hidden_layers_form_arch[-2:-1][0].permute(1, 0, 2)[token_index]

    return None

In [None]:
def build_word_embedding_lookup(
    model_output, wrd_vec_mode="concat", top_n_layers=4, max_len=MAX_LEN
):
    """
    build a embedding lookup - this will be needed when we do need to pull up vectors for any word while calculating wmd
    model_output: model output extracted as a dictionary from get_preds function; should include 'hidden_states', 'tokenized_sents', 'sent_lengths'
    wrd_vec_mode: concat/average/sum/last/second_last - way to extract word embeddings from the architecture
    top_n_layers: number of layers to work on to get word vectors using the wrd_vec_mode
    max_len: max length of the sentence for the architecture
    returns:
      vecs: a dict with keys as tokens and sentence number (e.g. date in sent 0 becomes date_0), and values as vectors extracted from bert like models
      documents: dictionary with sentence number as key and tokens like date_0 joined with a space as a string
    """
    vecs = dict()
    documents = dict()

    for token_ind in range(max_len):
        if token_ind == 0:
            # ignore CLS
            continue

        vectors = get_vector_for_each_token_position(
            model_output["hidden_states"],
            token_index=token_ind,
            mode=wrd_vec_mode,
            top_n_layers=top_n_layers,
        )
        for sent_ind, sent_len in enumerate(model_output["sent_lengths"]):
            if token_ind < sent_len - 1:  # ignore SEP which will be at sent_len-1 index
                txt = (
                    model_output["tokenized_sents"][sent_ind][token_ind]
                    + "_"
                    + str(sent_ind)
                )

                # store the token and its vector -> this will be our lookup storage for vectors
                vecs[txt] = vectors[sent_ind].detach().numpy()

                # store this so that we can do comparisons
                if sent_ind not in documents:
                    documents[sent_ind] = txt
                else:
                    documents[sent_ind] += " " + txt
    return vecs, documents

In [None]:
# Modified from https://github.com/src-d/wmd-relax/blob/master/wmd/__init__.py
# class to extract and calculate word movers distance using bert
class SimilarityWMD(object):
    def __init__(self, embedding_dict, sklearn_euclidean_distances=True, **kwargs):
        """
        :param embedding_dict: a dictionary to look up vectors 
        :param only_alpha: Indicates whether only alpha tokens must be used.
        :param frequency_processor: The function which is applied to raw \
                                    token frequencies.

        :type frequency_processor: callable
        """

        self.frequency_processor = kwargs.get(
            "frequency_processor", lambda t, f: np.log(1 + f)
        )
        self.embedding_dict = embedding_dict
        # get embed size
        self.emb_size = self.embedding_dict[next(iter(self.embedding_dict))].shape[0]
        self.sklearn_euclidean_distances = sklearn_euclidean_distances

    def _get_normalized_item(self, item):
        """
        get id and find a vector for the corresponding id in the embedding lookup
        """
        v = self.embedding_dict[item]
        return v / v.sum()

    def _dist_fn(self, u, v):
        return libwmdrelax.emd(u, v, self.dists)

    def _calc_euclidean_distances(self, evec):
        if self.sklearn_euclidean_distances:
            # call sklearn.metrics.pairwise.euclidean_distances
            return euclidean_distances(evec)

        evec_sqr = (evec * evec).sum(axis=1)
        dists = evec_sqr - 2 * evec.dot(evec.T) + evec_sqr[:, np.newaxis]
        dists[dists < 0] = 0
        dists = np.sqrt(dists)
        for i in range(len(dists)):
            dists[i, i] = 0
        return dists

    def compute_similarity(self, docs):
        """
        Calculates the similarity between two spaCy documents. Extracts the
        nBOW from them and evaluates the WMD.
        :return: The calculated similarity.
        :rtype: float.
        """

        # {'word1': 0.6931471805599453,...}
        # generates word -> freq mapping for each doc
        docs_nbow = [self._convert_document(d) for d in docs]

        # get vocab with indices for each
        # {239326000841: 0, 286393583696: 1, ...}
        vocabulary = set()
        for distribution in docs_nbow:
            vocabulary = vocabulary.union(set(distribution))

        vocabulary = {w: i for i, w in enumerate(sorted(vocabulary))}

        """
        #generate nbow
        e.g.
        [0.14285715 0.14285715 0.         0.         0.         0.
        0.14285715 0.14285715 0.         0.14285715 0.         0.14285715
        0.         0.14285715]
        """
        weights = list()
        for d in docs_nbow:
            weights.append(self._generate_weights(d, vocabulary))

        evec = np.zeros((len(vocabulary), self.emb_size), dtype=np.float32)

        for w, i in vocabulary.items():
            evec[i] = self._get_normalized_item(w)

        # calculate euclidean_distances between all pairs of vectors
        self.dists = self._calc_euclidean_distances(evec)

        # calculate word movers distance for all our sentences
        wmd_dists = pdist(weights, self._dist_fn)

        # return a datafrrame NxN (N = number of sentences) with distances between each pair
        # return pd.DataFrame(squareform(wmd_dists), index=docs, columns=docs)
        return squareform(wmd_dists)

    def _convert_document(self, doc):
        wrds = defaultdict(int)
        for t in doc.split():
            wrds[t] += 1
        return {t: self.frequency_processor(t, v) for t, v in wrds.items()}

    def _generate_weights(self, doc, vocabulary):
        w = np.zeros(len(vocabulary), dtype=np.float32)
        for t, v in doc.items():
            w[vocabulary[t]] = v
        w /= w.sum()
        return w

In [None]:
def eval_using_wmd(
    model_output,
    sentences_and_labels,
    wrd_vec_mode="concat",
    viz_dims=2,
    wrd_vec_top_n_layers=4,
    title_prefix=None,
    plt_xrange=[-0.03, 0.03],
    plt_yrange=[-0.03, 0.03],
    plt_zrange=[-0.05, 0.05],
):
    """
    model_output: model output extracted as a dictionary from get_preds function
    sentences_and_labels: tuple of sentence and labels_ids
    wrd_vec_top_n_layers: number of layers to use while extracting word embeddings
    wrd_vec_mode=
          'average' : avg last n layers
          'concat': concatenate last n layers
          'sum' : sum last n layers
          'last': return embeddings only from last layer
          'second_last': return embeddings only from second last layer
    viz_dims:2/3 for 2D/3D plot
    title_prefix: String to add before the descriptive title. Can be used to add model name etc.
    """
    # get all vectors for all words in each sentence
    vecs, documents = build_word_embedding_lookup(
        model_output, wrd_vec_mode=wrd_vec_mode, top_n_layers=wrd_vec_top_n_layers
    )

    # calculate the word movers distance
    dist_matrix = SimilarityWMD(vecs).compute_similarity(
        [documents[i] for i in range(len(documents))]
    )

    title_wrd_emv = "{} across {} layers".format(wrd_vec_mode, wrd_vec_top_n_layers)

    if title_prefix:
        final_title = "{} Word Vec: {}; Sentence Distance: Word Movers Distance".format(
            title_prefix, title_wrd_emv
        )
    else:
        final_title = "Word Vec: {}; Sentence Distance: Word Movers Distance".format(
            title_wrd_emv
        )

    # plot distances
    plt_dists(
        dist_matrix,
        sentences_and_labels=sentences_and_labels,
        dims=viz_dims,
        title=final_title,
        xrange=plt_xrange,
        yrange=plt_yrange,
        zrange=plt_zrange,
    )

In [None]:
eval_using_wmd(
    pretrained_preds,
    sents_and_labs,
    wrd_vec_mode="concat",
    plt_xrange=[-0.4, 0.4],
    plt_yrange=[-0.4, 0.4],
    title_prefix="Pretrained model:",
)

In [None]:
eval_using_wmd(
    pretrained_preds,
    sents_and_labs,
    wrd_vec_mode="average",
    plt_xrange=[-0.4, 0.4],
    plt_yrange=[-0.4, 0.4],
    title_prefix="Pretrained model:",
)

In [None]:
eval_using_wmd(
    pretrained_preds,
    sents_and_labs,
    wrd_vec_mode="second_last",
    plt_xrange=[-0.4, 0.4],
    plt_yrange=[-0.4, 0.4],
    title_prefix="Pretrained model:",
)

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
from torch import nn, optim


class SentimentClassifier(nn.Module):
    def __init__(self, n_classes, bertmodel, dropout_p=0.3):
        super(SentimentClassifier, self).__init__()
        self.bert = bertmodel
        self.dropout = nn.Dropout(p=dropout_p)
        self.out = nn.Linear(self.bert.config.hidden_size, n_classes)

    def forward(self, input_ids, attention_mask):
        last_hidden_state, pooled_output = self.bert(
            input_ids=input_ids, attention_mask=attention_mask
        )
        output = self.dropout(pooled_output)
        return self.out(output)

In [None]:
# PRE_TRAINED_MODEL_NAME = 'roberta-base'
finetuned_model = RobertaModel.from_pretrained(
    PRETRAINED_MODEL, output_hidden_states=True
)
# put this in eval mode so since we do not plan to do backprop and also any other special handling that it needs to do like dropout
finetuned_model.eval()
senti_model = SentimentClassifier(len(["neg", "pos"]), bertmodel=finetuned_model)

In [None]:
import glob

from torch import nn, optim

MODEL_SAVE_NAME = "imdb_movie_large_roberta_state"

if torch.cuda.is_available():
    map_location = lambda storage, loc: storage.cuda()
else:
    map_location = "cpu"

state_file_name = sorted(
    list(
        glob.glob(
            "/content/drive/My Drive/Datasets/IMDBMovieReviews/{}*".format(
                MODEL_SAVE_NAME
            )
        )
    )
)[-1]
print("Loading : {}".format(state_file_name))
state = torch.load(state_file_name, map_location=map_location)
senti_model.load_state_dict(state["model"])
state = None

In [None]:
finetuned_preds = get_preds(sents, tokenizer, senti_model.bert)

# Comparing models Pretrained (out-of-the-box) vs fine-tuned

## 1. Run with config : 
* word vec average across 4 layers 
* sentence vectors obtained from averaging across all word embeddings in the sentence

### 1.1 Pretrained model

In [None]:
eval_vectors(
    pretrained_preds,
    sents_and_labs,
    wrd_vec_mode="average",
    sentence_emb_mode="average_word_vectors",
    plt_xrange=[-0.03, 0.03],
    plt_yrange=[-0.03, 0.03],
    viz_dims=2,
    title_prefix="Pretrained Model:",
)

### 1.2 Fine-tuned model

In [None]:
eval_vectors(
    finetuned_preds,
    sents_and_labs,
    wrd_vec_mode="average",
    sentence_emb_mode="average_word_vectors",
    plt_xrange=[-0.6, 0.6],
    plt_yrange=[-0.6, 0.6],
    viz_dims=2,
    title_prefix="Finetuned model:",
)

## 2. Run with config : 
* word vec average across 4 layers 
* calculate distances directly using word movers distance

### 2.1 Pretrained model

In [None]:
eval_using_wmd(
    pretrained_preds,
    sents_and_labs,
    wrd_vec_mode="average",
    plt_xrange=[-0.6, 0.6],
    plt_yrange=[-0.6, 0.6],
    title_prefix="Pretrained model:",
)

### 2.2 Fine-tuned model

In [None]:
eval_using_wmd(
    finetuned_preds,
    sents_and_labs,
    wrd_vec_mode="average",
    plt_xrange=[-0.6, 0.6],
    plt_yrange=[-0.6, 0.6],
    title_prefix="Finetuned model:",
)