In [73]:
##import and set up

import pytorch_lightning as pl
import torch
from torchtext import data
import spacy
from torchtext import datasets
import os
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from models.simple_rnn import RNN, CNN
import lime
from xai.shap import xai_shap

from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization
from xai.visualize import interpret_sentence

nlp = spacy.load('en')


torch.backends.cudnn.deterministic = True
device='cpu'
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [66]:
#seed for reproducibility

SEED = 0

torch.manual_seed(SEED)


<torch._C.Generator at 0x20717c5dc10>

In [80]:
#for visualizing and interpreting

def forward_with_sigmoid(model,input):
    # print(torch.sigmoid(model(input)))
    # print(model(input))
    return torch.sigmoid(model(input))

# accumalate couple samples in this array for visualization purposes
vis_data_records_ig = []

def interpret_sentence(model, sentence,  TEXT, Label, min_len = 7, label = 0):
    PAD_IND = TEXT.vocab.stoi['pad']
    token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)
    lig = LayerIntegratedGradients(model, model.embedding)

    text = [tok.text for tok in nlp.tokenizer(sentence)]
    if len(text) < min_len:
        text += ['pad'] * (min_len - len(text))
    indexed = [TEXT.vocab.stoi[t] for t in text]

    model.zero_grad()

    input_indices = torch.tensor(indexed, device=device)
    input_indices = input_indices.unsqueeze(0)
    
    # input_indices dim: [sequence_length]
    seq_length = min_len

    # predict
    pred = forward_with_sigmoid(model,input_indices)[0,0].item()
    # print("Pred : ", pred)
    pred_ind = round(pred)

    # generate reference indices for each sample
    reference_indices = token_reference.generate_reference(seq_length, device=device).unsqueeze(0)

    # compute attributions and approximation delta using layer integrated gradients
    attributions_ig, delta = lig.attribute(input_indices, reference_indices, \
                                           n_steps=500, return_convergence_delta=True)

    print('pred: ', Label.vocab.itos[pred_ind], '(', '%.2f'%pred, ')', ', delta: ', abs(delta))

    add_attributions_to_visualizer(attributions_ig, text, pred, pred_ind, label, delta, vis_data_records_ig, TEXT, Label)
    
def add_attributions_to_visualizer(attributions, text, pred, pred_ind, label, delta, vis_data_records, TEXT, Label ):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()

    # storing couple samples in an array for visualization purposes
    vis_data_records.append(visualization.VisualizationDataRecord(
                            attributions,
                            pred,
                            Label.vocab.itos[pred_ind],
                            Label.vocab.itos[label],
                            Label.vocab.itos[1],
                            attributions.sum(),       
                            text,
                            delta))



In [68]:
#data set up
from torchtext import vocab
import torchtext
model = torch.load('models/imdb-model-cnn.pt')
model.eval()
model = model.to(device)   

#loaded_vectors = vocab.GloVe(name='6B', dim=50)
TEXT = torchtext.data.Field(lower=True, tokenize='spacy')
Label = torchtext.data.LabelField(dtype = torch.float)
train, test = torchtext.datasets.IMDB.splits(text_field=TEXT,
                                      label_field=Label,
                                      train='train',
                                      test='test'
)
                                      #path='data/aclImdb')
test= test.split(split_ratio = 0.04)
# If you prefer to use pre-downloaded glove vectors, you can load them with the following two command line
loaded_vectors = torchtext.vocab.Vectors('data/glove.6B.50d.txt')
TEXT.build_vocab(train, vectors=loaded_vectors, max_size=len(loaded_vectors.stoi))

 
TEXT.vocab.set_vectors(stoi=loaded_vectors.stoi, vectors=loaded_vectors.vectors, dim=loaded_vectors.dim)
Label.build_vocab(train)

PAD_IND = TEXT.vocab.stoi['pad']

token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)



