In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../")
sys.path.append("../experiments/")

In [3]:
import os
from pathlib import Path

import dill
import numpy as np
from typing import List, Tuple

In [4]:
from xbert import InputInstance
from segtok.tokenizer import web_tokenizer

In [5]:
RESULTS_DIR = "../results/"
DATASET_DIR = "../data/glue_data/"

STRATEGY_NAME_MAPPING = {
    "unk": "Unk",
    "delete": "Delete",
    "resampling": "OLM",
    "resampling_std": "OLM-S",
    "grad": "Grad.",
    "gradxinput": "Grad*Input",
    "saliency": "Sensitivity",
    "integratedgrad": "Integr. grad"
}

In [6]:
def experiment_load_relevances(experiment_dir: str,
                               relevance_filename: str = "relevances.pkl"):
    path = Path(experiment_dir)
    
    experiment_relevances = {}
    for relevance_file in path.glob(f"**/{relevance_filename}"):
        name = relevance_file.parents[0].name
        with relevance_file.open("rb") as f:
            relevances = dill.load(f)
            experiment_relevances[name] = relevances
            
    return experiment_relevances

In [7]:
SST2_RESULTS_PATH = os.path.join(RESULTS_DIR, "sst2")
SST2_DATA_PATH = os.path.join(DATASET_DIR, "SST-2/dev.tsv")

In [8]:
sst2_experiment_relevances = experiment_load_relevances(SST2_RESULTS_PATH)

In [9]:
def read_sst2_dataset(path: str) -> List[Tuple[List[str], str]]:
    dataset = []
    with open(path) as fin:
        fin.readline()
        for index, line in enumerate(fin):
            tokens = line.strip().split('\t')
            sent, target = tokens[0], tokens[1]
            dataset.append((sent, target))
            
    return dataset


def dataset_to_input_instances(dataset: List[Tuple[List[str], str]]) -> List[InputInstance]:
    input_instances = []
    for idx, (sent, _) in enumerate(dataset):
        instance = InputInstance(id_=idx, sent=web_tokenizer(sent))
        input_instances.append(instance)
        
    return input_instances

In [10]:
dataset = read_sst2_dataset(SST2_DATA_PATH)
input_instances = dataset_to_input_instances(dataset)

In [11]:
def relevance_to_colored_text(relevance_dict, input_instances, method, idx):
    output_text = ''
    
    sentence_relevance_dict = relevance_dict[method][idx]
    sentence_value_list = list(sentence_relevance_dict.values())
    max_value = np.abs(np.array(sentence_value_list)).max()
    normalized_sentence_value_list = sentence_value_list/max_value

    for word, score in zip(input_instances[idx].token_fields['sent']._tokens, normalized_sentence_value_list):
        red = 255 * min(1, 1+score)
        green = 255 * (1-abs(score))
        blue = 255 * min(1, 1-score)
        output_text += '\colorbox[RGB]{' + str(int(red)) + ',' + str(int(green)) + ',' + str(int(blue)) + '}{\strut ' + word + '} '

    return output_text, max_value

In [12]:
def colored_text_to_table(relevance_dict, input_instances, idx):
    table_string_start = ["\\begin{table*}[h]", "  \\centering", "  \\begin{tabular}{l|l|l}", "    method&relevances&maximum value \\\ \hline"]    
    table_string_end = ["  \\end{tabular}", "  \\caption{Example explanations for SST-2}", "  \\label{tab:example_explanations}", "\\end{table*}"]
    
    for method in relevance_dict.keys():
        text, max_value = relevance_to_colored_text(relevance_dict, input_instances, method, idx)
        table_string_start.append(f"    {method}&{text}&{'%.2g'%max_value}\\\\")
    
    return "\n".join(table_string_start+table_string_end)

In [13]:
print(colored_text_to_table(sst2_experiment_relevances, input_instances, 667))

\begin{table*}[h]
  \centering
  \begin{tabular}{l|l|l}
    method&relevances&maximum value \\ \hline
    delete&\colorbox[RGB]{255,0,0}{\strut good} \colorbox[RGB]{255,235,235}{\strut film} \colorbox[RGB]{255,253,253}{\strut ,} \colorbox[RGB]{250,250,255}{\strut but} \colorbox[RGB]{255,250,250}{\strut very} \colorbox[RGB]{250,250,255}{\strut glum} \colorbox[RGB]{251,251,255}{\strut .} &0.98\\
    grad&\colorbox[RGB]{255,160,160}{\strut good} \colorbox[RGB]{255,210,210}{\strut film} \colorbox[RGB]{187,187,255}{\strut ,} \colorbox[RGB]{166,166,255}{\strut but} \colorbox[RGB]{255,130,130}{\strut very} \colorbox[RGB]{255,0,0}{\strut glum} \colorbox[RGB]{232,232,255}{\strut .} &3.4e-07\\
    gradxinput&\colorbox[RGB]{255,31,31}{\strut good} \colorbox[RGB]{255,0,0}{\strut film} \colorbox[RGB]{255,5,5}{\strut ,} \colorbox[RGB]{214,214,255}{\strut but} \colorbox[RGB]{27,27,255}{\strut very} \colorbox[RGB]{101,101,255}{\strut glum} \colorbox[RGB]{255,236,236}{\strut .} &0.041\\
    integratedg