In [1]:
# preparation of the environment
%load_ext autoreload
%autoreload 2

INF = 1e30

import os
from os import path


# set the repository to the git repository
cwd = os.getcwd().split(os.path.sep)
while cwd[-1] != "stage_4_gm":
    os.chdir("..")
    cwd = os.getcwd().split(os.path.sep)
import numpy as np
import torch
from training_bert import BertNliLight
from regularize_training_bert import SNLIDataModule

from transformers import BertTokenizer
tk = BertTokenizer.from_pretrained('bert-base-uncased')

In [2]:
ckp = path.join(".cache", "logs", "igrida_trained", "0", "best.ckpt")
model = BertNliLight.load_from_checkpoint(ckp)
model = model.eval()

data_dir = os.path.join(".cache", "raw_data", "e_snli")

dm = SNLIDataModule(cache=data_dir,
                   batch_size = 1,
                   num_workers = 1,
                   nb_data = -1)

dm.prepare_data()

dm.setup(stage="test")

test_dataset = dm.test_set
test_dataloader = dm.test_dataloader()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
from attention_algorithms.plausibility_visu import hightlight_txt # function to highlight the text
from attention_algorithms.attention_metrics import normalize_attention
from IPython.display import display, HTML

def html_render(model_outputs):
    html = ''
    table_len = len(model_outputs['GROUNDTRUTH']['Entropy'])
    for i in range(table_len):
        html += '<table>'
        html += '<tr><th></th>' # One xtra head for model's name
        for column_name in model_outputs['GROUNDTRUTH'].keys():
            html+= '<th>'+ column_name +'</th>'
        html += ' </tr>'
        for name, model_content in model_outputs.items():
            html += '<tr>'
            html += '<td><b>' + name + '</b></td>'

            for k, output in model_content.items():
                displ = output[i] if output is not None else 'N/A'
                if isinstance(displ, float):
                    displ = str(round(displ, 3))
                html += '<td>' + displ + '</td>'

            html += '</tr>'

        html += '</table>'
    return html

