# Fine-tuning GPT-OSS-20B on AMD Strix Halo (Unsloth Benchmarks)

**This notebook is an adapted 1-to-1 comparison of the standard Hugging Face finetuning pipeline, accelerated using [Unsloth](https://github.com/unslothai/unsloth).**

### What is Unsloth?
Unsloth speeds up LLM fine-tuning and reduces VRAM usage dramatically, allowing us to load 20B models natively in 4-bit.

This notebook mirrors the exact configurations found in the standard HF notebook, using `unsloth.FastLanguageModel` as the backend engine.

## Setup and utilities

We import PyTorch, Hugging Face Transformers, PEFT, and TRL.  
All memory-tracking utilities are defined here for reuse across sections.  
The model is loaded with `Mxfp4Config(dequantize=True)` to ensure stability during fine-tuning on ROCm.

In [None]:
import os
os.environ["UNSLOTH_SKIP_TORCHVISION_CHECK"] = "1"
import unsloth
from unsloth import FastLanguageModel
import torch
from transformers import AutoTokenizer, Mxfp4Config

from peft import LoraConfig, get_peft_model, PeftModel
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset

def reset_peak_mem():
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()

def report_peak_mem(tag: str = ""):
    if torch.cuda.is_available():
        print(f"Peak training memory{(' ' + tag) if tag else ''}: {torch.cuda.max_memory_allocated()/1e9:.2f} GB")


## Model selection and training parameters

We fine-tune the **OpenAI GPT-OSS-20B (MXFP4)** model using bf16 dequantization for compatibility with Strix Halo.

| Parameter | Meaning | Value |
|:-----------|:---------|:------|
| **MODEL** | Base pretrained model | `openai/gpt-oss-20b` |
| **LR** | Learning rate | `2e-4` |
| **EPOCHS** | Number of full dataset passes | `1` |
| **BATCH_SIZE** | Samples per device per step | `4` |
| **MAX_LEN** | Token limit per example | `2048` |

Hyperparameters follow the ranges used in TRL’s SFTTrainer examples and OpenAI’s fine-tuning guide.  
Longer sequences or higher batch sizes will increase memory quadratically under eager attention.


In [None]:
MODEL = "openai/gpt-oss-20b"
model_name = MODEL.split("/")[-1]

LR = 2e-4
EPOCHS = 1
BATCH_SIZE = 4
MAX_LEN = 2048

## Section 1 – Non-reasoning SFT

Dataset: [`Abirate/english_quotes`](https://huggingface.co/datasets/Abirate/english_quotes)  
This dataset produces short, direct text completions without explicit reasoning or chain-of-thought.

Purpose:
- Validate the SFT training pipeline.  
- Confirm LoRA integration on dequantized weights.  
- Observe stable fine-tuning behavior before moving to reasoning datasets.

Each section loads its own dataset to prevent memory overlap and to keep lifetime of large tensors scoped.


### Unsloth Advantage: Native 4-bit loading
Instead of manually dequantizing `MXFP4` to `bf16` inside Strix Halo's VRAM—which consumes enormous amounts of memory—Unsloth's `FastLanguageModel` can load the model dynamically into optimized 4-bit.

> **Note:** We are purposefully strictly keeping the same SFT parameters here to demonstrate 1-to-1 comparison against the baseline standard notebook.


In [None]:
# Load and prepare the non-reasoning dataset locally to this section
quotes_ds = load_dataset("Abirate/english_quotes", split="train").shuffle(seed=42).select(range(1000))

def quotes_to_messages(ex):
    return {
        "messages": [
            {"role": "user", "content": f"Give me a quote about: {ex['tags']}"},
            {"role": "assistant", "content": f"{ex['quote']} - {ex['author']}"}
        ]
    }

quotes_ds = quotes_ds.map(quotes_to_messages, remove_columns=quotes_ds.column_names).train_test_split(test_size=0.2)
print(f"Quotes train: {len(quotes_ds['train'])}, test: {len(quotes_ds['test'])}")

### LoRA setup for non-reasoning SFT

The model is loaded with bf16 dequantized weights (`Mxfp4Config(dequantize=True)`).  
LoRA adapters are attached to selected MLP expert projections.  
Only ~0.07 % of total parameters are trainable, drastically reducing memory footprint while maintaining adaptation ability.


In [None]:
# Model and LoRA setup for non-reasoning SFT
model_quotes, tokenizer_quotes = FastLanguageModel.from_pretrained(
    model_name=MODEL,
    max_seq_length=MAX_LEN,
    dtype=None,
    load_in_4bit=True,
)

lora_config_quotes = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules="all-linear",
    target_parameters=[
        "7.mlp.experts.gate_up_proj",
        "7.mlp.experts.down_proj",
        "15.mlp.experts.gate_up_proj",
        "15.mlp.experts.down_proj",
        "23.mlp.experts.gate_up_proj",
        "23.mlp.experts.down_proj"
    ]
)

model_quotes = get_peft_model(model_quotes, lora_config_quotes)
model_quotes.print_trainable_parameters()

In [None]:
# Train on quotes
args_quotes = SFTConfig(
    output_dir=f"out-unsloth-{model_name}-lora",
    max_length=MAX_LEN,
    packing=False,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    warmup_ratio=0.03,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    learning_rate=LR,
    lr_scheduler_type="cosine_with_min_lr",
    lr_scheduler_kwargs={"min_lr_rate": 0.1},
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    bf16=True,
    report_to="none",
    save_safetensors=True,
    save_total_limit=1
)

def quotes_formatting_func(example):
    return [tokenizer_quotes.apply_chat_template(m, tokenize=False, add_generation_prompt=False) for m in example['messages']]

trainer_quotes = SFTTrainer(
    model=model_quotes,
    args=args_quotes,
    train_dataset=quotes_ds['train'],
    eval_dataset=quotes_ds['test'],
    formatting_func=quotes_formatting_func,
    processing_class=tokenizer_quotes
)

reset_peak_mem()
trainer_quotes.train()
report_peak_mem("lora")
trainer_quotes.save_model()

In [None]:
# Cleanup non-reasoning objects to free memory
del model_quotes, trainer_quotes
torch.cuda.empty_cache()

## Section 2 – Reasoning SFT

Dataset: [`HuggingFaceH4/Multilingual-Thinking`](https://huggingface.co/datasets/HuggingFaceH4/Multilingual-Thinking)  
This dataset provides structured *reasoning traces* — messages with explicit “thinking” content and final answers formatted in the **Harmony** schema.

Goals:
- Fine-tune GPT-OSS-20B to strengthen reasoning behavior and channel separation.  
- Use separate LoRA adapters (`out-lora-reasoning`) to isolate reasoning behavior from standard instruction following.


In [None]:
# Load the reasoning dataset only in this section
reason_ds = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train")
print(f"Reasoning samples: {len(reason_ds)}")
print(reason_ds[1]["messages"])

In [None]:
from datasets import load_dataset

reason_ds = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train")

# Filter only examples that actually include 'thinking'
reasoning_examples = [
    sample for sample in reason_ds
    if any(m.get("thinking") not in (None, "") for m in sample["messages"])
]

print(f"Total reasoning examples: {len(reasoning_examples)}")
print(reasoning_examples[0]["messages"])


### Harmony chat format

GPT-OSS models use the **Harmony** message template for chat and reasoning tasks.  
Messages are encoded using tags such as:

```
<|start|>system<|message|>...<|end|>
<|start|>user<|message|>...<|end|>
<|start|>assistant<|channel|>analysis<|message|>...<|end|>
<|start|>assistant<|channel|>final<|message|>...<|return|>

```

The tokenizer’s `apply_chat_template()` method converts message lists into this structure automatically.  
Harmony defines *channels* (`analysis`, `commentary`, `final`) that allow explicit reasoning steps before producing an answer.  
Fine-tuning must preserve this formatting to keep reasoning and output generation aligned.

ref: https://github.com/openai/harmony


In [None]:
tokenizer_reason = AutoTokenizer.from_pretrained(MODEL)

sample = reasoning_examples[0]
formatted = tokenizer_reason.apply_chat_template(
    sample["messages"],
    tokenize=False,
    add_generation_prompt=False
)
print(formatted)


### LoRA setup for reasoning SFT

- We again dequantize MXFP4 weights to bf16 and attach adapters to the same MoE MLP projections.  
- This ensures comparable capacity between the non-reasoning and reasoning adapters.  
- Peak memory is higher here due to longer sequences and reasoning traces.


In [None]:
# Model and LoRA setup for reasoning SFT
model_reason, tokenizer_reason = FastLanguageModel.from_pretrained(
    model_name=MODEL,
    max_seq_length=MAX_LEN,
    dtype=None,
    load_in_4bit=True,
)

lora_config_reason = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules="all-linear",
    target_parameters=[
        "7.mlp.experts.gate_up_proj",
        "7.mlp.experts.down_proj",
        "15.mlp.experts.gate_up_proj",
        "15.mlp.experts.down_proj",
        "23.mlp.experts.gate_up_proj",
        "23.mlp.experts.down_proj"
    ]
)

model_reason = get_peft_model(model_reason, lora_config_reason)
model_reason.print_trainable_parameters()

In [None]:
# Train on reasoning dataset
args_reason = SFTConfig(
    output_dir=f"out-unsloth-{model_name}-reasoning-lora",
    max_length=MAX_LEN,
    packing=False,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    warmup_ratio=0.03,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    learning_rate=LR,
    lr_scheduler_type="cosine_with_min_lr",
    lr_scheduler_kwargs={"min_lr_rate": 0.1},
    logging_steps=10,
    eval_strategy="no",
    save_strategy="epoch",
    bf16=True,
    report_to="none",
    save_safetensors=True,
    save_total_limit=1
)

def reason_formatting_func(example):
    return [tokenizer_reason.apply_chat_template(m, tokenize=False, add_generation_prompt=False) for m in example['messages']]

trainer_reason = SFTTrainer(
    model=model_reason,
    args=args_reason,
    train_dataset=reason_ds,
    processing_class=tokenizer_reason,
    formatting_func=reason_formatting_func,
)

reset_peak_mem()
trainer_reason.train()
report_peak_mem("reasoning-lora")
trainer_reason.save_model()

In [None]:
# Cleanup reasoning objects to free memory
del model_reason, trainer_reason
torch.cuda.empty_cache()

## Inference sanity check

We load the merged reasoning adapter (`out-lora-reasoning`) into the base model for validation.  
Generation runs with `attn_implementation="eager"` and bf16 weights.  
The prompt uses the Harmony chat template with an explicit reasoning language and user query to verify that the model:
1. Generates coherent reasoning in the `analysis` channel.  
2. Produces a correct, formatted final answer in the `final` channel.


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

from unsloth import FastLanguageModel
base, tok = FastLanguageModel.from_pretrained(
    model_name=MODEL,
    max_seq_length=MAX_LEN,
    dtype=None,
    load_in_4bit=True
)
FastLanguageModel.for_inference(base)

REASONING_LANGUAGE = "German"
SYSTEM_PROMPT = f"reasoning language: {REASONING_LANGUAGE}"
USER_PROMPT = "¿Cuál es el capital de Australia?"

messages = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": USER_PROMPT}
]
inputs = tok.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True, reasoning_effort="high").to(merged.device)
from transformers import TextStreamer
_ = merged.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.6, streamer=TextStreamer(tok))


## Save and Export Options (Unsloth Native)
Unsloth allows you to merge LoRA adapters back into the base model instantly and export to formats like GGUF or 16-bit without needing external scripts.


In [None]:
# Merge LoRA adapters to a 16-bit huggingface model
# model_reason.save_pretrained_merged("gpt-unsloth-merged-16bit", tokenizer_reason, save_method="merged_16bit")


In [None]:
# Export directly to GGUF format for llama.cpp/Ollama (Q8_0 for 8-bit, F16 for 16-bit)
# model_reason.save_pretrained_gguf("gpt-unsloth-gguf", tokenizer_reason, quantization_method="Q8_0")
