In [9]:
# Need to first download The bertviz package

Reference:

Vig, J. A Multiscale Visualization of Attention in the Transformer Model. \emph{arXiv preprint arXiv:1906.05714}, 2009, URL: https://arxiv.org/abs/1906.05714

Vig, J. bertviz. URL: https://github.com/jessevig/bertviz

In [10]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import torch 
from torch import nn
from torch.utils.data import DataLoader, Dataset, TensorDataset, SequentialSampler
import tqdm
from transformers import BertForSequenceClassification, AdamW, RobertaTokenizer, RobertaModel
from bertviz import head_view

In [11]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case = True)

In [12]:
roberta = RobertaModel.from_pretrained('roberta-base', output_attentions=True)

In [13]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [14]:
test_loader_bert = torch.load('test_dataloader_roberta.pth')

In [15]:
# Train first and second half of the sequence seperately, then concatenate the hidden state output

class RoBERTaClassifier(nn.Module):
    def __init__(self, roberta, num_classes):
        super().__init__()
        self.roberta = roberta
        self.linear = nn.Linear(roberta.config.hidden_size*2, num_classes)
        self.num_classes = num_classes
    
    def forward(self, input_ids_first, attention_masks_first, input_ids_second, attention_masks_second):
        h1, _, _ = self.roberta(input_ids = input_ids_first, attention_mask = attention_masks_first)
        h1_cls = h1[:, 0]
        h2, _, _ = self.roberta(input_ids = input_ids_second, attention_mask = attention_masks_second)
        h2_cls = h2[:, 0]
        h_cls = torch.cat((h1_cls, h2_cls), dim = -1)
        logits = self.linear(h_cls)
        return logits

In [16]:
model_roberta = torch.load('model_roberta.pt').to(device)

In [17]:
%%javascript
require.config({
  paths: {
      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min',
      jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
  }
});

<IPython.core.display.Javascript object>

In [18]:
sentence = 'A preliminary diagnosis of pyelonephritis was established. Other causes of fever were possible but less likely.'

encoded_dict_first = tokenizer.encode_plus(sentence, add_special_tokens = True, max_length = 512,\
                                     pad_to_max_length = True, return_attention_mask = True, return_tensors = 'pt')

In [23]:
_, _, attention = model_roberta.roberta(encoded_dict_first['input_ids'].to(device), encoded_dict_first['attention_mask'].to(device))

In [24]:
# Convert id to tokens

input_tokens = tokenizer.convert_ids_to_tokens(encoded_dict_first['input_ids'].tolist()[0])

In [25]:
# Take actual tokens.

attention_partial = []
for att in attention:
    attention_partial.append(att[:, :,:23,:23])
attention_partial = tuple(attention_partial)

In [26]:
head_view(attention_partial, input_tokens[:23], sentence_b_start = None)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>