In [4]:
INF = 1e30
LABELS = ["E", "N", "C"]
with torch.no_grad():

    display(HTML('<h4>Mean Head agregation</h4>'))
    model_outputs = {}

    for id_batch, elem in enumerate(test_dataloader) :

        if id_batch > 8:
            # only look at 5 sentences (batch of one here)
            break

        ids = elem["input_ids"]
        masks = elem["attention_masks"]
        labels = elem["labels"]
        a_true = elem["annotations"]
        
        
        # ids of the specials tokens
        special_tokens = [0, 101, 102]
        spe_tok_mask = torch.isin(ids, torch.tensor(special_tokens))[0].type(torch.uint8)
        
        # ids of the punctuation
        punct = list(range(999, 1037))
        punct_pos = 1 - torch.isin(ids, torch.tensor(punct)).type(torch.uint8)
        a_true = torch.mul(a_true, punct_pos) # we don't want the punctuation in our annotation
        a_true = list(np.array(a_true[0].numpy(), dtype=float))

        m = masks[0].sum() # nb tokens in the sentence
        tokens = tk.convert_ids_to_tokens(ids[0])[0:m]

        it = 0

        if it == 0:
            model_outputs["GROUNDTRUTH"] = {
                '[CLS] + P + [SEP] + H + [SEP]':  [hightlight_txt(tokens = tokens,
                                                                attention = a_true[0 : m])],
                'Entropy': [0.0],
                'Label':LABELS[labels[0]]
            }

            it += 1

        output = model(input_ids = ids,
                       attention_mask = masks)

        # process the attention_tensor
        attention_tensor = torch.stack(output["outputs"].attentions, dim=1) # shape [b, l, h, T, T]
        pad = torch.tensor([0])
        pad_mask = torch.logical_not(torch.isin(ids, pad)).type(torch.uint8).unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 12, 12, 150, 1)
        pad_mask = torch.transpose(pad_mask, dim0=3, dim1=4)
        attention_tensor = torch.mul(attention_tensor, pad_mask)

        # all the layer agregation
        a_hat = attention_tensor[0, :, :, :, :]
        a_hat = a_hat.sum(dim=1)/12 # mean over the heads
        a_hat = a_hat.sum(dim=0) # layer agregation
        a_hat = a_hat.sum(dim=0) # line agregation
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=0)

        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["all agreg"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            'Label':LABELS[labels[0]]
        }

        # layer 4 to 10 agregation
        a_hat = attention_tensor[0, 3:10, :, :, :] # select only some layers
        a_hat = a_hat.sum(dim=1)/12 # mean head agregation
        a_hat = a_hat.sum(dim=0)
        a_hat = a_hat.sum(dim=0) # line agregation
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=0)
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["layer 4 to 10"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            'Label':LABELS[labels[0]]
        }
        # layer 5 to 10 agregation
        a_hat = attention_tensor[0, 4:10, :, :, :] # select only some layers
        a_hat = a_hat.sum(dim=1)/12 # mean head agregation
        a_hat = a_hat.sum(dim=0)
        a_hat = a_hat.sum(dim=0) # line agregation
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=-1)
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["layer 5 to 10"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            'Label':LABELS[labels[0]]
        }

        # layer 0 to 10 agregation
        a_hat = attention_tensor[0, 0:10, :, :, :] # select only some layers
        a_hat = a_hat.sum(dim=1)/12 # mean head agregation
        a_hat = a_hat.sum(dim=0)
        a_hat = a_hat.sum(dim=0)
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=-1)
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["layer 1 to 10"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            'Label':LABELS[labels[0]]
        }
        
        # layer 11 to 12 agregation
        a_hat = attention_tensor[0, 10:12, :, :, :] # select only some layers
        a_hat = a_hat.sum(dim=1)/12 # mean head agregation
        a_hat = a_hat.sum(dim=0)
        a_hat = a_hat.sum(dim=0)
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=-1)
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["layer 11, 12"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            'Label':LABELS[labels[0]]
        }


        display(HTML(html_render(model_outputs)))

Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.0,N
all agreg,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.004,N
layer 4 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],1.917,N
layer 5 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],1.872,N
layer 1 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],1.747,N
"layer 11, 12",[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.001,N


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.0,E
all agreg,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.66,E
layer 4 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],2.347,E
layer 5 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],2.634,E
layer 1 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],2.15,E
"layer 11, 12",[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.64,E


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,C
all agreg,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,C
layer 4 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,C
layer 5 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,C
layer 1 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,C
"layer 11, 12",[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],1.076,C


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.0,N
all agreg,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.556,N
layer 4 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",2.021,N
layer 5 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",2.019,N
layer 1 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",1.975,N
"layer 11, 12","[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.361,N


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.0,E
all agreg,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.639,E
layer 4 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",2.387,E
layer 5 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",2.289,E
layer 1 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",2.262,E
"layer 11, 12","[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.409,E


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.0,C
all agreg,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.383,C
layer 4 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",2.087,C
layer 5 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",2.127,C
layer 1 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",1.892,C
"layer 11, 12","[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.2,C


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.0,E
all agreg,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.683,E
layer 4 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],2.532,E
layer 5 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],2.662,E
layer 1 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],2.304,E
"layer 11, 12",[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.654,E


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.0,N
all agreg,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.552,N
layer 4 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],1.74,N
layer 5 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],1.688,N
layer 1 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],1.841,N
"layer 11, 12",[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.323,N


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.0,C
all agreg,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.062,C
layer 4 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],2.052,C
layer 5 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],2.013,C
layer 1 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],2.014,C
"layer 11, 12",[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.035,C


In [7]:
INF = 1e30
LABELS = ["E", "N", "C"]
with torch.no_grad():

    display(HTML('<h4>Sum Head agregation</h4>'))
    model_outputs = {}

    for id_batch, elem in enumerate(test_dataloader) :

        if id_batch > 8:
            # only look at 5 sentences (batch of one here)
            break

        ids = elem["input_ids"]
        masks = elem["attention_masks"]
        labels = elem["labels"]
        a_true = elem["annotations"]
        
        
        # ids of the specials tokens
        special_tokens = [0, 101, 102]
        spe_tok_mask = torch.isin(ids, torch.tensor(special_tokens))[0].type(torch.uint8)
        
        # ids of the punctuation
        punct = list(range(999, 1037))
        punct_pos = 1 - torch.isin(ids, torch.tensor(punct)).type(torch.uint8)
        a_true = torch.mul(a_true, punct_pos) # we don't want the punctuation in our annotation
        a_true = list(np.array(a_true[0].numpy(), dtype=float))

        m = masks[0].sum() # nb tokens in the sentence
        tokens = tk.convert_ids_to_tokens(ids[0])[0:m]

        it = 0

        if it == 0:
            model_outputs["GROUNDTRUTH"] = {
                '[CLS] + P + [SEP] + H + [SEP]':  [hightlight_txt(tokens = tokens,
                                                                attention = a_true[0 : m])],
                'Entropy': [0.0],
                'Label':LABELS[labels[0]]
            }

            it += 1

        output = model(input_ids = ids,
                       attention_mask = masks)

        # process the attention_tensor
        attention_tensor = torch.stack(output["outputs"].attentions, dim=1) # shape [b, l, h, T, T]
        pad = torch.tensor([0])
        pad_mask = torch.logical_not(torch.isin(ids, pad)).type(torch.uint8).unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 12, 12, 150, 1)
        pad_mask = torch.transpose(pad_mask, dim0=3, dim1=4)
        attention_tensor = torch.mul(attention_tensor, pad_mask)

        # all the layer agregation
        a_hat = attention_tensor[0, :, :, :, :]
        a_hat = a_hat.sum(dim=1) # sum over the heads
        a_hat = a_hat.sum(dim=0) # layer agregation
        a_hat = a_hat.sum(dim=0) # line agregation
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=0)

        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["all agreg"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            'Label':LABELS[labels[0]]
        }

        # layer 4 to 10 agregation
        a_hat = attention_tensor[0, 3:10, :, :, :] # select only some layers
        a_hat = a_hat.sum(dim=1) # sum head agregation
        a_hat = a_hat.sum(dim=0) 
        a_hat = a_hat.sum(dim=0) # line agregation
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=0)
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["layer 4 to 10"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            'Label':LABELS[labels[0]]
        }
        # layer 5 to 10 agregation
        a_hat = attention_tensor[0, 4:10, :, :, :] # select only some layers
        a_hat = a_hat.sum(dim=1) # sum head agregation
        a_hat = a_hat.sum(dim=0)
        a_hat = a_hat.sum(dim=0) # line agregation
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=-1)
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["layer 5 to 10"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            'Label':LABELS[labels[0]]
        }

        # layer 0 to 10 agregation
        a_hat = attention_tensor[0, 0:10, :, :, :] # select only some layers
        a_hat = a_hat.sum(dim=1) # sum head agregation
        a_hat = a_hat.sum(dim=0) 
        a_hat = a_hat.sum(dim=0)
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=-1)
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["layer 1 to 10"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            'Label':LABELS[labels[0]]
        }


        display(HTML(html_render(model_outputs)))

Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.0,N
all agreg,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.0,N
layer 4 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.0,N
layer 5 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.0,N
layer 1 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.0,N


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.0,E
all agreg,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.012,E
layer 4 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.017,E
layer 5 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.002,E
layer 1 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.0,E


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,C
all agreg,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,C
layer 4 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,C
layer 5 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,C
layer 1 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,C


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.0,N
all agreg,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.0,N
layer 4 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.006,N
layer 5 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.0,N
layer 1 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.074,N


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.0,E
all agreg,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.002,E
layer 4 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.001,E
layer 5 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.0,E
layer 1 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.128,E


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.0,C
all agreg,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.0,C
layer 4 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.192,C
layer 5 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.001,C
layer 1 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.014,C


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.0,E
all agreg,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.114,E
layer 4 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.05,E
layer 5 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.118,E
layer 1 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.678,E


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.0,N
all agreg,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.0,N
layer 4 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.0,N
layer 5 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.0,N
layer 1 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.0,N


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.0,C
all agreg,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.0,C
layer 4 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.595,C
layer 5 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.534,C
layer 1 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.023,C


