# Imports

In [None]:
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Any

import torch
from datasets import load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
)

import evaluate

import itertools
import json
from copy import deepcopy

# Configuration & Setup

In [None]:
MODEL_NAME = "Qwen/Qwen3-4B"  # 4B, open-weight, causal LM for research use.[web:10]
DATASET_NAME = "sh0416/ag_news"  # AG News topic classification dataset.[web:8]

os.environ["TOKENIZERS_PARALLELISM"] = "false"

device = "cuda" if torch.cuda.is_available() else "cpu"

# Dataset

In [None]:
dataset = load_dataset(DATASET_NAME)  # splits: train, test.[web:8]

label_names = dataset["train"].features["label"].names
num_labels = len(label_names)

# For a small, quick run, you can optionally subsample.
# dataset["train"] = dataset["train"].shuffle(seed=42).select(range(2000))
# dataset["test"] = dataset["test"].shuffle(seed=42).select(range(500))

# Tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Ensure we have a padding token (Qwen is a causal LM)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Model

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=num_labels,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)

model.to(device)

# Preprocess

In [None]:
max_length = 256

def preprocess(examples):
    # AG News has "text" and "label".[web:8]
    return tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=max_length,
    )

encoded_dataset = dataset.map(preprocess, batched=True)
encoded_dataset = encoded_dataset.remove_columns(
    [col for col in encoded_dataset["train"].column_names if col not in ["input_ids", "attention_mask", "label"]]
)
encoded_dataset = encoded_dataset.with_format("torch")

train_dataset = encoded_dataset["train"]
eval_dataset = encoded_dataset["test"]  # or create your own split

# Metrics

In [None]:
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logits.argmax(axis=-1)
    results = accuracy.compute(predictions=preds, references=labels)
    results.update(
        f1.compute(predictions=preds, references=labels, average="macro")
    )

    return results

# Hyperparameter Tuning

## Training Arguments

In [None]:
training_arguments = TrainingArguments(
    output_dir="./qwen3_agnews_grid",
    evaluation_strategy="epoch",
    save_strategy="no",
    logging_strategy="steps",
    logging_steps=100,
    num_train_epochs=1,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    learning_rate=5e-5,
    weight_decay=0.01,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    gradient_accumulation_steps=4,
    fp16=torch.cuda.is_available(),
    report_to="none",
)

trainer_kwargs = dict(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

## Search Space

In [None]:
search_space = {
    "learning_rate": [5e-5, 1e-4],
    "num_train_epochs": [1, 2, 3],
    "per_device_train_batch_size": [4, 8],
}

## Helper Functions

In [None]:
def run_single_experiment(
    base_training_args: TrainingArguments,
    trainer_cls,
    trainer_kwargs: Dict[str, Any],
    hp_config: Dict[str, Any],
) -> Dict[str, Any]:
    """
    base_training_args: a TrainingArguments object with default values.
    trainer_cls: usually `Trainer`.
    trainer_kwargs: dict with keys like model, train_dataset, eval_dataset, tokenizer, compute_metrics.
    hp_config: specific hyperparameters for this run (e.g. lr, epochs, batch size).
    """
    # 1) Clone TrainingArguments and override selected fields
    args_dict = base_training_args.to_dict()
    for k, v in hp_config.items():
        args_dict[k] = v

    training_args = TrainingArguments(**args_dict)

    # 2) Create a fresh Trainer (important if model should be reinitialized externally)
    trainer = trainer_cls(
        args=training_args,
        **trainer_kwargs,
    )

    # 3) Train and evaluate
    train_output = trainer.train()
    eval_metrics = trainer.evaluate()

    result = {
        "hp_config": hp_config,
        "train_runtime": train_output.training_time,
        "train_samples": train_output.metrics.get("train_samples", None),
        "eval_metrics": eval_metrics,
    }

    return result

def grid_search_hyperparams(
    base_training_args: TrainingArguments,
    trainer_cls,
    trainer_kwargs: Dict[str, Any],
    search_space: Dict[str, List[Any]],
    results_path: str = "grid_search_results.jsonl",
) -> List[Dict[str, Any]]:
    """
    search_space: dict of hyperparameter -> list of values, e.g.
        {
            "learning_rate": [5e-5, 1e-4],
            "num_train_epochs": [1, 2, 3],
            "per_device_train_batch_size": [4, 8],
        }
    """
    # Cartesian product of search space.[web:14][web:17]
    keys = list(search_space.keys())
    value_lists = [search_space[k] for k in keys]

    all_results: List[Dict[str, Any]] = []

    os.makedirs(os.path.dirname(results_path) or ".", exist_ok=True)

    with open(results_path, "w", encoding="utf-8") as f_out:
        for combo in itertools.product(*value_lists):
            hp_config = {k: v for k, v in zip(keys, combo)}
            print("\n=== Running config:", hp_config, "===")

            result = run_single_experiment(
                base_training_args=base_training_args,
                trainer_cls=trainer_cls,
                trainer_kwargs=deepcopy(trainer_kwargs),
                hp_config=hp_config,
            )

            # Persist each result as one JSON line
            f_out.write(json.dumps(result) + "\n")
            f_out.flush()

            all_results.append(result)

    return all_results

# Train

In [None]:
results = grid_search_hyperparams(
    base_training_args=training_arguments,
    trainer_cls=Trainer,
    trainer_kwargs=trainer_kwargs,
    search_space=search_space,
    results_path="grid_search_results.jsonl",
)

# 5) Pick best config by metric (e.g. accuracy)
best = max(results, key=lambda r: r["eval_metrics"].get("eval_accuracy", 0.0))
print("Best config:", best["hp_config"])
print("Best metrics:", best["eval_metrics"])


# Inference

In [None]:
id2label = {i: name for i, name in enumerate(label_names)}

def infer(texts: List[str]):
    encodings = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    ).to(device)

    with torch.no_grad():
        outputs = model(**encodings)
        logits = outputs.logits
        preds = logits.argmax(dim=-1).cpu().tolist()
        
    return [id2label[p] for p in preds]

In [None]:
example_texts = [
    "Stocks rose today as the market reacted positively to the latest earnings reports.",
    "The team secured a last-minute victory in the championship game.",
]

print(infer(example_texts))