# Attention maps visualization

the objective is to visualize the attention maps for the different values of the regularization parameter

In [25]:
# preparation of the environment
%load_ext autoreload
%autoreload 2

import os
from os import path
# set the repository to the git repository
cwd = os.getcwd().split(os.path.sep)
while cwd[-1] != "stage_4_gm":
    os.chdir("..")
    cwd = os.getcwd().split(os.path.sep)

import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
from training_bert import BertNliLight
from regularize_training_bert import SNLIDataModule
from regularize_training_bert import BertNliRegu
from torch.utils.data import DataLoader

# --> from this environment
from attention_algorithms.raw_attention import RawAttention
from attention_algorithms.attention_metrics import normalize_attention

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## The models for the study

In [26]:
# the different models
models_dict = {}
for r in [0, 0.001, 0.002, 0.003, 0.004, 0.005, 0.08, 0.1, 0.4]:
    model = None
    if r==0:
        ckp = path.join(".cache", "logs", "igrida_trained", "0", "best.ckpt")
        model = BertNliLight.load_from_checkpoint(ckp)
    else :
        ckp = path.join(".cache", "logs", "igrida_trained", f"reg_mul={r}", "best.ckpt")
        model = BertNliRegu.load_from_checkpoint(ckp)
    models_dict[f"reg_mul={r}"] = model.eval()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.de

## The data

In [27]:
data_dir = os.path.join(".cache", "raw_data", "e_snli")

dm = SNLIDataModule(cache=data_dir,
                   batch_size = 1,
                   num_workers = 4,
                   nb_data = -1)

dm.prepare_data()

dm.setup(stage="test")

test_dataset = dm.test_set
test_dataloader = dm.test_dataloader()

## The HTML table for the visualization

In [28]:
from attention_algorithms.plausibility_visu import hightlight_txt # function to highlight the text
from attention_algorithms.attention_metrics import normalize_attention
from IPython.display import display, HTML

In [29]:
def html_render(model_outputs):
    html = ''
    table_len = len(model_outputs['GROUNDTRUTH']['Entropy'])
    for i in range(table_len):
        html += '<table>'
        html += '<tr><th></th>' # One xtra head for model's name
        for column_name in model_outputs['GROUNDTRUTH'].keys():
            html+= '<th>'+ column_name +'</th>'
        html += ' </tr>'
        for name, model_content in model_outputs.items():
            html += '<tr>'
            html += '<td><b>' + name + '</b></td>'

            for k, output in model_content.items():
                displ = output[i] if output is not None else 'N/A'
                if isinstance(displ, float):
                    displ = str(round(displ, 3))
                html += '<td>' + displ + '</td>'

            html += '</tr>'

        html += '</table>'
    return html

## Look at different metrics

In [30]:
INF = 1e30

### Agregation on the entire model

In [37]:
with torch.no_grad():
    
    display(HTML('<h4>ALL layer sum agreg</h4>'))
    model_outputs = {}
    
    for id_batch, elem in enumerate(test_dataloader) :

        if id_batch > 5:
            # only look at 5 sentences (batch of one here)
            break

        ids = elem["input_ids"]
        masks = elem["attention_masks"]
        labels = elem["labels"]
        a_true = list(np.array(elem["annotations"][0].numpy(), dtype=float))
        
        spe_tok_mask = torch.isin(ids, torch.tensor([0, 101, 102])).type(torch.uint8)[0]
        
        it = 0

        # the attention_inst
        raw_attention_inst = RawAttention(model = models_dict["reg_mul=0"],
                                         input_ids = ids,
                                         attention_mask = masks,
                                         test_mod = False
                                         )
        if it == 0:
            model_outputs["GROUNDTRUTH"] = {
                '[CLS] + P + [SEP] + H + [SEP]':  [hightlight_txt(tokens = raw_attention_inst.tokens,
                                                                attention = a_true[0 : len(raw_attention_inst.tokens)])],
                'Entropy': [0.0],
                '&#x3A3;a':[0.0]
            }

            it += 1
        
        # softmax calculus
        a_hat = raw_attention_inst.attention_tensor[0, :, :, :, :].sum(dim=0).sum(dim=0).sum(dim=0)
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:len(raw_attention_inst.tokens)], dim=0)
        a_visu = normalize_attention(attention=a_hat, tokens=raw_attention_inst.tokens)
        
        model_outputs["softmax"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = raw_attention_inst.tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            '&#x3A3;a':[a_hat.sum().item()]
        }
        
        # mean calculus
        a_hat = raw_attention_inst.attention_tensor[0, :, :, :, :].sum(dim=2)/len(raw_attention_inst.tokens)
        a_hat = torch.mul(a_hat, 1 - spe_tok_mask[0:len(raw_attention_inst.tokens)])
        a_hat = a_hat.sum(dim=0).sum(dim=0)/144
        a_visu = normalize_attention(attention=a_hat, tokens=raw_attention_inst.tokens)
        
        model_outputs["mean"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = raw_attention_inst.tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            '&#x3A3;a':[a_hat.sum().item()]
        }
        
        # mean & softmax calculus
        a_hat = raw_attention_inst.attention_tensor[0, :, :, :, :].sum(dim=2)/len(raw_attention_inst.tokens)
        a_hat = a_hat.sum(dim=0).sum(dim=0)/144
        #a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:len(raw_attention_inst.tokens)], dim=0)
        a_visu = normalize_attention(attention=a_hat, tokens=raw_attention_inst.tokens)
        
        model_outputs["mean & softmax"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = raw_attention_inst.tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            '&#x3A3;a':[a_hat.sum().item()]
        }
        
        ## the other method to evaluate the attention map (good before the visualization)
        # min - max scaler
        a_hat = raw_attention_inst.attention_tensor[0, :, :, :, :].sum(dim=0).sum(dim=0).sum(dim=0)
        a_hat = normalize_attention(attention=a_hat, tokens=raw_attention_inst.tokens)
        model_outputs["min max scaler"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = raw_attention_inst.tokens,
                                                            attention = a_hat)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            '&#x3A3;a':[a_hat.sum().item()]
        }
        
            
        display(HTML(html_render(model_outputs)))

Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.0,0.0
softmax,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.0,1.0
mean,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],1.94,0.496
mean & softmax,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],2.79,1.0
min max scaler,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],9.12,8.786


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.0,0.0
softmax,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.012,1.0
mean,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],1.924,0.497
mean & softmax,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],2.778,1.0
min max scaler,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],8.765,9.391


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,0.0
softmax,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,1.0
mean,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],2.107,0.591
mean & softmax,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],2.898,1.0
min max scaler,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],7.063,4.817


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.0,0.0
softmax,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.0,1.0
mean,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",1.658,0.441
mean & softmax,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",2.548,1.0
min max scaler,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",6.879,7.863


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.0,0.0
softmax,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.002,1.0
mean,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",1.688,0.444
mean & softmax,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",2.577,1.0
min max scaler,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",7.28,7.904


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.0,0.0
softmax,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.0,1.0
mean,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",1.717,0.456
mean & softmax,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",2.597,1.0
min max scaler,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",7.234,7.897


