based on the previous observation about the entropia, we will study some particular layers

In [55]:
# preparation of the environment
%load_ext autoreload
%autoreload 2

INF = 1e30

import os
from os import path


# set the repository to the git repository
cwd = os.getcwd().split(os.path.sep)
while cwd[-1] != "stage_4_gm":
    os.chdir("..")
    cwd = os.getcwd().split(os.path.sep)
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()
from tqdm.notebook import tqdm
from torch_set_up import DEVICE
from training_bert import BertNliLight
from regularize_training_bert import SNLIDataModule
from attention_algorithms.attention_metrics import default_plot_colormap

from transformers import BertTokenizer
tk = BertTokenizer.from_pretrained('bert-base-uncased')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [56]:
ckp = path.join(".cache", "logs", "igrida_trained", "0", "best.ckpt")
model = BertNliLight.load_from_checkpoint(ckp)
model = model.eval()

data_dir = os.path.join(".cache", "raw_data", "e_snli")

dm = SNLIDataModule(cache=data_dir,
                   batch_size = 1,
                   num_workers = 1,
                   nb_data = -1)

dm.prepare_data()

dm.setup(stage="test")

test_dataset = dm.test_set
test_dataloader = dm.test_dataloader()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## First a bit of visualisation

In [57]:
from attention_algorithms.plausibility_visu import hightlight_txt # function to highlight the text
from attention_algorithms.attention_metrics import normalize_attention
from IPython.display import display, HTML

In [58]:
def html_render(model_outputs):
    html = ''
    table_len = len(model_outputs['GROUNDTRUTH']['Entropy'])
    for i in range(table_len):
        html += '<table>'
        html += '<tr><th></th>' # One xtra head for model's name
        for column_name in model_outputs['GROUNDTRUTH'].keys():
            html+= '<th>'+ column_name +'</th>'
        html += ' </tr>'
        for name, model_content in model_outputs.items():
            html += '<tr>'
            html += '<td><b>' + name + '</b></td>'

            for k, output in model_content.items():
                displ = output[i] if output is not None else 'N/A'
                if isinstance(displ, float):
                    displ = str(round(displ, 3))
                html += '<td>' + displ + '</td>'

            html += '</tr>'

        html += '</table>'
    return html

In [59]:
INF = 1e30
LABELS = ["E", "N", "C"]
with torch.no_grad():

    display(HTML('<h4>Different type of agregation (Line agregation)</h4>'))
    model_outputs = {}

    for id_batch, elem in enumerate(test_dataloader) :

        if id_batch > 8:
            # only look at 5 sentences (batch of one here)
            break

        ids = elem["input_ids"]
        masks = elem["attention_masks"]
        labels = elem["labels"]
        a_true = elem["annotations"]
        
        
        # ids of the specials tokens
        special_tokens = [0, 101, 102]
        spe_tok_mask = torch.isin(ids, torch.tensor(special_tokens))[0].type(torch.uint8)
        
        # ids of the punctuation
        punct = list(range(999, 1037))
        punct_pos = 1 - torch.isin(ids, torch.tensor(punct)).type(torch.uint8)
        a_true = torch.mul(a_true, punct_pos) # we don't want the punctuation in our annotation
        a_true = list(np.array(a_true[0].numpy(), dtype=float))

        m = masks[0].sum() # nb tokens in the sentence
        tokens = tk.convert_ids_to_tokens(ids[0])[0:m]

        it = 0

        if it == 0:
            model_outputs["GROUNDTRUTH"] = {
                '[CLS] + P + [SEP] + H + [SEP]':  [hightlight_txt(tokens = tokens,
                                                                attention = a_true[0 : m])],
                'Entropy': [0.0],
                'Label':LABELS[labels[0]]
            }

            it += 1

        output = model(input_ids = ids,
                       attention_mask = masks)

        # process the attention_tensor
        attention_tensor = torch.stack(output["outputs"].attentions, dim=1) # shape [b, l, h, T, T]
        pad = torch.tensor([0])
        pad_mask = torch.logical_not(torch.isin(ids, pad)).type(torch.uint8).unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 12, 12, 150, 1)
        pad_mask = torch.transpose(pad_mask, dim0=3, dim1=4)
        attention_tensor = torch.mul(attention_tensor, pad_mask)

        # all the layer agregation
        a_hat = attention_tensor[0, :, :, :, :]
        a_hat = a_hat.sum(dim=1)/12 # mean over the heads
        a_hat = a_hat.sum(dim=0) # mean over the layers
        a_hat = a_hat.sum(dim=0) # line agregation
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=0)

        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["all agreg"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            'Label':LABELS[labels[0]]
        }

        # layer 4 to 10 agregation
        a_hat = attention_tensor[0, 3:10, :, :, :] # select only some layers
        a_hat = a_hat.sum(dim=1)/12 # mean head agregation
        a_hat = a_hat.sum(dim=0) # mean over the layers
        a_hat = a_hat.sum(dim=0) # line agregation
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=0)
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["layer 4 to 10"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            'Label':LABELS[labels[0]]
        }
        # layer 5 to 10 agregation
        a_hat = attention_tensor[0, 4:10, :, :, :] # select only some layers
        a_hat = a_hat.sum(dim=1)/12 # mean head agregation
        a_hat = a_hat.sum(dim=0) # mean over the layers
        a_hat = a_hat.sum(dim=0) # line agregation
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=-1)
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["layer 5 to 10"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            'Label':LABELS[labels[0]]
        }

        # layer 0 to 10 agregation
        a_hat = attention_tensor[0, 0:10, :, :, :] # select only some layers
        a_hat = a_hat.sum(dim=1)/12 # mean head agregation
        a_hat = a_hat.sum(dim=0) # mean over the layers
        a_hat = a_hat.sum(dim=0)
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=-1)
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["layer 1 to 10"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            'Label':LABELS[labels[0]]
        }


        display(HTML(html_render(model_outputs)))

Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.0,N
all agreg,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.004,N
layer 4 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],1.917,N
layer 5 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],1.872,N
layer 1 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],1.747,N


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.0,E
all agreg,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.66,E
layer 4 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],2.347,E
layer 5 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],2.634,E
layer 1 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],2.15,E


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,C
all agreg,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,C
layer 4 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,C
layer 5 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,C
layer 1 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,C


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.0,N
all agreg,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.556,N
layer 4 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",2.021,N
layer 5 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",2.019,N
layer 1 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",1.975,N


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.0,E
all agreg,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.639,E
layer 4 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",2.387,E
layer 5 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",2.289,E
layer 1 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",2.262,E


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.0,C
all agreg,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.383,C
layer 4 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",2.087,C
layer 5 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",2.127,C
layer 1 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",1.892,C


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.0,E
all agreg,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.683,E
layer 4 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],2.532,E
layer 5 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],2.662,E
layer 1 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],2.304,E


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.0,N
all agreg,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.552,N
layer 4 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],1.74,N
layer 5 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],1.688,N
layer 1 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],1.841,N


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.0,C
all agreg,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.062,C
layer 4 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],2.052,C
layer 5 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],2.013,C
layer 1 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],2.014,C


