### Visualising word embeddings for different PLMs

This notebook is intended to create 2D or 3D visualisation of embeddings from transformer based PLMs like roberta.


In [None]:
# imports
import re

# RobertaModel, RobertaTokenizer
import sys
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# visualization libs
import plotly.express as px
import plotly.graph_objects as go
import torch
from plotly.subplots import make_subplots
from scipy.spatial.distance import euclidean, pdist, squareform
from sklearn import manifold  # use this for MDS computation
from sklearn.metrics.pairwise import cosine_distances, euclidean_distances
from transformers import AutoModel, AutoTokenizer, BertModel, BertTokenizer

% matplotlib inline

import os

# Used to calculation of word movers distance between sentence
from collections import Counter

# Library to calculate Relaxed-Word Movers distance
from wmd import WMD, libwmdrelax

In [None]:
# Create a function to tokenize a set of texts
def preprocessing_for_bert(data, tokenizer_obj, max_length):
    """Perform required preprocessing steps for pretrained BERT.
    @param    data (np.array): Array of texts to be processed.
    @return   input_ids (torch.Tensor): Tensor of token ids to be fed to a model.
    @return   attention_masks (torch.Tensor): Tensor of indices specifying which
                  tokens should be attended to by the model.
    @return   attention_masks_without_special_tok (torch.Tensor): Tensor of indices specifying which
                  tokens should be attended to by the model excluding the special tokens (CLS/SEP)
    """
    # Create empty lists to store outputs
    input_ids = []
    attention_masks = []

    # For every sentence...
    for sent in data:
        # `encode_plus` will:
        #    (1) Tokenize the sentence
        #    (2) Add the `[CLS]` and `[SEP]` token to the start and end
        #    (3) Truncate/Pad sentence to max length
        #    (4) Map tokens to their IDs
        #    (5) Create attention mask
        #    (6) Return a dictionary of outputs
        encoded_sent = tokenizer_obj.encode_plus(
            text=sent,  # Preprocess sentence
            add_special_tokens=True,  # Add `[CLS]` and `[SEP]`
            max_length=max_length,  # Max length to truncate/pad
            pad_to_max_length=True,  # Pad sentence to max length
            truncation=True,  # Truncate longer seq to max_len
            return_attention_mask=True,  # Return attention mask
        )

        # Add the outputs to the lists
        input_ids.append(encoded_sent.get("input_ids"))
        attention_masks.append(encoded_sent.get("attention_mask"))

    # Convert lists to tensors
    input_ids = torch.tensor(input_ids)
    attention_masks = torch.tensor(attention_masks)

    # lets create another mask that will be useful when we want to average all word vectors later
    # we would like to average across all word vectors in a sentence, but excluding the CLS and SEP token
    # create a copy
    attention_masks_without_special_tok = attention_masks.clone().detach()

    # set the CLS token index to 0 for all sentences
    attention_masks_without_special_tok[:, 0] = 0

    # get sentence lengths and use that to set those indices to 0 for each length
    # essentially, the last index for each sentence, which is the SEP token
    sent_len = attention_masks_without_special_tok.sum(1).tolist()

    # column indices to set to zero
    col_idx = torch.LongTensor(sent_len)
    # row indices for all rows
    row_idx = torch.arange(attention_masks.size(0)).long()

    # set the SEP indices for each sentence token to zero
    attention_masks_without_special_tok[row_idx, col_idx] = 0

    return input_ids, attention_masks, attention_masks_without_special_tok

