In [1]:
# preparation of the environment
%load_ext autoreload
%autoreload 2

INF = 1e30

import os
from os import path


# set the repository to the git repository
cwd = os.getcwd().split(os.path.sep)
while cwd[-1] != "BertPlausibilityStudy":
    os.chdir("..")
    cwd = os.getcwd().split(os.path.sep)
import numpy as np
import torch
from lagrange_reg_training_bert_snli import BertNliLagrange
from DataModules.SnliDM import ESNLIDataModule

from transformers import BertTokenizer
tk = BertTokenizer.from_pretrained('bert-base-uncased')

In [2]:
%%capture
d = os.path.join(".cache", "logs", "snli_igrida_trained", "regu_study", "lagrange_reg")
muls = [f"mul={x}" for x in [0.0, 0.01, 0.014, 0.018, 0.02, 0.03, 0.04, 0.05]]
models_dict = {}
for mul in muls:
    ckp = path.join(d, mul, "checkpoints", "best.ckpt")
    model = BertNliLagrange.load_from_checkpoint(ckp)
    models_dict[mul] = model.eval() 

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.pr

In [3]:
data_dir = os.path.join(os.getcwd(), ".cache", "datasets", "EsnliDataSet")
dm = ESNLIDataModule(
    cache=data_dir,
    batch_size = 1,
    num_workers = 1,
    nb_data = -1
)

dm.prepare_data()

dm.setup(stage="test")

test_dataset = dm.test_set
test_dataloader = dm.test_dataloader()

In [4]:
from notebooks.plausibility_visu import hightlight_txt # function to highlight the text
from notebooks.plots_utils import normalize_attention
from IPython.display import display, HTML

def html_render(model_outputs):
    html = ''
    table_len = len(model_outputs['GROUNDTRUTH']['Entropy'])
    for i in range(table_len):
        html += '<table>'
        html += '<tr><th></th>' # One xtra head for model's name
        for column_name in model_outputs['GROUNDTRUTH'].keys():
            html+= '<th>'+ column_name +'</th>'
        html += ' </tr>'
        for name, model_content in model_outputs.items():
            html += '<tr>'
            html += '<td><b>' + name + '</b></td>'

            for k, output in model_content.items():
                displ = output[i] if output is not None else 'N/A'
                if isinstance(displ, float):
                    displ = str(round(displ, 3))
                html += '<td>' + displ + '</td>'

            html += '</tr>'

        html += '</table>'
    return html

