# Visualisation des cartes d'entropie

Nous allons ici proposer des méthodes de visualisation des cartes d'attention sur les différentes phrases.
Attention : la pluspart des outils qui seront utiles pour la visualisation ici seront présents dans le fichier `utils`

# Functions for the visualization

In [4]:
%load_ext autoreload
%autoreload 2

from IPython.display import display, HTML
import os
from os import path
import sys
import torch
cwd = os.getcwd().split(os.path.sep)

# point to the git repository
while cwd[-1] != "ExplanationPairSentencesTasks":
    os.chdir("..")
    cwd = os.getcwd().split(os.path.sep)
print(f">> current directory : {os.getcwd()}")

# add the root directory
sys.path.append(os.path.join(os.getcwd(), "src"))

# cache and data cache
cache_path = path.join(os.getcwd() ,'.cache')
dataset_path = path.join(cache_path, 'dataset')
log_path = path.join(cache_path, 'logs')
model_path = path.join(cache_path, 'models')
print(f">> cache path : {cache_path}")
print(f">> model path : {model_path}")
print(f">> dataset path : {dataset_path}")
print(f">> logs path : {log_path}")

from src.data_module.hatexplain import HateXPlainDM
from pur_attention import AttitModel
from modules import metrics
from notebooks.attention_based.utils.ckp_config import *


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
>> current directory : C:\Users\loicf\Documents\IRISA\ExplanationPairSentencesTasks
>> cache path : C:\Users\loicf\Documents\IRISA\ExplanationPairSentencesTasks\.cache
>> model path : C:\Users\loicf\Documents\IRISA\ExplanationPairSentencesTasks\.cache\models
>> dataset path : C:\Users\loicf\Documents\IRISA\ExplanationPairSentencesTasks\.cache\dataset
>> logs path : C:\Users\loicf\Documents\IRISA\ExplanationPairSentencesTasks\.cache\logs


In [None]:
def html_render(model_outputs):
    html = ''
    table_len = len(model_outputs['GROUNDTRUTH']['Entropy'])
    for i in range(table_len):
        html += '<table>'
        html += '<tr><th></th>' # One xtra head for model's name
        for column_name in model_outputs['GROUNDTRUTH'].keys():
            html+= '<th>'+ column_name +'</th>'
        html += ' </tr>'
        for name, model_content in model_outputs.items():
            html += '<tr>'
            html += '<td><b>' + name + '</b></td>'

            for k, output in model_content.items():
                displ = output[i] if output is not None else 'N/A'
                if isinstance(displ, float):
                    displ = str(round(displ, 3))
                html += '<td>' + displ + '</td>'

            html += '</tr>'

        html += '</table>'
    return html

In [None]:
def highlight_txt(tokens, attention, padding_filter=None):
    """
    Build an HTML of text along its weights.
    Args:
        tokens: list of tokens
        attention: list of attention weights
        padding_filter: padding filter to be hidden from visual
    """
    assert len(tokens) == len(attention), f'Length mismatch: f{len(tokens)} vs f{len(attention)}'

    MAX_ALPHA = 0.8 # transparency

    highlighted_text = ''
    # just for the visualization we normalize
    w_min, w_max = torch.min(attention), torch.max(attention)

    # In case of uniform: highlight all text
    if w_min == w_max:
        w_min = 0.

    w_norm = (attention - w_min)/(w_max - w_min)
    w_norm = [w / MAX_ALPHA for w in w_norm]

    if padding_filter is not None:
        id_non_pad = [i for i, tk in enumerate(tokens) if tk != padding_filter]
        w_norm = [w_norm[i] for i in id_non_pad]
        tokens = [tokens[i] for i in id_non_pad]

    highlighted_text = [f'<span style="background-color:rgba(135,206,250, {weight});">{text}</span>' for weight, text in zip(w_norm, tokens)]

    return ' '.join(highlighted_text)

# One head models

## HatexPlain

In [None]:
# load the data
# the hatexplain dataset
dm_kwargs = dict(cache_path=dataset_path,
                 batch_size=32,
                 num_workers=0,
                 n_data=999,
                 pur_attention=True)
dm = HateXPlainDM(**dm_kwargs)
dm.prepare_data()
dm.setup(stage="test")

test_dataloader = dm.test_dataloader() # load the test dataset
# load the models
hparams_path = path.join(log_path, "PurAttention", "n_layer_1_htx_adadel", 'hparams.yaml')
model_args = dict(
        cache_path=model_path,
        mode="exp",
        vocab=dm.vocab,
        lambda_entrop=0,
        lambda_supervise=0,
        lambda_lagrange=0,
        pretrained_vectors="glove.840B.300d",
        num_layers=1,
        num_heads=1,
        d_embedding=300,
        data="hatexplain",
        num_class=dm.num_class,
        opt="adadelta"
    )

models_dict = {
    f"n_layer={i+1}" : None for i in range(6)
}

for l in range(5) :
    # for each iteration update the model args
    model_args["num_layers"] = l+1
    ckp = gen_ckp(name="PurAttention",num_layers=l+1, num_heads=1, log_path=log_path, dataset="hatexplain")
    model = AttitModel.load_from_checkpoint(ckp, hparams_file=hparams_path, **model_args)
    model = model.eval()
    models_dict[f"n_layer={l+1}"] = model

## E-SNLI