In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score, precision_score, recall_score, average_precision_score
from tqdm.auto import tqdm
import json
from itertools import product

import warnings
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# device 
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS")
else:
    device = torch.device("cpu")
    print("Using CPU")

Using MPS


In [3]:
# config
LABEL_COLS = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']

# data
TRAIN_PATH = '../data/train.csv'
SAMPLE_SIZE = 50000
RANDOM_SEED = 42
N_FOLDS = 4

# model (fixed)
MODEL_NAME = 'distilroberta-base'
MAX_LENGTH = 128
HIDDEN_SIZE = 768
CLASSIFIER_HIDDEN = 256
DROPOUT = 0.1
NUM_LABELS = 6

# grid search parameters
LEARNING_RATES = [1e-5, 2e-5, 5e-5]
BATCH_SIZES = [64, 128]

# training (fixed)
EPOCHS = 1
MAX_GRAD_NORM = 1.0
THRESHOLD = 0.5
WARMUP_RATIO = 0.1
WEIGHT_DECAY = 0.01

WARMUP_RESULTS_PATH = '../results/bert_warmup_comparison.json'
RESULTS_PATH = '../results/bert_tuning_results.json'
BEST_PATH = '../results/bert_tuning_best.json'

print(f"Compare warmup/decay")
print(f"Grid search: {len(LEARNING_RATES)} LRs × {len(BATCH_SIZES)} batch sizes = {len(LEARNING_RATES) * len(BATCH_SIZES)} configs")
print(f"Each config: {N_FOLDS}-fold CV × {EPOCHS} epoch")


Compare warmup/decay
Grid search: 3 LRs × 2 batch sizes = 6 configs
Each config: 4-fold CV × 1 epoch


In [4]:
# load data 
train_df = pd.read_csv(TRAIN_PATH)

# take stratified sample
sample_df = train_df.groupby('toxic', group_keys=False).apply(
    lambda x: x.sample(frac=SAMPLE_SIZE/len(train_df), random_state=RANDOM_SEED)
).reset_index(drop=True)

print(f"Sampled {len(sample_df):,} from {len(train_df):,}\n")
print(f"{'Label':<15} {'Original %':<12} {'Sampled %':<12}")
print("-" * 40)
for col in LABEL_COLS:
    orig_pct = train_df[col].mean() * 100
    samp_pct = sample_df[col].mean() * 100
    print(f"{col:<15} {orig_pct:>10.2f}% {samp_pct:>10.2f}%")


Sampled 50,000 from 159,571

Label           Original %   Sampled %   
----------------------------------------
toxic                 9.58%       9.58%
severe_toxic          1.00%       1.03%
obscene               5.29%       5.33%
threat                0.30%       0.31%
insult                4.94%       4.91%
identity_hate         0.88%       0.87%


In [5]:
# 4-fold cross validation stratefied on 'Toxic'
skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=RANDOM_SEED)
folds = list(skf.split(sample_df, sample_df['toxic']))

print(f"{N_FOLDS}-Fold CV Split:")
for fold_idx, (train_idx, val_idx) in enumerate(folds):
    print(f"Fold {fold_idx+1}: Train={len(train_idx):,} | Val={len(val_idx):,}")


4-Fold CV Split:
Fold 1: Train=37,500 | Val=12,500
Fold 2: Train=37,500 | Val=12,500
Fold 3: Train=37,500 | Val=12,500
Fold 4: Train=37,500 | Val=12,500


In [6]:
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print(f"Tokenizer loaded: {MODEL_NAME}")

Tokenizer loaded: distilroberta-base


In [7]:
# dataset class (reused from bert.ipynb)
class BertDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=MAX_LENGTH):
        self.texts = df['comment_text'].values
        self.labels = df[LABEL_COLS].values.astype(np.float32)
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        labels = self.labels[idx]
        
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(labels, dtype=torch.float32)
        }


In [8]:
# model class (reused from bert.ipynb)
class ToxicClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.roberta = AutoModel.from_pretrained(MODEL_NAME)
        self.classifier = nn.Sequential(
            nn.Linear(HIDDEN_SIZE, CLASSIFIER_HIDDEN),
            nn.ReLU(),
            nn.Dropout(DROPOUT),
            nn.Linear(CLASSIFIER_HIDDEN, NUM_LABELS)
        )
    
    def forward(self, input_ids, attention_mask):
        outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        logits = self.classifier(pooled_output)
        return logits


