In [1]:
import os
import pickle
import argparse
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from bimpm import BIMPM
from dataset import SNLI
from util import prepare_output_dir

In [2]:
from args import conf, rawr_conf

In [7]:
class Batch:

    def __init__(self, premise=None, hypothesis=None, label=None):
        self.premise = premise
        self.hypothesis = hypothesis
        self.label = label
        
def to_text(x, vocab):
    if isinstance(x, Variable):
        x = x.data
    if isinstance(x, torch.cuda.LongTensor):
        x = x.cpu()
    if isinstance(x, torch.LongTensor):
        x = x.numpy().tolist()
    return ' '.join(vocab[w] for w in x if w != 1)

def real_length(x):
    # length of vector without padding
    if isinstance(x, Variable):
        return sum(x.data != 1)
    else:
        return sum(x != 1)


[torchtext.data.batch.Batch of size 32 from SNLI]
	[.premise]:[torch.cuda.LongTensor of size 32x7 (GPU 0)]
	[.hypothesis]:[torch.cuda.LongTensor of size 32x7 (GPU 0)]
	[.label]:[torch.cuda.LongTensor of size 32 (GPU 0)]

In [13]:
def get_onehot_grad(model, batch, p_not_h=False):
    criterion = nn.CrossEntropyLoss()
    extracted_grads = {}

    def extract_grad_hook(name):
        def hook(grad):
            extracted_grads[name] = grad
        return hook

    if p_not_h:
        batch_size, length = batch.premise.shape
    else:
        batch_size, length = batch.hypothesis.shape
    model.train()
    output = model(batch.premise, batch.hypothesis,
                   embed_grad_hook=extract_grad_hook('embed'),
                   p_not_h=p_not_h)
    label = torch.max(output, 1)[1]
    loss = criterion(output, label)
    loss.backward()
    embed_grad = extracted_grads['embed']
    if p_not_h:
        embed = model.word_emb(batch.premise)
    else:
        embed = model.word_emb(batch.hypothesis)
    onehot_grad = embed.view(-1) * embed_grad.contiguous().view(-1)
    onehot_grad = onehot_grad.view(batch_size, length, -1).sum(-1)
    return onehot_grad

In [142]:
def get_grad_rank(batch):
    one_hot_grad = get_onehot_grad(model, batch).detach().cpu().numpy()
    real_lengths = [real_length(x) for x in batch.hypothesis]    
    # sort by gradient of CrossEntropyLoss w.r.t. embedding * embedding
    # large gradient means large increase in loss when embedding is increased by epsilon
    # large gradient means large decrease in loss when embedding is decreased by epsilon
    # large gradient approx large decrease in loss when word is removed
    # large gradient approx small decrease in confidence when word is removed
    # large gradient means word is unimportant
    # first word is the least important
    rank = [np.argsort(-x[:l]).tolist() for x, l in zip(one_hot_grad, real_lengths)]
    return rank

In [143]:
def get_l1o_rank(batch):
    model.eval()
    
    # original prediction and confidence
    output = F.softmax(model(batch.premise, batch.hypothesis), 1)
    target_scores, target = torch.max(output, 1)
    
    # decrease in confidence on the original prediction
    losses = []
    x = batch.hypothesis
    criterion = nn.CrossEntropyLoss(reduction='none')
    # enumerate through words to be removed
    for i in range(x.shape[1]):
        # construct new hypothesis with ith word removed
        hypothesis_parts = []
        if i > 0:
            hypothesis_parts.append(x[:, :i])
        if i < x.shape[1] - 1:
            hypothesis_parts.append(x[:, i+1:])
        hypothesis = torch.cat(hypothesis_parts, axis=1)
        
        output = F.softmax(model(batch.premise, hypothesis), 1)
        # new confidence on the original prediction
        new_scores = output[np.arange(x.shape[0]), target]
        # decrease in confidence on the original prediction
        # small decrease means ith word is unimportant
        losses.append((target_scores - new_scores).detach().cpu().numpy())
        
    # sort the decrease in confidence by ascending order
    # small decrease means unimportant
    # the first word is the least important word
    losses = np.stack(losses, axis=1)
    real_lengths = [real_length(x) for x in batch.hypothesis]    
    rank = [np.argsort(x[:l]).tolist() for x, l in zip(losses, real_lengths)]
    return rank

In [147]:
model = BIMPM(conf, data)
model.load_state_dict(torch.load('results/baseline.pt'))
model.word_emb.weight.requires_grad = True
model = model.to(conf.device).eval()

In [148]:
data = SNLI(conf)
conf.char_vocab_size = len(data.char_vocab)
conf.word_vocab_size = len(data.TEXT.vocab)
conf.class_size = len(data.LABEL.vocab)
conf.max_word_len = data.max_word_len
q_vocab = data.TEXT.vocab.itos
a_vocab = data.LABEL.vocab.itos

loading data from data/
loading vector cache from .vector_cache/vocab.vectors.pt


In [176]:
em_important, em_unimportant = [], []
top3_important, top3_unimportant = [], []
for i, batch in enumerate(tqdm(data.dev_iter)):
    # if i > 5:
    #     break
    rank1 = get_grad_rank(batch)
    rank2 = get_l1o_rank(batch)
    em_important += [a[-1] == b[-1] for a, b in zip(rank1, rank2)]
    em_unimportant += [a[0] == b[0] for a, b in zip(rank1, rank2)]
    top3_important += [a[-1] in b[-1:-4:-1] for a, b in zip(rank1, rank2)]
    top3_unimportant += [a[0] in b[:3] for a, b in zip(rank1, rank2)]
print('exact match important', np.mean(em_important))
print('exact match unimportant', np.mean(em_unimportant))
print('top 3 important', np.mean(top3_important))
print('top 3 unimportant', np.mean(top3_unimportant))

100%|██████████| 308/308 [01:28<00:00,  3.47it/s]

exact match important 0.47561471245681775
exact match unimportant 0.37014834383255435
top 3 important 0.7985165616744564
top 3 unimportant 0.6394025604551921



