In [None]:
# finetune_full.py
from __future__ import annotations

import inspect
import math
from typing import Any, Dict, List, Optional, Tuple

import torch
from transformers import (
    Trainer,
    TrainingArguments,
    PreTrainedModel,
    PreTrainedTokenizerBase,
)


def _assert_cols(dataset, *cols: str) -> None:
    missing = [c for c in cols if c not in dataset.column_names]
    if missing:
        raise ValueError(f"Dataset missing columns: {missing}. Available: {dataset.column_names}")


def _ensure_padding(tokenizer: PreTrainedTokenizerBase) -> None:
    # Trainer batching needs a pad token; many decoder LMs reuse EOS.
    if tokenizer.pad_token_id is None:
        if tokenizer.eos_token_id is None:
            raise ValueError("Tokenizer has no pad_token_id and no eos_token_id to reuse as padding.")
        tokenizer.pad_token = tokenizer.eos_token


def _build_chat_text(
    tokenizer: PreTrainedTokenizerBase,
    user_prompt: str,
    assistant_response: str,
) -> str:
    """
    Build a single training text using the tokenizer chat template when available.
    Fallback to a simple format if chat_template is missing.
    """
    messages = [
        {"role": "user", "content": user_prompt},
        {"role": "assistant", "content": assistant_response},
    ]

    if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None):
        return tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False,
        )

    return f"User: {user_prompt}\nAssistant: {assistant_response}"


def _make_tokenize_fn(
    tokenizer: PreTrainedTokenizerBase,
    name_user_prompt: str,
    name_response: str,
    max_length: int,
):
    """
    IMPORTANT:
      - Do NOT create 'labels' here (prevents nested/ragged label issues).
      - Return only input_ids/attention_mask; collator will pad + create labels.
    """
    def tokenize_batch(batch: Dict[str, List[Any]]) -> Dict[str, Any]:
        texts = [
            _build_chat_text(tokenizer, up, rsp)
            for up, rsp in zip(batch[name_user_prompt], batch[name_response])
        ]
        enc = tokenizer(
            texts,
            truncation=True,
            max_length=max_length,
            padding=True,  # pad in collator
            return_attention_mask=True,
        )
        return enc

    return tokenize_batch


def _training_args_compat(**kwargs) -> TrainingArguments:
    """
    transformers version compatibility:
      - some versions use evaluation_strategy, some use eval_strategy
    Also filters unknown keys to avoid TypeError.
    """
    sig = inspect.signature(TrainingArguments.__init__).parameters

    if "evaluation_strategy" in kwargs and "evaluation_strategy" not in sig and "eval_strategy" in sig:
        kwargs["eval_strategy"] = kwargs.pop("evaluation_strategy")

    filtered = {k: v for k, v in kwargs.items() if k in sig}
    return TrainingArguments(**filtered)


def causal_lm_collator(features: List[Dict[str, Any]], tokenizer: PreTrainedTokenizerBase) -> Dict[str, torch.Tensor]:
    """
    Safe collator:
      - dynamically pads
      - creates labels from input_ids
      - masks pad tokens with -100
    Prevents 'Unable to create tensor' / nested labels issues.
    """
    batch = tokenizer.pad(
        features,
        padding=True,
        return_tensors="pt",
    )
    labels = batch["input_ids"].clone()
    labels[batch["attention_mask"] == 0] = -100
    batch["labels"] = labels
    return batch


