# Head Selection
**_BERT_** is a **_Multi-layer_ _Multi-Head_** Transformer architecture. As discuss in many of the current reseachers, different Attention heads captures different lingustic patterns. For a better deletion of words using Attention mechanism we need to choose a head which **captures pattern useful for classification.**

To do this we are using a Brute force mechanism to seach through all the possible heads. We are deleting TopK words attended by different heads from the sentence and measuring the new classification score. In case of sentiments, removing sentiments related words makes the sentence neutral. The heads are sorted by the amount to which it is able to make the sentences from dev set to Neutral.

In [1]:
import csv
import logging
import os
import random
import sys
import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
#from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear

from bertviz.bertviz import attention, visualization
from bertviz.bertviz.pytorch_pretrained_bert import BertModel, BertTokenizer

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.
Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [2]:
logger = logging.getLogger(__name__)

#TDRG paper
#bert_classifier_model_dir = "./bert_classifier/" ## Path of BERT classifier model path

bert_classifier_model_dir = './data/yelp/bert_classifier_3epochs/'

# Lipton
#bert_classifier_model_dir = "./data/lipton/sentiment/orig/bert_classifier_10epochs8b_490seqlen/"
#eval_accuracy = 0.9102040816326531  and  eval_loss = 0.35673839559838655  

bert_classifier_model_dir = "./data/lipton/sentiment/orig/bert_classifier_100epochs8b_490seqlen/"  #Apr 24
#eval_accuracy = 0.8979591836734694   and eval_loss = 0.9967757850885384           # Try with this one <--


# Image Caption
bert_classifier_model_dir = "./data/imagecaption/bert_classifier_10epochs/"

#Amazon 
bert_classifier_model_dir = "./data/amazon/bert_classifier_3epochs/"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
logger.info("device: {}, n_gpu {}".format(device, n_gpu))

torch.cuda.set_device(3)

05/16/2020 23:29:28 - INFO - __main__ -   device: cuda, n_gpu 4


In [3]:
## Model for performing Classification
model_cls = BertForSequenceClassification.from_pretrained(bert_classifier_model_dir, num_labels=2)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
model_cls.to(device)
model_cls.eval()

05/16/2020 23:29:31 - INFO - pytorch_pretrained_bert.modeling -   loading archive file ./data/amazon/bert_classifier_3epochs/
05/16/2020 23:29:31 - INFO - pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

05/16/2020 23:29:34 - INFO - bertviz.bertviz.pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/diego/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
    

In [4]:
## Model to get the attention weights of all the heads
model = BertModel.from_pretrained(bert_classifier_model_dir)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
model.to(device)
model.eval()

05/16/2020 23:29:44 - INFO - bertviz.bertviz.pytorch_pretrained_bert.modeling -   loading archive file ./data/amazon/bert_classifier_3epochs/
05/16/2020 23:29:44 - INFO - bertviz.bertviz.pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

05/16/2020 23:29:45 - INFO - bertviz.bertviz.pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/diego/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): BertLayerNorm()
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features

In [5]:
max_seq_len=70 # Maximum sequence length  for TDRG / IMAGE CAPTION / AMAZON

#max_seq_len=490   # for Lipton Lipton
sm = torch.nn.Softmax(dim=-1) ## Softmax over the batch

In [6]:
def run_multiple_examples(input_sentences, bs=32):
    """
    This fucntion returns classification predictions for batch of sentences.
    input_sentences: list of strings
    bs : batch_size : int
    """
    
    ## Prepare data for classification
    ids = []
    segment_ids = []
    input_masks = []
    pred_lt = []
    for sen in input_sentences:
        text_tokens = tokenizer.tokenize(sen)
        tokens = ["[CLS]"] + text_tokens + ["[SEP]"]
        temp_ids = tokenizer.convert_tokens_to_ids(tokens)
        input_mask = [1] * len(temp_ids)
        segment_id = [0] * len(temp_ids)
        padding = [0] * (max_seq_len - len(temp_ids))

        temp_ids += padding
        input_mask += padding
        segment_id += padding
        
        ids.append(temp_ids)
        input_masks.append(input_mask)
        segment_ids.append(segment_id)
    
    ## Convert input lists to Torch Tensors
    ids = torch.tensor(ids)
    segment_ids = torch.tensor(segment_ids)
    input_masks = torch.tensor(input_masks)
    
    steps = len(ids) // bs
    
    for i in range(steps+1):
        if i == steps:
            temp_ids = ids[i * bs : len(ids)]
            temp_segment_ids = segment_ids[i * bs: len(ids)]
            temp_input_masks = input_masks[i * bs: len(ids)]
        else:
            temp_ids = ids[i * bs : i * bs + bs]
            temp_segment_ids = segment_ids[i * bs: i * bs + bs]
            temp_input_masks = input_masks[i * bs: i * bs + bs]
        
        temp_ids = temp_ids.to(device)
        temp_segment_ids = temp_segment_ids.to(device)
        temp_input_masks = temp_input_masks.to(device)
        
        with torch.no_grad():
            preds = sm(model_cls(temp_ids, temp_segment_ids, temp_input_masks))
        pred_lt.extend(preds.tolist())
    
    return pred_lt