In [9]:
INF = 1e30
LABELS = ["E", "N", "C"]
with torch.no_grad():

    display(HTML('<h4>Lagrange Regularization</h4>'))
    model_outputs = {}

    for id_batch, elem in enumerate(test_dataloader) :

        if id_batch > 8:
            # only look at 5 sentences (batch of one here)
            break

        ids = elem["input_ids"]
        masks = elem["attention_masks"]
        labels = elem["labels"]
        a_true = elem["annotations"]
        
        
        # ids of the specials tokens
        special_tokens = [0, 101, 102]
        spe_tok_mask = torch.isin(ids, torch.tensor(special_tokens))[0].type(torch.uint8)
        
        # ids of the punctuation
        punct = list(range(999, 1037))
        punct_pos = 1 - torch.isin(ids, torch.tensor(punct)).type(torch.uint8)
        a_true = torch.mul(a_true, punct_pos) # we don't want the punctuation in our annotation
        a_true = list(np.array(a_true[0].numpy(), dtype=float))

        m = masks[0].sum() # nb tokens in the sentence
        tokens = tk.convert_ids_to_tokens(ids[0])[0:m]

        a_s = sum(a_true[0:m])
        it = 0

        if it == 0:
            model_outputs["GROUNDTRUTH"] = {
                '[CLS] + P + [SEP] + H + [SEP]':  [hightlight_txt(tokens = tokens,
                                                                attention = a_true[0 : m])],
                'Entropy': [(torch.log(torch.tensor([a_s]))/torch.log(m-3)).item()],
                'Label':LABELS[labels[0]]
            }

            it += 1
            
        for mul in models_dict:

            output = models_dict[mul](input_ids = ids,
                                      attention_mask = masks)

            # process the attention_tensor
            attention_tensor = torch.stack(output["outputs"].attentions, dim=1) # shape [b, l, h, T, T]
            pad = torch.tensor([0])
            pad_mask = torch.logical_not(torch.isin(ids, pad)).type(torch.uint8).unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 12, 12, 150, 1)
            pad_mask = torch.transpose(pad_mask, dim0=3, dim1=4)
            attention_tensor = torch.mul(attention_tensor, pad_mask)

            # layer 4 to 10 agregation
            a_hat = attention_tensor[0, 3:10, :, :, :] # select only some layers
            a_hat = a_hat.sum(dim=1)/12 # mean head agregation
            a_hat = a_hat.sum(dim=0)
            a_hat = a_hat.sum(dim=0) # line agregation
            a_hat = a_hat[0:m]
            a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=0)
            a_visu = normalize_attention(attention=a_hat, tokens=tokens)
            
            k = f"&lambda; = {mul.split('=')[-1]}"

            model_outputs[k] =  {
                '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                                attention = a_visu)],
                'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()/(torch.log(m).item())],
                'Label':LABELS[labels[0]]
            }


        display(HTML(html_render(model_outputs)))

Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.416,N
λ = 0.0,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.183,N
λ = 0.01,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.134,N
λ = 0.014,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.014,N
λ = 0.018,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.472,N
λ = 0.02,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.457,N
λ = 0.03,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.174,N
λ = 0.04,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.38,N
λ = 0.05,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.288,N


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.631,E
λ = 0.0,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.778,E
λ = 0.01,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.72,E
λ = 0.014,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.593,E
λ = 0.018,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.529,E
λ = 0.02,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.43,E
λ = 0.03,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.404,E
λ = 0.04,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.406,E
λ = 0.05,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.427,E


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.538,C
λ = 0.0,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.681,C
λ = 0.01,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.6,C
λ = 0.014,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.697,C
λ = 0.018,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.601,C
λ = 0.02,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.234,C
λ = 0.03,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.35,C
λ = 0.04,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.407,C
λ = 0.05,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.434,C


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.0,N
λ = 0.0,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.615,N
λ = 0.01,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.005,N
λ = 0.014,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.086,N
λ = 0.018,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.03,N
λ = 0.02,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.091,N
λ = 0.03,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.032,N
λ = 0.04,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.032,N
λ = 0.05,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.025,N


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.221,E
λ = 0.0,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.251,E
λ = 0.01,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.13,E
λ = 0.014,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.035,E
λ = 0.018,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.213,E
λ = 0.02,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.273,E
λ = 0.03,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.173,E
λ = 0.04,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.181,E
λ = 0.05,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.279,E


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.513,C
λ = 0.0,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.452,C
λ = 0.01,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.099,C
λ = 0.014,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.152,C
λ = 0.018,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.228,C
λ = 0.02,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.279,C
λ = 0.03,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.298,C
λ = 0.04,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.239,C
λ = 0.05,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.224,C


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.224,E
λ = 0.0,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.738,E
λ = 0.01,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.501,E
λ = 0.014,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.632,E
λ = 0.018,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.559,E
λ = 0.02,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.388,E
λ = 0.03,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.392,E
λ = 0.04,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.434,E
λ = 0.05,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.419,E


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.0,N
λ = 0.0,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.371,N
λ = 0.01,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.007,N
λ = 0.014,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.006,N
λ = 0.018,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.03,N
λ = 0.02,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.069,N
λ = 0.03,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.04,N
λ = 0.04,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.012,N
λ = 0.05,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.068,N


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.694,C
λ = 0.0,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.631,C
λ = 0.01,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.406,C
λ = 0.014,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.303,C
λ = 0.018,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.45,C
λ = 0.02,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.25,C
λ = 0.03,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.343,C
λ = 0.04,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.29,C
λ = 0.05,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.393,C