Maybe a regularization criterion is possible to find with this last discovery we clearly see that some information is possible to find here but we need to regularize the entropia on this special

## Now a bit a metrics to have a closer look to what happens

before running the following cells we need to execute `entropy_study_layers_study.py` in the folder `inference_scripts`
the script was executed with the following command line :
```{command line}
python .\inference_scripts\entropy_study_layers_study.py --batch_size 4
```

In [50]:
def html_render(model_outputs):
    html = ''

    table_len = len(model_outputs['all_layers']['AUC'])
    for i in range(table_len):
        html += '<table>'
        html += '<tr><th></th>' # One xtra head for model's name
        for column_name in model_outputs['all_layers'].keys():
            html+= '<th>'+ column_name +'</th>'
        html += ' </tr>'
        for name, model_content in model_outputs.items():
            html += '<tr>'
            html += '<td><b>' + name + '</b></td>'

            for k, output in model_content.items():
                displ = output[i] if output is not None else 'N/A'
                if isinstance(displ, float):
                    displ = str(round(displ, 3))
                html += '<td>' + displ + '</td>'

            html += '</tr>'

        html += '</table>'
    return html

In [51]:
# load the data
import pickle
dir = os.path.join(".cache", "plots", "entropy_study")
with open(os.path.join(dir, "a_true.pickle"), "rb") as f:
    a_true = pickle.load(f)
with open(os.path.join(dir, "all_layers.pickle"), "rb") as f:
    all_layers = pickle.load(f)
with open(os.path.join(dir, "layers_1_10.pickle"), "rb") as f:
    layers_1_10 = pickle.load(f)
with open(os.path.join(dir, "layers_4_10.pickle"), "rb") as f:
    layers_4_10 = pickle.load(f)
with open(os.path.join(dir, "layers_5_10.pickle"), "rb") as f:
    layers_5_10 = pickle.load(f)


In [52]:
def scalar_jaccard(y_true, y_pred):
    num = np.dot(y_true, y_pred)
    den = sum(y_true)+sum(y_pred)-np.dot(y_true, y_pred)
    return num/den

In [53]:
from sklearn.metrics import precision_score, recall_score
from sklearn.metrics import auc

def au_precision_curve(y_true, y_pred):
    tr = np.linspace(0, 1, 50)
    curve = []
    for t in tr:
        preds = 1*(y_pred>=t)
        curve.append(precision_score(y_true, preds))
        
    return auc(x=tr, y=curve)