In [None]:
# this will retrieve vector representations or embeddings at the word level
def get_vector(hidden_layers_form_arch, token_index=0, mode="average", top_n_layers=4):
    """
    retrieve vectors for a token_index from the top n layers and return a concatenated, averaged or summed vector
    hidden_layers_form_arch: tuple returned by the transformer library
    token_index: index of the token for which a vector is desired
    mode=
          'average' : avg last n layers
          'concat': concatenate last n layers
          'sum' : sum last n layers
          'last': return embeddings only from last layer
          'second_last': return embeddings only from second last layer

    top_n_layers: number of top layers to concatenate/ average / sum
    """
    if mode == "concat":
        # concatenate last 4 layer outputs -> returns [batch_size x seq_len x dim]
        # permute(1,0,2) swaps the the batch and seq_len dim , making it easy to return all the vectors for a particular token position
        return torch.cat(hidden_layers_form_arch[-top_n_layers:], dim=2).permute(
            1, 0, 2
        )[token_index]

    if mode == "average":
        # avg last 4 layer outputs -> returns [batch_size x seq_len x dim]
        return (
            torch.stack(hidden_layers_form_arch[-top_n_layers:])
            .mean(0)
            .permute(1, 0, 2)[token_index]
        )

    if mode == "sum":
        # sum last 4 layer outputs -> returns [batch_size x seq_len x dim]
        return (
            torch.stack(hidden_layers_form_arch[-top_n_layers:])
            .sum(0)
            .permute(1, 0, 2)[token_index]
        )

    if mode == "last":
        # last layer output -> returns [batch_size x seq_len x dim]
        return hidden_layers_form_arch[-1:][0].permute(1, 0, 2)[token_index]

    if mode == "second_last":
        # last layer output -> returns [batch_size x seq_len x dim]
        return hidden_layers_form_arch[-2:-1][0].permute(1, 0, 2)[token_index]

    return None

In [None]:
def plt_dists(
    dists, labels, dims=2, reducer=None, words_of_interest=[], title="", save_dir="./"
):
    """
    Plot distances using MDS in 2D/3D
    dists: precomputed distance matrix
    labels: labels to display on the plot
    dims: 2/3 for 2 or 3 dimensional plot, defaults to 2 for any other value passed
    words_of_interest: list of words to highlight with a different color
    title: title for the plot
    """
    cnt_dict = dict()
    color = list()

    # separate colors for words that are in words_of_interest vs other
    # each word will have a _SentenceNumber at the end to differentiate the words coming in from different sentences
    for v in labels:
        found = False
        for wrd_int in words_of_interest:
            if wrd_int in v:
                found = True
                break

        if found:
            color.append(1)
        else:
            color.append(0)

    # https://community.plotly.com/t/plotly-colours-list/11730/6
    colorscale = [[0, "darkcyan"], [1, "white"]]

    # dists is precomputed using cosine similarity and passed
    # calculate MDS with number of dims passed

    if reducer == "MDS":
        mds = manifold.MDS(
            n_components=dims,
            dissimilarity="precomputed",
            random_state=60,
            max_iter=90000,
        )
        results = mds.fit(dists)
    else:
        raise NotImplementedError

    # get coodinates for each point
    coords = results.embedding_

    # plot
    if dims == 3:
        fig = go.Figure(
            data=[
                go.Scatter3d(
                    x=coords[:, 0],
                    y=coords[:, 1],
                    z=coords[:, 2],
                    mode="markers+text",
                    textposition="top center",
                    text=labels,
                    marker=dict(
                        size=10,
                        color=color,
                        colorscale=colorscale,
                        opacity=0.8,
                    ),
                )
            ]
        )
    else:
        fig = go.Figure(
            data=[
                go.Scatter(
                    x=coords[:, 0],
                    y=coords[:, 1],
                    mode="markers+text",
                    text=labels,
                    textposition="top center",
                    marker=dict(
                        size=12,
                        color=color,
                        colorscale=colorscale,
                        opacity=0.8,
                    ),
                )
            ]
        )

    fig.update_layout(template="plotly_dark")
    if title != "":
        fig.update_layout(title_text=title)
    # save to html for later use?
    if save_dir != None:
        fig.write_html(f"{save_dir}")
    fig.show()

