# Visualize attention layers for BERT

In [1]:
import torch
from transformers import *
import sys
import os
import regex
!git clone https://github.com/jessevig/bertviz bertviz_repo
if not 'bertviz_repo' in sys.path:
    sys.path += ['bertviz_repo']
    
os.listdir()

fatal: destination path 'bertviz_repo' already exists and is not an empty directory.


['.ipynb_checkpoints',
 'bertviz_repo',
 'corpus_embeddings.txt',
 'corpus_labels.txt',
 'Sentence to Embeddings.ipynb',
 'Transformer probe.ipynb']

In [2]:
from bertviz import head_view, model_view, neuron_view

## Helper functions

In [3]:
def call_html():
  import IPython
  display(IPython.core.display.HTML('''
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              "d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min",
              jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
            },
          });
        </script>
        '''))
    
def show_model_view(model, tokenizer, sentence_a, sentence_b=None, hide_delimiter_attn=False):
    inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
    input_ids = inputs['input_ids']
    if sentence_b:
        token_type_ids = inputs['token_type_ids']
        attention = model(input_ids, token_type_ids=token_type_ids)[-1]
        sentence_b_start = token_type_ids[0].tolist().index(1)
    else:
        attention = model(input_ids)[-1]
        sentence_b_start = None
    input_id_list = input_ids[0].tolist() # Batch index 0
    tokens = tokenizer.convert_ids_to_tokens(input_id_list)  
    if hide_delimiter_attn:
        for i, t in enumerate(tokens):
            if t in ("[SEP]", "[CLS]"):
                for layer_attn in attention:
                    layer_attn[0, :, i, :] = 0
                    layer_attn[0, :, :, i] = 0
    model_view(attention, tokens, sentence_b_start)

## Models available with transformer library

In [4]:
#          Model          | Tokenizer          | Pretrained weights shortcut
MODELS = [(BertModel,       BertTokenizer,       'bert-base-uncased'),
          (OpenAIGPTModel,  OpenAIGPTTokenizer,  'openai-gpt'),
          (GPT2Model,       GPT2Tokenizer,       'gpt2'),
          (CTRLModel,       CTRLTokenizer,       'ctrl'),
          (TransfoXLModel,  TransfoXLTokenizer,  'transfo-xl-wt103'),
          (XLNetModel,      XLNetTokenizer,      'xlnet-base-cased'),
          (XLMModel,        XLMTokenizer,        'xlm-mlm-enfr-1024'),
          (DistilBertModel, DistilBertTokenizer, 'distilbert-base-cased'),
          (RobertaModel,    RobertaTokenizer,    'roberta-base'),
          (XLMRobertaModel, XLMRobertaTokenizer, 'xlm-roberta-base'),
         ]
# Each architecture is provided with several class for fine-tuning on down-stream tasks, e.g.
BERT_MODEL_CLASSES = [BertModel, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction,
                      BertForSequenceClassification, BertForTokenClassification, BertForQuestionAnswering]

## Loading BERT base as an example
### Tokenizer

In [5]:
# All the classes for an architecture can be initiated from pretrained weights for this architecture
# Note that additional weights added for fine-tuning are only initialized
# and need to be trained on the down-stream task
pretrained_weights = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(pretrained_weights, do_lower_case = True)

In [6]:
sentence_a = 'This is a very complex word: Pikachu'
tokenizer.tokenize(sentence_a)

['this', 'is', 'a', 'very', 'complex', 'word', ':', 'pi', '##ka', '##chu']

In [7]:
sentence_b = 'Pikachu is a Pokémon'
tokenizer.tokenize(sentence_b)

['pi', '##ka', '##chu', 'is', 'a', 'pokemon']

### Model (pretrained weights)

In [8]:
model = BertModel.from_pretrained(pretrained_weights, output_attentions=True)

In [9]:
inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
inputs

{'input_ids': tensor([[  101,  2023,  2003,  1037,  2200,  3375,  2773,  1024, 14255,  2912,
          20760,   102, 14255,  2912, 20760,  2003,  1037, 20421,   102]]),
 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [11]:
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].tolist())
print(tokens)

['[CLS]', 'this', 'is', 'a', 'very', 'complex', 'word', ':', 'pi', '##ka', '##chu', '[SEP]', 'pi', '##ka', '##chu', 'is', 'a', 'pokemon', '[SEP]']


### Attention weights

In [13]:
attention = model(inputs['input_ids'], token_type_ids=inputs['token_type_ids'])[-1]
attention[0].size()

torch.Size([1, 12, 19, 19])

In [14]:
call_html()

head_view(attention, tokens)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [19]:
sentence_a = 'attention seems to be directed to other words that are predictive of the source word, excluding the source word itself'
sentence_b = 'much of the attention is directed to a delimiter token'

inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].tolist())
attention = model(inputs['input_ids'], token_type_ids=inputs['token_type_ids'])[-1]

head_view(attention, tokens)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>