In [1]:
import os
import sys

sys.path.append("..")

In [2]:
from bertviz import head_view, head_view_raw
from transformers import BertTokenizer, BertModel
from IPython.core.display import display, HTML, Javascript

In [3]:
def show_head_view(model, tokenizer, sentence_a, sentence_b=None, layer=None, heads=None):
    inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
    input_ids = inputs['input_ids']
    if sentence_b:
        token_type_ids = inputs['token_type_ids']
        attention = model(input_ids, token_type_ids=token_type_ids)[-1]
        sentence_b_start = token_type_ids[0].tolist().index(1)
    else:
        attention = model(input_ids)[-1]
        sentence_b_start = None
    input_id_list = input_ids[0].tolist() # Batch index 0
    tokens = tokenizer.convert_ids_to_tokens(input_id_list)    
    head_view(attention, tokens, sentence_b_start, layer=layer, heads=heads)

In [4]:
model_version = 'bert-base-uncased'
do_lower_case = True
model = BertModel.from_pretrained(model_version, output_attentions=True)
tokenizer = BertTokenizer.from_pretrained(model_version, do_lower_case=do_lower_case)
sentence_a = "the rabbit quickly hopped"
sentence_b = "The turtle slowly crawled"
show_head_view(model, tokenizer, sentence_a, sentence_b)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<IPython.core.display.Javascript object>

In [5]:
def get_head_view(model, tokenizer, sentence_a, sentence_b=None, layer=None, heads=None, require_prefix=""):
    inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
    input_ids = inputs['input_ids']
    if sentence_b:
        token_type_ids = inputs['token_type_ids']
        attention = model(input_ids, token_type_ids=token_type_ids)[-1]
        sentence_b_start = token_type_ids[0].tolist().index(1)
    else:
        attention = model(input_ids)[-1]
        sentence_b_start = None
    input_id_list = input_ids[0].tolist() # Batch index 0
    tokens = tokenizer.convert_ids_to_tokens(input_id_list)    
    return head_view_raw(attention, tokens, sentence_b_start, layer=layer, heads=heads, require_prefix = require_prefix)

In [6]:
vis_html, vis_js = get_head_view(model, tokenizer, sentence_a, sentence_b, require_prefix="https:")

In [7]:
#need to use requirejs
cdn_js = '<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.5/require.min.js"></script>'

js = cdn_js + "\n" + '<script type="text/javascript">' + vis_js + "</script>"
html = "<html>" + vis_html + "\n" + js + "</html>"
display(HTML(html))

# switch Markdown to Code if you want to write this HTML to file.
with open("out_head.html", encoding="utf-8", mode="w") as f:
    f.write(html)