before executing this notebook, execute the following script to have all the data `.\shell_scripts\entropy_study.ps1`

In [1]:
# preparation of the environment
%load_ext autoreload
%autoreload 2

INF = 1e30

import os
from os import path


# set the repository to the git repository
cwd = os.getcwd().split(os.path.sep)
while cwd[-1] != "stage_4_gm":
    os.chdir("..")
    cwd = os.getcwd().split(os.path.sep)
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()
from tqdm.notebook import tqdm
from torch_set_up import DEVICE
from training_bert import BertNliLight
from regularize_training_bert import SNLIDataModule

### The metrics

In [2]:
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
from attention_algorithms.inference_metrics import *

### Special Render for the Metrics

In [3]:
from attention_algorithms.plausibility_visu import hightlight_txt # function to highlight the text
from attention_algorithms.attention_metrics import normalize_attention
from IPython.display import display, HTML

def html_render(model_outputs):
    html = ''

    table_len = len(model_outputs['all_layers']['AUC'])
    for i in range(table_len):
        html += '<table>'
        html += '<tr><th></th>' # One xtra head for model's name
        for column_name in model_outputs['all_layers'].keys():
            html+= '<th>'+ column_name +'</th>'
        html += ' </tr>'
        for name, model_content in model_outputs.items():
            html += '<tr>'
            html += '<td><b>' + name + '</b></td>'

            for k, output in model_content.items():
                displ = output[i] if output is not None else 'N/A'
                if isinstance(displ, float):
                    displ = str(round(displ, 3))
                html += '<td>' + displ + '</td>'

            html += '</tr>'

        html += '</table>'
    return html

### Create the figure to sum up all the metrics

In [4]:
%%capture
fig, axes = plt.subplots(8, 3, figsize = (30, 40))
plt.subplots_adjust(left=0.1,
                        bottom=0.1,
                        right=0.9,
                        top=0.9,
                        wspace=0.3,
                        hspace=0.8)

# set the limits for the axes
y_lims = [(0.5, 1), (0, 0.3), (0, 0.5) ,(0, 0.6), (0, 0.3), (0, 0.75), (0,0.3), (0,3)]
for i in range(axes.shape[0]):
    for j in range(axes.shape[1]) :
        axes[i,j].set_ylim(y_lims[i][0], y_lims[i][1])

### Mean Head agregation

In [5]:
# load the data
import pickle
dir = os.path.join(".cache", "plots", "entropy_study")
with open(os.path.join(dir, "a_true_head_mean.pickle"), "rb") as f:
    a_true = pickle.load(f)
with open(os.path.join(dir, "all_layers_head_mean.pickle"), "rb") as f:
    all_layers = pickle.load(f)
with open(os.path.join(dir, "layers_1_10_head_mean.pickle"), "rb") as f:
    layers_1_10 = pickle.load(f)
with open(os.path.join(dir, "layers_4_10_head_mean.pickle"), "rb") as f:
    layers_4_10 = pickle.load(f)
with open(os.path.join(dir, "layers_5_10_head_mean.pickle"), "rb") as f:
    layers_5_10 = pickle.load(f)

