# Fine-Tuning ClinicalBERT on ADE Text Classification

This notebook applies a fine-tuned ClinicalBERT model for classifying adverse drug events (ADEs) in medical texts. It is trained on the ADE-Corpus-V2 Dataset: Adverse Drug Reaction Data. This is a dataset for Classification if a sentence is ADE-related (True=1) or not (False=0).

#Install Prerequisites

In [None]:
# Mount onto drive
from google.colab import drive

drive.mount("/content/drive")

%cd '/content/drive/MyDrive/GaTech/bert/'

In [None]:
# Install required packages if needed
!pip install -r requirements.txt
!pip install torch -U

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    AutoModel,
    AdamW,
    get_linear_schedule_with_warmup
)
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
import warnings

warnings.filterwarnings('ignore')

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Clear cuda cache
torch.cuda.empty_cache()

# Create Tokenizer Class
Tokenization prepares the strings from the dataset into units called "tokens", which streamlines the way in which the model learns from this data.

In [None]:
class TokenizeDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]

        # If tweaking of the tokenizer is required:
        # https://huggingface.co/docs/transformers/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__.padding
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt')

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Define Training and Evaluation Functions

In [None]:
def plot_training_metrics(train_losses, val_losses, train_perplexities, val_perplexities, train_accuracies, val_accuracies, epochs):
    """Plot training and validation metrics."""

    epochs_range = range(1, epochs + 1)

    # Plot losses
    plt.figure(figsize=(10, 6))
    plt.plot(epochs_range, train_losses, 'b-', label='Training Loss')
    plt.plot(epochs_range, val_losses, 'r-', label='Validation Loss')
    plt.title('ClinicalBERT Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig('ClinicalBERTLoss.png')
    plt.show()
    plt.close()

    # Plot perplexities
    plt.figure(figsize=(10, 6))
    plt.plot(epochs_range, train_perplexities, 'b-', label='Training Perplexity')
    plt.plot(epochs_range, val_perplexities, 'r-', label='Validation Perplexity')
    plt.title('ClinicalBERT Training and Validation Perplexity')
    plt.xlabel('Epochs')
    plt.ylabel('Perplexity')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig('ClinicalBERTPerplexity.png')
    plt.show()
    plt.close()

    # Plot accuracies
    plt.figure(figsize=(10, 6))
    plt.plot(epochs_range, train_accuracies, 'b-', label='Training Accuracy')
    plt.plot(epochs_range, val_accuracies, 'r-', label='Validation Accuracy')
    plt.title('ClinicalBERT Training and Validation Accuracies')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig('ClinicalBERTAccuracy.png')
    plt.show()
    plt.close()

    plt.show()

In [None]:
def train_epoch(model, data_loader, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    predictions = []
    actual_labels = []

    for batch in data_loader:
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        loss = outputs.loss
        total_loss += loss.item()

        # Calculate accuracy
        _, preds = torch.max(outputs.logits, dim=1)
        predictions.extend(preds.cpu().tolist())
        actual_labels.extend(labels.cpu().tolist())

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()


    avg_loss = total_loss / len(data_loader)
    perplexity = torch.exp(torch.tensor(avg_loss)).item()
    report = classification_report(actual_labels, predictions, output_dict=True)
    accuracy = report['accuracy']


    return avg_loss, perplexity, accuracy

In [None]:
def evaluate_model(model, data_loader, device):
    model.eval()
    predictions = []
    actual_labels = []

    total_loss = 0

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = outputs.loss
            total_loss += loss.item()

            _, preds = torch.max(outputs.logits, dim=1)
            predictions.extend(preds.cpu().tolist())
            actual_labels.extend(labels.cpu().tolist())

    avg_loss = total_loss / len(data_loader)
    perplexity = torch.exp(torch.tensor(avg_loss)).item()

    # Get classification report
    report = classification_report(actual_labels, predictions, output_dict=True)
    accuracy = report['accuracy']

    return report, accuracy, avg_loss, perplexity

# Main Training Function

In [None]:
def train_bert(model, tokenizer, text_dict, hyperparam_dict):

    # Load in hyperparams
    num_epochs = hyperparam_dict['num_epochs']
    batch_size = hyperparam_dict['batch_size']
    learning_rate = hyperparam_dict['learning_rate']
    warmup_steps = hyperparam_dict['warmup_steps']
    max_length = hyperparam_dict['max_length']

    # Load in text
    train_texts = text_dict['train_texts']
    val_texts = text_dict['val_texts']
    train_labels = text_dict['train_labels']
    val_labels = text_dict['val_labels']

    # Tokenize the text from Training and Validation sets
    train_set = TokenizeDataset(train_texts, train_labels, tokenizer, max_length)
    val_set = TokenizeDataset(val_texts, val_labels, tokenizer, max_length)

    # Utilize PyTorch's DataLoader to pass minibatches and reshuffle the data
    # at every epoch to reduce model overfitting; this approach uses
    # Python’s multiprocessing to speed up data retrieval.
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=batch_size)

    # Initialize optimizer and scheduler
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    total_steps = len(train_loader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,
                                                num_training_steps=total_steps)

    train_losses = []
    train_perplexities = []
    train_accuracies = []
    val_losses = []
    val_perplexities = []
    val_accuracies = []

    # Training loop
    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}/{num_epochs}')
        train_loss, train_perplexity, train_acc = train_epoch(model, train_loader, optimizer, scheduler, device)

        train_losses.append(train_loss)
        train_perplexities.append(train_perplexity)
        train_accuracies.append(train_acc)

        print(f'Average train loss: {train_loss:.4f}')
        print(f'Train perplexity: {train_perplexity:.4f}')

        # Evaluation
        print('\nValidation Results:')
        val_report, val_acc, val_loss, val_perplexity = evaluate_model(model, val_set, device)
        print(val_report)
        print('-' * 60)
        print('Avg Acc:', val_acc)

        val_losses.append(val_loss)
        val_perplexities.append(val_perplexity)
        val_accuracies.append(val_acc)

    # Plot
    plot_training_metrics(train_losses, val_losses,
                          train_perplexities, val_perplexities,
                          train_accuracies, val_accuracies, num_epochs)

    return model, tokenizer

