#### Run relevance backout here

In [1]:
import pickle
import re
import os

import random
import numpy as np
import torch
from random import shuffle
import argparse
import pickle

import collections

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import sys
sys.path.append("..")

from model.BERT import *

from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from tqdm import tqdm, trange

from util.optimization import BERTAdam
from util.processor import *

from util.tokenization import *

from util.evaluation import *

import logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s', 
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)

from sklearn.metrics import classification_report

# this imports most of the helpers needed to eval the model
from run_classifier import *

lrp_data_dir = "../../results"
vocab_data_dir = "../../data/uncased_L-12_H-768_A-12/vocab.txt"
sys.path.append("..")
import operator

#### Set-ups

In [2]:
# Note that this notebook only supports single GPU evaluation
# which is sufficient for most of tasks by using lower batch size.
IS_CUDA = False
if IS_CUDA:
    CUDA_DEVICE = "cuda:5"
    device = torch.device(CUDA_DEVICE)
    n_gpu = torch.cuda.device_count()
    logger.info("device %s in total n_gpu %d distributed training", device, n_gpu)
else:
    # bad luck, we are on CPU now!
    logger.info("gpu is out of the picture, let us use CPU")
    device = torch.device("cpu")

10/30/2020 01:23:11 - INFO - run_classifier -   gpu is out of the picture, let us use CPU


#### Indicate your folders

In [3]:
TASK_NAME = "SST5"
DATA_DIR = "../../data/dataset/SST5/"
            
# "../../data/uncased_L-12_H-768_A-12/" is for the default BERT-base pretrain
BERT_PATH = "../../data/uncased_L-12_H-768_A-12/"
MODEL_PATH = "../../results/" + TASK_NAME + "/checkpoint.bin"
EVAL_BATCH_SIZE = 24 # you can tune this down depends on GPU you have.

# This loads the task processor for you.
processors = {
    "IMDb":IMDb_Processor,
    "SemEval":SemEval_Processor,
    "SST5":SST5_Processor,
    "SST2":SST2_Processor,
    "SST3":SST3_Processor,
    "Yelp5":Yelp5_Processor,
    "Yelp2":Yelp2_Processor,
    "AdvSA":AdvSA_Processor
}

processor = processors[TASK_NAME]()
label_list = processor.get_labels()

In [4]:
model, optimizer, tokenizer = \
    getModelOptimizerTokenizer(model_type="BERTPretrain",
                               vocab_file=BERT_PATH + "vocab.txt",
                               embed_file=None,
                               bert_config_file=BERT_PATH + "bert_config.json",
                               init_checkpoint=MODEL_PATH,
                               label_list=label_list,
                               do_lower_case=True,
                               # below is not required for eval
                               num_train_steps=20,
                               learning_rate=2e-5,
                               base_learning_rate=2e-5,
                               warmup_proportion=0.1,
                               init_lrp=True)
model = model.to(device) # send the model to device

10/30/2020 01:23:11 - INFO - run_classifier -   model = BERTPretrain


init_weight = True
init_lrp = True


In [5]:
test_examples = processor.get_test_examples(DATA_DIR)
test_features = \
    convert_examples_to_features(
        test_examples,
        label_list,
        512,
        tokenizer)

all_input_ids = torch.tensor([f.input_ids for f in test_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in test_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in test_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in test_features], dtype=torch.long)
all_seq_len = torch.tensor([[f.seq_len] for f in test_features], dtype=torch.long)

test_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                          all_label_ids, all_seq_len)
test_dataloader = DataLoader(test_data, batch_size=EVAL_BATCH_SIZE, shuffle=False)

 14%|█▍        | 315/2210 [00:00<00:00, 3137.66it/s]

0
guid= test-0
text_a= no movement , no yuks , not much of anything .
text_b= None
label= 1
1000
guid= test-1000
text_a= has all the poignancy of a hallmark card and all the comedy of a gallagher stand up act .
text_b= None
label= 2
2000
guid= test-2000
text_a= it 's still worth a look .
text_b= None
label= 3


100%|██████████| 2210/2210 [00:00<00:00, 2444.65it/s]


#### Call evaluation loop to get accuracy and attribution scores

In [None]:
# we did not exclude gradients, for attribution methods
model.eval() # this line will deactivate dropouts
test_loss, test_accuracy = 0, 0
nb_test_steps, nb_test_examples = 0, 0
pred_logits = []
actual = []

lrp_scores = []
inputs_ids = []
seqs_lens = []