In [None]:
def plt_dists_subplots(
    dists, labels, dims=2, reducer=None, words_of_interest=[], title="", save_dir="./"
):
    """
    Plot distances using MDS in 2D/3D
    dists: precomputed distance matrix
    labels: labels to display on the plot
    dims: 2/3 for 2 or 3 dimensional plot, defaults to 2 for any other value passed
    words_of_interest: list of words to highlight with a different color
    static_vectors: Boolean -> whether or not static vectors have been provided - if so, make each sample point a different colour
    title: title for the plot
    save_dir: path to save any created plots
    """
    cnt_dict = dict()
    color = list()

    # separate colors for words that are in words_of_interest vs other
    # each word will have a _SentenceNumber at the end to differentiate the words coming in from different sentences
    for v in labels:
        found = False
        for wrd_int in words_of_interest:
            if wrd_int in v:
                found = True
                break

        if found:
            color.append(1)
        else:
            color.append(0)

    # https://community.plotly.com/t/plotly-colours-list/11730/6
    colorscale = [[0, "darkcyan"], [1, "white"]]

    # dists is precomputed using cosine similarity and passed
    # calculate MDS with number of dims passed

    if reducer == "MDS":
        mds = manifold.MDS(
            n_components=dims,
            dissimilarity="precomputed",
            random_state=60,
            max_iter=90000,
        )
        results = mds.fit(dists)
    else:
        raise NotImplementedError

    # get coodinates for each point
    coords = results.embedding_

    # plot
    if dims == 3:
        fig = go.Figure(
            data=[
                go.Scatter3d(
                    x=coords[:, 0],
                    y=coords[:, 1],
                    z=coords[:, 2],
                    mode="markers+text",
                    textposition="top center",
                    text=labels,
                    marker=dict(
                        size=10,
                        color=color,
                        colorscale=colorscale,
                        opacity=0.8,
                    ),
                )
            ]
        )
    else:
        fig = go.Figure(
            data=[
                go.Scatter(
                    x=coords[:, 0],
                    y=coords[:, 1],
                    mode="markers+text",
                    text=labels,
                    textposition="top center",
                    marker=dict(
                        size=12,
                        color=color,
                        colorscale=colorscale,
                        opacity=0.8,
                    ),
                )
            ]
        )

    fig.update_layout(template="plotly_dark")
    if title != "":
        fig.update_layout(title_text=title)
    # save to html for later use?
    if save_dir != None:
        fig.write_html(f"{save_dir}")
    # fig.show()
    return fig

In [None]:
def eval_context_vecs(
    input_hidden_states,
    input_tokenized_sents,
    mode="concat",
    model_name="bert-base-uncased",
    reducer="MDS",
    top_n_layers=4,
    dims=2,
    max_length=15,
    sent_lengths=None,
    words_with_diff_color=None,
    save_dir="./",
):
    """
    function to get a vectors for each word in each sentence, add the sentence number to the end of each word
    calculate cosine distance between each pair of words and then pass it to the visualization function

    inputs:
    input_hidden_states: hiddent states retrieved from a BERT-like model
    input_tokenized_sents: tokenized sentences, used to assign labels for each point on the plot
    model:  'average' : avg last n layers
            'concat': concatenate last n layers
            'sum' : sum last n layers
            'last':  embeddings only from last layer
            'second_last':  embeddings only from second last layer
    top_n_layers: top n layers to use for concat/sum etc.
    viz_dims: 2/3 for 2D or 3D plot
    words_with_diff_color: words that should be highlighed with different color on the plot
    """
    vecs = list()
    labels = list()
    for token_ind in range(max_length):
        if token_ind == 0:
            # ignore CLS
            continue
        vectors = get_vector(
            input_hidden_states,
            token_index=token_ind,
            mode=mode,
            top_n_layers=top_n_layers,
        )
        for sent_ind, sent_len in enumerate(sent_lengths):
            if token_ind < sent_len - 1:
                # ignore SEP which will be at the last index of each sentence
                vecs.append(vectors[sent_ind])
                labels.append(
                    input_tokenized_sents[sent_ind][token_ind] + "_" + str(sent_ind)
                )

    # create a numpy matrix to pass to cosine distance
    mat = torch.stack(vecs).detach().numpy()
    # call the plot function on the cosine distance matrix

    plots = plt_dists_subplots(
        cosine_distances(mat),
        reducer=reducer,
        labels=labels,
        dims=dims,
        words_of_interest=words_with_diff_color,
        title=f"Model: {model_name} and Method: {mode}",
        save_dir=f"{save_dir}/{model_name}_{mode}_plotly_{dims}D.html",
    )
    return plots