In [6]:
with torch.no_grad():
    temp = {}
    for k in ["entailement", "neutral", "contradiction"]:
        display(HTML(f'<h4>metric for the label : {k}</h4>'))
        metric_output = {}
        
        metric_output["all_layers"] = {
            "AUC": [roc_auc_score(a_true[k],all_layers[k])],
            "Jaccard": [scalar_jaccard(a_true[k], all_layers[k])],
            "AUPRC" : [average_precision_score(a_true[k], all_layers[k])],
            "AU - Precision" : [au_precision_curve(a_true[k], all_layers[k])],
            "AU - Recall" : [au_recall_curve(a_true[k], all_layers[k])],
            "Precision (fixed tr)": [precision(a_true[k], all_layers[k])],
            "Recall (fixed tr)": [recall(a_true[k], all_layers[k])],
            "Entropy" : all_layers["entropy"][k]
            
        }

        metric_output["layers_1_10"] = {
            "AUC": [roc_auc_score(a_true[k],layers_1_10[k])],
            "Jaccard": [scalar_jaccard(a_true[k], layers_1_10[k])],
            "AUPRC" : [average_precision_score(a_true[k], layers_1_10[k])],
            "AU - Precision" : [au_precision_curve(a_true[k], layers_1_10[k])],
            "AU - Recall" : [au_recall_curve(a_true[k], layers_1_10[k])],
            "Precision (fixed tr)": [precision(a_true[k], layers_1_10[k])],
            "Recall (fixed tr)": [recall(a_true[k], layers_1_10[k])],
            "Entropy" : layers_1_10["entropy"][k]
        }

        metric_output["layers_4_10"] = {
            "AUC": [roc_auc_score(a_true[k],layers_4_10[k])],
            "Jaccard": [scalar_jaccard(a_true[k], layers_4_10[k])],
            "AUPRC" : [average_precision_score(a_true[k], layers_4_10[k])],
            "AU - Precision" : [au_precision_curve(a_true[k], layers_4_10[k])],
            "AU - Recall" : [au_recall_curve(a_true[k], layers_4_10[k])],
            "Precision (fixed tr)": [precision(a_true[k], layers_4_10[k])],
            "Recall (fixed tr)": [recall(a_true[k], layers_4_10[k])],
            "Entropy" : layers_4_10["entropy"][k]
        }

        metric_output["layers_5_10"] = {
            "AUC": [roc_auc_score(a_true[k],layers_5_10[k])],
            "Jaccard": [scalar_jaccard(a_true[k], layers_5_10[k])],
            "AUPRC" : [average_precision_score(a_true[k], layers_5_10[k])],
            "AU - Precision" : [au_precision_curve(a_true[k], layers_5_10[k])],
            "AU - Recall" : [au_recall_curve(a_true[k], layers_5_10[k])],
            "Precision (fixed tr)": [precision(a_true[k], layers_5_10[k])],
            "Recall (fixed tr)": [recall(a_true[k], layers_5_10[k])],
            "Entropy" : layers_5_10["entropy"][k]
        }
        
        temp[k] = metric_output.copy()
        

        display(HTML(html_render(metric_output)))

Unnamed: 0,AUC,Jaccard,AUPRC,AU - Precision,AU - Recall,Precision (fixed tr),Recall (fixed tr),Entropy
all_layers,0.623,0.005,0.286,0.034,0.015,0.026,0.006,0.428
layers_1_10,0.651,0.056,0.341,0.372,0.065,0.457,0.084,1.603
layers_4_10,0.661,0.059,0.361,0.373,0.067,0.494,0.087,1.712
layers_5_10,0.662,0.059,0.363,0.371,0.067,0.5,0.086,1.765


Unnamed: 0,AUC,Jaccard,AUPRC,AU - Precision,AU - Recall,Precision (fixed tr),Recall (fixed tr),Entropy
all_layers,0.692,0.014,0.148,0.042,0.029,0.041,0.025,0.401
layers_1_10,0.731,0.118,0.258,0.414,0.154,0.441,0.234,1.37
layers_4_10,0.73,0.115,0.26,0.407,0.149,0.437,0.225,1.474
layers_5_10,0.732,0.114,0.262,0.41,0.148,0.442,0.222,1.529


Unnamed: 0,AUC,Jaccard,AUPRC,AU - Precision,AU - Recall,Precision (fixed tr),Recall (fixed tr),Entropy
all_layers,0.698,0.073,0.318,0.389,0.09,0.409,0.104,0.354
layers_1_10,0.735,0.116,0.434,0.631,0.129,0.69,0.189,1.165
layers_4_10,0.741,0.116,0.447,0.631,0.128,0.704,0.189,1.245
layers_5_10,0.743,0.116,0.452,0.635,0.128,0.713,0.189,1.274


In [7]:
%%capture
# complete the graphs
metrics = list(temp["entailement"]["all_layers"].keys())

# set the titles
cols = ["Head means", "Head Sum", "Mean EVW"]
rows = metrics.copy()

for ax, col in zip(axes[0], cols):
    ax.set_title(col)

for ax, row in zip(axes[:,0], rows):
    ax.set_ylabel(row, rotation=60, fontsize=15, labelpad=20)

for id_m, m in enumerate(metrics):
    ax = axes[id_m, 0]
    for label in ["entailement", "neutral", "contradiction"]:
        buff = []
        for agreg in temp[label]:
            buff.append(temp[label][agreg][m][0])
            
        x = [1,2,3,4]
        ax.scatter(x, buff, label=label)
        ax.set_xticks(x)
        ax.set_xticklabels(list(temp[label].keys()),fontsize = 15, rotation=80)
        ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), prop={"size": 10})
        