In [9]:
# evaluation function (reused from bert.ipynb)
def evaluate(model, dataloader, criterion, device):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            logits = model(input_ids, attention_mask)
            probs = torch.sigmoid(logits).cpu().numpy()
            all_preds.append(probs)
            all_labels.append(labels.cpu().numpy())
    
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    preds_binary = (all_preds >= THRESHOLD).astype(int)
    
    # macro metrics
    macro_precision = precision_score(all_labels, preds_binary, average='macro', zero_division=0)
    macro_recall = recall_score(all_labels, preds_binary, average='macro', zero_division=0)
    macro_f1 = f1_score(all_labels, preds_binary, average='macro', zero_division=0)
    
    # per-label AUC-PR
    per_label_auc_pr = {}
    for i, label in enumerate(LABEL_COLS):
        try:
            per_label_auc_pr[label] = average_precision_score(all_labels[:, i], all_preds[:, i])
        except:
            per_label_auc_pr[label] = 0.0
    
    macro_auc_pr = np.mean(list(per_label_auc_pr.values()))
    
    return {
        'macro_precision': macro_precision,
        'macro_recall': macro_recall,
        'macro_f1': macro_f1,
        'macro_auc_pr': macro_auc_pr,
        'per_label_auc_pr': per_label_auc_pr
    }


In [None]:
# warmup/decay comparison
# test: lr=5e-5, bs=128, with vs without warmup/weight_decay

WARMUP_RATIO = 0.1
WEIGHT_DECAY = 0.01

def train_fold_with_warmup(fold_idx, train_idx, val_idx, lr, batch_size, use_warmup_decay=False):
    fold_train_df = sample_df.iloc[train_idx]
    fold_val_df = sample_df.iloc[val_idx]
    
    train_dataset = BertDataset(fold_train_df, tokenizer)
    val_dataset = BertDataset(fold_val_df, tokenizer)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    model = ToxicClassifier().to(device)
    criterion = nn.BCEWithLogitsLoss()
    
    if use_warmup_decay:
        optimizer = AdamW(model.parameters(), lr=lr, weight_decay=WEIGHT_DECAY)
        total_steps = len(train_loader) * EPOCHS
        warmup_steps = int(total_steps * WARMUP_RATIO)
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
    else:
        optimizer = AdamW(model.parameters(), lr=lr)
        scheduler = None
    
    # train 1 epoch
    model.train()
    for batch in tqdm(train_loader, desc=f"Fold {fold_idx+1}", leave=False):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
        optimizer.step()
        if scheduler:
            scheduler.step()
    
    # evaluate
    metrics = evaluate(model, val_loader, criterion, device)
    
    # cleanup
    del model, optimizer, train_loader, val_loader
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    return metrics

# run comparison
test_lr = 5e-5
test_bs = 128

print(f"Testing warmup/decay comparison: lr={test_lr}, bs={test_bs}\n")

warmup_comparison_results = []

for use_warmup in [False, True]:
    config_name = "with_warmup_and_decay" if use_warmup else "without_warmup_and_decay"
    print(f"Running: {config_name}...")
    
    fold_metrics = []
    for fold_idx, (train_idx, val_idx) in enumerate(folds):
        metrics = train_fold_with_warmup(fold_idx, train_idx, val_idx, test_lr, test_bs, use_warmup)
        fold_metrics.append(metrics)
    
    # average across folds
    avg_result = {
        'config': config_name,
        'macro_precision': round(np.mean([m['macro_precision'] for m in fold_metrics]), 3),
        'macro_recall': round(np.mean([m['macro_recall'] for m in fold_metrics]), 3),
        'macro_f1': round(np.mean([m['macro_f1'] for m in fold_metrics]), 3),
        'macro_auc_pr': round(np.mean([m['macro_auc_pr'] for m in fold_metrics]), 3),
        'per_label_auc_pr': {
            label: round(np.mean([m['per_label_auc_pr'][label] for m in fold_metrics]), 3)
            for label in LABEL_COLS
        }
    }
    warmup_comparison_results.append(avg_result)
    print(f"Done. Macro F1: {avg_result['macro_f1']:.3f}")

# save results
with open(WARMUP_RESULTS_PATH, 'w') as f:
    json.dump(warmup_comparison_results, f, indent=2)