# we don't need gradient in this case.
for step, batch in enumerate(tqdm(test_dataloader, desc="Iteration")):
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    input_ids, input_mask, segment_ids, label_ids, seq_lens = batch
    # truncate to save space and computing resource
    max_seq_lens = max(seq_lens)[0]
    input_ids = input_ids[:,:max_seq_lens]
    input_mask = input_mask[:,:max_seq_lens]
    segment_ids = segment_ids[:,:max_seq_lens]

    input_ids = input_ids.to(device)
    input_mask = input_mask.to(device)
    segment_ids = segment_ids.to(device)
    label_ids = label_ids.to(device)
    seq_lens = seq_lens.to(device)

    # intentially with gradient
    tmp_test_loss, logits = \
        model(input_ids, segment_ids, input_mask, seq_lens,
                device=device, labels=label_ids)

    # for lrp
    LRP_class = len(label_list) - 1
    Rout_mask = torch.zeros((input_ids.shape[0], len(label_list))).to(device)
    Rout_mask[:, LRP_class] = 1.0
    relevance_score = logits*Rout_mask
    lrp_score = model.backward_lrp(relevance_score).sum(dim=-1).cpu().data
    input_ids = input_ids.cpu().data
    seq_lens = seq_lens.cpu().data
    lrp_scores.append(lrp_score)
    inputs_ids.append(input_ids)
    seqs_lens.append(seq_lens)
    
    # for gradient
    
    # for attention only tracing
    
    
    logits = F.softmax(logits, dim=-1)
    logits = logits.detach().cpu().numpy()
    pred_logits.append(logits)
    label_ids = label_ids.to('cpu').numpy()
    actual.append(label_ids)
    outputs = np.argmax(logits, axis=1)
    tmp_test_accuracy=np.sum(outputs == label_ids)

    test_loss += tmp_test_loss.mean().item()
    test_accuracy += tmp_test_accuracy

    nb_test_examples += input_ids.size(0)
    nb_test_steps += 1
    
test_loss = test_loss / nb_test_steps
test_accuracy = test_accuracy / nb_test_examples

result = collections.OrderedDict()
result = {'test_loss': test_loss,
            str(len(label_list))+ '-class test_accuracy': test_accuracy}
logger.info("***** Eval results *****")
for key in result.keys():
    logger.info("  %s = %s\n", key, str(result[key]))
# get predictions needed for evaluation
pred_logits = np.concatenate(pred_logits, axis=0)
actual = np.concatenate(actual, axis=0)
pred_label = np.argmax(pred_logits, axis=-1)

lrp_state_dict = dict()
lrp_state_dict["lrp_scores"] = lrp_scores
lrp_state_dict["inputs_ids"] = inputs_ids
lrp_state_dict["seqs_lens"] = seqs_lens
logger.info("***** Finish LRP *****")

Iteration:  20%|██        | 19/93 [03:08<11:37,  9.42s/it]

#### Aggregated lrp scores on a token aggregated across a dataset

In [1]:
def load_lrp_states(task):
    lrp_state_dict = torch.load(os.path.join(lrp_data_dir + task + "/lrp_state.pt"))
    lrp_scores = lrp_state_dict["lrp_scores"]
    inputs_ids = lrp_state_dict["inputs_ids"]
    seqs_lens = lrp_state_dict["seqs_lens"]
    return lrp_scores, inputs_ids, seqs_lens

def inverse_mapping(vocab_dict):
    inverse_vocab_dict = {}
    for k, v in vocab_dict.items():
        inverse_vocab_dict[v] = k
    return inverse_vocab_dict

def translate(token_ids, vocab):
    tokens = []
    for _id in token_ids.tolist():
        tokens.append(vocab[_id])
    return tokens

SST-5

In [96]:
lrp_scores, inputs_ids, seqs_lens = load_lrp_states("SST5")
vocab = inverse_mapping(load_vocab(vocab_data_dir, pretrain=False))

In [102]:
word_lrp = {}
word_lrp_list = []
for batch_idx in range(len(inputs_ids)):
    for seq_idx in range(inputs_ids[batch_idx].shape[0]):
        seq_len = seqs_lens[batch_idx][seq_idx].tolist()[0]
        tokens = translate(inputs_ids[batch_idx][seq_idx], vocab)[:seq_len]
        lrp_ss = lrp_scores[batch_idx][seq_idx].tolist()[:seq_len]
        for i in range(len(tokens)):
            word_lrp_list.append((tokens[i], lrp_ss[i]))
            if tokens[i] in word_lrp.keys():
                word_lrp[tokens[i]].append(lrp_ss[i])
            else:
                word_lrp[tokens[i]] = [lrp_ss[i]]
filter_word_lrp = {}
for k, v in word_lrp.items():
    if len(v) > 0:
        filter_word_lrp[k] = sum(v)*1.0/len(v)
filter_word_lrp = [(k, v) for k, v in filter_word_lrp.items()] 
filter_word_lrp.sort(key = lambda x: x[1], reverse=True)  
word_lrp_list.sort(key = lambda x: x[1], reverse=True)  