plt.legend()

### The sum agregation

In [8]:
# load the data
import pickle
dir = os.path.join(".cache", "plots", "entropy_study")
with open(os.path.join(dir, "a_true_sum.pickle"), "rb") as f:
    a_true = pickle.load(f)
with open(os.path.join(dir, "all_layers_sum.pickle"), "rb") as f:
    all_layers = pickle.load(f)
with open(os.path.join(dir, "layers_1_10_sum.pickle"), "rb") as f:
    layers_1_10 = pickle.load(f)
with open(os.path.join(dir, "layers_4_10_sum.pickle"), "rb") as f:
    layers_4_10 = pickle.load(f)
with open(os.path.join(dir, "layers_5_10_sum.pickle"), "rb") as f:
    layers_5_10 = pickle.load(f)

In [9]:
with torch.no_grad():
    temp = {}
    for k in ["entailement", "neutral", "contradiction"]:
        display(HTML(f'<h4>metric for the label : {k}</h4>'))
        metric_output = {}
        
        metric_output["all_layers"] = {
            "AUC": [roc_auc_score(a_true[k],all_layers[k])],
            "Jaccard": [scalar_jaccard(a_true[k], all_layers[k])],
            "AUPRC" : [average_precision_score(a_true[k], all_layers[k])],
            "AU - Precision" : [au_precision_curve(a_true[k], all_layers[k])],
            "AU - Recall" : [au_recall_curve(a_true[k], all_layers[k])],
            "Precision (fixed tr)": [precision(a_true[k], all_layers[k])],
            "Recall (fixed tr)": [recall(a_true[k], all_layers[k])],
            "Entropy" : all_layers["entropy"][k]
            
        }

        metric_output["layers_1_10"] = {
            "AUC": [roc_auc_score(a_true[k],layers_1_10[k])],
            "Jaccard": [scalar_jaccard(a_true[k], layers_1_10[k])],
            "AUPRC" : [average_precision_score(a_true[k], layers_1_10[k])],
            "AU - Precision" : [au_precision_curve(a_true[k], layers_1_10[k])],
            "AU - Recall" : [au_recall_curve(a_true[k], layers_1_10[k])],
            "Precision (fixed tr)": [precision(a_true[k], layers_1_10[k])],
            "Recall (fixed tr)": [recall(a_true[k], layers_1_10[k])],
            "Entropy" : layers_1_10["entropy"][k]
        }

        metric_output["layers_4_10"] = {
            "AUC": [roc_auc_score(a_true[k],layers_4_10[k])],
            "Jaccard": [scalar_jaccard(a_true[k], layers_4_10[k])],
            "AUPRC" : [average_precision_score(a_true[k], layers_4_10[k])],
            "AU - Precision" : [au_precision_curve(a_true[k], layers_4_10[k])],
            "AU - Recall" : [au_recall_curve(a_true[k], layers_4_10[k])],
            "Precision (fixed tr)": [precision(a_true[k], layers_4_10[k])],
            "Recall (fixed tr)": [recall(a_true[k], layers_4_10[k])],
            "Entropy" : layers_4_10["entropy"][k]
        }

        metric_output["layers_5_10"] = {
            "AUC": [roc_auc_score(a_true[k],layers_5_10[k])],
            "Jaccard": [scalar_jaccard(a_true[k], layers_5_10[k])],
            "AUPRC" : [average_precision_score(a_true[k], layers_5_10[k])],
            "AU - Precision" : [au_precision_curve(a_true[k], layers_5_10[k])],
            "AU - Recall" : [au_recall_curve(a_true[k], layers_5_10[k])],
            "Precision (fixed tr)": [precision(a_true[k], layers_5_10[k])],
            "Recall (fixed tr)": [recall(a_true[k], layers_5_10[k])],
            "Entropy" : layers_5_10["entropy"][k]
        }
        
        temp[k] = metric_output.copy()
        

        display(HTML(html_render(metric_output)))

