In [16]:
# preparation of the environment
%load_ext autoreload
%autoreload 2

import os
from os import path
# set the repository to the git repository
cwd = os.getcwd().split(os.path.sep)
while cwd[-1] != "stage_4_gm":
    os.chdir("..")
    cwd = os.getcwd().split(os.path.sep)

import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
from training_bert import BertNliLight
from regularize_training_bert import SNLIDataModule
from regularize_training_bert import BertNliRegu
from torch.utils.data import DataLoader

from transformers import BertTokenizer
tk = BertTokenizer.from_pretrained('bert-base-uncased')

# --> from this environment
from attention_algorithms.raw_attention import RawAttention
from attention_algorithms.attention_metrics import normalize_attention

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [17]:
from attention_algorithms.plausibility_visu import hightlight_txt # function to highlight the text
from attention_algorithms.attention_metrics import normalize_attention
from IPython.display import display, HTML

In [18]:
ckp = path.join(".cache", "logs", "igrida_trained", "0", "best.ckpt")
model = BertNliLight.load_from_checkpoint(ckp)
model = model.eval()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [19]:
def html_render(model_outputs):
    html = ''
    table_len = len(model_outputs['GROUNDTRUTH']['Entropy'])
    for i in range(table_len):
        html += '<table>'
        html += '<tr><th></th>' # One xtra head for model's name
        for column_name in model_outputs['GROUNDTRUTH'].keys():
            html+= '<th>'+ column_name +'</th>'
        html += ' </tr>'
        for name, model_content in model_outputs.items():
            html += '<tr>'
            html += '<td><b>' + name + '</b></td>'

            for k, output in model_content.items():
                displ = output[i] if output is not None else 'N/A'
                if isinstance(displ, float):
                    displ = str(round(displ, 3))
                html += '<td>' + displ + '</td>'

            html += '</tr>'

        html += '</table>'
    return html

In [20]:
data_dir = os.path.join(".cache", "raw_data", "e_snli")

dm = SNLIDataModule(cache=data_dir,
                   batch_size = 1,
                   num_workers = 4,
                   nb_data = -1 # multiple of three for the consistency
                   )

dm.prepare_data()

dm.setup(stage="test")

test_dataset = dm.test_set
test_dataloader = dm.test_dataloader()

## Agregation over the layers

In [21]:
INF = 1e30
with torch.no_grad():

    display(HTML('<h4>Different type of agregation (Line agregation)</h4>'))
    model_outputs = {}

    for id_batch, elem in enumerate(test_dataloader) :

        if id_batch > 5:
            # only look at 5 sentences (batch of one here)
            break

        ids = elem["input_ids"]
        masks = elem["attention_masks"]
        labels = elem["labels"]
        a_true = list(np.array(elem["annotations"][0].numpy(), dtype=float))

        special_tokens = list(range(999, 1014)) + [0, 101, 102]
        spe_tok_mask = torch.isin(ids, torch.tensor(special_tokens))[0].type(torch.uint8)

        m = masks[0].sum() # nb tokens in the sentence
        tokens = tk.convert_ids_to_tokens(ids[0])[0:m]

        it = 0

        if it == 0:
            model_outputs["GROUNDTRUTH"] = {
                '[CLS] + P + [SEP] + H + [SEP]':  [hightlight_txt(tokens = tokens,
                                                                attention = a_true[0 : m])],
                'Entropy': [0.0],
                '&#x3A3;a':[0.0]
            }

            it += 1

        output = model(input_ids = ids,
                       attention_mask = masks)
        attention_tensor = torch.stack(output["outputs"].attentions, dim=1)

        # softmax calculus
        a_hat = attention_tensor[0, :, :, 0:m, :].sum(dim=2).sum(dim=0).sum(dim=0)
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=0)

        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["softmax"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            '&#x3A3;a':[a_hat.sum().item()]
        }

        # min max calculus
        a_hat = attention_tensor[0, :, :, 0:m, :].sum(dim=2).sum(dim=0).sum(dim=0)
        a_hat = a_hat[0:m]
        a_hat = normalize_attention(attention=a_hat, tokens=tokens)
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["min-max"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            '&#x3A3;a':[a_hat.sum().item()]
        }

        # mean calculus
        a_hat = attention_tensor[0, :, :, 0:m, :].mean(dim=2).sum(dim=1).sum(dim=0)/144
        a_hat = a_hat[0:m]
        a_hat = torch.mul(a_hat, 1 - spe_tok_mask[0:m])
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["mean"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            '&#x3A3;a':[a_hat.sum().item()]
        }


        raw_attention_inst = RawAttention(model=model,
                                          input_ids=ids,
                                          attention_mask=masks,
                                          test_mod=False,
                                          test=None)

        a_hat = raw_attention_inst.attention_tensor[0, :, :, 0:m, :].sum(dim=2).sum(dim=0).sum(dim=0)
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=0)
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["softmax raw"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            '&#x3A3;a':[a_hat.sum().item()]
        }


        display(HTML(html_render(model_outputs)))

Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.0,0.0
softmax,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.0,1.0
min-max,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],8.518,12.062
mean,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],1.666,0.407
softmax raw,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.0,1.0


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.0,0.0
softmax,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.0,1.0
min-max,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],6.805,15.551
mean,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],1.616,0.392
softmax raw,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.0,1.0


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,0.0
softmax,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,1.0
min-max,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],6.332,4.161
mean,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],1.848,0.511
softmax raw,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,1.0


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.0,0.0
softmax,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.0,1.0
min-max,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",6.258,8.413
mean,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",1.278,0.318
softmax raw,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.0,1.0


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.0,0.0
softmax,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.0,1.0
min-max,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",5.932,11.305
mean,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",1.307,0.319
softmax raw,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.0,1.0


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.0,0.0
softmax,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.0,1.0
min-max,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",6.235,10.086
mean,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",1.339,0.332
softmax raw,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.0,1.0