In [98]:
word_lrp_list[:100]

[('alright', 1.019275188446045),
 ('undertaker', 0.8830116987228394),
 ('nokia', 0.8632627725601196),
 ('undertaker', 0.8587819933891296),
 ('wwe', 0.767245888710022),
 ('buffy', 0.7479455471038818),
 ('excited', 0.7397792935371399),
 ('thankful', 0.7372829914093018),
 ('mccartney', 0.7239643931388855),
 ('wwe', 0.7167426347732544),
 ('gonna', 0.6992533206939697),
 ('excited', 0.6982369422912598),
 ('nintendo', 0.6770433783531189),
 ('boyfriend', 0.6717503070831299),
 ('?', 0.6471765041351318),
 ('halloween', 0.6394405364990234),
 ('awesome', 0.6379751563072205),
 ('week', 0.6373948454856873),
 ('?', 0.6347154378890991),
 ('grateful', 0.6337091326713562),
 ('maiden', 0.6289215087890625),
 ('colbert', 0.6279644966125488),
 ('yeah', 0.6249728798866272),
 ('springsteen', 0.6225460767745972),
 ('homecoming', 0.6137117147445679),
 ('rapper', 0.6127239465713501),
 ('aston', 0.6041433811187744),
 ('nintendo', 0.6025475859642029),
 ('rebirth', 0.5997878313064575),
 ('shawn', 0.5962246656417847

In [99]:
word_lrp_list[-100:]

[('taylor', -0.3367699682712555),
 ('amanda', -0.3370693325996399),
 ('tomorrow', -0.33790042996406555),
 (':', -0.3379462957382202),
 ('reminds', -0.3385217785835266),
 ('anything', -0.34000062942504883),
 ('went', -0.34044402837753296),
 ('shit', -0.34194329380989075),
 ('fuck', -0.3427245020866394),
 ('yankee', -0.3431705832481384),
 ('.', -0.34351444244384766),
 ('watching', -0.34353217482566833),
 ('[CLS]', -0.34413909912109375),
 ('gonna', -0.3455921411514282),
 ('thanksgiving', -0.34575310349464417),
 ('niall', -0.34684664011001587),
 ('ready', -0.3488250970840454),
 ('mentally', -0.35005658864974976),
 ('something', -0.35091954469680786),
 ('riots', -0.35354745388031006),
 ('barely', -0.35374513268470764),
 ('divorce', -0.3550001382827759),
 ('i', -0.35557347536087036),
 ('will', -0.3558569550514221),
 ('my', -0.3571268916130066),
 ('nba', -0.3614344596862793),
 ('.', -0.3616555631160736),
 ('.', -0.36282384395599365),
 ('.', -0.36341166496276855),
 ('\\', -0.3635162115097046),

In [103]:
filter_word_lrp[:20]

[('buffy', 0.7479455471038818),
 ('thankful', 0.7372829914093018),
 ('aston', 0.6041433811187744),
 ('rebirth', 0.5997878313064575),
 ('monopoly', 0.5807051658630371),
 ('parking', 0.5708843469619751),
 ('grossing', 0.5660892724990845),
 ('shortlisted', 0.5579737424850464),
 ('rehearsals', 0.5478571057319641),
 ('springsteen', 0.5440821647644043),
 ('analysis', 0.543275773525238),
 ('scriptures', 0.5288156270980835),
 ('concerts', 0.5242246985435486),
 ('mummy', 0.5122504830360413),
 ('albums', 0.5100706815719604),
 ('daytona', 0.49747034907341003),
 ('happier', 0.49571555852890015),
 ('sipping', 0.4943842887878418),
 ('parramatta', 0.49320822954177856),
 ('buzzing', 0.49006736278533936)]

In [104]:
filter_word_lrp[-20:]

[('recognise', -0.32276779413223267),
 ('instinct', -0.3252364993095398),
 ('bjp', -0.3279317418734233),
 ('speech', -0.3292480707168579),
 ('thru', -0.32933974266052246),
 ('varsity', -0.3318825960159302),
 ('juventus', -0.3321632742881775),
 ('ties', -0.33269569277763367),
 ('reminds', -0.3385217785835266),
 ('yankee', -0.3431705832481384),
 ('pga', -0.3462705910205841),
 ('mentally', -0.35005658864974976),
 ('barely', -0.35374513268470764),
 ('shankar', -0.3691314160823822),
 ('testified', -0.3721970021724701),
 ('attributed', -0.4401932656764984),
 ('suicide', -0.4415339529514313),
 ('pregnant', -0.44782427698373795),
 ('dodgers', -0.448966383934021),
 ('contestant', -0.46337562799453735)]