Unnamed: 0,AUC,Jaccard,AUPRC,AU - Precision,AU - Recall,Precision (fixed tr),Recall (fixed tr),Entropy
all_layers,0.516,0.004,0.23,0.032,0.015,0.029,0.005,0.068
layers_1_10,0.636,0.063,0.333,0.432,0.077,0.427,0.073,0.113
layers_4_10,0.644,0.067,0.349,0.46,0.082,0.463,0.079,0.121
layers_5_10,0.644,0.068,0.35,0.464,0.082,0.469,0.081,0.124


Unnamed: 0,AUC,Jaccard,AUPRC,AU - Precision,AU - Recall,Precision (fixed tr),Recall (fixed tr),Entropy
all_layers,0.617,0.012,0.112,0.041,0.027,0.039,0.017,0.057
layers_1_10,0.72,0.158,0.261,0.454,0.201,0.452,0.204,0.082
layers_4_10,0.717,0.155,0.261,0.449,0.199,0.446,0.202,0.086
layers_5_10,0.718,0.156,0.263,0.45,0.199,0.446,0.203,0.086


Unnamed: 0,AUC,Jaccard,AUPRC,AU - Precision,AU - Recall,Precision (fixed tr),Recall (fixed tr),Entropy
all_layers,0.63,0.075,0.269,0.404,0.093,0.408,0.086,0.026
layers_1_10,0.72,0.133,0.425,0.682,0.15,0.683,0.149,0.073
layers_4_10,0.725,0.135,0.436,0.691,0.151,0.697,0.152,0.074
layers_5_10,0.726,0.137,0.439,0.699,0.153,0.704,0.153,0.074


In [10]:
%%capture
# complete the graphs
metrics = list(temp["entailement"]["all_layers"].keys())
for id_m, m in enumerate(metrics):
    ax = axes[id_m, 1]
    for label in ["entailement", "neutral", "contradiction"]:
        buff = []
        for agreg in temp[label]:
            buff.append(temp[label][agreg][m][0])
            
        x = [1,2,3,4]
        ax.scatter(x, buff, label=label)
        ax.set_xticks(x)
        ax.set_xticklabels(list(temp[label].keys()),fontsize = 15, rotation=80)
        ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), prop={"size": 10})

### Mean everywhere

In [11]:
# load the data
import pickle
dir = os.path.join(".cache", "plots", "entropy_study")
with open(os.path.join(dir, "a_true_mean.pickle"), "rb") as f:
    a_true = pickle.load(f)
with open(os.path.join(dir, "all_layers_mean.pickle"), "rb") as f:
    all_layers = pickle.load(f)
with open(os.path.join(dir, "layers_1_10_mean.pickle"), "rb") as f:
    layers_1_10 = pickle.load(f)
with open(os.path.join(dir, "layers_4_10_mean.pickle"), "rb") as f:
    layers_4_10 = pickle.load(f)
with open(os.path.join(dir, "layers_5_10_mean.pickle"), "rb") as f:
    layers_5_10 = pickle.load(f)

