<a href="https://colab.research.google.com/github/namdori61/colab-playground/blob/master/XAI_Sentiment_analysis_CNN_IMDB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Reference
Captum Interpreting text models (IMDB sentiment analysis) tutorial: https://captum.ai/tutorials/IMDB_TorchText_Interpret

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [2]:
!pip install captum



In [0]:
import spacy

import torch
import torchtext
import torchtext.data
import torch.nn as nn
import torch.nn.functional as F

from torchtext.vocab import Vocab

from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization

nlp = spacy.load('en')

In [0]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [0]:
class CNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, n_filters,
                 filter_sizes, output_dim, dropout, pad_idx):
        super.__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim,
                                      padding_idx=pad_idx)
        self.convs = nn.ModuleList(
            [
             nn.Conv2d(
                 in_channels=1,
                 out_channels=n_filters,
                 kernel_size=(fs, embedding_dim)
             )
             for fs in filter_sizes
            ]
        )
        self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, text):
        #text = [sent_len, batch_size]
        embedded = self.embedding(text)
        embedded = embedded.unsqueeze(1)

        conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]
        pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]

        cat = self.dropout(torch.cat(pooled, dim=1))

        return self.fc(cat)

In [6]:
model = torch.load('/content/gdrive/My Drive/tutorials/model/imdb-model-cnn.pt')
model.eval()
model = model.to(device)



In [0]:
def forward_with_sigmoid(input):
    return torch.sigmoid(model(input))

In [0]:
TEXT = torchtext.data.Field(lower=True, tokenize='spacy')
Label = torchtext.data.LabelField(dtype = torch.float)

In [10]:
train, test = torchtext.datasets.IMDB.splits(text_field=TEXT,
                                      label_field=Label,
                                      train='train',
                                      test='test')
test, _ = test.split(split_ratio = 0.04)

downloading aclImdb_v1.tar.gz


aclImdb_v1.tar.gz: 100%|██████████| 84.1M/84.1M [00:06<00:00, 12.2MB/s]


In [11]:
from torchtext import vocab

loaded_vectors = torchtext.vocab.Vectors('/content/gdrive/My Drive/tutorials/data/pretrained/glove.6B.50d.txt')
TEXT.build_vocab(train, vectors=loaded_vectors, max_size=len(loaded_vectors.stoi))
    
TEXT.vocab.set_vectors(stoi=loaded_vectors.stoi, vectors=loaded_vectors.vectors, dim=loaded_vectors.dim)
Label.build_vocab(train)

100%|█████████▉| 399683/400000 [00:10<00:00, 40467.42it/s]

In [12]:
print('Vocabulary Size: ', len(TEXT.vocab))

Vocabulary Size:  101520


In [0]:
PAD_IND = TEXT.vocab.stoi['pad']

In [0]:
token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)

In [0]:
lig = LayerIntegratedGradients(model, model.embedding)

In [0]:
vis_data_records_ig = []

def interpret_sentence(model, sentence, min_len = 7, label = 0):
    text = [tok.text for tok in nlp.tokenizer(sentence)]
    if len(text) < min_len:
        text += ['pad'] * (min_len - len(text))
    indexed = [TEXT.vocab.stoi[t] for t in text]

    model.zero_grad()

    input_indices = torch.tensor(indexed, device=device)
    input_indices = input_indices.unsqueeze(0)
    
    seq_length = min_len

    pred = forward_with_sigmoid(input_indices).item()
    pred_ind = round(pred)

    reference_indices = token_reference.generate_reference(seq_length, device=device).unsqueeze(0)

    attributions_ig, delta = lig.attribute(input_indices, reference_indices, \
                                           n_steps=500, return_convergence_delta=True)

    print('pred: ', Label.vocab.itos[pred_ind], '(', '%.2f'%pred, ')', ', delta: ', abs(delta))

    add_attributions_to_visualizer(attributions_ig, text, pred, pred_ind, label, delta, vis_data_records_ig)
    
def add_attributions_to_visualizer(attributions, text, pred, pred_ind, label, delta, vis_data_records):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()

    vis_data_records.append(visualization.VisualizationDataRecord(
                            attributions,
                            pred,
                            Label.vocab.itos[pred_ind],
                            Label.vocab.itos[label],
                            Label.vocab.itos[1],
                            attributions.sum(),       
                            text,
                            delta))

In [17]:
interpret_sentence(model, 'It was a fantastic performance !', label=1)
interpret_sentence(model, 'Best film ever', label=1)
interpret_sentence(model, 'Such a great show!', label=1)
interpret_sentence(model, 'It was a horrible movie', label=0)
interpret_sentence(model, 'I\'ve never watched something as bad', label=0)
interpret_sentence(model, 'It is a disgusting movie!', label=0)

pred:  pos ( 0.99 ) , delta:  tensor([2.1500e-05], device='cuda:0', dtype=torch.float64)
pred:  pos ( 1.00 ) , delta:  tensor([6.6486e-05], device='cuda:0', dtype=torch.float64)
pred:  pos ( 1.00 ) , delta:  tensor([0.0003], device='cuda:0', dtype=torch.float64)
pred:  pos ( 0.69 ) , delta:  tensor([0.0003], device='cuda:0', dtype=torch.float64)
pred:  neg ( 0.22 ) , delta:  tensor([0.0011], device='cuda:0', dtype=torch.float64)
pred:  pos ( 0.80 ) , delta:  tensor([0.0008], device='cuda:0', dtype=torch.float64)


In [18]:
print('Visualize attributions based on Integrated Gradients')
visualization.visualize_text(vis_data_records_ig)

Visualize attributions based on Integrated Gradients


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
pos,pos (0.99),pos,-0.61,It was a fantastic performance ! pad
,,,,
pos,pos (1.00),pos,-1.1,Best film ever pad pad pad pad
,,,,
pos,pos (1.00),pos,-0.59,Such a great show ! pad pad
,,,,
neg,pos (0.69),pos,-1.82,It was a horrible movie pad pad
,,,,
neg,neg (0.22),pos,-2.45,I 've never watched something as bad
,,,,