In [None]:
def run_visualise_pipeline(
    texts,
    focus_word,
    model_names,
    cache_dir,
    max_length,
    reducer=None,
    dims=2,
    mode="concat",
    save_dir="./",
):
    """
    Function to run through the steps of loading model, tokenizing, extracting embedding, reducing dimensions and plotting for given text

    """

    # we can run over mulitple model names to get multiple plots in one run

    if len(model_names) == 1:

        model_names = model_names[0]
        # Initialize tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            model_names, cache_dir=cache_dir
        )  # RobertaTokenizer.from_pretrained(PRETRAINED_MODEL)

        # initialize model
        # output_hidden_states = True will give us all hiddent states for all layers
        model = AutoModel.from_pretrained(
            model_names, output_hidden_states=True, cache_dir=cache_dir
        )  # RobertaModel.from_pretrained(PRETRAINED_MODEL ,output_hidden_states = True)

        # put this in eval mode so since we do not plan to do backprop and also any other special handling that it needs to do like dropout
        model.eval()

        # run sentences through the tokenizer
        (
            input_ids,
            attention_masks,
            attention_masks_without_special_tok,
        ) = preprocessing_for_bert(texts, tokenizer, max_length=max_length)

        # call the model on the sentences
        outputs = model(input_ids, attention_masks)  # (tokenized_tensor, sent_tensor)
        hidden_states = outputs[2]

        print("Total hidden layers:", len(hidden_states))
        print(
            "First layer : hidden_states[0].shape ", hidden_states[0].shape
        )  # [batch_size x seq_length x vector_dim]

        # Lengths of each sentence
        sent_lengths = attention_masks.sum(1).tolist()
        # get tokenized sentences
        tokenized_sents = [tokenizer.convert_ids_to_tokens(i) for i in input_ids]

        MODE = mode

        # if save dir not exit, make it

        if not os.path.exists(f"{save_dir}"):
            os.makedirs(f"{save_dir}")

        # we save the file dynamically based on the models name, but want to remove any forward or backward slashes with _ to avoid the save function thinking its a directory
        model_name = model_name.replace("/", "_")

        plots = eval_context_vecs(
            hidden_states,
            tokenized_sents,
            model_name=model_name,
            reducer=reducer,
            mode=MODE,
            max_length=max_length,
            sent_lengths=sent_lengths,
            words_with_diff_color=focus_word,
            dims=dims,
            save_dir=save_dir,
        )
        return plots

    elif len(model_names) > 1:
        # set dictionary to append plots to
        plots = {}
        for model_name in model_names:
            print(f"Working on: {model_name}")
            # Initialize tokenizer
            tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
            # RobertaTokenizer.from_pretrained(PRETRAINED_MODEL)

            # initialize model
            # output_hidden_states = True will give us all hiddent states for all layers
            model = AutoModel.from_pretrained(
                model_name, output_hidden_states=True, cache_dir=cache_dir
            )
            # RobertaModel.from_pretrained(PRETRAINED_MODEL ,output_hidden_states = True)

            # put this in eval mode so since we do not plan to do backprop and also any other special handling that it needs to do like dropout
            model.eval()

            # run sentences through the tokenizer
            (
                input_ids,
                attention_masks,
                attention_masks_without_special_tok,
            ) = preprocessing_for_bert(texts, tokenizer, max_length=max_length)

            # call the model on the sentences using the attention masks which will mean no attention paid to cls and seg tokens
            outputs = model(
                input_ids, attention_masks_without_special_tok
            )  # (tokenized_tensor, sent_tensor)
            hidden_states = outputs[2]

            # print("Total hidden layers:", len(hidden_states))
            # print("First layer : hidden_states[0].shape ", hidden_states[0].shape)     # [batch_size x seq_length x vector_dim]

            # Lengths of each sentence using original attention mask - needs to just avoid any pad tokens basically
            sent_lengths = attention_masks.sum(1).tolist()
            # get tokenized sentences
            tokenized_sents = [tokenizer.convert_ids_to_tokens(i) for i in input_ids]

            MODE = mode

            # remove forward or backward slashes from model names to allow saving
            model_name = model_name.replace("/", "_")

            plot = eval_context_vecs(
                hidden_states,
                tokenized_sents,
                model_name=model_name,
                reducer=reducer,
                mode=MODE,
                max_length=max_length,
                sent_lengths=sent_lengths,
                words_with_diff_color=focus_word,
                dims=dims,
                save_dir=save_dir,
            )
            plots[model_name] = plot

        # if save dir not exit, make it

        if not os.path.exists(f"{save_dir}"):
            os.makedirs(f"{save_dir}")
        # combine plots into list
        fig_model_names = list(plots.keys())
        # figures
        figures = plots.values()

        # make subplot
        all_plots = make_subplots(
            rows=1,
            cols=2,
            subplot_titles=fig_model_names,
            specs=[[{"type": "scene"}, {"type": "scene"}]],
        )

        # need to pull out traces per figure
        for i, figure in enumerate(figures):
            for trace in range(len(figure["data"])):
                all_plots.append_trace(figure["data"][trace], row=1, col=i + 1)

        # save multi plot to file
        all_plots.write_html(
            f"{save_dir}/multi_{'_'.join(fig_model_names)}_compare_{dims}D.html"
        )

        return all_plots

