# Document classification with BERT

[BERT](https://en.wikipedia.org/wiki/BERT_(language_model))
(Bidirectional Encoder Representations from Transformers) is a pre-trained model. 

It can be loaded from the [Huggingface](https://huggingface.co/) transformers
library and fine tuned for the downstream task, namely text classification.

In this notebook we fine tune the BERT base model to perform text classification
on the 20newsgroup dataset.

In [7]:
# load dataset
from sklearn.datasets import fetch_20newsgroups

categories = ["alt.atheism", "soc.religion.christian", "comp.graphics", "sci.med"]
idx_to_categories = {i: cat for i, cat in enumerate(categories)}
twenty_train = fetch_20newsgroups(
    subset="train",
    categories=categories,
    shuffle=True,
    random_state=42,
    remove=("headers", "footers", "quotes"),
)

twenty_test = fetch_20newsgroups(
    subset="test",
    categories=categories,
    shuffle=True,
    random_state=42,
    remove=("headers", "footers", "quotes"),
)

In [16]:
import numpy as np
from typing import List
import torch
from torch.utils.data import DataLoader, Dataset

class TextClassificationDataset(Dataset):
    def __init__(self, corpus: List[str], targets: np.ndarray, tokenizer, max_length: int):
        self.corpus = corpus
        self.targets = targets
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.corpus)
    
    def __getitem__(self, idx):
        doc = self.corpus[idx]
        target = self.targets[idx]
        encoding = self.tokenizer(doc, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True)
        return {'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'label': torch.tensor(target)}


In [3]:
from transformers import BertModel
class BERTClassifier(torch.nn.Module):
    def __init__(self, bert_model_name, num_classes):
        super(BERTClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.dropout = torch.nn.Dropout(0.1)
        self.fc = torch.nn.Linear(self.bert.config.hidden_size, num_classes)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        x = self.dropout(pooled_output)
        logits = self.fc(x)
        return logits

In [27]:
# train function
from tqdm.notebook import tqdm
def train(model, data_loader, optimizer, scheduler, device):
    model.train()
    for batch in tqdm(data_loader):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = torch.nn.CrossEntropyLoss()(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()

In [29]:
from sklearn.metrics import accuracy_score, classification_report
# evaluate function
def evaluate(model, data_loader, device):
    model.eval()
    predictions = []
    actual_labels = []
    with torch.no_grad():
        for batch in tqdm(data_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            _, preds = torch.max(outputs, dim=1)
            predictions.extend(preds.cpu().tolist())
            actual_labels.extend(labels.cpu().tolist())
    return accuracy_score(actual_labels, predictions), classification_report(actual_labels, predictions)

In [8]:

def predict_sentiment(doc, model, tokenizer, device, max_length=128):
    model.eval()
    encoding = tokenizer(doc, return_tensors='pt', max_length=max_length, padding='max_length', truncation=True)
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        _, preds = torch.max(outputs, dim=1)
    return idx_to_categories[preds.item()]

In [9]:
# model params
bert_model_name = 'bert-base-uncased'
num_classes = len(categories)
max_length = 128
batch_size = 16
num_epochs = 1
learning_rate = 2e-5

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BERTClassifier(bert_model_name, num_classes).to(device)

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [23]:
# dataloaders
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
train_dataset = TextClassificationDataset(twenty_train.data, twenty_train.target, tokenizer, max_length)
val_dataset = TextClassificationDataset(twenty_test.data, twenty_test.target, tokenizer, max_length)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

In [24]:
# optimizer and lr scheduler
from transformers import get_linear_schedule_with_warmup
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_dataloader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

In [28]:
# main train loop
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    train(model, train_dataloader, optimizer, scheduler, device)
    accuracy, report = evaluate(model, val_dataloader, device)
    print(f"Validation Accuracy: {accuracy:.4f}")
    print(report)

Epoch 1/1


  0%|          | 0/3 [00:00<?, ?it/s]

Validation Accuracy: 0.2250
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        10
           1       0.25      0.33      0.29         9
           2       0.00      0.00      0.00        13
           3       0.23      0.75      0.35         8

    accuracy                           0.23        40
   macro avg       0.12      0.27      0.16        40
weighted avg       0.10      0.23      0.13        40

