<a href="https://colab.research.google.com/github/bhussn/SecSplitLLM/blob/main/SecSplitLLM/notebooks/gpt-2/Fine_Tune_bert_base_uncased.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Installs
!pip install torch
!pip install transformers
!pip install datasets
!pip install evaluate
!pip install wandb
!pip install accelerate

Found existing installation: torch 2.6.0+cu124
Uninstalling torch-2.6.0+cu124:
  Successfully uninstalled torch-2.6.0+cu124
Collecting torch
  Downloading torch-2.7.1-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.6.77 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.6.77 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.6.80 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.5.1.17 (from torch)
  Downloading nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-c

In [None]:
!pip install torchvision

Collecting torch==2.6.0 (from torchvision)
  Downloading torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl.metadata (28 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.6.0->torchvision)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch==2.6.0->torchvision)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-c

In [None]:
import os
import time
import torch
import numpy as np
import wandb
import evaluate
import csv
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    GPT2ForSequenceClassification,
    Trainer,
    TrainingArguments,
    TrainerCallback
)

In [None]:
# This callback helps us keep track of important stuff each epoch:
# time taken, GPU memory, loss, and accuracy.
# Saves it in a CSV and sends it to W&B for easy tracking.
class SimpleLoggerCallback(TrainerCallback):
    def __init__(self, log_path="training_log.csv"):
        self.log_path = log_path
        self.epoch_start_time = None
        # Write CSV header right away so we don't lose track
        with open(self.log_path, mode="w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["epoch", "duration_sec", "gpu_mem_allocated_gb", "train_loss", "eval_accuracy"])

    # Mark the time when an epoch starts — need this to calculate how long it took later.
    def on_epoch_begin(self, args, state, control, **kwargs):
        self.epoch_start_time = time.time()

    # Grab the training loss whenever we get logs from Trainer
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs and "loss" in logs:
            self.last_train_loss = logs["loss"]

    # When epoch ends, calculate duration, GPU memory used, and log everything
    def on_epoch_end(self, args, state, control, logs=None, **kwargs):
        duration = time.time() - self.epoch_start_time
        gpu_mem = torch.cuda.memory_allocated() / (1024 ** 3) if torch.cuda.is_available() else 0
        train_loss = getattr(self, "last_train_loss", None)
        eval_acc = logs.get("eval_accuracy") if logs else None

        # Print some quick info to console so we can see progress live
        print(f"Epoch {int(state.epoch)} done in {duration:.2f}s | GPU mem: {gpu_mem:.2f} GB | Loss: {train_loss:.4f} | Val Acc: {eval_acc}")

        # Log all the important metrics to WandB
        wandb.log({
            f"epoch_{int(state.epoch)}_duration_sec": duration,
            f"epoch_{int(state.epoch)}_gpu_mem_allocated_GB": gpu_mem,
            "train_loss": train_loss,
            "eval_accuracy": eval_acc
        })

        # Append the epoch stats to our CSV log file for offline use
        with open(self.log_path, mode="a", newline="") as f:
            writer = csv.writer(f)
            writer.writerow([int(state.epoch), duration, gpu_mem, train_loss, eval_acc])

# Simple metric calculation — just accuracy here for SST-2
def compute_metrics(eval_pred):
    accuracy_metric = evaluate.load("accuracy")
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return accuracy_metric.compute(predictions=preds, references=labels)

# Tokenize input sentences - pad and truncate to max length 128 for consistency
def tokenize_function(examples, tokenizer):
    return tokenizer(examples["sentence"], padding="max_length", truncation=True, max_length=128)

In [None]:
def main():
    # Manually define arguments for Google Colab
    class Args:
        sample_fraction = 0.001
        learning_rate = 2e-5
        batch_size = 8
        epochs = 2
        output_dir = "./results"

    args = Args()

In [None]:
# Fixed parameters — locked in for GPT2 + SST-2 setup
    model_name = "gpt2"
    dataset_name = "glue"
    dataset_config = "sst2"
    run_name = "gpt2-sst2"

    print("Logging into WandB so we can track this run...")
    wandb.login()
    os.environ["WANDB_PROJECT"] = run_name

    print(f"Loading dataset: {dataset_name} with config {dataset_config}")
    dataset = load_dataset(dataset_name, dataset_config)

    # If you want to speed up experimentation, just use a fraction of the data
    if args.sample_fraction < 1.0:
        for split in ["train", "validation"]:
            dataset[split] = dataset[split].shuffle(seed=42).select(range(int(args.sample_fraction * len(dataset[split]))))

    print(f"Loading tokenizer and model '{model_name}'")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    # GPT2 doesn't have a padding token by default, so set it to eos token for padding
    tokenizer.pad_token = tokenizer.eos_token

    model = GPT2ForSequenceClassification.from_pretrained(model_name, num_labels=2)
    model.config.pad_token_id = tokenizer.pad_token_id

    print("Tokenizing dataset... this might take a minute")
    tokenized_dataset = dataset.map(lambda x: tokenize_function(x, tokenizer), batched=True)

In [None]:
    # Setup training parameters — tweak these to your liking
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        learning_rate=args.learning_rate,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        logging_dir=os.path.join(args.output_dir, "logs"),
        logging_steps=10,
        report_to="wandb",
        run_name=run_name,
    )

    # Put everything together for Trainer API
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset["validation"],
        compute_metrics=compute_metrics,
        callbacks=[SimpleLoggerCallback(log_path=os.path.join(args.output_dir, "training_log.csv"))],
    )

In [None]:
    # Run the trainer and print results
    trainer.train()

    print("Training finished. Running final evaluation...")
    results = trainer.evaluate()
    print("Evaluation results:", results)

    print(f"Saving the fine-tuned model and tokenizer in {args.output_dir}")
    model.save_pretrained(os.path.join(args.output_dir, "model"))
    tokenizer.save_pretrained(os.path.join(args.output_dir, "tokenizer"))

    main()