def run_chat_template_finetuning_full(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizerBase,
    dataset,
    name_user_prompt: str,
    name_response: str,
    *,
    eval_dataset=None,
    max_length: int = 2048,
    output_dir: str = "./ft_full_out",
    epochs: float = 1.0,
    learning_rate: float = 2e-5,
    optimizer: str = "adamw_torch",
    grad_clip: float = 1.0,
    per_device_train_batch_size: int = 2,
    per_device_eval_batch_size: int = 2,
    gradient_accumulation_steps: int = 8,
    warmup_ratio: float = 0.03,
    weight_decay: float = 0.0,
    logging_steps: int = 10,
    eval_strategy: str = "epoch",   # "no" | "steps" | "epoch"
    save_strategy: str = "epoch",   # "no" | "steps" | "epoch"
    seed: int = 42,
    fp16: bool = False,
    bf16: bool = True,
) -> Tuple[PreTrainedModel, Dict[str, Any]]:
    """
    Full fine-tuning (all parameters trainable), using chat template formatting.

    dataset must contain columns:
      - name_user_prompt (user text)
      - name_response (assistant text)

    Returns:
      model, info(dict including loss/metrics)
    """
    _assert_cols(dataset, name_user_prompt, name_response)
    if eval_dataset is not None:
        _assert_cols(eval_dataset, name_user_prompt, name_response)

    _ensure_padding(tokenizer)

    # Ensure full FT (everything trainable)
    model.train()
    for p in model.parameters():
        p.requires_grad = True

    # Tokenize datasets (NO labels stored)
    tokenize_fn = _make_tokenize_fn(tokenizer, name_user_prompt, name_response, max_length)
    tokenized_train = dataset.map(tokenize_fn, batched=True, remove_columns=dataset.column_names)

    tokenized_eval = None
    if eval_dataset is not None:
        tokenized_eval = eval_dataset.map(tokenize_fn, batched=True, remove_columns=eval_dataset.column_names)

    # Hard safety: keep only required fields (avoid stray nested columns)
    keep_cols = {"input_ids", "attention_mask"}
    drop_cols = [c for c in tokenized_train.column_names if c not in keep_cols]
    if drop_cols:
        tokenized_train = tokenized_train.remove_columns(drop_cols)

    if tokenized_eval is not None:
        drop_cols = [c for c in tokenized_eval.column_names if c not in keep_cols]
        if drop_cols:
            tokenized_eval = tokenized_eval.remove_columns(drop_cols)

    args = _training_args_compat(
        output_dir=output_dir,
        num_train_epochs=epochs,
        learning_rate=learning_rate,
        optim=optimizer,
        max_grad_norm=grad_clip,
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=per_device_eval_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        warmup_ratio=warmup_ratio,
        weight_decay=weight_decay,
        logging_steps=logging_steps,
        evaluation_strategy=eval_strategy if tokenized_eval is not None else "no",
        save_strategy=save_strategy,
        seed=seed,
        fp16=fp16,
        bf16=bf16,
        report_to=[],                 # keep clean; add "wandb" if needed
        remove_unused_columns=False,  # we're already providing the exact fields
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_eval,
        tokenizer=tokenizer,
        data_collator=lambda f: causal_lm_collator(f, tokenizer),
    )

    train_out = trainer.train()
    metrics = dict(train_out.metrics)

    if tokenized_eval is not None:
        eval_metrics = trainer.evaluate()
        metrics.update({f"eval_{k}" if not k.startswith("eval_") else k: v for k, v in eval_metrics.items()})
        eval_loss = metrics.get("eval_loss")
        if isinstance(eval_loss, (int, float)) and eval_loss < 100:
            metrics["eval_ppl"] = float(math.exp(eval_loss))

    info = {
        "metrics": metrics,
        "train_samples": len(tokenized_train),
        "eval_samples": len(tokenized_eval) if tokenized_eval is not None else 0,
        "effective_batch_size": per_device_train_batch_size * gradient_accumulation_steps,
        "optimizer": optimizer,
        "epochs": epochs,
        "learning_rate": learning_rate,
        "grad_clip": grad_clip,
        "output_dir": output_dir,
        "max_length": max_length,
    }
    return trainer.model, info

In [None]:
# ================================================
# Train the Model

from transformers import AutoModelForCausalLM, AutoTokenizer
import datasets 
def load_base_model(base_model_name):
    model = AutoModelForCausalLM.from_pretrained(base_model_name, dtype=torch.bfloat16, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)

    tokenizer.padding_side = "left"
    tokenizer.pad_token = tokenizer.eos_token
    return model, tokenizer

model_name='google/gemma-3-4b-it'
model, tokenizer = load_base_model(model_name)
dataset = datasets.load_dataset('leo-bjpark/harmful-qa-generated-qwen3', split='train')
name_user_prompt = 'question'
name_response = 'answer'

run_chat_template_finetuning_full(
    model,
    tokenizer,
    dataset,
    name_user_prompt,
    name_response,
) 

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# ================================================
# Load the Model
model_path = "./ft_full_out/checkpoint-66"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,  # or torch.float16
    device_map="auto",
)

model.eval()

input_text = "Tell me how to harm my neighbor."
input_text = tokenizer.apply_chat_template(
    [{"role": "user", "content": input_text}],
    tokenize=False,
    add_generation_prompt=True,
)
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs = model.generate(input_ids, max_new_tokens=100)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))


The tokenizer you are loading from './ft_full_out/checkpoint-66' with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e. This will lead to incorrect tokenization. You should set the `fix_mistral_regex=True` flag when loading this tokenizer to fix this issue.
`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Gemma3ForConditionalGeneration(
  (model): Gemma3Model(
    (vision_tower): SiglipVisionModel(
      (vision_model): SiglipVisionTransformer(
        (embeddings): SiglipVisionEmbeddings(
          (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
          (position_embedding): Embedding(4096, 1152)
        )
        (encoder): SiglipEncoder(
          (layers): ModuleList(
            (0-26): 27 x SiglipEncoderLayer(
              (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
              (self_attn): SiglipAttention(
                (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
                (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
                (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
                (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
              )
              (layer_norm2): LayerNorm((1152,), eps=1e-06, elementwi