In [None]:
# set up constants
# Define some constants
CACHE_DIR = (
    ".cache"  # set to whichever cache dir you want for downloaded transformer models
)
MODEL_NAME = ["bert-base-uncased"]  #'roberta-large'
MAX_LENGTH = 15

# I picked these sentences to see how if I really get different word vectors for "heart" in different contexts
TEXTS = [
    "He did not have to heart to tell them",
    "The patient had recently experience a heart attack",
    "felt chest pain during the night",
]

# this defines what I would like highlighted when I visualize the word vectors
FOCUS_WORD = ["heart", "chest"]

# the dimensionality reduction algorithm to use
REDUCER = "MDS"

# the word embedding extraction method
MODE = "concat"

SAVE_DIR = "./plots/"

DIMS = 3

# bert base 

In [None]:
bert_plot = run_visualise_pipeline(
    texts=TEXTS,
    focus_word=FOCUS_WORD,
    model_names=MODEL_NAME,
    cache_dir=CACHE_DIR,
    max_length=MAX_LENGTH,
    reducer=REDUCER,
    dims=DIMS,
    mode=MODE,
    save_dir=SAVE_DIR,
)

In [None]:
bert_plot

# roberta-base

In [None]:
# try same text etc but different model
MODEL_NAME = ["roberta-base"]  #'roberta-large'
roberta_plot = run_visualise_pipeline(
    texts=TEXTS,
    focus_word=FOCUS_WORD,
    model_names=MODEL_NAME,
    cache_dir=CACHE_DIR,
    max_length=MAX_LENGTH,
    reducer=REDUCER,
    dims=DIMS,
    mode=MODE,
    save_dir=SAVE_DIR,
)

# biomed roberta

In [None]:
# try same text etc but different model
MODEL_NAME = ["allenai/biomed_roberta_base"]  #'roberta-large'
run_visualise_pipeline(
    texts=TEXTS,
    focus_word=FOCUS_WORD,
    model_names=MODEL_NAME,
    cache_dir=CACHE_DIR,
    max_length=MAX_LENGTH,
    reducer=REDUCER,
    dims=DIMS,
    mode=MODE,
    save_dir=SAVE_DIR,
)

# Try subplots


In [None]:
# # combine plots into list
# figures = [bert_plot, roberta_plot]

# # make subplot
# fig = make_subplots(rows=1, cols=2, subplot_titles=["Bert-base-uncased", "roberta-base"])

# # need to pull out traces per figure
# for i, figure in enumerate(figures):
#     for trace in range(len(figure["data"])):
#         fig.append_trace(figure["data"][trace], row = 1, col = i+1)

# fig