### CLS map

In [38]:
with torch.no_grad():
    
    display(HTML('<h4>CLS agreg</h4>'))
    model_outputs = {}
    
    for id_batch, elem in enumerate(test_dataloader) :

        if id_batch > 5:
            # only look at 5 sentences (batch of one here)
            break

        ids = elem["input_ids"]
        masks = elem["attention_masks"]
        labels = elem["labels"]
        a_true = list(np.array(elem["annotations"][0].numpy(), dtype=float))
        
        spe_tok_mask = torch.isin(ids, torch.tensor([0, 101, 102])).type(torch.uint8)[0]
        
        it = 0

        # the attention_inst
        raw_attention_inst = RawAttention(model = models_dict["reg_mul=0"],
                                         input_ids = ids,
                                         attention_mask = masks,
                                         test_mod = False
                                         )
        if it == 0:
            model_outputs["GROUNDTRUTH"] = {
                '[CLS] + P + [SEP] + H + [SEP]':  [hightlight_txt(tokens = raw_attention_inst.tokens,
                                                                attention = a_true[0 : len(raw_attention_inst.tokens)])],
                'Entropy': [0.0],
                '&#x3A3;a':[0.0]
            }

            it += 1
        
        # softmax calculus
        a_hat = raw_attention_inst.attention_tensor[0, :, :, 0, :].sum(dim=0).sum(dim=0)
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:len(raw_attention_inst.tokens)], dim=0)
        a_visu = normalize_attention(attention=a_hat, tokens=raw_attention_inst.tokens)
        
        model_outputs["softmax"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = raw_attention_inst.tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            '&#x3A3;a':[a_hat.sum().item()]
        }
        
        # mean calculus
        a_hat = raw_attention_inst.attention_tensor[0, :, :, 0, :]
        a_hat = torch.mul(a_hat, 1 - spe_tok_mask[0:len(raw_attention_inst.tokens)])
        a_hat = a_hat.sum(dim=0).sum(dim=0)/144
        a_visu = normalize_attention(attention=a_hat, tokens=raw_attention_inst.tokens)
        
        model_outputs["mean"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = raw_attention_inst.tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            '&#x3A3;a':[a_hat.sum().item()]
        }
        
        # mean & softmax calculus
        a_hat = raw_attention_inst.attention_tensor[0, :, :, 0, :]
        a_hat = a_hat.sum(dim=0).sum(dim=0)/144
        #a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:len(raw_attention_inst.tokens)], dim=0)
        a_visu = normalize_attention(attention=a_hat, tokens=raw_attention_inst.tokens)
        
        model_outputs["mean & softmax"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = raw_attention_inst.tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            '&#x3A3;a':[a_hat.sum().item()]
        }
        
        ## the other method to evaluate the attention map (good before the visualization)
        # min - max scaler
        a_hat = raw_attention_inst.attention_tensor[0, :, :, 0, :].sum(dim=0).sum(dim=0)
        a_hat = normalize_attention(attention=a_hat, tokens=raw_attention_inst.tokens)
        model_outputs["min max scaler"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = raw_attention_inst.tokens,
                                                            attention = a_hat)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            '&#x3A3;a':[a_hat.sum().item()]
        }
        
            
        display(HTML(html_render(model_outputs)))

Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.0,0.0
softmax,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.411,1.0
mean,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],1.77,0.495
mean & softmax,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],2.65,1.0
min max scaler,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],6.441,5.574


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.0,0.0
softmax,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.479,1.0
mean,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],1.865,0.514
mean & softmax,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],2.734,1.0
min max scaler,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],7.113,6.704


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,0.0
softmax,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.018,1.0
mean,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],2.014,0.589
mean & softmax,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],2.825,1.0
min max scaler,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],6.437,5.154


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.0,0.0
softmax,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.02,1.0
mean,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",1.535,0.473
mean & softmax,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",2.436,1.0
min max scaler,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",4.425,3.924


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.0,0.0
softmax,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.706,1.0
mean,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",1.609,0.468
mean & softmax,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",2.512,1.0
min max scaler,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",5.339,5.611


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.0,0.0
softmax,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.981,1.0
mean,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",1.639,0.491
mean & softmax,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",2.528,1.0
min max scaler,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",5.007,5.979
