In [None]:
import re
import torch
import numpy as np
import pandas as pd
import nltk
from nltk.stem import WordNetLemmatizer
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
import torch.nn as nn
import torch.nn.functional as F
import nlpaug.augmenter.word as naw
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.figure_factory as ff
import optuna 

In [None]:
# ---------------------------------------------
#Setup environment checks
# ---------------------------------------------
print("Is CUDA available? ", torch.cuda.is_available())
print("Number of available GPUs:", torch.cuda.device_count())
if torch.cuda.is_available():
    print("Current GPU:", torch.cuda.current_device())
    print("Device name:", torch.cuda.get_device_name(0))
else:
    print("CUDA is not available. Check your driver/environment setup.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
# ---------------------------------------------
#Download NLTK data
# ---------------------------------------------
nltk.download('wordnet')
nltk.download('omw-1.4')

In [None]:
# ---------------------------------------------
# Text Cleaning and Lemmatization
# ---------------------------------------------
lemmatizer = WordNetLemmatizer()

def clean_text(text):
    # Lowercase
    text = text.lower()
    # Remove non-informative chars
    text = re.sub(r"[^a-z0-9.,!?'\s-]", '', text)
    # Replace multiple spaces with a single space
    text = re.sub(r"\s+", " ", text).strip()
    # Normalize excessive punctuation
    text = re.sub(r"!+", "!", text)
    text = re.sub(r"\?+", "?", text)
    # Lemmatize tokens
    tokens = text.split()
    tokens = [lemmatizer.lemmatize(token) for token in tokens]
    text = " ".join(tokens)
    return text

In [None]:
# ---------------------------------------------
# Data Loading and Preprocessing
# ---------------------------------------------
data = pd.read_excel('\\\\vi240c060002.woc.prod\\e$\\datasets\\Fields\\2ND Source_Of_Incident\\WCMLDataset12_23.xlsx')
example_data = data.copy()

text_fields = [
    'Incident Description', 
    'Activity Engaged in During Accident', 
    'General HS Comments', 
    'Injury Description'
]
example_data[text_fields] = example_data[text_fields].fillna('')
for field in text_fields:
    example_data[field] = example_data[field].apply(clean_text)

example_data['Combined_Text'] = (
    example_data['Incident Description'] + ' ' +
    example_data['Activity Engaged in During Accident'] + ' ' +
    example_data['General HS Comments'] + ' ' +
    example_data['Injury Description']
).str.strip()

targets = [
    'Event of Injury Desc', 
    'Source of Injury Desc', 
    'Event of Incident Desc', 
    'Source of Incident Desc',
    'EDI Cause Desc'
]

label_encoders = {}
for target in targets:
    le = LabelEncoder()
    example_data[target + '_Encoded'] = le.fit_transform(example_data[target])
    label_encoders[target] = le

focus_target = 'Source of Incident Desc'
focus_target_encoded = focus_target + '_Encoded'

In [None]:
# ---------------------------------------------
# Rare Class Identification and Augmentation
# ---------------------------------------------
rare_threshold = 250
class_counts = example_data[focus_target].value_counts()
rare_classes_list = class_counts[class_counts < rare_threshold].index.tolist()

if rare_classes_list:
    syn_aug = naw.SynonymAug(aug_src='wordnet', aug_min=1, aug_max=3, aug_p=0.1)
    def augment_text(text, augmenter=syn_aug):
        return augmenter.augment(text)
    rare_class_filter = example_data[focus_target].isin(rare_classes_list)
    rare_class_data = example_data[rare_class_filter]
    augmented_samples = []
    for _, row in rare_class_data.iterrows():
        augmented_text = augment_text(row['Combined_Text'])
        new_row = row.copy()
        new_row['Combined_Text'] = augmented_text
        augmented_samples.append(new_row)
    augmented_df = pd.DataFrame(augmented_samples)
    example_data = pd.concat([example_data, augmented_df], ignore_index=True)

In [None]:
# ---------------------------------------------
# Tokenization and Data Split
# ---------------------------------------------
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

X_train, X_test, y_train, y_test = train_test_split(
    example_data['Combined_Text'], 
    example_data[focus_target_encoded], 
    test_size=0.2, 
    random_state=42
)

X_train = X_train.astype(str).tolist()
X_test = X_test.astype(str).tolist()

train_encodings = tokenizer(list(X_train), truncation=True, padding=True, max_length=512)
test_encodings = tokenizer(list(X_test), truncation=True, padding=True, max_length=512)

class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

train_dataset = Dataset(train_encodings, list(y_train))
test_dataset = Dataset(test_encodings, list(y_test))

num_labels = len(label_encoders[focus_target].classes_)

In [None]:
# ---------------------------------------------
# Compute Class Weights and Implement Focal Loss
# ---------------------------------------------
class_counts_train = np.bincount(y_train)
total_samples = len(y_train)
class_weights = total_samples / (num_labels * class_counts_train.astype(float))
class_weights = torch.tensor(class_weights, dtype=torch.float)

class FocalLoss(nn.Module):
    def __init__(self, alpha=None, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.reduction = reduction
    def forward(self, inputs, targets, gamma=2.0):
        if isinstance(self.alpha, torch.Tensor):
            self.alpha = self.alpha.to(inputs.device)
        log_prob = F.log_softmax(inputs, dim=-1)
        prob = torch.exp(log_prob)
        log_prob = log_prob.gather(dim=-1, index=targets.unsqueeze(1)).squeeze(1)
        prob = prob.gather(dim=-1, index=targets.unsqueeze(1)).squeeze(1)

        focal_weight = (1 - prob) ** gamma
        if self.alpha is not None:
            alpha_t = self.alpha[targets]
            focal_weight = focal_weight * alpha_t

        loss = -focal_weight * log_prob
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

    
class FocalTrainer(Trainer):
    def __init__(self, alpha=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_fn = FocalLoss(alpha=alpha)

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits

        gamma_val = getattr(self.args, "gamma", 2.0)
        loss = self.loss_fn(logits, labels, gamma=gamma_val)
        return (loss, outputs) if return_outputs else loss    

In [None]:
# ---------------------------------------------
# Compute Metrics Function
# ---------------------------------------------
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='weighted', zero_division=0)
    acc = accuracy_score(labels, predictions)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [None]:
def my_objective(metrics):
    return metrics["eval_f1"] 

In [None]:
# ---------------------------------------------
# model_init for hyperparameter search
# ---------------------------------------------
def model_init(hp):
    model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=num_labels)
    model.to(device)
    return model

In [None]:
# ---------------------------------------------
# Optuna hyperparameter search
# ---------------------------------------------
def hp_space_optuna(trial):
    return {
        "learning_rate": trial.suggest_float("learning_rate", 1e-5, 5e-5, log=True),
        "gamma": trial.suggest_float("gamma", 0.5, 3.5),
        "weight_decay": trial.suggest_float("weight_decay", 0.001, 0.2),
        "warmup_steps": trial.suggest_int("warmup_steps", 300, 1500)
    }

In [None]:
# ---------------------------------------------
# Trainer for hyperparameter search
# ---------------------------------------------
search_args = TrainingArguments(
    output_dir='\\\\vi240c060002.woc.prod\\e$\\hyperparameters\\hp_search_results_sourceofincident',
    evaluation_strategy="epoch",
    save_strategy="no",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    num_train_epochs=15,  
    logging_dir='./logs_hp_search',
    logging_steps=50,
    load_best_model_at_end=False,
    no_cuda=(not torch.cuda.is_available())
)

search_trainer = FocalTrainer(
    alpha=class_weights,
    model_init=model_init,
    args=search_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
)

In [None]:
# ---------------------------------------------
# Hyperparameter Search
# ---------------------------------------------
best_run = search_trainer.hyperparameter_search(
    hp_space=hp_space_optuna,
    backend="optuna",
    n_trials=50,
    direction="maximize" 
    compute_objective=my_objective
)

print("Best Run:", best_run)
print("Best Hyperparams:", best_run.hyperparameters)