In [1]:
import transformers
import torch
from transformers import BertModel, BertTokenizer
from torch import nn
from torch.utils.data import Dataset, DataLoader
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [2]:
PRE_TRAINED_MODEL_NAME = "bert-base-uncased"
MAX_LEN = 300
category_names = ['programming','business','health','marketing','politics','sports']

In [3]:
class CategoryClassifier(nn.Module):

  def __init__(self, n_classes):
    super(CategoryClassifier, self).__init__()
    self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
    self.drop = nn.Dropout(p=0.3)
    self.out = nn.Linear(self.bert.config.hidden_size, n_classes)
  
  def forward(self, input_ids, attention_mask):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        pooled_output = outputs.pooler_output
        output = self.drop(pooled_output)
        return self.out(output)

class MyDataset(Dataset):

  def __init__(self, reviews, targets, tokenizer, max_len):
    self.reviews = reviews
    self.targets = targets
    self.tokenizer = tokenizer
    self.max_len = max_len
  
  def __len__(self):
    return len(self.reviews)
  
  def __getitem__(self, item):
    review = str(self.reviews[item])
    target = self.targets[item]

    encoding = self.tokenizer.encode_plus(
      review,
      add_special_tokens=True,
      max_length=self.max_len,
      truncation=True,
      return_token_type_ids=False,
      padding='max_length',
      return_attention_mask=True,
      return_tensors='pt',
    )

    return {
      'review_text': review,
      'input_ids': encoding['input_ids'].flatten(),
      'attention_mask': encoding['attention_mask'].flatten(),
      'targets': torch.tensor(target, dtype=torch.long)
    }

def create_data_loader(df, tokenizer, max_len, batch_size):
  ds = MyDataset(
    reviews=df.stemmed_text.to_numpy(),
    targets=df.category_id.to_numpy(),
    tokenizer=tokenizer,
    max_len=max_len
  )

  return DataLoader(
    ds,
    batch_size=batch_size,
    num_workers=4
  )

In [8]:
input_text = "As the global economy continues to recover from the impact of the COVID-19 pandemic, businesses are looking for new ways to grow and thrive in the post-pandemic world. With the rise of digital technologies and changing consumer preferences, companies must be agile and innovative to stay competitive. From leveraging data analytics to exploring new markets and partnerships, businesses must adapt quickly to meet the evolving needs of their customers. Those that can successfully navigate these challenges and seize new opportunities will be well-positioned for success in the years ahead."

In [9]:
# create the model object
model = CategoryClassifier(6)
# load the saved state dictionary
model.load_state_dict(torch.load('best_model_state.bin'))
# move the model to the device
model = model.to(device)
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
encoded_text = tokenizer.encode_plus(
      input_text,
      add_special_tokens=True,
      max_length=MAX_LEN,
      truncation=True,
      return_token_type_ids=False,
      padding='max_length',
      return_attention_mask=True,
      return_tensors='pt',
    )

In [11]:
input_ids = encoded_text['input_ids'].to(device)
attention_mask = encoded_text['attention_mask'].to(device)

output = model(input_ids, attention_mask)
_, prediction = torch.max(output, dim=1)

print(f'Input text: {input_text}')
print(f'Category  : {category_names[prediction]}')

Input text: As the global economy continues to recover from the impact of the COVID-19 pandemic, businesses are looking for new ways to grow and thrive in the post-pandemic world. With the rise of digital technologies and changing consumer preferences, companies must be agile and innovative to stay competitive. From leveraging data analytics to exploring new markets and partnerships, businesses must adapt quickly to meet the evolving needs of their customers. Those that can successfully navigate these challenges and seize new opportunities will be well-positioned for success in the years ahead.
Category  : business
