In [1]:
# preparation of the environment
%load_ext autoreload
%autoreload 2

INF = 1e30

import os
from os import path


# set the repository to the git repository
cwd = os.getcwd().split(os.path.sep)
while cwd[-1] != "stage_4_gm":
    os.chdir("..")
    cwd = os.getcwd().split(os.path.sep)
import numpy as np
import torch
from regularize_training_bert import BertNliRegu
from regularize_training_bert import SNLIDataModule

from transformers import BertTokenizer
tk = BertTokenizer.from_pretrained('bert-base-uncased')

In [2]:
%%capture
d = os.path.join(".cache", "logs", "igrida_trained", "regu_study", "layer_4_10")
muls = os.listdir(d)
models_dict = {}
for mul in muls:
    ckp = path.join(d, mul, "checkpoints", "best.ckpt")
    model = BertNliRegu.load_from_checkpoint(ckp)
    models_dict[mul] = model.eval() 

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predic

In [3]:
data_dir = os.path.join(".cache", "raw_data", "e_snli")
dm = SNLIDataModule(cache=data_dir,
                   batch_size = 1,
                   num_workers = 1,
                   nb_data = -1)

dm.prepare_data()

dm.setup(stage="test")

test_dataset = dm.test_set
test_dataloader = dm.test_dataloader()

In [4]:
from attention_algorithms.plausibility_visu import hightlight_txt # function to highlight the text
from attention_algorithms.attention_metrics import normalize_attention
from IPython.display import display, HTML

def html_render(model_outputs):
    html = ''
    table_len = len(model_outputs['GROUNDTRUTH']['Entropy'])
    for i in range(table_len):
        html += '<table>'
        html += '<tr><th></th>' # One xtra head for model's name
        for column_name in model_outputs['GROUNDTRUTH'].keys():
            html+= '<th>'+ column_name +'</th>'
        html += ' </tr>'
        for name, model_content in model_outputs.items():
            html += '<tr>'
            html += '<td><b>' + name + '</b></td>'

            for k, output in model_content.items():
                displ = output[i] if output is not None else 'N/A'
                if isinstance(displ, float):
                    displ = str(round(displ, 3))
                html += '<td>' + displ + '</td>'

            html += '</tr>'

        html += '</table>'
    return html