## The CLS Map

In [22]:
INF = 1e30
with torch.no_grad():

    display(HTML('<h4>Different type of agregation (CLS)</h4>'))
    model_outputs = {}

    for id_batch, elem in enumerate(test_dataloader) :

        if id_batch > 5:
            # only look at 5 sentences (batch of one here)
            break

        ids = elem["input_ids"]
        masks = elem["attention_masks"]
        labels = elem["labels"]
        a_true = list(np.array(elem["annotations"][0].numpy(), dtype=float))

        special_tokens = list(range(999, 1014)) + [0, 101, 102]
        spe_tok_mask = torch.isin(ids, torch.tensor(special_tokens))[0].type(torch.uint8)

        m = masks[0].sum() # nb tokens in the sentence
        tokens = tk.convert_ids_to_tokens(ids[0])[0:m]

        it = 0

        if it == 0:
            model_outputs["GROUNDTRUTH"] = {
                '[CLS] + P + [SEP] + H + [SEP]':  [hightlight_txt(tokens = tokens,
                                                                attention = a_true[0 : m])],
                'Entropy': [0.0],
                '&#x3A3;a':[0.0]
            }

            it += 1

        output = model(input_ids = ids,
                       attention_mask = masks)
        attention_tensor = torch.stack(output["outputs"].attentions, dim=1)
        # softmax calculus
        a_hat = attention_tensor[0, :, :, 0, 0 : m]
        a_hat = a_hat.sum(dim=0).sum(dim=0)
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=0)
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["softmax"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            '&#x3A3;a':[a_hat.sum().item()]
        }

        a_hat = attention_tensor[0, :, :, 0, 0 : m]
        a_hat = a_hat.sum(dim=0).sum(dim=0)
        a_hat /= 144
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)
        model_outputs["mean"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            '&#x3A3;a':[a_hat.sum().item()]
        }

        a_hat = attention_tensor[0, :, :, 0, 0 : m]
        a_hat = a_hat.sum(dim=0).sum(dim=0)
        a_hat = normalize_attention(attention=a_hat, tokens=tokens)
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)
        model_outputs["min-max"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            '&#x3A3;a':[a_hat.sum().item()]
        }


        display(HTML(html_render(model_outputs)))

Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.0,0.0
softmax,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.098,1.0
mean,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],2.65,1.0
min-max,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],5.917,4.402


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.0,0.0
softmax,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.497,1.0
mean,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],2.734,1.0
min-max,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],7.243,6.469


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,0.0
softmax,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.005,1.0
mean,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],2.825,1.0
min-max,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],5.771,4.135


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.0,0.0
softmax,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.001,1.0
mean,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",2.436,1.0
min-max,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",3.566,2.733


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.0,0.0
softmax,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.026,1.0
mean,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",2.512,1.0
min-max,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",4.734,3.88


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.0,0.0
softmax,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.693,1.0
mean,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",2.528,1.0
min-max,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",4.431,4.564
