Import necessary libraries

In [1]:
import torch
from transformers import BertModel, BertTokenizer
from torch import nn
from torch.utils.data import Dataset, DataLoader
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

  from .autonotebook import tqdm as notebook_tqdm


device(type='cpu')

Global parameters

In [2]:
PRE_TRAINED_MODEL_NAME = "bert-base-uncased"
MAX_LEN = 128
category_names = ['programming','business','health','marketing','politics','sports']
conn_string = "host='localhost' dbname='postgres' user='postgres' password='12345'"
text_list = []
category_list = []
department_list = []

Global functions

In [3]:
class CategoryClassifier(nn.Module):

  def __init__(self, n_classes):
    super(CategoryClassifier, self).__init__()
    self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
    self.drop = nn.Dropout(p=0.3)
    self.out = nn.Linear(self.bert.config.hidden_size, n_classes)
  
  def forward(self, input_ids, attention_mask):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        pooled_output = outputs.pooler_output
        output = self.drop(pooled_output)
        return self.out(output)

class MyDataset(Dataset):

  def __init__(self, reviews, targets, tokenizer, max_len):
    self.reviews = reviews
    self.targets = targets
    self.tokenizer = tokenizer
    self.max_len = max_len
  
  def __len__(self):
    return len(self.reviews)
  
  def __getitem__(self, item):
    review = str(self.reviews[item])
    target = self.targets[item]

    encoding = self.tokenizer.encode_plus(
      review,
      add_special_tokens=True,
      max_length=self.max_len,
      truncation=True,
      return_token_type_ids=False,
      padding='max_length',
      return_attention_mask=True,
      return_tensors='pt',
    )

    return {
      'review_text': review,
      'input_ids': encoding['input_ids'].flatten(),
      'attention_mask': encoding['attention_mask'].flatten(),
      'targets': torch.tensor(target, dtype=torch.long)
    }

def create_data_loader(df, tokenizer, max_len, batch_size):
  ds = MyDataset(
    reviews=df.stemmed_text.to_numpy(),
    targets=df.category_id.to_numpy(),
    tokenizer=tokenizer,
    max_len=max_len
  )

  return DataLoader(
    ds,
    batch_size=batch_size,
    num_workers=4
  )

def department_name(category):
   if category == 'programming':
      return 'Information Technology'
   if category == 'business' or category == 'politics':
      return 'Business and Public Affairs'
   if category == 'health':
      return 'Healthcare Services'
   if category == 'marketing':
      return 'Marketing'
   if category == 'sports':
      return 'Sports Division'

Category classifier model loading

In [4]:
# create the model object
model = CategoryClassifier(6)

# load the saved state dictionary
if torch.cuda.is_available():
    state_dict = torch.load('prod/best_model_state.bin')
else:
    state_dict = torch.load('prod/best_model_state.bin', map_location=torch.device('cpu'))
model.load_state_dict(state_dict)

# move the model to the device
model = model.to(device)

tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Connection to the input text database and getting input data

In [5]:
import psycopg2

# create a connection object
conn = psycopg2.connect(conn_string)

# create a cursor object
cur = conn.cursor()

# execute a query
cur.execute("SELECT text FROM calls_data")

# fetch the query results
rows = cur.fetchall()
# print the rows
for row in rows:
    input_text = row[0]
    text_list.append(row[0])

# close the cursor and connection
cur.close()
conn.close()


Assign predicted category and responsible company department to each element of input text data

In [6]:
for i, text in enumerate(text_list):
    encoded_text = tokenizer.encode_plus(
      text,
      add_special_tokens=True,
      max_length=MAX_LEN,
      truncation=True,
      return_token_type_ids=False,
      padding='max_length',
      return_attention_mask=True,
      return_tensors='pt',
    )
    input_ids = encoded_text['input_ids'].to(device)
    attention_mask = encoded_text['attention_mask'].to(device)

    output = model(input_ids, attention_mask)
    _, prediction = torch.max(output, dim=1)
    category_list.append(category_names[prediction])
    department_list.append(department_name(category_list[i]))
    print(f"  Input text: {text}")
    print(f"  Predicted category: {category_list[i]}")
    print("")
    

  Input text: As the global economy continues to recover from the impact of the COVID-19 pandemic, businesses are looking for new ways to grow and thrive in the post-pandemic world. With the rise of digital technologies and changing consumer preferences, companies must be agile and innovative to stay competitive. From leveraging data analytics to exploring new markets and partnerships, businesses must adapt quickly to meet the evolving needs of their customers. Those that can successfully navigate these challenges and seize new opportunities will be well-positioned for success in the years ahead.
  Predicted category: business

  Input text: Sports are an important part of many peoples lives and offer numerous benefits to those who participate in them. Engaging in sports can improve physical fitness, promote teamwork and social interaction, and provide a healthy outlet for stress relief.There are many different types of sports, each with its own unique set of rules and challenges. From

Push each input text with assigned to it category and company's department to the output database

In [7]:
# create a connection object
conn = psycopg2.connect(conn_string)

# create a cursor object
cur = conn.cursor()

print("Inserted values: ")
for i, text in enumerate(text_list):
    # define the SQL statement to insert the values
    query = "INSERT INTO categorized_data (text, category, department) VALUES (%s, %s, %s)"
    
    # execute the SQL statement with the input_text and category values
    cur.execute(query, (text, category_list[i], department_list[i]))
    print(f"    Text: {text}")
    print(f"    Predicted category: {category_list[i]}")
    print(f"    Assigned department: {department_list[i]}")
    print("")

# commit the changes to the database
conn.commit()

# close the cursor and connection objects
cur.close()
conn.close()

Inserted values: 
    Text: As the global economy continues to recover from the impact of the COVID-19 pandemic, businesses are looking for new ways to grow and thrive in the post-pandemic world. With the rise of digital technologies and changing consumer preferences, companies must be agile and innovative to stay competitive. From leveraging data analytics to exploring new markets and partnerships, businesses must adapt quickly to meet the evolving needs of their customers. Those that can successfully navigate these challenges and seize new opportunities will be well-positioned for success in the years ahead.
    Predicted category: business
    Assigned department: Business and Public Affairs

    Text: Sports are an important part of many peoples lives and offer numerous benefits to those who participate in them. Engaging in sports can improve physical fitness, promote teamwork and social interaction, and provide a healthy outlet for stress relief.There are many different types of s