<a href="https://colab.research.google.com/github/calvinli2024/CS614-genai/blob/main/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Packages

In [None]:
%pip install evaluate peft

In [None]:
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Any

import torch
from datasets import load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    AutoModelForCausalLM,
    pipeline
)

import evaluate
from evaluate import evaluator

import itertools
import json
from copy import deepcopy

from peft import LoraConfig, TaskType, get_peft_model

In [None]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Dataset

In [None]:
dataset = load_dataset("sh0416/ag_news")

dataset = {
    "train": dataset["train"].shuffle(seed=42).select(range(5000)),
    "valid": dataset["train"].shuffle(seed=42).select(range(5000, 6000)),
    "test": dataset["test"].shuffle(seed=42).select(range(2000))
}

label_names = dataset["train"]["label"]
num_labels = len(label_names)

# Tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")

# Ensure we have a padding token (Qwen is a causal LM)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Model

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    "Qwen/Qwen3-0.6B",
    num_labels=num_labels,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
    pad_token_id=tokenizer.pad_token_id,
)

lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False, # set to False for training
    r=8, # dimension of the smaller matrices
    lora_alpha=32, # scaling factor
    lora_dropout=0.1 # dropout of LoRA layers
)

model = get_peft_model(model, lora_config)

device = "cuda" if torch.cuda.is_available() else "cpu"

model.to(device)

# Preprocess

In [None]:
max_length = 256

def preprocess(dataset_split, select_columns: List[str]):
    def run_tokenizer(row):
      return tokenizer(
          row["text"],
          padding="max_length",
          truncation=True,
          max_length=max_length,
      )

    def add_text(row):
      row["text"] = row["title"] + " " + row["description"]

      return row

    dataset_split = dataset_split.map(add_text)

    encoded_dataset = dataset_split.map(run_tokenizer, batched=True)

    encoded_dataset = encoded_dataset.remove_columns(
        [col for col in encoded_dataset.column_names if col not in select_columns]
    )

    return encoded_dataset.with_format("torch")

train_dataset = preprocess(dataset['train'], ["input_ids", "attention_mask", "label"])
valid_dataset = preprocess(dataset['valid'], ["input_ids", "attention_mask", "label"])
test_dataset = preprocess(dataset['test'], ["label", "text"])

# Metrics

In [None]:
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logits.argmax(axis=-1)
    results = accuracy.compute(predictions=preds, references=labels)
    results.update(
        f1.compute(predictions=preds, references=labels, average="macro")
    )

    return results

# Hyperparameter Tuning

## Training Arguments

In [None]:
training_arguments = TrainingArguments(
    output_dir="./qwen3_agnews_grid",
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="steps",
    logging_steps=100,
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate=5e-5,
    weight_decay=0.01,
    warmup_steps=0, # Setting to 0 to replace warmup_ratio, as it is deprecated.
    lr_scheduler_type="cosine",
    gradient_accumulation_steps=4,
    fp16=False, # Changed from torch.cuda.is_available() to False to resolve bfloat16 AMP issue
    report_to="none",
)

trainer_kwargs = dict(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    compute_metrics=compute_metrics,
)

## Search Space

In [None]:
idx = 0
for learning_rate in [1e-4, 1e-3]:
    for num_train_epochs in [1, 2, 3]:
        hyperparameters[idx] = {
            "learning_rate": learning_rate,
            "num_train_epochs": epochs
        }

        idx += 1


## Helper Functions

In [None]:
def run_single_experiment(
    base_training_args: TrainingArguments,
    trainer_cls,
    trainer_kwargs: Dict[str, Any],
    hp_config: Dict[str, float|int],
) -> Dict[str, Any]:
    """
    base_training_args: a TrainingArguments object with default values.
    trainer_cls: usually `Trainer`.
    trainer_kwargs: dict with keys like model, train_dataset, eval_dataset, tokenizer, compute_metrics.
    hp_config: specific hyperparameters for this run (e.g. lr, epochs, batch size).
    """
    # 1) Clone TrainingArguments and override selected fields
    args_dict = base_training_args.to_dict()
    for k, v in hp_config.items():
        args_dict[k] = v

    training_args = TrainingArguments(**args_dict)

    # 2) Create a fresh Trainer (important if model should be reinitialized externally)
    trainer = trainer_cls(
        args=training_args,
        **trainer_kwargs,
    )

    # 3) Train and evaluate
    train_output = trainer.train()
    eval_metrics = trainer.evaluate()

    result = {
        "hp_config": hp_config,
        "train_samples": train_output.metrics.get("train_samples", None),
        "eval_metrics": eval_metrics,
    }

    return result

def grid_search_hyperparams(
    base_training_args: TrainingArguments,
    trainer_cls,
    trainer_kwargs: Dict[str, Any],
    hyperparameters: Dict[int, Dict[str, int|float]],
    results_path: str = "grid_search_results.jsonl",
) -> List[Dict[str, Any]]:
    all_results = {}

    os.makedirs(os.path.dirname(results_path) or ".", exist_ok=True)

    with open(results_path, "w", encoding="utf-8") as f_out:
        for idx, combo in hyperparameters.items():
            print("\n=== Running config:", combo, "===")

            result = run_single_experiment(
                base_training_args=base_training_args,
                trainer_cls=trainer_cls,
                trainer_kwargs=deepcopy(trainer_kwargs),
                hp_config=combo,
            )

            trainer_cls.save_model(f"checkpoint_{idx}")
            
            # Persist each result as one JSON line
            f_out.write(json.dumps(result) + "\n")
            f_out.flush()

            all_results[idx] = result

    return all_results

# Train

In [None]:
results = grid_search_hyperparams(
    base_training_args=training_arguments,
    trainer_cls=Trainer,
    trainer_kwargs=trainer_kwargs,
    hyperparameters=hyperparameters,
    results_path="grid_search_results.jsonl",
)

# 5) Pick best config by metric (e.g. accuracy)
best = max(results.items(), key=lambda r: r[1]["eval_metrics"].get("eval_accuracy", 0.0))

print("Best index:", best[0])
print("Best config:", best[1]["hp_config"])
print("Best metrics:", best[1["eval_metrics"])

# Inference

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    f"./qwen3_agnews_grid/checkpoing-{best[0]}",
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
    pad_token_id=tokenizer.pad_token_id,
)

results = evaluator("text-classification").compute(
    model_or_pipeline=pipeline("text-classification", model=model),
    data=test_dataset,
    metric=evaluate.load("f1"),
    input_column="text",
    label_column="label",
    strategy="bootstrap", 
    n_resamples=200
)

print(results)