In [5]:
INF = 1e30
LABELS = ["E", "N", "C"]
with torch.no_grad():

    display(HTML('<h4>Mean Head agregation</h4>'))
    model_outputs = {}

    for id_batch, elem in enumerate(test_dataloader) :

        if id_batch > 8:
            # only look at 5 sentences (batch of one here)
            break

        ids = elem["input_ids"]
        masks = elem["attention_masks"]
        labels = elem["labels"]
        a_true = elem["annotations"]
        
        
        # ids of the specials tokens
        special_tokens = [0, 101, 102]
        spe_tok_mask = torch.isin(ids, torch.tensor(special_tokens))[0].type(torch.uint8)
        
        # ids of the punctuation
        punct = list(range(999, 1037))
        punct_pos = 1 - torch.isin(ids, torch.tensor(punct)).type(torch.uint8)
        a_true = torch.mul(a_true, punct_pos) # we don't want the punctuation in our annotation
        a_true = list(np.array(a_true[0].numpy(), dtype=float))

        m = masks[0].sum() # nb tokens in the sentence
        tokens = tk.convert_ids_to_tokens(ids[0])[0:m]

        it = 0

        if it == 0:
            model_outputs["GROUNDTRUTH"] = {
                '[CLS] + P + [SEP] + H + [SEP]':  [hightlight_txt(tokens = tokens,
                                                                attention = a_true[0 : m])],
                'Entropy': [0.0],
                'Label':LABELS[labels[0]]
            }

            it += 1
            
        for mul in models_dict:

            output = models_dict[mul](input_ids = ids,
                                      attention_mask = masks)

            # process the attention_tensor
            attention_tensor = torch.stack(output["outputs"].attentions, dim=1) # shape [b, l, h, T, T]
            pad = torch.tensor([0])
            pad_mask = torch.logical_not(torch.isin(ids, pad)).type(torch.uint8).unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 12, 12, 150, 1)
            pad_mask = torch.transpose(pad_mask, dim0=3, dim1=4)
            attention_tensor = torch.mul(attention_tensor, pad_mask)

            # layer 4 to 10 agregation
            a_hat = attention_tensor[0, 3:10, :, :, :] # select only some layers
            a_hat = a_hat.sum(dim=1)/12 # mean head agregation
            a_hat = a_hat.sum(dim=0)
            a_hat = a_hat.sum(dim=0) # line agregation
            a_hat = a_hat[0:m]
            a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=0)
            a_visu = normalize_attention(attention=a_hat, tokens=tokens)

            model_outputs[mul] =  {
                '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                                attention = a_visu)],
                'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
                'Label':LABELS[labels[0]]
            }


        display(HTML(html_render(model_outputs)))

Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.0,N
modif_mul=0.001,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],1.534,N
mul=0,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],1.331,N
mul=0.0,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],1.331,N
mul=0.0005,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.076,N
mul=0.001,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.004,N
mul=0.0015,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.0,N
mul=0.002,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],3.332,N
mul=0.0025,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.0,N
mul=0.003,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.0,N


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.0,E
modif_mul=0.001,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],2.474,E
mul=0,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],2.745,E
mul=0.0,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],2.745,E
mul=0.0005,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],1.833,E
mul=0.001,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.001,E
mul=0.0015,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.049,E
mul=0.002,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],3.296,E
mul=0.0025,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.0,E
mul=0.003,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],1.313,E


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,C
modif_mul=0.001,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],2.668,C
mul=0,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],2.665,C
mul=0.0,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],2.665,C
mul=0.0005,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.624,C
mul=0.001,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,C
mul=0.0015,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.05,C
mul=0.002,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],3.332,C
mul=0.0025,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,C
mul=0.003,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,C


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.0,N
modif_mul=0.001,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",1.224,N
mul=0,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",2.079,N
mul=0.0,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",2.079,N
mul=0.0005,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.513,N
mul=0.001,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.047,N
mul=0.0015,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.0,N
mul=0.002,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",3.091,N
mul=0.0025,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.58,N
mul=0.003,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.0,N


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.0,E
modif_mul=0.001,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.892,E
mul=0,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",1.712,E
mul=0.0,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",1.712,E
mul=0.0005,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.147,E
mul=0.001,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.014,E
mul=0.0015,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.0,E
mul=0.002,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",3.135,E
mul=0.0025,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.002,E
mul=0.003,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.0,E


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.0,C
modif_mul=0.001,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",1.094,C
mul=0,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",1.931,C
mul=0.0,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",1.931,C
mul=0.0005,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.038,C
mul=0.001,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.164,C
mul=0.0015,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.0,C
mul=0.002,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",3.135,C
mul=0.0025,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.087,C
mul=0.003,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.0,C


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.0,E
modif_mul=0.001,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],2.287,E
mul=0,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],2.653,E
mul=0.0,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],2.653,E
mul=0.0005,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],2.152,E
mul=0.001,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.01,E
mul=0.0015,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.087,E
mul=0.002,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],3.091,E
mul=0.0025,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.0,E
mul=0.003,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.002,E


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.0,N
modif_mul=0.001,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],2.322,N
mul=0,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],1.657,N
mul=0.0,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],1.657,N
mul=0.0005,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.057,N
mul=0.001,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.521,N
mul=0.0015,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.0,N
mul=0.002,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],3.178,N
mul=0.0025,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.001,N
mul=0.003,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.056,N


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.0,C
modif_mul=0.001,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],2.017,C
mul=0,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],1.898,C
mul=0.0,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],1.898,C
mul=0.0005,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.177,C
mul=0.001,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.476,C
mul=0.0015,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.0,C
mul=0.002,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],2.996,C
mul=0.0025,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.0,C
mul=0.003,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.0,C
