In [25]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [26]:
import os
import pandas as pd
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [3]:
from captum.attr import LayerIntegratedGradients
import torch


In [4]:

from pathlib import Path

from nlp_cmp_utils import create_data_loader
from nlp_cmp_utils import IMDBClassifier
from nlp_cmp_utils import explain_tweet_bert


In [5]:
from transformers import DistilBertTokenizerFast
from transformers import DistilBertModel

In [6]:
data_output_path = Path('/dataStore/transformers_projects/')

In [7]:
pd.set_option('max_colwidth',100)

### get file data

In [9]:
filename = 'result_2020_11_15.json'

In [10]:
df = pd.read_json(filename)

In [11]:
df

Unnamed: 0,id,feature_id,label_text,entry_id,feature_text,type
0,27,1,{stunned},13,RT @Jwhitbrook: With the new Picard up on Amazon I can finally get a clean screenshot and ask......,pos
1,1,2,{love},5,"RT @klaushismydaddy: netflix: i love ALL of my programs equally!!! stranger things, lucifer, tig...",pos
2,2,3,{excellent},5,Just finished Picard on Amazon Prime. Was not expecting to cry. What an excellent show.,pos
3,3,5,"{you,suggest,go,do,so.}",5,"If you haven’t seen the amazon prime show “Upload”, I suggest you go do so. I really enjoyed thi...",pos
4,25,6,"{""https://t.co/l2HIKNQ95V\n\nWorks""}",13,@brexitblog_info @boblister_poole https://t.co/l2HIKNQ95V\n\nWorks both ways,pos
...,...,...,...,...,...,...
423,400,1549,{NEVER},265,RT @Floydbirman: @LoyalDefender2K I have NEVER watch ANY award for decades. Same as BBC Question...,neg
424,328,1556,"{out,of,touch}",167,RT @annesayer6: @LozzaFox @EquityUK I would suggest actors denounce Equity UK for being complete...,neg
425,410,1557,{disgracefully},265,@JohnSimpsonNews Not to mention anti SNP bias! It’s blatant in Scotland! There is no doubt the b...,neg
426,422,1565,{hate},348,RT @Ally__Cinnamon: Boris Johnson could say I fuckin hate the blacks man and there would be a de...,neg


In [36]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

In [37]:
df['label'] = df.type.apply(lambda x: 1 if x=='pos' else 0,)

In [78]:

BATCH_SIZE=1

# full data set
results_batch = create_data_loader(df.feature_text.to_list(), df.label.to_list(), df.label_text.to_list() ,tokenizer , 512, BATCH_SIZE)

### load bert model

In [39]:
base_model = DistilBertModel.from_pretrained("distilbert-base-uncased")

In [40]:
model = IMDBClassifier(2, base_model)

In [41]:
model.load_state_dict(torch.load(data_output_path / 'distilbert_IMDB'))

<All keys matched successfully>

In [42]:
model.eval()
model.zero_grad()
model.cpu()
embeddings = model.model.embeddings.word_embeddings

### compare with captum

In [43]:
def ml(input_ids, additional_forward_args):
    
    return torch.softmax(model(input_ids,additional_forward_args),-1)

In [44]:
lig = LayerIntegratedGradients(ml, embeddings,)

In [79]:
results=[]
for batch in tqdm(results_batch):
    
    attribs = explain_tweet_bert(batch,lig,n_steps=25)
    word_parts = tokenizer.batch_decode(batch['input_ids'][0])
    
    attr_data_df = pd.DataFrame(data=(zip(attribs, word_parts)), columns=('score', 'word')).rename_axis('position')
    
    # remove ['PAD'] tags - bert only
    attr_data_df.reset_index(inplace=True)
    word_dict = attr_data_df.loc[attr_data_df.word != '[PAD]'].to_dict(orient='records')
    
    results.append({'model_dict':word_dict,'human_words':batch['found_label'],'sentiment':batch['label']})
    

100%|██████████| 428/428 [1:14:57<00:00, 10.51s/it]


In [80]:
res_df_cap = pd.DataFrame(results)


In [67]:
def fix_tuple_list(inp):
    """pytorch dataloader turns lists into lists of tuples. this is to correct back"""
    
    return [i[0] for i in inp]

In [81]:
res_df_cap['human_words']=res_df_cap.human_words.apply(fix_tuple_list)

In [83]:
res_df_cap.head()

Unnamed: 0,model_dict,human_words,sentiment
0,"[{'position': 0, 'score': 0.15727868574772647, 'word': '[CLS]'}, {'position': 1, 'score': 0.1249...",[stunned],pos
1,"[{'position': 0, 'score': 0.21714102995484544, 'word': '[CLS]'}, {'position': 1, 'score': 0.0128...",[love],pos
2,"[{'position': 0, 'score': 0.1665978808950008, 'word': '[CLS]'}, {'position': 1, 'score': -0.0337...",[excellent],pos
3,"[{'position': 0, 'score': 0.43073863908028953, 'word': '[CLS]'}, {'position': 1, 'score': 0.0644...","[you, suggest, go, do, so.]",pos
4,"[{'position': 0, 'score': 0.1545824975521331, 'word': '[CLS]'}, {'position': 1, 'score': 0.03480...","[""https://t.co/l2HIKNQ95V\n\nWorks""]",pos


In [84]:
res_df_cap.to_pickle('all_scores_bert_25steps.pkl')