In [7]:
def read_file(path,size):
    with open(path) as fp:
        data = fp.read().splitlines()[:size]
    return data

In [8]:
def get_attention_for_batch(input_sentences, bs=32):
    """
    This function calculates attention weights of all the heads and
    returns it along with the encoded sentence for further processing.
    
    input sentence: list of strings
    bs : batch_size
    """
    
    ## Preprocessing for BERT 
    ids = []
    segment_ids = []
    input_masks = []
    pred_lt = []
    ids_for_decoding = []
    for sen in input_sentences:
        text_tokens = tokenizer.tokenize(sen)
        tokens = ["[CLS]"] + text_tokens + ["[SEP]"]
        temp_ids = tokenizer.convert_tokens_to_ids(tokens)
        
        input_mask = [1] * len(temp_ids)
        segment_id = [0] * len(temp_ids)
        padding = [0] * (max_seq_len - len(temp_ids))
        
        ids_for_decoding.append(tokenizer.convert_tokens_to_ids(tokens))
        temp_ids += padding
        input_mask += padding
        segment_id += padding
        
        ids.append(temp_ids)
        input_masks.append(input_mask)
        segment_ids.append(segment_id)
        
    ## Convert the list of int ids to Torch Tensors
    ids = torch.tensor(ids)
    segment_ids = torch.tensor(segment_ids)
    input_masks = torch.tensor(input_masks)
    
    steps = len(ids) // bs
    
    for i in trange(steps+1):
        if i == steps:
            temp_ids = ids[i * bs : len(ids)]
            temp_segment_ids = segment_ids[i * bs: len(ids)]
            temp_input_masks = input_masks[i * bs: len(ids)]
        else:
            temp_ids = ids[i * bs : i * bs + bs]
            temp_segment_ids = segment_ids[i * bs: i * bs + bs]
            temp_input_masks = input_masks[i * bs: i * bs + bs]
        
        temp_ids = temp_ids.to(device)
        temp_segment_ids = temp_segment_ids.to(device)
        temp_input_masks = temp_input_masks.to(device)
        
        with torch.no_grad():
            _, _, attn = model(temp_ids, temp_segment_ids, temp_input_masks)
        
        # Add all the Attention Weights to CPU memory
        # Attention weights for each layer is stored in a dict 'attn_prob'
        for k in range(12):
            attn[k]['attn_probs'] = attn[k]['attn_probs'].to('cpu')
        
        '''
        attention weights are stored in this way:
        att_lt[layer]['attn_probs']['input_sentence']['head']['length_of_sentence']
        '''
        # Concate Attention weights for all the examples in the list att_lt[layer_no]['attn_probs']
        
        if i == 0:
            att_lt = attn
            heads = len(att_lt)
        else:
            for j in range(heads):
                att_lt[j]['attn_probs'] = torch.cat((att_lt[j]['attn_probs'],attn[j]['attn_probs']),0)
        
    
    return att_lt, ids_for_decoding

In [9]:
def process_sentences(input_sentences, att, decoding_ids, threshold=0.25):
    """
    This function processes each input sentence by removing the top tokens defined threshold value.
    Each sentence is processed for each head.
    
    input_ids: list of strings
    decoding_ids: indexed input_sentnces thus len(input_sentences) == len(decoding_ids)
    threshold: Percentage of the top indexes to be removed
    """
    # List of None of num_of_layers * num_of_heads to save the results of each head for input_sentences
    
    lt = [None for x in range(len(att) * len(att[0]['attn_probs'][0]))]
    #print(len(lt))
    
    inx = 0
    for i in trange(len(att)): #  For all the layers
        for j in range(len(att[i]['attn_probs'][0])): # For all the heads in the ith Layer
            processed_sen = [None for q in decoding_ids] # List of len(decoding_ids)
            for k in range(len(input_sentences)): # For all the senteces 
                _, topi = att[i]['attn_probs'][k][j][0].topk(len(decoding_ids[k])) # Get top attended ids
                topi = topi.tolist()
                topi = topi[:int(len(topi) * threshold)] 
                ## Decode the sentece after removing the topk indexes
                final_indexes = []
                count = 0
                count1 = 0
                tokens = ["[CLS]"] + tokenizer.tokenize(input_sentences[k]) + ["[SEP]"]
                while count < len(decoding_ids[k]):
                    if count in topi: # Remove index if present in topk
                        while (count + count1 + 1) < len(decoding_ids[k]):
                            if "##" in tokens[count + count1 + 1]:
                                count1 += 1
                            else:
                                break
                        count += count1
                        count1 = 0
                    else: # Else add to the decoded sentence
                        final_indexes.append(decoding_ids[k][count])
                    count += 1
                tmp = tokenizer.convert_ids_to_tokens(final_indexes) # Convert ids to token
                # Convert toknes to sentence
                processed_sen[k] = " ".join(tmp).replace(" ##", "").replace("[CLS]","").replace("[SEP]","").strip()
            lt[inx] = processed_sen # Store sentences for inxth head
            inx += 1
    
    return lt

