# BERT Baseline for Figurative Language Understanding

In this notebook, we prepare a BERT-based baseline model for the BESSTIE dataset.
We use the preprocessed encoder-ready data produced earlier and focus on:

- Loading preprocessed tensors
- Understanding their structure
- Preparing datasets and dataloaders
- Verifying that everything is ready for training

## Imports and environment setup

We import the libraries required for:
- loading PyTorch tensors
- building datasets and dataloaders
- defining a BERT classification model
- basic evaluation utilities

No training is performed yet.

In [None]:
# Core libraries
import torch
import numpy as np
import random

# PyTorch utilities
from torch.utils.data import Dataset, DataLoader

# Hugging Face model
from transformers import BertForSequenceClassification

# Metrics
from sklearn.metrics import accuracy_score, f1_score

In [None]:
print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

## Reproducibility

We fix random seeds to ensure reproducible results.
This is important for fair comparison with other models and experiments.

In [None]:
# Fix random seeds
SEED = 50

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print("Random seeds fixed.")

In [None]:
# Simple check: generate a random number twice
print(np.random.rand())
print(np.random.rand())

## Loading preprocessed encoder data

We load the encoder-ready tensors created during preprocessing.
These tensors already include:
- input_ids
- attention_mask
- labels

The data is split into train, validation, and test sets.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

print("Google Drive mounted.")

In [None]:
# Path to preprocessed data
DATA_PATH = "/content/drive/MyDrive/DNLP/data/processed_data/sarcasm_bert_encoder.pt"

# Load the data
encoder_data = torch.load(DATA_PATH)

print("Keys in loaded data:", encoder_data.keys())

In [None]:
# Check split sizes
for split in encoder_data:
    print(
        split,
        encoder_data[split]["input_ids"].shape,
        encoder_data[split]["labels"].shape
    )

## Inspecting tensor structure

Before modeling, we inspect the shape and content of tensors
to confirm that preprocessing was applied correctly.

We check:
- sequence length (max_length)
- label format
- attention masks

In [None]:
# Inspect one training example
sample_input_ids = encoder_data["train"]["input_ids"][0]
sample_attention_mask = encoder_data["train"]["attention_mask"][0]
sample_label = encoder_data["train"]["labels"][0]

print("Input IDs shape:", sample_input_ids.shape)
print("Attention mask shape:", sample_attention_mask.shape)
print("Label:", sample_label)

In [None]:
sample_input_ids

In [None]:
sample_attention_mask

In [None]:
sample_label

In [None]:
# Verify attention mask consistency
print("Number of real tokens:", sample_attention_mask.sum().item())
print("Total sequence length:", sample_attention_mask.shape[0])

## Dataset class

We define a simple PyTorch Dataset to wrap the preprocessed tensors.
This allows us to use PyTorch DataLoader for batching.

The dataset does NOT perform tokenization.
It only returns already-prepared tensors.

In [None]:
class BertDataset(Dataset):
    def __init__(self, split_data):
        self.input_ids = split_data["input_ids"]
        self.attention_mask = split_data["attention_mask"]
        self.labels = split_data["labels"]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels": self.labels[idx]
        }

In [None]:
train_dataset = BertDataset(encoder_data["train"])
print("Number of training samples:", len(train_dataset))
print("Keys returned by dataset:", train_dataset[0].keys())

## DataLoaders

We create DataLoaders for batching.
At this stage, we only test small batches to verify correctness.

In [None]:
BATCH_SIZE = 8

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

print("Train DataLoader ready.")

In [None]:
# Fetch one batch
batch = next(iter(train_loader))

for key in batch:
    print(key, batch[key].shape)

## Loading BERT for sequence classification

We load a pretrained BERT model with a classification head.
The model is configured for binary classification.

No training is performed in this cell.

In [None]:
model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=2
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print("Model loaded on device:", device)

In [None]:
# Count trainable parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {num_params:,}")

## Sanity check: forward pass

Before training, we perform a single forward pass on one batch.
This ensures:
- no shape mismatch
- no device errors
- correct output format

In [None]:
model.eval()

batch = next(iter(train_loader))

with torch.no_grad():
    outputs = model(
        input_ids=batch["input_ids"].to(device),
        attention_mask=batch["attention_mask"].to(device),
        labels=batch["labels"].to(device)
    )

print("Loss:", outputs.loss.item())
print("Logits shape:", outputs.logits.shape)

In [None]:
# Check that logits match expected shape
assert outputs.logits.shape[1] == 2
print("Forward pass successful. Shapes are correct.")

## Status

At this point:
- Data loading is correct
- Tensor shapes are consistent
- DataLoaders work as expected
- BERT forward pass runs without errors

The notebook is ready for:
- defining loss functions (weighted or focal)
- training loop
- evaluation

We intentionally stop here before any heavy computation.

