In [None]:
import random
from itertools import product

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.metrics import (
    ConfusionMatrixDisplay,
    accuracy_score,
    classification_report,
    confusion_matrix,
    precision_recall_fscore_support,
)
from sklearn.model_selection import train_test_split
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from transformers import BertModel, BertTokenizer

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seed()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64

In [None]:
data_frame = pd.read_csv("/kaggle/input/dataset/classifier_data.csv")
data_frame = data_frame.iloc[:, 1:]
data_frame.head()

###  Format Input Texts and Labels for Classifier


In [None]:
def format_input(row):
    return f"{row['original_caption']} <SEP> {row['generated_caption']} <SEP> {row['perturbation_percentage']}"


data_frame["input_text"] = data_frame.apply(format_input, axis=1)
data_frame["label"] = data_frame["model_type"].apply(
    lambda x: 0 if x == "Model A (SmolVLM)" else 1
)

In [None]:
print(f"SmolVLM examples: {sum(data_frame['label'] == 0)}")
print(f"Custom Model examples: {sum(data_frame['label'] == 1)}")

In [None]:
data_frame.head()

###  Split Data into Train, Validation, and Test Sets

In [None]:
train_texts, temp_texts, train_labels, temp_labels = train_test_split(
    data_frame["input_text"],
    data_frame["label"],
    test_size=0.3,
    random_state=42,
    stratify=data_frame["label"],
)

val_texts, test_texts, val_labels, test_labels = train_test_split(
    temp_texts, temp_labels, test_size=2 / 3, random_state=42, stratify=temp_labels
)

### Tokenize Texts and Prepare Labels with BERT Tokenizer

In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")


def tokenize(texts, labels):
    encodings = tokenizer(
        list(texts), truncation=True, padding=True, return_tensors="pt", max_length=512
    )
    labels = torch.tensor(list(labels.values))
    return encodings, labels


### Tokenize Train, Validation, and Test Splits

In [None]:
train_encodings, train_labels = tokenize(train_texts, train_labels)
val_encodings, val_labels = tokenize(val_texts, val_labels)
test_encodings, test_labels = tokenize(test_texts, test_labels)

###  CaptionClassifierDataset: Dataset for BERT-based Classification

In [None]:
class CaptionClassifierDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        return {
            "input_ids": self.encodings["input_ids"][idx],
            "attention_mask": self.encodings["attention_mask"][idx],
            "labels": self.labels[idx],
        }

    def __len__(self):
        return len(self.labels)

In [None]:
train_dataset = CaptionClassifierDataset(train_encodings, train_labels)
val_dataset = CaptionClassifierDataset(val_encodings, val_labels)
test_dataset = CaptionClassifierDataset(test_encodings, test_labels)

In [None]:
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True
)
test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True
)

###  CaptionClassifier: BERT-based Classifier for Caption Prediction

In [None]:
class CaptionClassifier(torch.nn.Module):
    def __init__(self, num_classes=2, dropout=0.3):
        super(CaptionClassifier, self).__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        for param in self.bert.parameters():
            param.requires_grad = False
        self.dropout = torch.nn.Dropout(dropout)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        x = self.dropout(outputs.pooler_output)
        return self.classifier(x)

    def predict(self, dataloader):
        self.eval()
        all_preds = []
        with torch.no_grad(), tqdm(total=len(dataloader), desc="Predicting") as pbar:
            for batch in dataloader:
                input_ids = batch["input_ids"].to(DEVICE)
                attention_mask = batch["attention_mask"].to(DEVICE)

                outputs = self(input_ids, attention_mask)
                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                pbar.update(1)
        return np.array(all_preds)

###  Evaluate Classifier: Accuracy, Precision, Recall, F1 Score