# Load Data

In [None]:
# Load the dataset
ade_cl_df = pd.read_parquet("hf://datasets/ade-benchmark-corpus/ade_corpus_v2/Ade_corpus_v2_classification/train-00000-of-00001.parquet")

# Display basic information about the dataset
print("Dataset Shape:", ade_cl_df.shape)
print("\nSample of the data:")
display(ade_cl_df.head())
print("\nClass distribution:")
display(ade_cl_df['label'].value_counts())

# Train the Model

In [None]:
# Define what model we are fine-tuning.
model_path = "medicalai/ClinicalBERT"
tokenizer_path = "medicalai/ClinicalBERT"

# Assumes we are using LLMs that have alredy been pretrained.
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(tokenizer_path, num_labels=2).to(device)

In [None]:
# Split data
train_texts, val_texts, train_labels, val_labels = train_test_split(ade_cl_df['text'].values,
                                                                    ade_cl_df['label'].values,
                                                                    test_size=0.2,
                                                                    random_state=42)

text_dict = {'train_texts': train_texts, 'val_texts': val_texts, 'train_labels': train_labels, 'val_labels': val_labels}

#Train

In [None]:
hyperparam_dict = {'max_length': 128,        # Maximum sequence length\n",
                   'num_labels': 2,          # Binary classification (ADE vs non-ADE)\n",
                   'num_epochs': 10,          # Number of training epochs\n",
                   'batch_size': 64,         # Batch size for training\n",
                   'learning_rate': 5e-6,    # Learning rate for optimizer\n",
                   'warmup_steps': 100,
                   'weight_decay': 0.01 # Number of warmup steps for scheduler\n",
                   }

In [None]:
# Train the model
trained_model, trained_tokenizer = train_bert(model, tokenizer, text_dict, hyperparam_dict)

# Save the model and tokenizer
trained_model.save_pretrained("tiny_bert_ade_classifier")
trained_tokenizer.save_pretrained("tiny_bert_ade_classifier")
print("Model and tokenizer saved successfully!")

# Trained Model Demonstration

In [None]:
def predict_ade(text, model, tokenizer):
    # Prepare the text
    encoding = tokenizer(
        text,
        add_special_tokens=True,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )

    # Move to device and get prediction
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        _, preds = torch.max(outputs.logits, dim=1)

    return "ADE" if preds.item() == 1 else "Not ADE"

In [None]:
# Example usage
example_text = "The patient experienced severe headache after taking aspirin."
prediction = predict_ade(example_text, trained_model, trained_tokenizer)
print(f"Text: {example_text}")
print(f"Prediction: {prediction}")