https://captum.ai/tutorials/IMDB_TorchText_Interpret

In [91]:
import spacy

import torch
import torchtext
import torchtext.data
import torch.nn as nn
import torch.nn.functional as F

from torchtext.vocab import Vocab

from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization
# from interpretation_package.flair_model_wrapper import ModelWrapper
nlp = spacy.load('en')

In [92]:
nlp.vocab.__len__()

489

In [93]:
len(nlp.vocab.strings)

1160

In [64]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [65]:
class CNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, 
                 dropout, pad_idx):
        
        super().__init__()
                
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx = pad_idx)
        
        self.convs = nn.ModuleList([
                                    nn.Conv2d(in_channels = 1, 
                                              out_channels = n_filters, 
                                              kernel_size = (fs, embedding_dim)) 
                                    for fs in filter_sizes
                                    ])
        
        self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)

        self.dropout = nn.Dropout(dropout)
        
    def forward(self, text):
        
        #text = [sent len, batch size]
        
        #text = text.permute(1, 0)
                
        #text = [batch size, sent len]
        
        embedded = self.embedding(text)

        #embedded = [batch size, sent len, emb dim]
        
        embedded = embedded.unsqueeze(1)
        
        #embedded = [batch size, 1, sent len, emb dim]
        
        conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]
            
        #conved_n = [batch size, n_filters, sent len - filter_sizes[n] + 1]
                
        pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]
        
        #pooled_n = [batch size, n_filters]
        
        cat = self.dropout(torch.cat(pooled, dim = 1))

        #cat = [batch size, n_filters * len(filter_sizes)]
            
        return self.fc(cat)

In [66]:
model = torch.load('model/imdb-model-cnn.pt')
model.eval()
model = model.to(device)



In [67]:
def forward_with_sigmoid(input):
    return torch.sigmoid(model(input))

In [68]:
TEXT = torchtext.data.Field(lower=True, tokenize='spacy')
Label = torchtext.data.LabelField(dtype = torch.float)



In [70]:
train, test = torchtext.datasets.IMDB.splits(text_field=TEXT,
                                      label_field=Label,
                                      train='train',
                                      test='test',
                                      path='data/aclImdb')
test, _ = test.split(split_ratio = 0.04)




In [71]:
from torchtext import vocab

loaded_vectors = vocab.GloVe(name='6B', dim=50)

In [100]:
len(loaded_vectors)

400000

In [72]:
TEXT.build_vocab(train, vectors=loaded_vectors, max_size=len(loaded_vectors.stoi))
TEXT.vocab.set_vectors(stoi=loaded_vectors.stoi, vectors=loaded_vectors.vectors, dim=loaded_vectors.dim)
Label.build_vocab(train)

In [73]:
print('Vocabulary Size: ', len(TEXT.vocab))

Vocabulary Size:  101513


In [55]:
PAD_IND = TEXT.vocab.stoi['pad']

In [None]:
# need training again

In [74]:
model.train()

CNN(
  (embedding): Embedding(101982, 50, padding_idx=1)
  (convs): ModuleList(
    (0): Conv2d(1, 100, kernel_size=(3, 50), stride=(1, 1))
    (1): Conv2d(1, 100, kernel_size=(4, 50), stride=(1, 1))
    (2): Conv2d(1, 100, kernel_size=(5, 50), stride=(1, 1))
  )
  (fc): Linear(in_features=300, out_features=1, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.5, inplace=False)
)

In [56]:
token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)

In [57]:
lig = LayerIntegratedGradients(model, model.embedding)

In [58]:
Label.vocab.itos = ['neg','pos']

In [59]:
vis_data_records_ig = []

def interpret_sentence(model, sentence, min_len = 7, label = 0):
    text = [tok.text for tok in nlp.tokenizer(sentence)]
    if len(text) < min_len:
        text += ['pad'] * (min_len - len(text))
    indexed = [TEXT.vocab.stoi[t] for t in text]

    model.zero_grad()

    input_indices = torch.tensor(indexed, device=device)
    input_indices = input_indices.unsqueeze(0)
    
    # input_indices dim: [sequence_length]
    seq_length = min_len

    # predict
    pred = forward_with_sigmoid(input_indices).item()
    pred_ind = round(pred)

    # generate reference indices for each sample
    reference_indices = token_reference.generate_reference(seq_length, device=device).unsqueeze(0)

    # compute attributions and approximation delta using layer integrated gradients
    attributions_ig, delta = lig.attribute(input_indices, reference_indices, \
                                           n_steps=500, return_convergence_delta=True)
    print(pred_ind, pred)
    print('pred: ', Label.vocab.itos[pred_ind], '(', '%.2f'%pred, ')', ', delta: ', abs(delta))
    
    add_attributions_to_visualizer(attributions_ig, text, pred, pred_ind, label, delta, vis_data_records_ig)
    
def add_attributions_to_visualizer(attributions, text, pred, pred_ind, label, delta, vis_data_records):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()

    # storing couple samples in an array for visualization purposes
    vis_data_records.append(visualization.VisualizationDataRecord(
                            attributions,
                            pred,
                            Label.vocab.itos[pred_ind],
                            Label.vocab.itos[label],
                            Label.vocab.itos[1],
                            attributions.sum(),       
                            text,
                            delta))

In [60]:
interpret_sentence(model, 'It was a fantastic performance !', label=1)
interpret_sentence(model, 'Best film ever', label=1)
interpret_sentence(model, 'Such a great show!', label=1)
interpret_sentence(model, 'It was a horrible movie', label=0)
interpret_sentence(model, 'I\'ve never watched something as bad', label=0)
interpret_sentence(model, 'It is a disgusting movie!', label=0)

1 0.989221453666687
pred:  pos ( 0.99 ) , delta:  tensor([2.1484e-05], device='cuda:0', dtype=torch.float64)
1 0.9983250498771667
pred:  pos ( 1.00 ) , delta:  tensor([6.6464e-05], device='cuda:0', dtype=torch.float64)
1 0.9967617392539978
pred:  pos ( 1.00 ) , delta:  tensor([0.0003], device='cuda:0', dtype=torch.float64)
1 0.6885504722595215
pred:  pos ( 0.69 ) , delta:  tensor([0.0003], device='cuda:0', dtype=torch.float64)
0 0.21652160584926605
pred:  neg ( 0.22 ) , delta:  tensor([0.0011], device='cuda:0', dtype=torch.float64)
1 0.8049034476280212
pred:  pos ( 0.80 ) , delta:  tensor([0.0008], device='cuda:0', dtype=torch.float64)


In [61]:
visualization.visualize_text(vis_data_records_ig)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
pos,pos (0.99),pos,-0.61,It was a fantastic performance ! pad
,,,,
pos,pos (1.00),pos,-1.1,Best film ever pad pad pad pad
,,,,
pos,pos (1.00),pos,-0.59,Such a great show ! pad pad
,,,,
neg,pos (0.69),pos,-1.82,It was a horrible movie pad pad
,,,,
neg,neg (0.22),pos,-2.45,I 've never watched something as bad
,,,,
