# Model analysis and heatmap

In [1]:
import numpy as np
import pandas as pd
import torch
import json
import spacy
from prettytable import PrettyTable
import os
import sys
import yaml
from yaml import Loader
from util import Dictionary


## Importing the local .py file
import plot_annotation_matrix as pa


#### We store the useful data:

The useful data is a list of dictionnaries with 5 fields:

| Name Of Field | percentage_correct_pred                        | label                           | prediction                                   | attention                                       | text                    |
|---------------|------------------------------------------------|---------------------------------|----------------------------------------------|-------------------------------------------------|-------------------------|
| Type          | float                                          | list of ints                    | list of ints                                 | list of floats                                  | list of list of strings |
| Desc          | Percentage of correct prediction on this batch | The true rating for each review | The prediction of the rating for each review | The value of the attention matrix for the batch | The reviews in text     |


In [2]:
data = np.load("content/useful_data.npy", allow_pickle=True)

In [3]:
for i in range(20):
    print('percentage_correct_preddata', i, data[i]['percentage_correct_pred'])

percentage_correct_preddata 0 0.66
percentage_correct_preddata 1 0.56
percentage_correct_preddata 2 0.6
percentage_correct_preddata 3 0.6
percentage_correct_preddata 4 0.54
percentage_correct_preddata 5 0.66
percentage_correct_preddata 6 0.58
percentage_correct_preddata 7 0.74
percentage_correct_preddata 8 0.58
percentage_correct_preddata 9 0.6
percentage_correct_preddata 10 0.78
percentage_correct_preddata 11 0.86
percentage_correct_preddata 12 0.6
percentage_correct_preddata 13 0.58
percentage_correct_preddata 14 0.58
percentage_correct_preddata 15 0.56
percentage_correct_preddata 16 0.6
percentage_correct_preddata 17 0.76
percentage_correct_preddata 18 0.58
percentage_correct_preddata 19 0.58


In [4]:
data[11]['label']

array([4, 0, 1, 4, 3, 4, 3, 4, 4, 4, 4, 4, 3, 4, 4, 3, 4, 4, 3, 4, 2, 0,
       4, 4, 3, 4, 2, 4, 4, 4, 4, 2, 3, 3, 3, 3, 4, 4, 4, 0, 3, 2, 1, 4,
       4, 3, 4, 4, 4, 2])

In [5]:
data[11]['prediction']

array([4, 4, 0, 4, 3, 4, 3, 4, 4, 4, 4, 4, 3, 4, 4, 4, 4, 4, 3, 4, 2, 0,
       4, 4, 3, 4, 2, 4, 4, 4, 4, 1, 3, 3, 3, 3, 4, 4, 4, 0, 4, 4, 2, 4,
       4, 3, 4, 4, 4, 2])

In [6]:
pd.DataFrame(data[11]['prediction']).value_counts()

4    32
3    10
2     4
0     3
1     1
dtype: int64

In [7]:
pd.DataFrame(data[11]['label']).value_counts()

4    28
3    12
2     5
0     3
1     2
dtype: int64

In [8]:
attention = None
for i in range(50):
    text = data[11]['text'][i]
    text = pa.clean_word(text)
    taille = len(text)
    if(taille<500):
        attention = pd.DataFrame(data[11]['attention'][i])
        attention = attention[[j for j in range(taille)]].sum(axis=0)
        attention = pa.rescale(attention)
        attention = np.array(attention)
        #attention[attention<1]=0
        pa.generate(text, attention, "heatmaps/sample{}.tex".format(i),label=data[11]['label'][i], prediction=data[11]['prediction'][i])

In [9]:
attention

array([6.55447093e-06, 6.65515671e+01, 2.94320278e+01, 1.35551900e-01,
       1.47940156e-07, 3.15557719e-10, 1.00000000e+02, 9.98578873e+01,
       3.40661407e-02, 4.01927614e+00, 1.53042572e-06, 1.30460949e-08,
       4.66018605e-07, 3.78496838e-06, 1.85829041e-08, 2.33820941e-07,
       5.52758320e-06, 2.53599546e-06, 3.18948388e-08, 1.11881384e-08,
       6.64998788e-06, 6.20091285e-08, 6.81581355e-07, 6.05844974e-08,
       1.06690038e-07, 7.19324526e-05, 1.21802132e-05, 3.40697170e-07,
       1.26076301e-07, 6.34069750e-07, 5.04323168e-07, 4.74291227e-07,
       1.65380987e-09, 0.00000000e+00, 6.83726356e-08, 1.18540981e-06,
       2.70050915e-09, 5.06845943e-08, 1.34527212e-09, 3.49992852e-08,
       2.85037549e-06, 9.23593342e-03, 8.58236926e-10, 1.61679858e-09,
       7.23675129e-08, 9.99559326e+01, 5.60607838e-10, 4.00413825e-07,
       1.53217491e-06, 8.00950602e-07, 1.90215090e-07, 4.52599634e-04,
       9.91841944e-07, 2.65643780e-06, 3.06790149e-10, 3.61169214e-07,
      