In [1]:
# Need to first download The bertviz package

Reference:

Vig, J. A Multiscale Visualization of Attention in the Transformer Model. \emph{arXiv preprint arXiv:1906.05714}, 2009, URL: https://arxiv.org/abs/1906.05714

Vig, J. bertviz. URL: https://github.com/jessevig/bertviz

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import torch 
from torch import nn
from torch.utils.data import DataLoader, Dataset, TensorDataset, SequentialSampler
import tqdm
from transformers import BertForSequenceClassification, AdamW, BertTokenizer, BertModel
from bertviz import head_view

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])


In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case = True)

In [4]:
bert = BertModel.from_pretrained('bert-base-cased', output_attentions=True)

In [5]:
bert.embeddings.word_embeddings = nn.Embedding(tokenizer.vocab_size, 768, padding_idx = 0)

In [6]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [15]:
test_loader_bert = torch.load('test_dataloader.pth')

In [16]:
# Train first and second half of the sequence seperately, then concatenate the hidden state output

class BERTClassifier(nn.Module):
    def __init__(self, bert, num_classes):
        super().__init__()
        self.bert = bert
        self.linear = nn.Linear(bert.config.hidden_size*2, num_classes)
        self.num_classes = num_classes
    
    def forward(self, input_ids_first, attention_masks_first, input_ids_second, attention_masks_second):
        h1, _, _ = self.bert(input_ids = input_ids_first, attention_mask = attention_masks_first)
        h1_cls = h1[:, 0]
        h2, _, _ = self.bert(input_ids = input_ids_second, attention_mask = attention_masks_second)
        h2_cls = h2[:, 0]
        h_cls = torch.cat((h1_cls, h2_cls), dim = -1)
        logits = self.linear(h_cls)
        return logits

In [17]:
model_bert = torch.load('model_retrain.pt').to(device)

In [18]:
%%javascript
require.config({
  paths: {
      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min',
      jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
  }
});

<IPython.core.display.Javascript object>

In [19]:
# Take the first test instance for example, and record token ids and attention outputs

for i, (input_ids_first, attention_masks_first, input_ids_second, attention_masks_second, label) in enumerate(test_loader_bert):
    input_ids_first = input_ids_first[0].to(device).unsqueeze(0)
    attention_masks_first = attention_masks_first[0].to(device).unsqueeze(0)
    input_ids_second = input_ids_second[0].to(device).unsqueeze(0)
    attention_masks_second = attention_masks_second[0].to(device).unsqueeze(0)
    
    _, _, attention1 = model_bert.bert(input_ids_first, attention_masks_first)
    _, _, attention2 = model_bert.bert(input_ids_second, attention_masks_second)

    break

In [20]:
# Convert id to tokens

input_tokens_first = tokenizer.convert_ids_to_tokens(input_ids_first.tolist()[0])
input_tokens_second = tokenizer.convert_ids_to_tokens(input_ids_second.tolist()[0])

In [40]:
# Take first 25 tokens of the second half only, for visualization convenience

attention2_partial = []
for attention in attention2:
    attention2_partial.append(attention[:,:,:25,:25])
attention2_partial = tuple(attention2_partial)

In [45]:
head_view(attention2_partial, input_tokens_second[:25], sentence_b_start = None)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>