based on the previous observation about the entropia, we will study some particular layers

In [1]:
# preparation of the environment
%load_ext autoreload
%autoreload 2

INF = 1e30

import os
from os import path


# set the repository to the git repository
cwd = os.getcwd().split(os.path.sep)
while cwd[-1] != "stage_4_gm":
    os.chdir("..")
    cwd = os.getcwd().split(os.path.sep)
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()
from tqdm.notebook import tqdm
from torch_set_up import DEVICE
from training_bert import BertNliLight
from regularize_training_bert import SNLIDataModule
from attention_algorithms.attention_metrics import default_plot_colormap

from transformers import BertTokenizer
tk = BertTokenizer.from_pretrained('bert-base-uncased')

In [2]:
ckp = path.join(".cache", "logs", "igrida_trained", "0", "best.ckpt")
model = BertNliLight.load_from_checkpoint(ckp)
model = model.eval()

data_dir = os.path.join(".cache", "raw_data", "e_snli")

dm = SNLIDataModule(cache=data_dir,
                   batch_size = 1,
                   num_workers = 1,
                   nb_data = -1)

dm.prepare_data()

dm.setup(stage="test")

test_dataset = dm.test_set
test_dataloader = dm.test_dataloader()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## First a bit of visualisation

In [3]:
from attention_algorithms.plausibility_visu import hightlight_txt # function to highlight the text
from attention_algorithms.attention_metrics import normalize_attention
from IPython.display import display, HTML

In [4]:
def html_render(model_outputs):
    html = ''
    table_len = len(model_outputs['GROUNDTRUTH']['Entropy'])
    for i in range(table_len):
        html += '<table>'
        html += '<tr><th></th>' # One xtra head for model's name
        for column_name in model_outputs['GROUNDTRUTH'].keys():
            html+= '<th>'+ column_name +'</th>'
        html += ' </tr>'
        for name, model_content in model_outputs.items():
            html += '<tr>'
            html += '<td><b>' + name + '</b></td>'

            for k, output in model_content.items():
                displ = output[i] if output is not None else 'N/A'
                if isinstance(displ, float):
                    displ = str(round(displ, 3))
                html += '<td>' + displ + '</td>'

            html += '</tr>'

        html += '</table>'
    return html

In [6]:
INF = 1e30
with torch.no_grad():

    display(HTML('<h4>Different type of agregation (Line agregation)</h4>'))
    model_outputs = {}

    for id_batch, elem in enumerate(test_dataloader) :

        if id_batch > 5:
            # only look at 5 sentences (batch of one here)
            break

        ids = elem["input_ids"]
        masks = elem["attention_masks"]
        labels = elem["labels"]
        a_true = list(np.array(elem["annotations"][0].numpy(), dtype=float))

        special_tokens = [0, 101, 102]
        spe_tok_mask = torch.isin(ids, torch.tensor(special_tokens))[0].type(torch.uint8)

        m = masks[0].sum() # nb tokens in the sentence
        tokens = tk.convert_ids_to_tokens(ids[0])[0:m]

        it = 0

        if it == 0:
            model_outputs["GROUNDTRUTH"] = {
                '[CLS] + P + [SEP] + H + [SEP]':  [hightlight_txt(tokens = tokens,
                                                                attention = a_true[0 : m])],
                'Entropy': [0.0],
                '&#x3A3;a':[0.0]
            }

            it += 1

        output = model(input_ids = ids,
                       attention_mask = masks)

        # process the attention_tensor
        attention_tensor = torch.stack(output["outputs"].attentions, dim=1) # shape [b, l, h, T, T]
        pad = torch.tensor([0])
        pad_mask = torch.logical_not(torch.isin(ids, pad)).type(torch.uint8).unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 12, 12, 150, 1)
        pad_mask = torch.transpose(pad_mask, dim0=3, dim1=4)
        attention_tensor = torch.mul(attention_tensor, pad_mask)

        # all the layer agregation
        a_hat = attention_tensor[0, :, :, :, :]
        a_hat = a_hat.sum(dim=1)/12 # mean over the heads
        a_hat = a_hat.sum(dim=0) # mean over the layers
        a_hat = a_hat.sum(dim=0) # line agregation
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=0)

        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["all agreg"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_hat)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            '&#x3A3;a':[a_hat.sum().item()]
        }

        # layer 4 to 10 agregation
        a_hat = attention_tensor[0, 3:10, :, :, :] # select only some layers
        a_hat = a_hat.sum(dim=1)/12 # mean head agregation
        a_hat = a_hat.sum(dim=0) # mean over the layers
        a_hat = a_hat.sum(dim=0) # line agregation
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=0)
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["layer 4 to 10"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            '&#x3A3;a':[a_hat.sum().item()]
        }
        # layer 5 to 10 agregation
        a_hat = attention_tensor[0, 4:10, :, :, :] # select only some layers
        a_hat = a_hat.sum(dim=1)/12 # mean head agregation
        a_hat = a_hat.sum(dim=0) # mean over the layers
        a_hat = a_hat.sum(dim=0) # line agregation
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=-1)
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["layer 5 to 10"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            '&#x3A3;a':[a_hat.sum().item()]
        }

        # layer 0 to 10 agregation
        a_hat = attention_tensor[0, 0:10, :, :, :] # select only some layers
        a_hat = a_hat.sum(dim=1)/12 # mean head agregation
        a_hat = a_hat.sum(dim=0) # mean over the layers
        a_hat = a_hat.sum(dim=0)
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=-1)
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["layer 0 to 10"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            '&#x3A3;a':[a_hat.sum().item()]
        }


        display(HTML(html_render(model_outputs)))

Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.0,0.0
all agreg,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.004,1.0
layer 4 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],1.917,1.0
layer 5 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],1.872,1.0
layer 0 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],1.747,1.0


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.0,0.0
all agreg,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.66,1.0
layer 4 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],2.347,1.0
layer 5 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],2.634,1.0
layer 0 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],2.15,1.0


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,0.0
all agreg,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,1.0
layer 4 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,1.0
layer 5 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,1.0
layer 0 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,1.0


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.0,0.0
all agreg,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.556,1.0
layer 4 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",2.021,1.0
layer 5 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",2.019,1.0
layer 0 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",1.975,1.0


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.0,0.0
all agreg,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.639,1.0
layer 4 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",2.387,1.0
layer 5 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",2.289,1.0
layer 0 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",2.262,1.0


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Σa
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.0,0.0
all agreg,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.383,1.0
layer 4 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",2.087,1.0
layer 5 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",2.127,1.0
layer 0 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",1.892,1.0


