In [56]:
# Need to first download The bertviz package

Reference:

Vig, J. A Multiscale Visualization of Attention in the Transformer Model. \emph{arXiv preprint arXiv:1906.05714}, 2009, URL: https://arxiv.org/abs/1906.05714

Vig, J. bertviz. URL: https://github.com/jessevig/bertviz

In [57]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import torch 
from torch import nn
from torch.utils.data import DataLoader, Dataset, TensorDataset, SequentialSampler
import tqdm
from transformers import BertForSequenceClassification, AdamW, BertTokenizer, BertModel
from bertviz import head_view

In [58]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case = True)

In [59]:
bert = BertModel.from_pretrained('bert-base-cased', output_attentions=True)

In [60]:
bert.embeddings.word_embeddings = nn.Embedding(tokenizer.vocab_size, 768, padding_idx = 0)

In [61]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [62]:
# Train first and second half of the sequence seperately, then concatenate the hidden state output

class BERTClassifier(nn.Module):
    def __init__(self, bert, num_classes):
        super().__init__()
        self.bert = bert
        self.linear = nn.Linear(bert.config.hidden_size*2, num_classes)
        self.num_classes = num_classes
    
    def forward(self, input_ids_first, attention_masks_first, input_ids_second, attention_masks_second):
        h1, _, _ = self.bert(input_ids = input_ids_first, attention_mask = attention_masks_first)
        h1_cls = h1[:, 0]
        h2, _, _ = self.bert(input_ids = input_ids_second, attention_mask = attention_masks_second)
        h2_cls = h2[:, 0]
        h_cls = torch.cat((h1_cls, h2_cls), dim = -1)
        logits = self.linear(h_cls)
        return logits

In [63]:
model_bert = torch.load('model_retrain.pt').to(device)

In [64]:
%%javascript
require.config({
  paths: {
      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min',
      jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
  }
});

<IPython.core.display.Javascript object>

In [65]:
sentence = 'A preliminary diagnosis of pyelonephritis was established. Other causes of fever were possible but less likely.'

encoded_dict_first = tokenizer.encode_plus(sentence, add_special_tokens = True, max_length = 512,\
                                     pad_to_max_length = True, return_attention_mask = True, return_tensors = 'pt')

In [66]:
_, _, attention = model_bert.bert(encoded_dict_first['input_ids'].to(device), encoded_dict_first['attention_mask'].to(device))

In [67]:
# Convert id to tokens

input_tokens = tokenizer.convert_ids_to_tokens(encoded_dict_first['input_ids'].tolist()[0])

In [68]:
# Take actual tokens.

attention_partial = []
for att in attention:
    attention_partial.append(att[:, :,:24,:24])
attention_partial = tuple(attention_partial)

In [69]:
head_view(attention_partial, input_tokens[:24], sentence_b_start = None)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>