In [1]:
import sys

sys.path.append('../')

In [2]:
import torch
from transformers import BertForSequenceClassification
from training_params import TOKENIZER, MAX_LEN, DEVICE, PRE_TRAINED_MODEL_NAME, LABEL_DICT
from inference_params import CHECKPOINT_PATH

import numpy as np

In [3]:
model = BertForSequenceClassification.from_pretrained(
                                                        PRE_TRAINED_MODEL_NAME,
                                                        num_labels=len(LABEL_DICT),
                                                        output_attentions=False,
                                                        output_hidden_states=False
                                                    )

# model = nn.DataParallel(model)
checkpoint = torch.load(CHECKPOINT_PATH)
model.load_state_dict(checkpoint['state_dict'])

model.eval()
model.cuda()

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [4]:
review_text = "I love Boat bassheads! best earphones ever!!!"

In [10]:
def infer(text):
    encoding = TOKENIZER.encode_plus(
            text,
            add_special_tokens=True,
            max_length=MAX_LEN,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )
    
    input_ids = encoding['input_ids'].to(DEVICE)
    attention_mask = encoding['attention_mask'].to(DEVICE)
    output = model(input_ids, attention_mask)
    logits = output['logits'].detach().cpu().numpy()
    pred = np.argmax(logits, axis=1)
    pred_name = [LABEL_DICT[str(pr)] for pr in pred]
    return pred_name[0]

In [6]:
infer(review_text)

'POSITIVE'

In [8]:
texts = ["these earphones are not good",
         "worst purchase ever",
         "i bought these three months ago and they stopped working",
         "love these",
         "these are ok",
        "nice for the price",
        "amazing product",
        "i am neeraj",
        "this is made in china",
        "i don't mind these",
        "i kind of like the bass",
        "good not great"]

In [14]:
for text in texts:
    print(text)
    print("- Prediced as: ", infer(text))
    print('.')

these earphones are not good
- Prediced as:  NEGATIVE
.
worst purchase ever
- Prediced as:  NEGATIVE
.
i bought these three months ago and they stopped working
- Prediced as:  NEUTRAL
.
love these
- Prediced as:  POSITIVE
.
these are ok
- Prediced as:  POSITIVE
.
nice for the price
- Prediced as:  POSITIVE
.
amazing product
- Prediced as:  POSITIVE
.
i am neeraj
- Prediced as:  NEUTRAL
.
this is made in china
- Prediced as:  NEUTRAL
.
i don't mind these
- Prediced as:  NEUTRAL
.
i kind of like the bass
- Prediced as:  NEUTRAL
.
good not great
- Prediced as:  NEUTRAL
.
