# Qualifying attention map from BERT

In [4]:
%load_ext autoreload
%autoreload 2

from IPython.display import display, HTML
import os
from os import path

# path for the code execution in the README file
cache = path.join(os.getcwd(), '.cache' , 'logs' ,'bert','0','checkpoints')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
from training_bert import BertNliLight

In [5]:
model = BertNliLight.load_from_checkpoint(checkpoint_path=path.join(cache, 'best.ckpt'))

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
def hightlight_txt(tokens, attention, show_pad=False):
    """
    Build an HTML of text along its weights.
    Args:
        tokens: list of tokens
        attention: list of attention weights
        show_pad: whethere showing padding tokens
    """
    assert len(tokens) == len(attention), f'Length mismatch: f{len(tokens)} vs f{len(attention)}'
    
    MAX_ALPHA = 0.8 # transparency

    highlighted_text = ''
    w_min, w_max = torch.min(attention), torch.max(attention)
    
    # In case of uniform: highlight all text
    if w_min == w_max: 
        w_min = 0.
    
    w_norm = (attention - w_min)/(w_max - w_min)
    w_norm = [w / MAX_ALPHA for w in w_norm]

    if not show_pad:
        id_non_pad = [i for i, tk in enumerate(tokens) if tk != '[pad]']
        w_norm = [w_norm[i] for i in id_non_pad]
        tokens = [tokens[i] for i in id_non_pad]
        
    highlighted_text = [f'<span style="background-color:rgba(135,206,250, {weight});">{text}</span>' for weight, text in zip(w_norm, tokens)]
    
    return ' '.join(highlighted_text)

In [7]:
import numpy as np
import torch

tokens = 'An older and younger man smiling.'.split(' ')
attentions = torch.softmax(torch.rand(6), dim=-1)
visual = hightlight_txt(tokens, attentions)

display(HTML('<h3>Example of attention</h3>'))
display(HTML(visual))