def au_recall_curve(y_true, y_pred):
    tr = np.linspace(0, 1, 50)
    curve = []
    for t in tr:
        preds = 1*(y_pred>=t)
        curve.append(recall_score(y_true, preds))
        
    return auc(x=tr, y=curve)



def precision(y_true, y_pred):
    tr = 0.2
    preds = 1*(y_pred>=tr)
    return precision_score(y_true, preds)

def recall(y_true, y_pred):
    tr = 0.2
    preds = 1*(y_pred>=tr)
    return recall_score(y_true, preds)
        


In [54]:
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score

with torch.no_grad():
    for k in ["entailement", "neutral", "contradiction"]:
        display(HTML(f'<h4>metric for the label : {k}</h4>'))
        metric_output = {}
        metric_output["all_layers"] = {
            "AUC": [roc_auc_score(a_true[k],all_layers[k])],
            "Jaccard": [scalar_jaccard(a_true[k], all_layers[k])],
            "AUPRC" : [average_precision_score(a_true[k], all_layers[k])],
            "AU - Precision" : [au_precision_curve(a_true[k], all_layers[k])],
            "AU - Recall" : [au_recall_curve(a_true[k], all_layers[k])],
            "Precision (fixed tr)": [precision(a_true[k], all_layers[k])],
            "Recall (fixed tr)": [recall(a_true[k], all_layers[k])]
            
        }

        metric_output["layers_1_10"] = {
            "AUC": [roc_auc_score(a_true[k],layers_1_10[k])],
            "Jaccard": [scalar_jaccard(a_true[k], layers_1_10[k])],
            "AUPRC" : [average_precision_score(a_true[k], layers_1_10[k])],
            "AU - Precision" : [au_precision_curve(a_true[k], layers_1_10[k])],
            "AU - Recall" : [au_recall_curve(a_true[k], layers_1_10[k])],
            "Precision (fixed tr)": [precision(a_true[k], layers_1_10[k])],
            "Recall (fixed tr)": [recall(a_true[k], layers_1_10[k])]
        }

        metric_output["layers_4_10"] = {
            "AUC": [roc_auc_score(a_true[k],layers_4_10[k])],
            "Jaccard": [scalar_jaccard(a_true[k], layers_4_10[k])],
            "AUPRC" : [average_precision_score(a_true[k], layers_4_10[k])],
            "AU - Precision" : [au_precision_curve(a_true[k], layers_4_10[k])],
            "AU - Recall" : [au_recall_curve(a_true[k], layers_4_10[k])],
            "Precision (fixed tr)": [precision(a_true[k], layers_4_10[k])],
            "Recall (fixed tr)": [recall(a_true[k], layers_4_10[k])]
        }

        metric_output["layers_5_10"] = {
            "AUC": [roc_auc_score(a_true[k],layers_5_10[k])],
            "Jaccard": [scalar_jaccard(a_true[k], layers_5_10[k])],
            "AUPRC" : [average_precision_score(a_true[k], layers_5_10[k])],
            "AU - Precision" : [au_precision_curve(a_true[k], layers_5_10[k])],
            "AU - Recall" : [au_recall_curve(a_true[k], layers_5_10[k])],
            "Precision (fixed tr)": [precision(a_true[k], layers_5_10[k])],
            "Recall (fixed tr)": [recall(a_true[k], layers_5_10[k])],
        }

        display(HTML(html_render(metric_output)))

Unnamed: 0,AUC,Jaccard,AUPRC,AU - Precision,AU - Recall,Precision (fixed tr),Recall (fixed tr)
all_layers,0.623,0.005,0.286,0.034,0.015,0.026,0.006
layers_1_10,0.651,0.056,0.341,0.372,0.065,0.457,0.084
layers_4_10,0.661,0.059,0.361,0.373,0.067,0.494,0.087
layers_5_10,0.662,0.059,0.363,0.371,0.067,0.5,0.086


Unnamed: 0,AUC,Jaccard,AUPRC,AU - Precision,AU - Recall,Precision (fixed tr),Recall (fixed tr)
all_layers,0.692,0.014,0.148,0.042,0.029,0.041,0.025
layers_1_10,0.731,0.118,0.258,0.414,0.154,0.441,0.234
layers_4_10,0.73,0.115,0.26,0.407,0.149,0.437,0.225
layers_5_10,0.732,0.114,0.262,0.41,0.148,0.442,0.222


Unnamed: 0,AUC,Jaccard,AUPRC,AU - Precision,AU - Recall,Precision (fixed tr),Recall (fixed tr)
all_layers,0.698,0.073,0.318,0.389,0.09,0.409,0.104
layers_1_10,0.735,0.116,0.434,0.631,0.129,0.69,0.189
layers_4_10,0.741,0.116,0.447,0.631,0.128,0.704,0.189
layers_5_10,0.743,0.116,0.452,0.635,0.128,0.713,0.189