In [None]:
# try multiple models
MODEL_NAMES = ["allenai/biomed_roberta_base", "roberta-base"]
multi_plots = run_visualise_pipeline(
    texts=TEXTS,
    focus_word=FOCUS_WORD,
    model_names=MODEL_NAMES,
    cache_dir=CACHE_DIR,
    max_length=MAX_LENGTH,
    reducer=REDUCER,
    dims=DIMS,
    mode=MODE,
    save_dir=SAVE_DIR,
)

In [None]:
multi_plots

# Static word embeddings

In [None]:
def get_static_transformer_embedding(model, tokenizer, words):

    # tokenize the word to get the id

    word_embeddings = model.embeddings.word_embeddings.weight

    # get the word or token ids for each

    token_ids = tokenizer.encode(
        words, add_special_tokens=False, is_split_into_words=True
    )

    # now get the tokenized versions to act as the labels

    tokenized_words = tokenizer.tokenize(words, is_split_into_words=True)

    # now we can return these token ids and embeddings as a list or array

    return tokenized_words, token_ids, word_embeddings[token_ids, :]

In [None]:
def plot_static_vecs(
    model_names,
    words,
    cache_dir="<CACHE_DIR>",
    reducer="MDS",
    dims=2,
    save_dir="./plots/",
):

    # get model name
    if len(model_names) == 1:

        model_names = model_names[0]
        # Initialize tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            model_names, cache_dir=cache_dir
        )  # RobertaTokenizer.from_pretrained(PRETRAINED_MODEL)

        # initialize model
        # output_hidden_states = True will give us all hiddent states for all layers
        model = AutoModel.from_pretrained(
            model_names, output_hidden_states=True, cache_dir=cache_dir
        )  # RobertaModel.from_pretrained(PRETRAINED_MODEL ,output_hidden_states = True)

        # put this in eval mode so since we do not plan to do backprop and also any other special handling that it needs to do like dropout
        model.eval()

        # get word embeddings
        tokenized_words, token_ids, word_embs = get_static_transformer_embedding(
            model, tokenizer, words
        )

        # print(f"{tokenized_words} ,  {token_ids}, {word_embs} of shape: {word_embs.shape}")

        # get cosine distances
        # create a numpy matrix to pass to cosine distance
        mat = word_embs.detach().numpy()

        cosine_dists = cosine_distances(mat)

        # we save the file dynamically based on the models name, but want to remove any forward or backward slashes with _ to avoid the save function thinking its a directory
        model_name_save = model_names.replace("/", "_")
        # dists is precomputed using cosine similarity and passed

        # calculate MDS with number of dims passed
        if reducer == "MDS":
            mds = manifold.MDS(
                n_components=dims,
                dissimilarity="precomputed",
                random_state=60,
                max_iter=90000,
            )
            results = mds.fit(cosine_dists)
        else:
            raise NotImplementedError

        # get coodinates for each point
        coords = results.embedding_

        # set labels to tokenized words
        labels = tokenized_words

        # plot
        if dims == 3:
            fig = px.scatter_3d(
                x=coords[:, 0],
                y=coords[:, 1],
                z=coords[:, 2],
                text=labels,
                color=labels,
            )
            # edit location of text for each point
            fig.update_traces(textposition="top center")
            # remove legend
            fig.update_layout(showlegend=False)

        else:
            fig = px.scatter(x=coords[:, 0], y=coords[:, 1], text=labels, color=labels)
            # edit location of text for each point
            fig.update_traces(textposition="top center")
            # remove legend
            fig.update_layout(showlegend=False)

        # set title
        title = f"Model: {model_names} and Method: static"
        if title != "":
            fig.update_layout(title_text=title)
        # save to html for later use?
        save_dir = f"{save_dir}/{model_name_save}_static_plotly_test_{dims}D.html"
        if save_dir != None:
            fig.write_html(f"{save_dir}")

        return fig

    else:

        return None

In [None]:
plots = plot_static_vecs(
    model_names=["bert-base-uncased"],
    words=[
        "one",
        "two",
        "three",
        "queen",
        "king",
        "amputate",
        "hospital",
        "ambulance",
        "cheese",
        "accident",
        "medical",
        "prince",
        "princess",
        "royal",
    ],
    dims=3,
    cache_dir=".cache",
)

In [None]:
plots