In [10]:
def get_block_head(processed_sentence_list, lmbd = 0.1):
    """
    This function calculate classification scores for sentences generated by each head
    and sort them from best to worst.
    score = min(pred) + lmbd / max(pred) + lmbd, lmbd is smoothing param
    pred is list of probability score for each class, for best case pred = [0.5, 0.5] ==> score = 1
    
    it returns sorted list of (Layer, Head, Score)
    """
    scores = {}
    #scores_1 = {}
    for i in trange(len(processed_sentence_list)): # sentences by each head
        pred = np.array(run_multiple_examples(processed_sentence_list[i]))
        scores[i] = np.mean([(min(x[0], x[1])+lmbd)/(max(x[0], x[1])+lmbd) for x in pred])
        #scores_1[i] = np.mean([abs(max(x[0],x[1]) - min(x[0],x[1])) for x in pred])
    temp = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)
    #temp1 = sorted(scores_1.items(), key=lambda kv: kv[1], reverse=False)
    score_lt = [(x // 12, x - (12 * (x // 12)),y) for x,y in temp]
    #score1_lt = [(x // 12, x - (12 * (x // 12)),y) for x,y in temp1]
    return score_lt  #score1_lt

In [11]:
#TDRG
#pos_examples_file = "/home/ubuntu/bhargav/data/yelp/sentiment_dev_1.txt"
#neg_examples_file = "/home/ubuntu/bhargav/data/yelp/sentiment_dev_0.txt"

#YELP
pos_examples_file = "./data/yelp/sentiment_dev_1.txt"
neg_examples_file = "./data/yelp/sentiment_dev_0.txt"

#IMAGE CAPTION
pos_examples_file = "./data/imagecaption/sentiment_dev_1.txt"
neg_examples_file = "./data/imagecaption/sentiment_dev_0.txt"

#Amazon
pos_examples_file = "./data/amazon/sentiment_dev_1.txt"
neg_examples_file = "./data/amazon/sentiment_dev_0.txt"


'''
100 examples from each class worked good, the bottlenack is the run_multiple_examples() function,
with higher memory (either with cpu of gpu) one can reduce the processing time by incresing batch_size.
With batch_size of 32 it takes around 24 mins for 100 example on cpu.
'''

pos_data = read_file(pos_examples_file,100)
neg_data = read_file(neg_examples_file,100)
data = pos_data + neg_data

In [11]:
#for LIPTON
#get 100 pos and 100 neg examples at random from dev file
import pandas as pd
datadir = "data/lipton/sentiment/orig/"
devfile = os.path.join(datadir,"dev.tsv")
dev = pd.read_table(devfile,sep="\t")
pos_data, neg_data = [], []
inds = [i for i in range(dev.shape[0])]
random.shuffle(inds)
for i in inds:
    line = dev.Text.values[i]
    cur_label = 0 if dev.Sentiment.values[i] == "Negative" else 1
    if cur_label == 0:
        if len(neg_data) < 100:
            neg_data.append(line)
    else:
        if len(pos_data) < 100:
            pos_data.append(line)
            
    if len(neg_data) == 100 and len(pos_data) == 100:
        break
data = pos_data + neg_data

In [12]:
print(len(pos_data), len(neg_data), len(data))

100 100 200


In [13]:
# they implicitly use batch_size of 32 ( whereas we trained with half that)
#att, decoding_ids = get_attention_for_batch(data)       #tdrg yelp

#att, decoding_ids = get_attention_for_batch(data)       #imagecaption 

att, decoding_ids = get_attention_for_batch(data, 16)  #lipton / amazon


100%|██████████| 13/13 [00:01<00:00, 10.47it/s]


In [15]:
# they use a threshhold of .25%   
# threshold: Percentage of the top indexes to be removed   <--- tHIS MIGHT BE TOO HIGH

# FOR YELP/ IMAGE CAPTION / AMAZON

sen_list = process_sentences(data, att, decoding_ids,threshold=0.25)  #yelp and image caption
scores = get_block_head(sen_list)

100%|██████████| 12/12 [00:06<00:00,  1.89it/s]
100%|██████████| 144/144 [01:10<00:00,  2.03it/s]


In [28]:
# they use a threshhold of .25%   
# threshold: Percentage of the top indexes to be removed   <--- tHIS MIGHT BE TOO HIGH

#FOR LIPTON

#sen_list = process_sentences(data, att, decoding_ids,threshold=0.25)  #yelp

#sen_list_10 = process_sentences(data, att, decoding_ids,threshold=0.10)
sen_list_15 = process_sentences(data, att, decoding_ids,threshold=0.15)
sen_list_20 = process_sentences(data, att, decoding_ids,threshold=0.20)
sen_list_30 = process_sentences(data, att, decoding_ids,threshold=0.30)

#try with different thresholds !  .1, .15, .2, .25, .3      

100%|██████████| 12/12 [01:16<00:00,  6.38s/it]
100%|██████████| 12/12 [01:17<00:00,  6.42s/it]
100%|██████████| 12/12 [01:18<00:00,  6.54s/it]


In [39]:
#List num_of_layers * num_of_heads of each layer/head of input_sentences with top 25% attributes removed
print(len(sen_list),len(sen_list[0]))     # 12 x 12  x  200 sentences
i = 100
print("Original")
print(len(data[i]), data[i])           #original
print("\nThreshold 10")
print(len(sen_list_10[i][i]),sen_list_10[i][i])   #processed first layer and first head of first sentence 
print("\n Threshold 15")
print(len(sen_list_15[i][i]),sen_list_15[i][i])   #processed first layer and first head of first sentence 
print("\n Threshold 20")
print(len(sen_list_20[i][i]),sen_list_20[i][i])   #processed first layer and first head of first sentence 
print("\n Threshold 25")
print(len(sen_list[i][i]),sen_list[i][i])   #processed first layer and first head of first sentence 
print("\n Threshold 30")
print(len(sen_list_30[i][i]),sen_list_30[i][i])   #processed first layer and first head of first sentence 

#print(len(sen_list[143][0]),sen_list[143][0]) #processed last layer and last head of first sentence

144 200
Original
1069 First of all, the reason I'm giving this film 2 stars instead of 1 is because at least Peter Falk gave his usual fantastic performance as Lieutenant Columbo. He alone can get 10 stars for trying to save this otherwise utterly worthless attempt at making a movie.<br /><br />I was initially all fired up at reading one poster's comment that Andrew Stevens in this movie gave "the performance of his career." To me, it was the abysmal performance by Stevens that absolutely ruined this movie, and so I was all prepared to hurl all sorts of insults at the person who made the aforementioned comment. Then I thought to myself, what else has Stevens done? So I checked and, you know, that person was absolutely right. In the 17 years since this Columbo movie was made, apparently every one of the 33 projects that Stevens has been in since then has been utter crap, so it is doubtful that anybody has even seen the rest of his career.<br /><br />If you like Columbo, see every other 

In [40]:
#scores = get_block_head(sen_list)
#scores_10 = get_block_head(sen_list_10)
scores_15 = get_block_head(sen_list_15)
scores_20 = get_block_head(sen_list_20)
scores_30 = get_block_head(sen_list_30)

100%|██████████| 144/144 [10:52<00:00,  4.53s/it]
100%|██████████| 144/144 [10:56<00:00,  4.56s/it]
100%|██████████| 144/144 [10:47<00:00,  4.50s/it]


In [27]:
print(scores_10[0:10])  #10 thresh
#  [(8, 11, 0.17688744312860485), (8, 3, 0.16646452595741665), (8, 4, 0.16536730077420828), (8, 7, 0.16113962145781294), (8, 0, 0.15593628422752848), (8, 2, 0.15470411915568807), (9, 0, 0.15140050671408953), (7, 5, 0.15072199877638995), (8, 8, 0.14986440507830787), (8, 9, 0.14957368786096775)]

print(scores_15[0:10])  #15 thresh

print(scores_20[0:10])  #20 thresh

print(scores[0:10])  #25 thresh
#  [(8, 8, 0.20559701026283966), (8, 11, 0.19127006203151573), (8, 9, 0.18814623826875007), (11, 8, 0.18692809688409256), (7, 5, 0.18353356225877854), (9, 9, 0.18059253591562302), (7, 9, 0.17818772485683795), (6, 5, 0.1774358329795011), (9, 4, 0.17600542364521263), (9, 6, 0.17376864976801762)]

print(scores_30[0:10])  #30 thresh


[(8, 8, 0.20559701026283966), (8, 11, 0.19127006203151573), (8, 9, 0.18814623826875007), (11, 8, 0.18692809688409256), (7, 5, 0.18353356225877854), (9, 9, 0.18059253591562302), (7, 9, 0.17818772485683795), (6, 5, 0.1774358329795011), (9, 4, 0.17600542364521263), (9, 6, 0.17376864976801762)]
[(8, 11, 0.17688744312860485), (8, 3, 0.16646452595741665), (8, 4, 0.16536730077420828), (8, 7, 0.16113962145781294), (8, 0, 0.15593628422752848), (8, 2, 0.15470411915568807), (9, 0, 0.15140050671408953), (7, 5, 0.15072199877638995), (8, 8, 0.14986440507830787), (8, 9, 0.14957368786096775)]


In [17]:
#for Yelp TDRG   <-- do to show its weakness   #I got layer 8 and head 1    ( but still attribution is pretty uniform)

#scores   
#[(8, 1, 0.23903657674356474), (3, 2, 0.22334244144017074), (11, 5, 0.22290891025170304), (6, 0, 0.21972539246421388), (4, 2, 0.21848756769808247), (10, 11, 0.21640071453619267), (5, 4, 0.20899156231361693), (4, 1, 0.2067413901469147), (9, 7, 0.20085241894907047), (4, 9, 0.1969363758990871), (6, 5, 0.19436958617237984), (10, 0, 0.1941946590678377), (5, 7, 0.19313810872086903), (8, 4, 0.19215605317168943), (8, 8, 0.18770924721348628), (6, 8, 0.18629729823007607), (8, 9, 0.18483354081234135), (9, 8, 0.18213737216825984), (10, 10, 0.18066124945832143), (6, 2, 0.17955855985333163), (9, 5, 0.1782620783663914), (7, 5, 0.17782650826477636), (8, 5, 0.1750809159697408), (8, 6, 0.17116546321875653), (8, 0, 0.17101999169715307), (6, 7, 0.17075152030632992), (9, 10, 0.16983285936185574), (9, 3, 0.16933621140176697), (6, 4, 0.1656612731118689), (8, 11, 0.16559293441524148), (4, 11, 0.16244136493667177), (7, 0, 0.15778741111059338), (11, 2, 0.15771811491892582), (11, 9, 0.1541693440259734), (10, 4, 0.15182641855438994), (8, 10, 0.1494333893267871), (6, 6, 0.1486183062584987), (4, 6, 0.1485121350529015), (6, 11, 0.14616055669578518), (4, 4, 0.14595315745616766), (5, 3, 0.1452600371389312), (8, 2, 0.14207932213296157), (11, 0, 0.14124012884155004), (10, 6, 0.14113836559037662), (11, 11, 0.14067368798987784), (4, 7, 0.14035772913007338), (8, 3, 0.1401644358154404), (11, 7, 0.14006396820906786), (11, 1, 0.13950213094019337), (11, 3, 0.13851649567255905), (6, 1, 0.1370440473100292), (10, 2, 0.1366707844055332), (6, 10, 0.13609105266020555), (5, 2, 0.1358202019107529), (11, 8, 0.1349294833107715), (10, 1, 0.13485114328135256), (10, 7, 0.13469741679122532), (3, 6, 0.13424493865590464), (10, 3, 0.1340317730167978), (4, 3, 0.13366829219672652), (7, 9, 0.1327338832168681), (5, 8, 0.1320738977571235), (7, 4, 0.1318912914343732), (5, 5, 0.13175194669565946), (10, 8, 0.13168917493753318), (11, 4, 0.13045016237876791), (7, 2, 0.12974020398731256), (9, 0, 0.12820001400939154), (5, 10, 0.12779472037339562), (9, 2, 0.12630340980264781), (9, 4, 0.12614358037223602), (7, 10, 0.126004764279385), (7, 7, 0.1256395774033541), (5, 11, 0.1251331822420115), (9, 1, 0.1247822916093963), (2, 7, 0.12417690977862067), (7, 8, 0.12323587304796586), (10, 5, 0.12297736332054889), (9, 9, 0.12256917930857014), (5, 0, 0.12246831494896987), (3, 4, 0.12190025795034218), (7, 1, 0.12138096268780839), (0, 4, 0.1210144647472587), (3, 5, 0.12096072160621732), (2, 3, 0.12092496623638248), (11, 10, 0.12090430123780756), (1, 1, 0.12058078959421911), (0, 0, 0.11987742761091895), (9, 11, 0.11983251557584347), (7, 6, 0.11910646620921422), (4, 0, 0.11902995272997739), (3, 10, 0.11876368716113682), (2, 2, 0.1185443317630272), (4, 8, 0.11833797790837855), (6, 9, 0.11749092266515027), (3, 8, 0.11709475629864158), (1, 9, 0.11686831682050483), (2, 5, 0.11683294793499945), (8, 7, 0.11662213699303678), (7, 11, 0.1164310975125378), (0, 9, 0.11583271543459653), (7, 3, 0.11561447322148423), (5, 1, 0.11541527527890719), (6, 3, 0.11539918556801507), (3, 7, 0.11492908912959962), (2, 9, 0.11457794292701412), (4, 5, 0.11406652893591902), (0, 3, 0.11347259059401094), (1, 6, 0.11326190369935557), (10, 9, 0.1128458962146832), (4, 10, 0.1128009999274843), (1, 2, 0.11249556041494384), (3, 11, 0.11240006745644657), (3, 9, 0.11187086557314938), (0, 8, 0.11185795788603305), (2, 1, 0.11184227421854402), (5, 9, 0.1117949252098748), (0, 10, 0.11125839925549841), (2, 6, 0.11120353018078827), (0, 1, 0.11110620261502906), (1, 10, 0.11078926495242485), (2, 0, 0.11056452072223104), (1, 11, 0.11055677233685371), (0, 11, 0.11050952970741806), (3, 3, 0.11047849879055233), (0, 5, 0.11021734344122748), (3, 1, 0.11018946886356826), (2, 10, 0.10962099298939502), (0, 7, 0.10878052612959854), (3, 0, 0.10862221266936933), (1, 0, 0.10836133631894736), (2, 11, 0.10828972428341828), (1, 8, 0.10799836644179461), (0, 2, 0.10793480473266001), (9, 6, 0.10747256422945496), (1, 7, 0.10742426305743508), (2, 4, 0.10717004468036893), (1, 4, 0.10713829798493762), (5, 6, 0.10712936552946001), (1, 3, 0.10662285031537275), (2, 8, 0.10644658045626329), (1, 5, 0.10581179695692038), (11, 6, 0.10552712113280183), (0, 6, 0.10443320219533686)]
print("")

[(8, 1, 0.23903657674356474),
 (3, 2, 0.22334244144017074),
 (11, 5, 0.22290891025170304),
 (6, 0, 0.21972539246421388),
 (4, 2, 0.21848756769808247),
 (10, 11, 0.21640071453619267),
 (5, 4, 0.20899156231361693),
 (4, 1, 0.2067413901469147),
 (9, 7, 0.20085241894907047),
 (4, 9, 0.1969363758990871),
 (6, 5, 0.19436958617237984),
 (10, 0, 0.1941946590678377),
 (5, 7, 0.19313810872086903),
 (8, 4, 0.19215605317168943),
 (8, 8, 0.18770924721348628),
 (6, 8, 0.18629729823007607),
 (8, 9, 0.18483354081234135),
 (9, 8, 0.18213737216825984),
 (10, 10, 0.18066124945832143),
 (6, 2, 0.17955855985333163),
 (9, 5, 0.1782620783663914),
 (7, 5, 0.17782650826477636),
 (8, 5, 0.1750809159697408),
 (8, 6, 0.17116546321875653),
 (8, 0, 0.17101999169715307),
 (6, 7, 0.17075152030632992),
 (9, 10, 0.16983285936185574),
 (9, 3, 0.16933621140176697),
 (6, 4, 0.1656612731118689),
 (8, 11, 0.16559293441524148),
 (4, 11, 0.16244136493667177),
 (7, 0, 0.15778741111059338),
 (11, 2, 0.15771811491892582),
 (11, 

In [21]:
# FOR IMAGE CAPTION
for s in scores:
    print(s)

(8, 4, 0.15872124043746288)
(8, 7, 0.15313005978987224)
(8, 11, 0.151654679255837)
(10, 4, 0.15127377859204544)
(10, 9, 0.1511029737080658)
(10, 11, 0.15107000490937236)
(6, 2, 0.14955530722068505)
(4, 1, 0.14928748884398232)
(7, 0, 0.14906258472175327)
(7, 11, 0.1475579435228811)
(9, 7, 0.14500589475632758)
(2, 2, 0.14450277782390422)
(8, 8, 0.14346186616448636)
(10, 1, 0.1426556151733003)
(6, 1, 0.14251394857728616)
(6, 0, 0.14246976886282084)
(4, 2, 0.1424644795317995)
(7, 10, 0.14204145834056617)
(9, 5, 0.14183040851051623)
(6, 7, 0.14145037448127487)
(9, 10, 0.14062761350387756)
(9, 2, 0.13944930105164943)
(8, 10, 0.1394157092380312)
(4, 0, 0.1393841124805446)
(9, 0, 0.1393034407142085)
(10, 5, 0.13876747507266696)
(6, 10, 0.13792932115423973)
(8, 6, 0.1378217878827729)
(6, 11, 0.13779677020917627)
(5, 1, 0.13614893997016472)
(5, 8, 0.13605227816702284)
(9, 8, 0.13602680569528341)
(6, 4, 0.13581095497155543)
(8, 3, 0.13579601390585622)
(8, 0, 0.13576425923300012)
(8, 9, 0.13516717

In [16]:
# FOR AMAZON  
for s in scores:
    print(s)

(9, 5, 0.2821267870197217)
(9, 11, 0.2804659528610392)
(4, 1, 0.2702826391652081)
(4, 2, 0.2686453297203014)
(7, 5, 0.2558982372686172)
(10, 1, 0.2548711599639146)
(3, 6, 0.2547133216759923)
(6, 4, 0.25457741595840155)
(10, 6, 0.2540983151499425)
(8, 10, 0.2530348236594129)
(9, 0, 0.25153850235432024)
(9, 7, 0.25086725930265563)
(9, 8, 0.2497266541711812)
(11, 11, 0.24868830116149063)
(5, 2, 0.24625265297750812)
(11, 0, 0.24587316106601706)
(8, 1, 0.2458382475522489)
(6, 7, 0.2434262283835387)
(6, 0, 0.24253808384613174)
(8, 2, 0.24121870636911416)
(11, 7, 0.24070947092925374)
(3, 4, 0.2403634657335317)
(11, 5, 0.2393964068587743)
(11, 3, 0.23590736514637933)
(4, 6, 0.23542582390969882)
(7, 4, 0.23472813429487474)
(10, 8, 0.2344309689666082)
(11, 1, 0.23331967464094852)
(2, 5, 0.2328445136874893)
(3, 2, 0.23235173020481859)
(10, 7, 0.23229496479943998)
(10, 2, 0.2318160123320329)
(8, 4, 0.23155748020731803)
(7, 6, 0.23106784687559856)
(10, 11, 0.22818983501086143)
(8, 11, 0.22798889215

In [18]:
# for LIPTON  .. this took around 11 minutes using ./data/lipton/sentiment/orig/bert_classifier_10epochs8b_490seqlen/
#scores    #according to this layer 9 and head 5 gives the most but its pretty uniform .. hmmm.. not super convincing ( why not an ensemble then ? )

#[(9, 5, 0.22910906945266027), (9, 4, 0.20670482156851616), (11, 5, 0.20384664328140217), (9, 9, 0.2027791205407322), (8, 4, 0.19366057264682351), (8, 3, 0.1929708452634522), (10, 11, 0.1924844757337584), (8, 0, 0.1912139770748914), (10, 8, 0.19096766338766102), (8, 1, 0.19064305221106376), (11, 7, 0.18917525337332117), (11, 0, 0.18796913047231342), (10, 2, 0.1857773846771663), (8, 8, 0.18555147862865973), (10, 0, 0.18530059177637234), (11, 2, 0.18495542770861484), (9, 2, 0.18475037539261013), (9, 0, 0.1841438644026269), (11, 4, 0.18140669805587734), (10, 9, 0.18016958380296674), (9, 1, 0.1799556444681705), (8, 11, 0.1797203777121738), (10, 5, 0.17958857143178505), (8, 7, 0.17886423133516335), (11, 3, 0.17881788652643082), (10, 4, 0.17843035955545758), (10, 6, 0.17813508280376536), (11, 6, 0.17802176577302084), (9, 10, 0.17787153222013147), (8, 5, 0.17732895140520186), (9, 3, 0.17613587979399356), (11, 1, 0.17594829998328604), (10, 10, 0.17578660223763148), (8, 9, 0.17564796681179257), (9, 7, 0.17563120922211783), (8, 2, 0.1753101101461102), (8, 10, 0.1747312199274393), (7, 4, 0.16945213119226252), (5, 7, 0.16936174418069228), (6, 10, 0.16795158412774336), (7, 6, 0.16649419325832185), (5, 4, 0.1651518529581891), (7, 9, 0.16469694778258695), (7, 3, 0.16452084200970687), (9, 6, 0.16354265678158442), (10, 3, 0.16331819279395884), (8, 6, 0.16331625466434685), (9, 8, 0.16327127028055513), (10, 7, 0.163098795384727), (7, 2, 0.1626278604999319), (11, 10, 0.16233153977410752), (4, 6, 0.16214737843926777), (11, 11, 0.16211695020543535), (7, 0, 0.16089281622059912), (7, 7, 0.16081335698132243), (7, 8, 0.1589754294115524), (9, 11, 0.1589265489047186), (11, 8, 0.15852928487757295), (7, 10, 0.15707820738638756), (5, 1, 0.15687621447472652), (10, 1, 0.15515437741163898), (2, 3, 0.1551460585797901), (6, 0, 0.1548807110553279), (7, 1, 0.15395799252569156), (7, 5, 0.15229415791014433), (4, 2, 0.15186431742637152), (5, 0, 0.15157253037355953), (6, 5, 0.15076005645689922), (6, 7, 0.15074430937346742), (5, 10, 0.15010913817566668), (5, 11, 0.15007268340070756), (3, 8, 0.15007106649191967), (5, 8, 0.14975564791320012), (4, 4, 0.1490504683894608), (5, 5, 0.149014187066919), (2, 5, 0.14814437405718062), (5, 3, 0.1477559440078062), (4, 3, 0.14737676824214005), (6, 2, 0.14630223368638817), (6, 4, 0.14562702254610183), (4, 0, 0.14541923635179313), (2, 11, 0.14486951594498368), (11, 9, 0.14390854882222753), (2, 2, 0.14186954191469386), (6, 8, 0.1415233051499205), (3, 6, 0.1412151891662137), (6, 11, 0.13976853577873255), (6, 9, 0.13964447988182316), (1, 2, 0.13864583210237966), (5, 2, 0.1377782747208492), (2, 10, 0.13716651926941828), (6, 1, 0.13680782823896354), (7, 11, 0.13541384084776464), (5, 6, 0.1332300926002088), (4, 11, 0.13282151345939952), (6, 6, 0.13196132172407055), (4, 8, 0.1315829140990315), (3, 10, 0.13112511141269376), (1, 7, 0.13011743578817267), (6, 3, 0.12996468854288415), (5, 9, 0.12973969064601398), (2, 7, 0.12912978677195777), (0, 4, 0.12901492519384994), (0, 8, 0.12812594403523836), (3, 2, 0.12743444680710783), (1, 9, 0.12738812246961215), (1, 6, 0.12680959161172786), (4, 1, 0.125215670769727), (4, 5, 0.12498888481834258), (2, 6, 0.12449183497329691), (0, 9, 0.12325136333246975), (0, 6, 0.12264510875982099), (1, 11, 0.1223869126832301), (4, 9, 0.12207874539677677), (4, 10, 0.12204815966692621), (4, 7, 0.12167764101560813), (0, 5, 0.1216509180446421), (2, 1, 0.12154367079065047), (0, 10, 0.12070691094214922), (3, 1, 0.12063199760152898), (1, 8, 0.11968806324305246), (3, 7, 0.1195478440793774), (3, 4, 0.11938728747427772), (1, 1, 0.11895614866267316), (0, 7, 0.11866641294412643), (2, 8, 0.11832455905426274), (1, 5, 0.11776213346842508), (0, 3, 0.11548711856593791), (1, 3, 0.11514066073132391), (0, 11, 0.11501069082440221), (3, 3, 0.11458264421366839), (2, 4, 0.11449397341419225), (0, 0, 0.11299689002266738), (3, 9, 0.11285334549500341), (2, 0, 0.11282576791440299), (3, 5, 0.1126289646281459), (2, 9, 0.11251620759882218), (1, 4, 0.1122227909962318), (1, 10, 0.11147971908048575), (0, 2, 0.11145458944377395), (1, 0, 0.11133206961487419), (0, 1, 0.10845245556169843), (3, 11, 0.10791545942834053), (3, 0, 0.10664399820341712)]
print("")

[(9, 5, 0.22910906945266027),
 (9, 4, 0.20670482156851616),
 (11, 5, 0.20384664328140217),
 (9, 9, 0.2027791205407322),
 (8, 4, 0.19366057264682351),
 (8, 3, 0.1929708452634522),
 (10, 11, 0.1924844757337584),
 (8, 0, 0.1912139770748914),
 (10, 8, 0.19096766338766102),
 (8, 1, 0.19064305221106376),
 (11, 7, 0.18917525337332117),
 (11, 0, 0.18796913047231342),
 (10, 2, 0.1857773846771663),
 (8, 8, 0.18555147862865973),
 (10, 0, 0.18530059177637234),
 (11, 2, 0.18495542770861484),
 (9, 2, 0.18475037539261013),
 (9, 0, 0.1841438644026269),
 (11, 4, 0.18140669805587734),
 (10, 9, 0.18016958380296674),
 (9, 1, 0.1799556444681705),
 (8, 11, 0.1797203777121738),
 (10, 5, 0.17958857143178505),
 (8, 7, 0.17886423133516335),
 (11, 3, 0.17881788652643082),
 (10, 4, 0.17843035955545758),
 (10, 6, 0.17813508280376536),
 (11, 6, 0.17802176577302084),
 (9, 10, 0.17787153222013147),
 (8, 5, 0.17732895140520186),
 (9, 3, 0.17613587979399356),
 (11, 1, 0.17594829998328604),
 (10, 10, 0.17578660223763148

In [None]:
#TODO integradted gradients / expected gradients / integrated hessians

# IG using captum
# pip install captum
# https://captum.ai/tutorials/IMDB_TorchText_Interpret  