In [1]:
import json

In [2]:
import re
import numpy as np
from torch.nn.functional import softmax
from transformers import AutoTokenizer, AutoModelForSequenceClassification

def pre_proccess(text):
    text = text.lower()
    text = re.sub('["\',!-.:-@0-9/]()', ' ', text)
    return text

# Wrapper to adapt output format
class SentimentAnalisysModelWrapper:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        
    def __predict(self, text_input):
        text_preprocessed = pre_proccess(text_input)
        tokenized = self.tokenizer(text_preprocessed, padding=True, truncation=True, max_length=512, 
                                    add_special_tokens = True, return_tensors="pt")
        
        tensor_logits = self.model(**tokenized)
        prob = softmax(tensor_logits[0]).detach().numpy()
        pred = np.argmax(prob)
        
        return pred, prob
    
    def predict_label(self, text_inputs):
        return self.predict(text_inputs)[0]
        
    def predict_proba(self, text_inputs):
        return self.predict(text_inputs)[1]
        
    def predict(self, text_inputs):
        if isinstance(text_inputs, str):
            text_inputs = [text_inputs]
        
        preds = []
        probs = []

        for text_input in text_inputs:
            pred, prob = self.__predict(text_input)
            preds.append(pred)
            probs.append(prob[0])

        return np.array(preds), np.array(probs) # ([0, 1], [[0.99, 0.01], [0.03, 0.97]])

# Auxiliar function to load and wrap a model from Hugging Face
def load_model(model_name):
    print(f'Loading model {model_name}...')
    model = AutoModelForSequenceClassification.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    return SentimentAnalisysModelWrapper(model, tokenizer)

In [6]:
model = load_model('textattack/bert-base-uncased-imdb')
# model = load_model('textattack/bert-base-uncased-rotten-tomatoes')

Loading model textattack/bert-base-uncased-imdb...


In [7]:
# with open("rotten_tomatoes_all_lexicons.json", mode="rt") as f:
with open("imdb_all_lexicons.json", mode="rt") as f:
    lexicons = json.load(f)
    
for k in lexicons:
    new_list = []
    for lex in lexicons[k]:
        pred = model.predict(lex)
        new_list.append((lex, (pred[0][0], list(pred[1][0]))))
    lexicons[k] = new_list

  prob = softmax(tensor_logits[0]).detach().numpy()


## Rotten Tomatoes

In [5]:
lexicons

{'neg_verbs': [('avoids', (0, [0.99845815, 0.0015417968])),
  ('given', (0, [0.6998323, 0.30016768])),
  ('spare', (0, [0.86081946, 0.13918054])),
  ('have', (0, [0.7627887, 0.2372113])),
  ('modeled', (0, [0.9195984, 0.0804016])),
  ('tells', (0, [0.80778116, 0.1922188])),
  ('italicized', (0, [0.9853703, 0.014629735])),
  ('withered', (0, [0.99837935, 0.0016205859])),
  ('sum', (0, [0.6032982, 0.39670184])),
  ('undermines', (0, [0.99832374, 0.0016762016])),
  ('dismissed', (0, [0.9866006, 0.013399359])),
  ('simpering', (0, [0.9993192, 0.0006807732])),
  ('squanders', (0, [0.9990922, 0.00090783427])),
  ('removed', (0, [0.9905794, 0.009420529]))],
 'pos_verbs': [('talking', (1, [0.013159547, 0.9868405])),
  ('manages', (1, [0.029663043, 0.970337])),
  ('discover', (1, [0.32330358, 0.6766965])),
  ('combines', (1, [0.11192099, 0.88807905])),
  ('awakens', (1, [0.007863645, 0.99213636])),
  ('coming', (1, [0.08797607, 0.9120239])),
  ('knows', (1, [0.0042919596, 0.995708])),
  ('gives

## IMDB

In [8]:
lexicons

{'neg_verbs': [('replaces', (0, [0.9206565, 0.07934349])),
  ('evaluate', (0, [0.70827526, 0.2917247])),
  ('looked', (0, [0.7049822, 0.29501778])),
  ('had', (0, [0.82382256, 0.17617746])),
  ('have', (0, [0.5899664, 0.41003358])),
  ('bored', (0, [0.88253695, 0.11746309])),
  ('thought', (0, [0.66970134, 0.33029863])),
  ('lying', (0, [0.80019397, 0.199806])),
  ('pretending', (0, [0.8348354, 0.16516465])),
  ('crawls', (0, [0.78746206, 0.21253793])),
  ('died', (0, [0.7514838, 0.24851613])),
  ('mutilating', (0, [0.97070193, 0.029298099])),
  ('saw', (0, [0.62543684, 0.3745632])),
  ('say', (0, [0.70253146, 0.2974685])),
  ('appeared', (0, [0.6304042, 0.36959583])),
  ('takes', (0, [0.56704277, 0.4329572])),
  ('mocking', (0, [0.65667796, 0.343322])),
  ('expected', (0, [0.69823587, 0.3017641])),
  ('wastes', (0, [0.9926779, 0.007322012])),
  ('play', (0, [0.5247547, 0.47524533])),
  ('see', (0, [0.5913498, 0.40865022])),
  ('think', (0, [0.58346474, 0.41653526])),
  ('Wish', (0, [0