In [12]:
with torch.no_grad():
    temp = {}
    for k in ["entailement", "neutral", "contradiction"]:
        display(HTML(f'<h4>metric for the label : {k}</h4>'))
        metric_output = {}
        
        metric_output["all_layers"] = {
            "AUC": [roc_auc_score(a_true[k],all_layers[k])],
            "Jaccard": [scalar_jaccard(a_true[k], all_layers[k])],
            "AUPRC" : [average_precision_score(a_true[k], all_layers[k])],
            "AU - Precision" : [au_precision_curve(a_true[k], all_layers[k])],
            "AU - Recall" : [au_recall_curve(a_true[k], all_layers[k])],
            "Precision (fixed tr)": [precision(a_true[k], all_layers[k])],
            "Recall (fixed tr)": [recall(a_true[k], all_layers[k])],
            "Entropy" : all_layers["entropy"][k]
            
        }

        metric_output["layers_1_10"] = {
            "AUC": [roc_auc_score(a_true[k],layers_1_10[k])],
            "Jaccard": [scalar_jaccard(a_true[k], layers_1_10[k])],
            "AUPRC" : [average_precision_score(a_true[k], layers_1_10[k])],
            "AU - Precision" : [au_precision_curve(a_true[k], layers_1_10[k])],
            "AU - Recall" : [au_recall_curve(a_true[k], layers_1_10[k])],
            "Precision (fixed tr)": [precision(a_true[k], layers_1_10[k])],
            "Recall (fixed tr)": [recall(a_true[k], layers_1_10[k])],
            "Entropy" : layers_1_10["entropy"][k]
        }

        metric_output["layers_4_10"] = {
            "AUC": [roc_auc_score(a_true[k],layers_4_10[k])],
            "Jaccard": [scalar_jaccard(a_true[k], layers_4_10[k])],
            "AUPRC" : [average_precision_score(a_true[k], layers_4_10[k])],
            "AU - Precision" : [au_precision_curve(a_true[k], layers_4_10[k])],
            "AU - Recall" : [au_recall_curve(a_true[k], layers_4_10[k])],
            "Precision (fixed tr)": [precision(a_true[k], layers_4_10[k])],
            "Recall (fixed tr)": [recall(a_true[k], layers_4_10[k])],
            "Entropy" : layers_4_10["entropy"][k]
        }

        metric_output["layers_5_10"] = {
            "AUC": [roc_auc_score(a_true[k],layers_5_10[k])],
            "Jaccard": [scalar_jaccard(a_true[k], layers_5_10[k])],
            "AUPRC" : [average_precision_score(a_true[k], layers_5_10[k])],
            "AU - Precision" : [au_precision_curve(a_true[k], layers_5_10[k])],
            "AU - Recall" : [au_recall_curve(a_true[k], layers_5_10[k])],
            "Precision (fixed tr)": [precision(a_true[k], layers_5_10[k])],
            "Recall (fixed tr)": [recall(a_true[k], layers_5_10[k])],
            "Entropy" : layers_5_10["entropy"][k]
        }
        
        temp[k] = metric_output.copy()
        

        display(HTML(html_render(metric_output)))

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Unnamed: 0,AUC,Jaccard,AUPRC,AU - Precision,AU - Recall,Precision (fixed tr),Recall (fixed tr),Entropy
all_layers,0.622,0.036,0.283,0.114,0.041,0.072,0.001,2.891
layers_1_10,0.649,0.04,0.338,0.097,0.045,0.078,0.001,2.972
layers_4_10,0.656,0.041,0.355,0.111,0.046,0.077,0.001,2.86
layers_5_10,0.658,0.041,0.359,0.114,0.046,0.08,0.001,2.821


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Unnamed: 0,AUC,Jaccard,AUPRC,AU - Precision,AU - Recall,Precision (fixed tr),Recall (fixed tr),Entropy
all_layers,0.694,0.034,0.143,0.016,0.047,0.003,0.0,2.974
layers_1_10,0.694,0.035,0.184,0.028,0.048,0.004,0.0,3.045
layers_4_10,0.702,0.036,0.209,0.043,0.051,0.051,0.002,2.942
layers_5_10,0.708,0.038,0.22,0.053,0.052,0.066,0.003,2.905


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Unnamed: 0,AUC,Jaccard,AUPRC,AU - Precision,AU - Recall,Precision (fixed tr),Recall (fixed tr),Entropy
all_layers,0.708,0.044,0.325,0.08,0.051,0.071,0.001,2.955
layers_1_10,0.714,0.043,0.373,0.071,0.05,0.032,0.001,3.015
layers_4_10,0.727,0.045,0.412,0.093,0.053,0.039,0.001,2.914
layers_5_10,0.732,0.047,0.427,0.106,0.054,0.063,0.001,2.876


In [13]:
%%capture
# complete the graphs
metrics = list(temp["entailement"]["all_layers"].keys())
for id_m, m in enumerate(metrics):
    ax = axes[id_m, 2]
    for label in ["entailement", "neutral", "contradiction"]:
        buff = []
        for agreg in temp[label]:
            buff.append(temp[label][agreg][m][0])
            
        x = [1,2,3,4]
        ax.scatter(x, buff, label=label)
        ax.set_xticks(x)
        ax.set_xticklabels(list(temp[label].keys()),fontsize = 15, rotation=80)
        ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), prop={"size": 10})
        
plt.legend(prop={"size":10})
plt.savefig(os.path.join(os.getcwd(),".cache", "plots", "entropy_study", "metrics_graph.png"))

In [14]:
fig.savefig(os.path.join(os.getcwd(),".cache", "plots", "entropy_study", "metrics_graph.png"))