In [4]:
INF = 1e30
LABELS = ["E", "N", "C"]
with torch.no_grad():

    display(HTML('<h4>Mean Head agregation</h4>'))
    model_outputs = {}

    for id_batch, elem in enumerate(test_dataloader) :

        if id_batch > 8:
            # only look at 5 sentences (batch of one here)
            break

        ids = elem["input_ids"]
        masks = elem["attention_masks"]
        labels = elem["labels"]
        a_true = elem["annotations"]
        
        
        # ids of the specials tokens
        special_tokens = [0, 101, 102]
        spe_tok_mask = torch.isin(ids, torch.tensor(special_tokens))[0].type(torch.uint8)
        
        # ids of the punctuation
        punct = list(range(999, 1037))
        punct_pos = 1 - torch.isin(ids, torch.tensor(punct)).type(torch.uint8)
        a_true = torch.mul(a_true, punct_pos) # we don't want the punctuation in our annotation
        a_true = list(np.array(a_true[0].numpy(), dtype=float))

        m = masks[0].sum() # nb tokens in the sentence
        tokens = tk.convert_ids_to_tokens(ids[0])[0:m]

        it = 0

        if it == 0:
            model_outputs["GROUNDTRUTH"] = {
                '[CLS] + P + [SEP] + H + [SEP]':  [hightlight_txt(tokens = tokens,
                                                                attention = a_true[0 : m])],
                'Entropy': [0.0],
                'Label':LABELS[labels[0]]
            }

            it += 1

        output = model(input_ids = ids,
                       attention_mask = masks)

        # process the attention_tensor
        attention_tensor = torch.stack(output["outputs"].attentions, dim=1) # shape [b, l, h, T, T]
        pad = torch.tensor([0])
        pad_mask = torch.logical_not(torch.isin(ids, pad)).type(torch.uint8).unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 12, 12, 150, 1)
        pad_mask = torch.transpose(pad_mask, dim0=3, dim1=4)
        attention_tensor = torch.mul(attention_tensor, pad_mask)

        # all the layer agregation
        a_hat = attention_tensor[0, :, :, :, :]
        a_hat = a_hat.sum(dim=1)/12 # mean over the heads
        a_hat = a_hat.sum(dim=0)/12 # layer agregation
        a_hat = a_hat.sum(dim=0) # line agregation
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=0)

        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["all agreg"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            'Label':LABELS[labels[0]]
        }

        # layer 4 to 10 agregation
        a_hat = attention_tensor[0, 3:10, :, :, :] # select only some layers
        a_hat = a_hat.sum(dim=1)/12 # mean head agregation
        a_hat = a_hat.sum(dim=0)/7
        a_hat = a_hat.sum(dim=0) # line agregation
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=0)
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["layer 4 to 10"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            'Label':LABELS[labels[0]]
        }
        # layer 5 to 10 agregation
        a_hat = attention_tensor[0, 4:10, :, :, :] # select only some layers
        a_hat = a_hat.sum(dim=1)/12 # mean head agregation
        a_hat = a_hat.sum(dim=0)/6
        a_hat = a_hat.sum(dim=0) # line agregation
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=-1)
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["layer 5 to 10"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            'Label':LABELS[labels[0]]
        }

        # layer 0 to 10 agregation
        a_hat = attention_tensor[0, 0:10, :, :, :] # select only some layers
        a_hat = a_hat.sum(dim=1)/12 # mean head agregation
        a_hat = a_hat.sum(dim=0)/10
        a_hat = a_hat.sum(dim=0)
        a_hat = a_hat[0:m]
        a_hat = torch.softmax(a_hat - INF * spe_tok_mask[0:m], dim=-1)
        a_visu = normalize_attention(attention=a_hat, tokens=tokens)

        model_outputs["layer 1 to 10"] =  {
            '[CLS] + P + [SEP] + H + [SEP]': [hightlight_txt(tokens = tokens,
                                                            attention = a_visu)],
            'Entropy': [(-a_hat * torch.log(a_hat + 1e-16)).sum().item()],
            'Label':LABELS[labels[0]]
        }


        display(HTML(html_render(model_outputs)))

Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],0.0,N
all agreg,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],3.257,N
layer 4 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],3.306,N
layer 5 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],3.3,N
layer 1 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP],3.317,N


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],0.0,E
all agreg,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],3.217,E
layer 4 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],3.279,E
layer 5 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],3.279,E
layer 1 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP],3.285,E


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.0,C
all agreg,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],2.374,C
layer 4 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],1.402,C
layer 5 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],0.915,C
layer 1 to 10,[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP],2.544,C


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",0.0,N
all agreg,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",3.022,N
layer 4 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",3.074,N
layer 5 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",3.07,N
layer 1 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP]",3.079,N


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",0.0,E
all agreg,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",3.061,E
layer 4 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",3.123,E
layer 5 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",3.12,E
layer 1 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP]",3.126,E


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",0.0,C
all agreg,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",3.06,C
layer 4 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",3.118,C
layer 5 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",3.114,C
layer 1 to 10,"[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman has been shot . [SEP]",3.123,C


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],0.0,E
all agreg,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],3.013,E
layer 4 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],3.077,E
layer 5 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],3.077,E
layer 1 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad . [SEP],3.079,E


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],0.0,N
all agreg,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],3.115,N
layer 4 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],3.157,N
layer 5 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],3.153,N
layer 1 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man poses in front of an ad for beer . [SEP],3.163,N


Unnamed: 0,[CLS] + P + [SEP] + H + [SEP],Entropy,Label
GROUNDTRUTH,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],0.0,C
all agreg,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],2.946,C
layer 4 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],2.974,C
layer 5 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],2.969,C
layer 1 to 10,[CLS] an old man with a package poses in front of an advertisement . [SEP] a man walks by an ad . [SEP],2.981,C