print(f"\nResults saved to {WARMUP_RESULTS_PATH}")


Testing warmup/decay comparison: lr=5e-05, bs=128

Running: without_warmup_and_decay...


Loading weights: 100%|██████████| 103/103 [00:00<00:00, 2122.09it/s, Materializing param=pooler.dense.weight]                             
RobertaModel LOAD REPORT from: distilroberta-base
Key                       | Status     |  | 
--------------------------+------------+--+-
lm_head.dense.bias        | UNEXPECTED |  | 
lm_head.layer_norm.bias   | UNEXPECTED |  | 
lm_head.layer_norm.weight | UNEXPECTED |  | 
lm_head.dense.weight      | UNEXPECTED |  | 
lm_head.bias              | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
Fold 1:   0%|          | 0/293 [00:00<?, ?it/s]

In [None]:
# train one fold (with warmup and weight decay)
def train_fold(fold_idx, train_idx, val_idx, lr, batch_size):
    fold_train_df = sample_df.iloc[train_idx]
    fold_val_df = sample_df.iloc[val_idx]
    
    train_dataset = BertDataset(fold_train_df, tokenizer)
    val_dataset = BertDataset(fold_val_df, tokenizer)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    model = ToxicClassifier().to(device)
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=WEIGHT_DECAY)
    criterion = nn.BCEWithLogitsLoss()
    
    # warmup scheduler
    total_steps = len(train_loader) * EPOCHS
    warmup_steps = int(total_steps * WARMUP_RATIO)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
    
    # train 1 epoch
    model.train()
    for batch in tqdm(train_loader, desc=f"Fold {fold_idx+1}", leave=False):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
        optimizer.step()
        scheduler.step()
    
    # evaluate
    metrics = evaluate(model, val_loader, criterion, device)
    
    # cleanup
    del model, optimizer, scheduler, train_loader, val_loader
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    return metrics


In [None]:
# grid search
all_results = []
grid = list(product(LEARNING_RATES, BATCH_SIZES))

print(f"Starting grid search: {len(grid)} configs × {N_FOLDS} folds\n")

for config_idx, (lr, batch_size) in enumerate(grid):
    print(f"Config {config_idx+1}/{len(grid)}: lr={lr}, batch_size={batch_size}")
    
    fold_metrics = []
    for fold_idx, (train_idx, val_idx) in enumerate(folds):
        metrics = train_fold(fold_idx, train_idx, val_idx, lr, batch_size)
        fold_metrics.append(metrics)
    
    # average across folds
    avg_metrics = {
        'lr': lr,
        'batch_size': batch_size,
        'macro_precision': round(np.mean([m['macro_precision'] for m in fold_metrics]), 3),
        'macro_recall': round(np.mean([m['macro_recall'] for m in fold_metrics]), 3),
        'macro_f1': round(np.mean([m['macro_f1'] for m in fold_metrics]), 3),
        'macro_auc_pr': round(np.mean([m['macro_auc_pr'] for m in fold_metrics]), 3),
        'per_label_auc_pr': {
            label: round(np.mean([m['per_label_auc_pr'][label] for m in fold_metrics]), 3)
            for label in LABEL_COLS
        }
    }
    all_results.append(avg_metrics)
    
    print(f"  → Macro F1: {avg_metrics['macro_f1']:.3f} | AUC-PR: {avg_metrics['macro_auc_pr']:.3f}\n")

print("Grid search complete!")


In [None]:
# save results
with open(RESULTS_PATH, 'w') as f:
    json.dump(all_results, f, indent=2)

print(f"Results saved to {RESULTS_PATH}")

# save best config separately (same format as bert_results.json)
sorted_results = sorted(all_results, key=lambda x: x['macro_f1'], reverse=True)
best = sorted_results[0]
best_results = {
    'lr': best['lr'],
    'batch_size': best['batch_size'],
    'macro_precision': best['macro_precision'],
    'macro_recall': best['macro_recall'],
    'macro_f1': best['macro_f1'],
    'macro_auc_pr': best['macro_auc_pr'],
    'per_label_auc_pr': best['per_label_auc_pr']
}

with open(BEST_PATH, 'w') as f:
    json.dump(best_results, f, indent=2)

print(f"Best config saved to {BEST_PATH}")

