In this notebook, you'll use the method of integrated gradients to identify the tokens in the input that are most responsible for the predictions that a bigram CNN model is making.  Before running, install the captum library:

```
pip install captum
```

In [2]:
!pip install captum

Collecting captum
  Downloading captum-0.4.1-py3-none-any.whl (1.4 MB)
[K     |████████████████████████████████| 1.4 MB 2.4 MB/s eta 0:00:01
Installing collected packages: captum
Successfully installed captum-0.4.1


In [3]:
import torch
import numpy as np
import torch.nn as nn
import nltk
import random
from captum.attr import LayerIntegratedGradients, visualization

In [4]:
def get_batches(x, y, batch_size=12):
    batches_x=[]
    batches_y=[]
    for i in range(0, len(x), batch_size):
        xbatch=x[i:i+batch_size]
        ybatch=y[i:i+batch_size]
        
        maxlen=max([len(sent) for sent in xbatch])
        
        # pad sequence with 0's to maximum sequence length within that batch
        for j in range(len(xbatch)):
            xbatch[j].extend([0] * (maxlen-len(xbatch[j])))
                        
        batches_x.append(torch.LongTensor(xbatch))
        batches_y.append(torch.LongTensor(ybatch))
    
    return batches_x, batches_y

In [5]:
PAD_INDEX = 0             # reserved for padding words
UNKNOWN_INDEX = 1         # reserved for unknown words

data_lens = []

def read_embeddings(filename, vocab_size=100000):
    """
  Utility function, loads in the `vocab_size` most common embeddings from `filename`
  
  Arguments:
  - filename:     path to file
                  automatically infers correct embedding dimension from filename
  - vocab_size:   maximum number of embeddings to load

  Returns 
  - embeddings:   torch.FloatTensor matrix of size (vocab_size x word_embedding_dim)
  - vocab:        dictionary mapping word (str) to index (int) in embedding matrix
  """

  # get the embedding size from the first embedding
    with open(filename, encoding="utf-8") as file:
        word_embedding_dim = len(file.readline().split(" ")) - 1

    vocab = {"[PAD]":0, "[UNK]":1}

    embeddings = np.zeros((vocab_size, word_embedding_dim))
    with open(filename, encoding="utf-8") as file:
        for idx, line in enumerate(file):

            if idx + 2 >= vocab_size:
                break

            cols = line.rstrip().split(" ")
            val = np.array(cols[1:])
            word = cols[0]
            embeddings[idx + 2] = val
            vocab[word] = idx + 2
  
    return torch.FloatTensor(embeddings), vocab

In [6]:
embeddings, vocab=read_embeddings("../data/glove.6B.100d.100K.txt")
rev_vocab={vocab[l]:l for l in vocab}

In [7]:
def read_labels(filename):
    labels={}
    with open(filename) as file:
        for line in file:
            cols = line.split("\t")
            label = cols[0]
            if label not in labels:
                labels[label]=len(labels)
    return labels

In [8]:
def read_data(filename, vocab, labels, max_data_points=1000):
    """
    :param filename: the name of the file
    :return: list of tuple ([word index list], label)
    as input for the forward and backward function
    """    
    data = []
    data_labels = []
    with open(filename) as file:
        for line in file:
            cols = line.split("\t")
            label = cols[0]
            text = cols[1]
            w_int = []
            for w in nltk.word_tokenize(text.lower()):
                if w in vocab:
                    w_int.append(vocab[w])
                else:
                    w_int.append(UNKNOWN_INDEX)
                    
            data.append((w_int))
            data_labels.append(labels[label])
            

    # shuffle the data
    tmp = list(zip(data, data_labels))
    random.shuffle(tmp)
    data, data_labels = zip(*tmp)
    
    if max_data_points is None:
        return data, data_labels
    
    return data[:max_data_points], data_labels[:max_data_points]

In [9]:
def transform_data(text):
    w_int = []
    for w in nltk.word_tokenize(text.lower()):
        if w in vocab:
            w_int.append(vocab[w])
        else:
            w_int.append(UNKNOWN_INDEX)
    return w_int

In [10]:
# Change this to the directory with your data (from the CheckData_TODO.ipynb exercise).  
# The directory should contain train.tsv, dev.tsv and test.tsv
directory="../data/lmrd"

In [11]:
labels=read_labels("%s/train.tsv" % directory)
rev_labels={labels[l]:l for l in labels}

In [12]:
trainX, trainY=read_data("%s/train.tsv" % directory, vocab, labels, max_data_points=10000)

In [13]:
devX, devY=read_data("%s/dev.tsv" % directory, vocab, labels, max_data_points=100)

In [14]:
batch_trainX, batch_trainY=get_batches(trainX, trainY)
batch_devX, batch_devY=get_batches(devX, devY)

In [15]:
class CNNClassifier_bigram(nn.Module):

    """
    CNN with a window size of 2 (i.e., 2grams) and 96 filters
    
    """
    def __init__(self, pretrained_embeddings):
        super().__init__()
        
        self.num_filters=96
        
        self.num_labels = 2

        _, embedding_dim=pretrained_embeddings.shape
        
        self.embeddings = nn.Embedding.from_pretrained(pretrained_embeddings, freeze=True)

        # convolution over 2 words    
        self.conv_2 = nn.Conv1d(embedding_dim, self.num_filters, 2, 1)
        
        self.fc = nn.Linear(self.num_filters, self.num_labels)

    
    def forward(self, input): 
        
        # batch_size x max_seq_length x embeddings_size
        x0 = self.embeddings(input)
        
        # batch_size x embeddings_size x max_seq_length
        # (the input order expected by nn.Conv1d)
        x0 = x0.permute(0, 2, 1)

        # convolution
        x2 = self.conv_2(x0)
        # non-linearity
        x2 = torch.tanh(x2)
        # global max-pooling over the entire sequence
        x2=torch.max(x2, 2)[0]

        out = self.fc(x2)
        
        return out        

In [16]:
def evaluate(model, x, y):
    model.eval()
    corr = 0.
    total = 0.
    with torch.no_grad():
        for x, y in zip(x, y):
            y_preds=model.forward(x)
            for idx, y_pred in enumerate(y_preds):
                prediction=torch.argmax(y_pred)
                if prediction == y[idx]:
                    corr += 1.
                total+=1                          
    return corr/total

In [17]:
def predict(model, x):
    model.eval()
    preds=[]
    
    with torch.no_grad():
        for batch_x in x:
            y_preds=model.forward(batch_x).numpy()
            for y_pred in y_preds:
                prediction=np.argmax(y_pred)
                preds.append(prediction)
                
    return preds

In [18]:
def train(model, model_filename, train_batches_x, train_batches_y, dev_batches_x, dev_batches_y):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    losses = []
    cross_entropy=nn.CrossEntropyLoss()

    best_dev_acc=0.
    
    for epoch in range(5):
        model.train()

        for x, y in zip(train_batches_x, train_batches_y):
            y_pred=model.forward(x)
            loss = cross_entropy(y_pred.view(-1, 2), y.view(-1))
            losses.append(loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        dev_accuracy=evaluate(model, dev_batches_x, dev_batches_y)
        
        # we're going to save the model that performs the best on *dev* data
        if dev_accuracy > best_dev_acc:
            torch.save(model.state_dict(), model_filename)
            print("%.3f is better than %.3f, saving model ..." % (dev_accuracy, best_dev_acc))
            best_dev_acc = dev_accuracy
        if epoch % 1 == 0:
            print("Epoch %s, dev accuracy: %.3f" % (epoch, dev_accuracy))
            
    model.load_state_dict(torch.load(model_filename))            
    print("\nBest Performing Model achieves dev accuracy of : %.3f" % (best_dev_acc))

First, let's train our model.

In [None]:
cnn_model = CNNClassifier_bigram(pretrained_embeddings=embeddings)
train(cnn_model, "cnn.bigram.model", batch_trainX, batch_trainY, batch_devX, batch_devY)

In [None]:
def interpret(x, y, model, vocab, rev_labels, rev_vocab):
    
    ''' https://captum.ai/tutorials/IMDB_TorchText_Interpret '''
    
    model.eval()
    _, maxlen=x.shape
    
    # baseline is uninformative sequence of padding tokens
    baseline=torch.LongTensor([[PAD_INDEX]*maxlen])
    y_preds=model.forward(x)
    
    y_preds=torch.nn.functional.softmax(y_preds, dim=1)
    y_preds=y_preds.detach().numpy()
    preds=[]
    for y_pred in y_preds:
        prediction=np.argmax(y_pred)
        preds.append(prediction)
    
    # we'll get our attributions with respect to target class #1
    target_class=1
    
    ig = LayerIntegratedGradients(cnn_model, cnn_model.embeddings)
    attributions, delta = ig.attribute(x, baseline, target=target_class, return_convergence_delta=True)
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.numpy()
    
    orig=[]
    
    for idx, sent in enumerate(x):
        orig.append([])
        for tok in sent:
            tok=int(tok.numpy())
            orig[idx].append(rev_vocab[tok])
    
    y=y.numpy()

    records=[]
    for idx, pred in enumerate(preds):
        records.append(visualization.VisualizationDataRecord(
                                    attributions[idx],
                                    y_preds[idx][0],
                                    rev_labels[preds[idx]],
                                    rev_labels[y[idx]],
                                    rev_labels[target_class],
                                    attributions[idx].sum(),
                                    orig[idx],
                                    delta))
    visualization.visualize_text(records, legend=None)

**Q1**. Create a smaller set of toy examples here to interpret this method on relatively short texts.  How does this accord with your understanding of what a bigram CNN should be paying attention to?

In [None]:
x=["The writing was amazing and Daniel Day-Lewis was terrific!", "Terrible!", "This movie is not bad", "Exactly what I was looking for.", "Outrageously good."]
y=["pos", "neg", "pos", "pos", "pos"]
batch_x, batch_y=get_batches([transform_data(xs) for xs in x], [labels[ys] for ys in y])
interpret(batch_x[0], batch_y[0], cnn_model, vocab, rev_labels, rev_vocab)

**Q2**. Read in a batch of your development data and examine the terms that are identified as being most important in the input.  Are they what you would expect?

In [None]:
interpret(batch_devX[0], batch_devY[0], cnn_model, vocab, rev_labels, rev_vocab)