In [None]:
def evaluate_classifier(model, dataloader, DEVICE, metric="all"):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad(), tqdm(total=len(dataloader), desc="Evaluating") as pbar:
        for batch in dataloader:
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            outputs = model(input_ids, attention_mask)
            preds = torch.argmax(outputs, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            pbar.update(1)

    if metric == "accuracy":
        return accuracy_score(all_labels, all_preds)
    else:
        precision, recall, f1, _ = precision_recall_fscore_support(
            all_labels, all_preds, average="macro"
        )
        print(
            "Classification Report:\n",
            classification_report(
                all_labels, all_preds, target_names=["SmolVLM", "Custom Model"]
            ),
        )
        return {
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "accuracy": accuracy_score(all_labels, all_preds),
        }

###  Train Classifier: Training with Early Stopping and Validation Accuracy


In [None]:
def train_classifier(
    model,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    scheduler,
    DEVICE,
    epochs=20,
    patience=4,
):
    model.to(DEVICE)
    best_val_acc = 0
    no_improve = 0
    best_model_state = None

    history = {"train_loss": [], "val_accuracy": []}

    epoch_pbar = tqdm(total=epochs, desc="Training epochs", position=0)
    for epoch in range(epochs):
        model.train()
        total_loss = 0

        batch_pbar = tqdm(
            total=len(train_loader),
            desc=f"Epoch {epoch + 1}/{epochs}",
            position=1,
            leave=False,
        )
        for batch in train_loader:
            optimizer.zero_grad()
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            batch_pbar.update(1)
            batch_pbar.set_description(
                f"Epoch {epoch + 1}/{epochs} (loss: {loss.item():.4f})"
            )
        batch_pbar.close()

        avg_train_loss = total_loss / len(train_loader)
        history["train_loss"].append(avg_train_loss)
        val_acc = evaluate_classifier(model, val_loader, DEVICE, metric="accuracy")
        history["val_accuracy"].append(val_acc)
        scheduler.step(val_acc)

        epoch_pbar.set_description(
            f"Training epochs | Loss: {avg_train_loss:.4f} | Val Acc: {val_acc:.4f}"
        )
        epoch_pbar.update(1)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict().copy()
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                print("Early stopping triggered.")
                break
    epoch_pbar.close()
    return best_val_acc, best_model_state

###  Hyperparameter Tuning: Search for Optimal Learning Rate and Dropout


In [None]:
def hyper_parameter_tuning():
    hyperparameters = {
        "learning_rate": [1e-5, 2e-5, 5e-5],
        "dropouts": [0.3, 0.5],
    }
    wd = 0.001
    results = []
    best_val_acc = 0
    best_model_state = None
    best_hyperparams = {}

    total_combinations = len(hyperparameters["learning_rate"]) * len(
        hyperparameters["dropouts"]
    )

    hp_pbar = tqdm(total=total_combinations, desc="Hyperparameter tuning", position=0)
    for lr, dropout in product(
        hyperparameters["learning_rate"],
        hyperparameters["dropouts"],
    ):
        hp_pbar.set_description(
            f"Hyperparameter tuning | lr={lr}, dropout={dropout}, weight_decay={wd}"
        )

        model = CaptionClassifier(dropout=dropout)
        optimizer = AdamW(model.classifier.parameters(), lr=lr, weight_decay=wd)
        criterion = nn.CrossEntropyLoss()
        scheduler = ReduceLROnPlateau(
            optimizer, mode="max", factor=0.5, patience=1, verbose=True
        )

        (
            val_acc,
            model_state,
        ) = train_classifier(
            model,
            train_loader,
            val_loader,
            optimizer,
            criterion,
            scheduler,
            DEVICE=DEVICE,
            epochs=20,
            patience=4,
        )

        results.append(
            {"lr": lr, "dropout": dropout, "weight_decay": wd, "val_acc": val_acc}
        )

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model_state
            best_hyperparams = {"lr": lr, "dropout": dropout, "weight_decay": wd}
        hp_pbar.update(1)
    hp_pbar.close()
    results_df = pd.DataFrame(results).sort_values(by="val_acc", ascending=False)
    print("\nAll hyperparameter results:")
    print(results_df)
    return best_model_state, best_hyperparams, best_val_acc

###  Hyperparameter Tuning: Identify Best Hyperparameters and Validation Accuracy


In [None]:
best_model_state, best_hyperparams, best_val_acc = hyper_parameter_tuning()
print("\nBest hyperparameters with best validation accuracy of {best_val_acc:.4f}:")
print(best_hyperparams)

### Final Model: CaptionClassifier with Best Hyperparameters


In [None]:
final_model = CaptionClassifier(dropout=best_hyperparams["dropout"])
final_model.load_state_dict(best_model_state)
final_model.to(DEVICE)

In [None]:
print("\nEvaluating on test set...")
test_metrics = evaluate_classifier(final_model, test_loader, DEVICE)
print(f"Test Accuracy: {test_metrics['accuracy']:.4f}")
print(f"Test Precision: {test_metrics['precision']:.4f}")
print(f"Test Recall: {test_metrics['recall']:.4f}")
print(f"Test F1 Score: {test_metrics['f1']:.4f}")

### Plotting Confusion Matrix

In [None]:
def plot_confusion_matrix(y_true, y_pred, classes):
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
    disp.plot(cmap="Blues")
    plt.title("Confusion Matrix")
    plt.show()


y_true = test_labels
y_pred = final_model.predict(test_loader)
plot_confusion_matrix(y_true, y_pred, classes=["SmolVLM", "Custom Model"])