based on the previous observation about the entropia, we will study some particular layers

In [9]:
# preparation of the environment
%load_ext autoreload
%autoreload 2

INF = 1e30

import os
from os import path


# set the repository to the git repository
cwd = os.getcwd().split(os.path.sep)
while cwd[-1] != "stage_4_gm":
    os.chdir("..")
    cwd = os.getcwd().split(os.path.sep)
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()
from tqdm.notebook import tqdm
from torch_set_up import DEVICE
from training_bert import BertNliLight
from regularize_training_bert import SNLIDataModule
from attention_algorithms.attention_metrics import default_plot_colormap

from transformers import BertTokenizer
tk = BertTokenizer.from_pretrained('bert-base-uncased')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
ckp = path.join(".cache", "logs", "igrida_trained", "0", "best.ckpt")
model = BertNliLight.load_from_checkpoint(ckp)
model = model.eval()

data_dir = os.path.join(".cache", "raw_data", "e_snli")

dm = SNLIDataModule(cache=data_dir,
                   batch_size = 1,
                   num_workers = 1,
                   nb_data = -1)

dm.prepare_data()

dm.setup(stage="test")

test_dataset = dm.test_set
test_dataloader = dm.test_dataloader()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## First a bit of visualisation

In [11]:
from attention_algorithms.plausibility_visu import hightlight_txt # function to highlight the text
from attention_algorithms.attention_metrics import normalize_attention
from IPython.display import display, HTML

In [12]:
def html_render(model_outputs):
    html = ''
    table_len = len(model_outputs['GROUNDTRUTH']['Entropy'])
    for i in range(table_len):
        html += '<table>'
        html += '<tr><th></th>' # One xtra head for model's name
        for column_name in model_outputs['GROUNDTRUTH'].keys():
            html+= '<th>'+ column_name +'</th>'
        html += ' </tr>'
        for name, model_content in model_outputs.items():
            html += '<tr>'
            html += '<td><b>' + name + '</b></td>'

            for k, output in model_content.items():
                displ = output[i] if output is not None else 'N/A'
                if isinstance(displ, float):
                    displ = str(round(displ, 3))
                html += '<td>' + displ + '</td>'

            html += '</tr>'

        html += '</table>'
    return html

In [16]:
INF = 1e30
with torch.no_grad():

    display(HTML('<h4>Different type of agregation (Line agregation)</h4>'))
    model_outputs = {}

    for id_batch, elem in enumerate(test_dataloader) :

        if id_batch > 5:
            # only look at 5 sentences (batch of one here)
            break

        ids = elem["input_ids"]
        masks = elem["attention_masks"]
        labels = elem["labels"]
        a_true = list(np.array(elem["annotations"][0].numpy(), dtype=float))

        special_tokens = [0, 101, 102]
        spe_tok_mask = torch.isin(ids, torch.tensor(special_tokens))[0].type(torch.uint8)

        m = masks[0].sum() # nb tokens in the sentence
        tokens = tk.convert_ids_to_tokens(ids[0])[0:m]

        it = 0

        if it == 0:
            model_outputs["GROUNDTRUTH"] = {
                '[CLS] + P + [SEP] + H + [SEP]':  [hightlight_txt(tokens = tokens,
                                                                attention = a_true[0 : m])],
                'Entropy': [0.0],
                '&#x3A3;a':[0.0]
            }

            it += 1

        output = model(input_ids = ids,
                       attention_mask = masks)

        # process the attention_tensor
        attention_tensor = torch.stack(output["outputs"].attentions, dim=1) # shape [b, l, h, T, T]
        pad = torch.tensor([0])
        pad_mask = torch.logical_not(torch.isin(ids, pad)).type(torch.uint8).unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 12, 12, 150, 1)
        pad_mask = torch.transpose(pad_mask, dim0=3, dim1=4)
        attention_tensor = torch.mul(attention_tensor, pad_mask)



        # all the layer agregation
        a_hat = attention_tensor[0, :, :, 0:m, :].sum(dim=2).sum(dim=0).sum(dim=0)
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=0)

        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["all agreg"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            '&#x3A3;a':[a_hat.sum().item()]
        }

        # layer 4 to 10 agregation
        a_hat = attention_tensor[0, 3:10, :, :, :] # select only some layers
        a_hat = a_hat.sum(dim=2)/12 # mean head agregation
        a_hat = a_hat.sum(dim=0)/7 # mean over the layers
        a_hat = a_hat.sum(dim=0)
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=0)
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["layer 4 to 10"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            '&#x3A3;a':[a_hat.sum().item()]
        }
        # layer 5 to 10 agregation
        a_hat = attention_tensor[0, 4:10, :, :, :] # select only some layers
        a_hat = a_hat.sum(dim=2)/12 # mean head agregation
        a_hat = a_hat.sum(dim=0)/7 # mean over the layers
        a_hat = a_hat.sum(dim=0)
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=0)
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["layer 5 to 10"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            '&#x3A3;a':[a_hat.sum().item()]
        }


        display(HTML(html_render(model_outputs)))

Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.0,0.0
all agreg,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.0,1.0
layer 4 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],3.306,1.0
layer 5 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],3.309,1.0


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.0,0.0
all agreg,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.012,1.0
layer 4 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],3.279,1.0
layer 5 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],3.284,1.0


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,0.0
all agreg,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,1.0
layer 4 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],1.402,1.0
layer 5 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],1.447,1.0


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.0,0.0
all agreg,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.0,1.0
layer 4 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",3.074,1.0
layer 5 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",3.076,1.0


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.0,0.0
all agreg,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.002,1.0
layer 4 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",3.123,1.0
layer 5 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",3.124,1.0


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.0,0.0
all agreg,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.0,1.0
layer 4 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",3.118,1.0
layer 5 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",3.12,1.0


Maybe a regularization criterion is possible to find with this last discovery we clearly see that some information is possible to find here but we need to regularize the entropia on this special

## Now a bit a metrics to have a closer look to what happens