## Loading precomputed class weights

Class weights were computed during preprocessing and saved to ensure
consistent handling of class imbalance across experiments.

We load and reuse these weights here.

In [None]:
# Load class weights saved during preprocessing
WEIGHTS_PATH = "/content/drive/MyDrive/DNLP/data/processed_data/sarcasm_weights.pt"
class_weights = torch.load(WEIGHTS_PATH).to(device)

print("Loaded class weights:", class_weights)

In [None]:
assert class_weights.shape[0] == 2
assert class_weights.min() > 0
print("Precomputed class weights verified.")

## Weighted Cross-Entropy Loss

This is the standard loss used to handle class imbalance.
Misclassifying minority-class samples results in a higher penalty.

In [None]:
weighted_loss_fn = torch.nn.CrossEntropyLoss(
    weight=class_weights.to(device)
)

print("Weighted loss function initialized.")

In [None]:
# Test loss on dummy logits
dummy_logits = torch.tensor([[2.0, 0.5]], device=device)
dummy_label = torch.tensor([1], device=device)

loss_value = weighted_loss_fn(dummy_logits, dummy_label)
print("Dummy weighted loss:", loss_value.item())

## Optimizer setup

We define the optimizer following the reference paper settings.
The optimizer is defined but not used yet.

In [None]:
from torch.optim import AdamW

optimizer = AdamW(
    model.parameters(),
    lr=2e-5,
    weight_decay=0.01
)

print("Optimizer initialized.")

In [None]:
# Check optimizer has parameters
param_groups = sum(len(g["params"]) for g in optimizer.param_groups)
print("Number of parameter groups:", param_groups)

## Evaluation metrics

We define helper functions to compute accuracy and macro F1.
These metrics are used consistently across experiments.

In [None]:
def compute_metrics(logits, labels):
    preds = torch.argmax(logits, dim=1).cpu().numpy()
    labels = labels.cpu().numpy()

    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average="macro")

    return acc, f1

In [None]:
dummy_logits = torch.tensor([[0.2, 0.8], [0.6, 0.4]])
dummy_labels = torch.tensor([1, 0])

acc, f1 = compute_metrics(dummy_logits, dummy_labels)
print("Accuracy:", acc)
print("Macro F1:", f1)

## Sanity check: metrics on one batch

Before full training, we test metric computation on a single batch.
This ensures that logits, labels, and metrics are compatible.

In [None]:
model.eval()
batch = next(iter(train_loader))

with torch.no_grad():
    outputs = model(
        input_ids=batch["input_ids"].to(device),
        attention_mask=batch["attention_mask"].to(device)
    )

acc, f1 = compute_metrics(
    outputs.logits,
    batch["labels"].to(device)
)

print("Batch accuracy:", acc)
print("Batch macro F1:", f1)

In [None]:
assert 0.0 <= acc <= 1.0
assert 0.0 <= f1 <= 1.0
print("Metric computation successful.")

## Training configuration

We define all training hyperparameters in one place.
This makes experiments easy to reproduce and modify.

In [None]:
# Training configuration
NUM_EPOCHS = 1          # sanity run (will increase later)
LEARNING_RATE = 2e-5
BATCH_SIZE = 8
LOG_INTERVAL = 10       # how often to print loss

print("Training configuration set.")

In [None]:
assert NUM_EPOCHS >= 1
assert LEARNING_RATE > 0
print("Training config verified.")

## Training loop

We define a simple training loop for BERT.
At this stage, we only support:
- weighted cross-entropy loss
- single-GPU or CPU training

This loop will be reused for all experiments.

In [None]:
def train_one_epoch(model, dataloader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0.0

    for step, batch in enumerate(dataloader):
        optimizer.zero_grad()

        # Move data to device
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        # Forward pass
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        # Compute loss
        loss = loss_fn(outputs.logits, labels)

        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Log progress
        if step % LOG_INTERVAL == 0:
            print(f"Step {step} - Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(dataloader)
    return avg_loss

In [None]:
print("Training loop function defined.")

## Validation loop

The validation loop evaluates the model without updating weights.
We compute accuracy and macro F1.

In [None]:
def evaluate(model, dataloader, device):
    model.eval()

    all_logits = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )

            all_logits.append(outputs.logits)
            all_labels.append(labels)

    all_logits = torch.cat(all_logits, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    acc, f1 = compute_metrics(all_logits, all_labels)
    return acc, f1

In [None]:
print("Evaluation loop function defined.")

## Validation DataLoader

We prepare a DataLoader for the validation split.

In [None]:
val_dataset = BertDataset(encoder_data["val"])

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False
)

print("Validation DataLoader ready.")

In [None]:
batch = next(iter(val_loader))
for k in batch:
    print(k, batch[k].shape)

## Time logging utilities

We define helper functions to log elapsed time during training.
This helps us understand computational cost and compare experiments.

In [None]:
import time

def format_time(seconds):
    """Convert seconds to hh:mm:ss format."""
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)
    return f"{hours:02d}:{minutes:02d}:{secs:02d}"