In [81]:
#examples to interpret
interpret_sentence(model, 'It was a fantastic performance !', TEXT=TEXT, Label=Label, label=1)
interpret_sentence(model, 'Best film ever', TEXT=TEXT, Label=Label, label=1)
interpret_sentence(model, 'Such a great show!', TEXT=TEXT, Label=Label, label=1)
interpret_sentence(model, 'It was a horrible movie', TEXT=TEXT, Label=Label, label=0)
interpret_sentence(model, 'I\'ve never watched something as bad', TEXT=TEXT, Label=Label, label=0)
interpret_sentence(model, 'It is a disgusting movie!', TEXT=TEXT, Label=Label, label=0)

pred:  pos ( 0.99 ) , delta:  tensor([2.2198e-05], dtype=torch.float64)
pred:  pos ( 1.00 ) , delta:  tensor([6.6302e-05], dtype=torch.float64)
pred:  pos ( 1.00 ) , delta:  tensor([0.0003], dtype=torch.float64)
pred:  pos ( 0.69 ) , delta:  tensor([0.0003], dtype=torch.float64)
pred:  neg ( 0.22 ) , delta:  tensor([0.0011], dtype=torch.float64)
pred:  pos ( 0.80 ) , delta:  tensor([0.0008], dtype=torch.float64)


In [82]:
print('Visualize attributions based on Integrated Gradients')
visualization.visualize_text(vis_data_records_ig)
vis_data_records_ig = [] #to clear for future references

Visualize attributions based on Integrated Gradients


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
pos,pos (0.99),pos,-0.61,It was a fantastic performance ! pad
,,,,
pos,pos (1.00),pos,-1.1,Best film ever pad pad pad pad
,,,,
pos,pos (1.00),pos,-0.59,Such a great show ! pad pad
,,,,
neg,pos (0.69),pos,-1.82,It was a horrible movie pad pad
,,,,
neg,neg (0.22),pos,-2.45,I 've never watched something as bad
,,,,


In [75]:

##running with our basic 50% rnn
#call model
model_rnn = RNN()

#get data
text, labels, train_data, test_data, valid_data, train_iterator, valid_iterator, test_iterator = model_rnn.preprocess_data(device, BATCH_SIZE=64)

# init model
INPUT_DIM = len(text.vocab)
EMBEDDING_DIM = 100
HIDDEN_DIM = 256
OUTPUT_DIM = 1

model_rnn.create_model(INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM)

model_rnn.load_state_dict(torch.load("./models/model_imdb.pt"))

model_rnn.eval()



RNN(
  (embedding): Embedding(20002, 100)
  (rnn): LSTM(100, 256)
  (fc): Linear(in_features=256, out_features=1, bias=True)
  (output): Softmax(dim=None)
)

In [83]:
interpret_sentence(model_rnn, 'It was a fantastic performance !', TEXT=text, Label=labels, label=1)
interpret_sentence(model_rnn, 'Best film ever', TEXT=text, Label=labels, label=1)
interpret_sentence(model_rnn, 'Such a great show!', TEXT=text, Label=labels, label=1)
interpret_sentence(model_rnn, 'It was a horrible movie', TEXT=text, Label=labels, label=0)
interpret_sentence(model_rnn, 'I\'ve never watched something as bad', TEXT=text, Label=labels, label=0)
interpret_sentence(model_rnn, 'It is a disgusting movie!', TEXT=text, Label=labels, label=0)

pred:  neg ( 0.73 ) , delta:  tensor([0., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)
pred:  neg ( 0.73 ) , delta:  tensor([0., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)
pred:  neg ( 0.73 ) , delta:  tensor([0., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)
pred:  neg ( 0.73 ) , delta:  tensor([0., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)
pred:  neg ( 0.73 ) , delta:  tensor([0., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)
pred:  neg ( 0.73 ) , delta:  tensor([0., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)


In [79]:
print('Visualize attributions based on Integrated Gradients')
visualization.visualize_text(vis_data_records_ig)
vis_data_records_ig = [] #to clear for future references

Visualize attributions based on Integrated Gradients


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
neg,neg (0.73),neg,,It was a fantastic performance ! pad
,,,,
neg,neg (0.73),neg,,Best film ever pad pad pad pad
,,,,
neg,neg (0.73),neg,,Such a great show ! pad pad
,,,,
pos,neg (0.73),neg,,It was a horrible movie pad pad
,,,,
pos,neg (0.73),neg,,I 've never watched something as bad
,,,,
