In [1]:
# Need to first download The bertviz package

Reference:

Vig, J. A Multiscale Visualization of Attention in the Transformer Model. \emph{arXiv preprint arXiv:1906.05714}, 2009, URL: https://arxiv.org/abs/1906.05714

Vig, J. bertviz. URL: https://github.com/jessevig/bertviz

In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import torch 
from torch import nn
from torch.utils.data import DataLoader, Dataset, TensorDataset, SequentialSampler
import tqdm
from transformers import BertForSequenceClassification, AdamW, RobertaTokenizer, RobertaModel
from bertviz import head_view

In [4]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case = True)

In [5]:
roberta = RobertaModel.from_pretrained('roberta-base', output_attentions=True)

In [6]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [7]:
test_loader_bert = torch.load('test_dataloader_roberta.pth')

In [8]:
# Train first and second half of the sequence seperately, then concatenate the hidden state output

class RoBERTaClassifier(nn.Module):
    def __init__(self, roberta, num_classes):
        super().__init__()
        self.roberta = roberta
        self.linear = nn.Linear(roberta.config.hidden_size*2, num_classes)
        self.num_classes = num_classes
    
    def forward(self, input_ids_first, attention_masks_first, input_ids_second, attention_masks_second):
        h1, _, _ = self.roberta(input_ids = input_ids_first, attention_mask = attention_masks_first)
        h1_cls = h1[:, 0]
        h2, _, _ = self.roberta(input_ids = input_ids_second, attention_mask = attention_masks_second)
        h2_cls = h2[:, 0]
        h_cls = torch.cat((h1_cls, h2_cls), dim = -1)
        logits = self.linear(h_cls)
        return logits

In [9]:
model_roberta = torch.load('model_roberta.pt').to(device)

In [10]:
%%javascript
require.config({
  paths: {
      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min',
      jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
  }
});

<IPython.core.display.Javascript object>

In [11]:
# Take the first test instance for example, and record token ids and attention outputs

for i, (input_ids_first, attention_masks_first, input_ids_second, attention_masks_second, label) in enumerate(test_loader_bert):
    input_ids_first = input_ids_first[0].to(device).unsqueeze(0)
    attention_masks_first = attention_masks_first[0].to(device).unsqueeze(0)
    input_ids_second = input_ids_second[0].to(device).unsqueeze(0)
    attention_masks_second = attention_masks_second[0].to(device).unsqueeze(0)
    
    _, _, attention1 = model_roberta.roberta(input_ids_first, attention_masks_first)
    _, _, attention2 = model_roberta.roberta(input_ids_second, attention_masks_second)

    break

In [12]:
# Convert id to tokens

input_tokens_first = tokenizer.convert_ids_to_tokens(input_ids_first.tolist()[0])
input_tokens_second = tokenizer.convert_ids_to_tokens(input_ids_second.tolist()[0])

In [13]:
# Take first 25 tokens of the second half only, for visualization convenience

attention2_partial = []
for attention in attention2:
    attention2_partial.append(attention[:,:,:25,:25])
attention2_partial = tuple(attention2_partial)

In [14]:
head_view(attention2_partial, input_tokens_second[:25], sentence_b_start = None)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>