## Full training configuration

We now switch from a sanity run to full training.
These settings follow the reference paper as closely as possible.

In [None]:
# Full training settings
NUM_EPOCHS = 30          # paper-aligned
LOG_INTERVAL = 100       # less verbose during full training

print(f"Training for {NUM_EPOCHS} epochs.")

In [None]:
# ================================
# Full Training Script with tqdm
# ================================

from tqdm.auto import tqdm
import time

# ----------------
# Training setup
# ----------------
NUM_EPOCHS = 30

history = {
    "train_loss": [],
    "val_acc": [],
    "val_f1": []
}

print("Starting full training...\n")

start_time = time.time()

# tqdm progress bar (installer-style)
epoch_bar = tqdm(
    range(1, NUM_EPOCHS + 1),
    desc="Installing BERT baseline",
    unit="epoch",
    dynamic_ncols=True
)

# ----------------
# Training loop
# ----------------
for epoch in epoch_bar:
    epoch_start = time.time()

    # ---- Train ----
    train_loss = train_one_epoch(
        model=model,
        dataloader=train_loader,
        optimizer=optimizer,
        loss_fn=weighted_loss_fn,
        device=device
    )

    # ---- Validate ----
    val_acc, val_f1 = evaluate(
        model=model,
        dataloader=val_loader,
        device=device
    )

    epoch_time = time.time() - epoch_start

    # ---- Save metrics ----
    history["train_loss"].append(train_loss)
    history["val_acc"].append(val_acc)
    history["val_f1"].append(val_f1)

    # ---- Update progress bar (installer info) ----
    epoch_bar.set_postfix({
        "loss": f"{train_loss:.4f}",
        "val_f1": f"{val_f1:.4f}",
        "epoch_time": format_time(epoch_time)
    })

# ----------------
# Final time log
# ----------------
total_time = time.time() - start_time

print("\nTraining completed successfully.")
print("Total training time:", format_time(total_time))

## Saving final model and training logs

We save:
- trained model weights
- training history (loss and metrics)

This ensures reproducibility and allows later analysis.

In [None]:
FINAL_MODEL_PATH = "/content/drive/MyDrive/DNLP/checkpoints/bert_full_training.pt"
HISTORY_PATH = "/content/drive/MyDrive/DNLP/checkpoints/bert_training_history.pt"

torch.save(model.state_dict(), FINAL_MODEL_PATH)
torch.save(history, HISTORY_PATH)

print("Final model saved to:", FINAL_MODEL_PATH)
print("Training history saved to:", HISTORY_PATH)

In [None]:
import os

assert os.path.exists(FINAL_MODEL_PATH)
assert os.path.exists(HISTORY_PATH)
print("Saved files verified.")

## Per-variety evaluation

We evaluate the trained model separately for each English variety.
The model is trained once on all data; only evaluation is grouped by variety.

In [None]:
from collections import defaultdict

def evaluate_per_variety(model, dataloader, device):
    model.eval()

    # Store logits, labels, and varieties
    all_logits = []
    all_labels = []
    all_varieties = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            varieties = batch["variety"]  # list of strings

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )

            all_logits.append(outputs.logits.cpu())
            all_labels.append(labels.cpu())
            all_varieties.extend(varieties)

    all_logits = torch.cat(all_logits, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    # Group indices by variety
    variety_indices = defaultdict(list)
    for idx, var in enumerate(all_varieties):
        variety_indices[var].append(idx)

    # Compute metrics per variety
    results = {}
    for var, indices in variety_indices.items():
        var_logits = all_logits[indices]
        var_labels = all_labels[indices]

        acc, f1 = compute_metrics(var_logits, var_labels)
        results[var] = {
            "accuracy": acc,
            "macro_f1": f1,
            "num_samples": len(indices)
        }

    return results

In [None]:
# Check that varieties exist and align
test_varieties = encoder_data["test"]["variety"]

print("Number of test samples:", len(test_varieties))
print("Unique varieties:", set(test_varieties))

## Running per-variety evaluation on the test set

In [None]:
test_dataset = BertDataset(encoder_data["test"])
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False
)

per_variety_results = evaluate_per_variety(
    model=model,
    dataloader=test_loader,
    device=device
)

per_variety_results

## Per-variety results summary


In [None]:
for var, metrics in per_variety_results.items():
    print(f"\nVariety: {var}")
    print(f"  Samples: {metrics['num_samples']}")
    print(f"  Accuracy: {metrics['accuracy']:.4f}")
    print(f"  Macro F1: {metrics['macro_f1']:.4f}")