Maybe a regularization criterion is possible to find with this last discovery we clearly see that some information is possible to find here but we need to regularize the entropia on this special

## Now a bit a metrics to have a closer look to what happens

before running the following cells we need to execute `entropy_study_layers_study.py` in the folder `inference_scripts`
the script was executed with the following command line :
```{command line}
python .\inference_scripts\entropy_study_layers_study.py --batch_size 4
```

In [8]:
def html_render(model_outputs):
    html = ''

    table_len = len(model_outputs['all_layers']['AUC'])
    for i in range(table_len):
        html += '<table>'
        html += '<tr><th></th>' # One xtra head for model's name
        for column_name in model_outputs['all_layers'].keys():
            html+= '<th>'+ column_name +'</th>'
        html += ' </tr>'
        for name, model_content in model_outputs.items():
            html += '<tr>'
            html += '<td><b>' + name + '</b></td>'

            for k, output in model_content.items():
                displ = output[i] if output is not None else 'N/A'
                if isinstance(displ, float):
                    displ = str(round(displ, 3))
                html += '<td>' + displ + '</td>'

            html += '</tr>'

        html += '</table>'
    return html

In [12]:
# load the data
import pickle
dir = os.path.join(".cache", "plots", "entropy_study")
with open(os.path.join(dir, "a_true.pickle"), "rb") as f:
    a_true = pickle.load(f)
with open(os.path.join(dir, "all_layers.pickle"), "rb") as f:
    all_layers = pickle.load(f)
with open(os.path.join(dir, "layers_1_10.pickle"), "rb") as f:
    layers_1_10 = pickle.load(f)
with open(os.path.join(dir, "layers_4_10.pickle"), "rb") as f:
    layers_4_10 = pickle.load(f)
with open(os.path.join(dir, "layers_5_10.pickle"), "rb") as f:
    layers_5_10 = pickle.load(f)


In [13]:
def scalar_jaccard(y_true, y_pred):
    num = np.dot(y_true, y_pred)
    den = sum(y_true)+sum(y_pred)-np.dot(y_true, y_pred)
    return num/den

In [None]:
from sklearn.metrics import precision_score
from sklearn.metrics import auc

def au_precision_curve(y_true, y_pred):
    tr = np.linspace(0, 1, 150)
    curve = []
    for t in tr:
        preds = 1*(y_pred>=tr)
        curve.append(y_true, preds)
        
    return auc(x=tr, y=curve)
        


In [14]:
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score

with torch.no_grad():
    for k in ["entailement", "neutral", "contradiction"]:
        display(HTML(f'<h4>metric for the label : {k}</h4>'))
        metric_output = {}
        metric_output["all_layers"] = {
            "AUC": [roc_auc_score(a_true[k],all_layers[k])],
            "Jaccard": [scalar_jaccard(a_true[k], all_layers[k])],
            "AUPRC" : [average_precision_score(a_true[k], all_layers[k])],
            "AU - Precision" : [au_precision_curve(a_true[k], all_layers[k])]
        }

        metric_output["layers_1_10"] = {
            "AUC": [roc_auc_score(a_true[k],layers_1_10[k])],
            "Jaccard": [scalar_jaccard(a_true[k], layers_1_10[k])],
            "AUPRC" : [average_precision_score(a_true[k], layers_1_10[k])],
            "AU - Precision" : [au_precision_curve(a_true[k], layers_1_10[k])]
        }

        metric_output["layers_4_10"] = {
            "AUC": [roc_auc_score(a_true[k],layers_4_10[k])],
            "Jaccard": [scalar_jaccard(a_true[k], layers_4_10[k])],
            "AUPRC" : [average_precision_score(a_true[k], layers_4_10[k])],
            "AU - Precision" : [au_precision_curve(a_true[k], layers_1_10[k])]
        }

        metric_output["layers_5_10"] = {
            "AUC": [roc_auc_score(a_true[k],layers_5_10[k])],
            "Jaccard": [scalar_jaccard(a_true[k], layers_5_10[k])],
            "AUPRC" : [average_precision_score(a_true[k], layers_5_10[k])]
        }

        display(HTML(html_render(metric_output)))

Unnamed: 0,AUC,Jaccard,AUPRC
all_layers,0.956,0.061,0.416
layers_1_10,0.956,0.056,0.412
layers_4_10,0.956,0.055,0.409
layers_5_10,0.955,0.055,0.408


Unnamed: 0,AUC,Jaccard,AUPRC
all_layers,0.955,0.087,0.251
layers_1_10,0.953,0.104,0.265
layers_4_10,0.951,0.1,0.256
layers_5_10,0.95,0.099,0.257


Unnamed: 0,AUC,Jaccard,AUPRC
all_layers,0.959,0.102,0.424
layers_1_10,0.961,0.102,0.453
layers_4_10,0.959,0.102,0.451
layers_5_10,0.959,0.102,0.452
