In [1]:
from utils import InferenceProcessor, inference
from absa_bert import BertABSATagger, ABSADataset
from settings import * 


import itertools 
import os 
import pandas as pd 

from torch.utils.data import DataLoader 
from transformers import BertConfig, BertTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
test_path = os.path.join(YELP_DIR, 'sample.csv') 

d_test = pd.read_csv(test_path, encoding='utf-8-sig')

In [3]:
d_test

Unnamed: 0,user_id,business_id,text,stars
0,hKBQ-PFlcB-t5FK3HUxoyQ,w5el96z7deUQMuGGBlvf7A,wing bigger see flavor good service pretty lou...,3
1,quF5ORcDanGUIR5P3AsRQA,d28lZlSps97FHhXZazJ8zA,everytime walk misconduct tavern chant head ex...,4
2,h0Jn9rkacf3tGw5BeZMHWA,TcNZXteosegb1RO4O5hREw,place set bar boy go drop domalises star huge ...,5
3,0G-bcpdR48tfZy6koGUiTQ,t1o0DX6aEUVm9ur43kQ_Yw,friendly quick use drive thru morning line get...,4
4,dNyQX33DDCjc_MzD0XjDog,aBo4pJhFKPs5MLHobREyJA,great little donut shop time grab dozen work g...,4
...,...,...,...,...
49995,90QrD73MiLPjs4mtkmBqQQ,arKiXax3ScSM_z3O-0CIyw,love love love place make actually crave snail...,5
49996,W67nXodRWTIa-d1NJlncvA,OvC4Ecgzk2SI7R8qD0rk9Q,sky bar hippest least tawdry incarnation ever ...,4
49997,7bWZAphAjvu5PuTIKDDTow,QpyvUColtSYH5VtkcrN6Xw,cool local diner dennis say come year come try...,5
49998,eFUcc2QbaW0InWgsBsMiwA,LNHq9WxfhN2UBNOR2tnRQQ,personally feel tomato pie bite sweet scrape t...,3


In [4]:
class args:
    batch_size = 32
    latent_dim = 64 
    seed = 42 
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    absa_type = 'san'
    semeval_dir = SEM_DIR
    model_type = 'bert'
    fix_tfm = 0
    max_seq_length = 512 
    model_name_or_path = 'bert-base-uncased'
    param_dir = os.path.join(PARAM_DIR, 'best-parameters.pt')


    output_dir = f'{model_type}-{absa_type}'
    output_dir = output_dir

In [5]:
processor = InferenceProcessor()
label_list = processor.get_labels() # ['O', 'EQ', 'B-POS',...]
num_labels = len(label_list)

config_class, model_class, tokenizer_class = BertConfig, BertABSATagger, BertTokenizer
config = config_class.from_pretrained(args.model_name_or_path, num_labels=num_labels)
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

config.absa_type = args.absa_type 
config.fix_tfm = args.fix_tfm 

model = model_class.from_pretrained(args.model_name_or_path, config=config)
model.load_state_dict(torch.load(args.param_dir))
model = model.to(args.device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertABSATagger: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertABSATagger from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertABSATagger from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertABSATagger were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['tagger.self_

In [6]:
test_set = ABSADataset(args, d_test, tokenizer)

test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False)

In [7]:
absa_label_vocab = {'O':0, 'EQ':1, 'B-POS':2, 'I-POS':3, 'E-POS':4, 'S-POS':5,
                    'B-NEG':6, 'I-NEG':7, 'E-NEG':8, 'S-NEG':9, 'B-NEU':10, 
                    'I-NEU':11, 'E-NEU':12, 'S-NEU':13}

In [13]:
pos_words, neg_words, pos_vocabs, neg_vocabs = inference(args, test_loader, model, tokenizer)

extract opinion in reviews, total iteration: 50,000: 36789it [00:58, 621.14it/s]

In [None]:
d_test.loc[:, 'pos_words'] = pos_words 
d_test.loc[:, 'neg_words'] = neg_words 

d_test.pos_words = d_test.pos_words.apply(lambda x: '' if x == [] else x)
d_test.neg_words = d_test.neg_words.apply(lambda x: '' if x == [] else x)

In [12]:
d_test

Unnamed: 0,user_id,business_id,text,stars,pos_words,neg_words
0,hKBQ-PFlcB-t5FK3HUxoyQ,w5el96z7deUQMuGGBlvf7A,wing bigger see flavor good service pretty lou...,3,[service],
1,quF5ORcDanGUIR5P3AsRQA,d28lZlSps97FHhXZazJ8zA,everytime walk misconduct tavern chant head ex...,4,"[food, drink, bar, food, mac, cheese, platt, m...",
2,h0Jn9rkacf3tGw5BeZMHWA,TcNZXteosegb1RO4O5hREw,place set bar boy go drop domalises star huge ...,5,"[place, bar, dom, shrimp, shrimp, bread, consi...",
3,0G-bcpdR48tfZy6koGUiTQ,t1o0DX6aEUVm9ur43kQ_Yw,friendly quick use drive thru morning line get...,4,"[line, mcdonald, order]",
4,dNyQX33DDCjc_MzD0XjDog,aBo4pJhFKPs5MLHobREyJA,great little donut shop time grab dozen work g...,4,"[don, apple, fr]",
...,...,...,...,...,...,...
49995,90QrD73MiLPjs4mtkmBqQQ,arKiXax3ScSM_z3O-0CIyw,love love love place make actually crave snail...,5,[place],
49996,W67nXodRWTIa-d1NJlncvA,OvC4Ecgzk2SI7R8qD0rk9Q,sky bar hippest least tawdry incarnation ever ...,4,"[bar, patio, solar, panel, wall, window, bar, ...",
49997,7bWZAphAjvu5PuTIKDDTow,QpyvUColtSYH5VtkcrN6Xw,cool local diner dennis say come year come try...,5,"[diner, breakfast, server, food, food]",
49998,eFUcc2QbaW0InWgsBsMiwA,LNHq9WxfhN2UBNOR2tnRQQ,personally feel tomato pie bite sweet scrape t...,3,"[pie, pie, tomato, pie]",
