# Fine-tuning GPT-OSS-20B (MXFP4) on AMD Strix Halo (GFX1151)

**Hardware**: AMD Strix Halo (GFX1151) – 128 GB unified memory.  
**Objective**: demonstrate LoRA fine-tuning of *gpt-oss-20b* under MXFP4 quantization, comparing a non-reasoning dataset and a reasoning dataset, each isolated in its own memory scope and adapter.

---

### About GPT-OSS-20B
`gpt-oss-20b` is an open-weight **20 billion-parameter mixture-of-experts (MoE)** model released by **OpenAI**.  
Only a fraction of its experts (≈ 3.6 B active parameters per token) are used at inference time, which lowers compute while maintaining reasoning quality.  
The model is distributed in **MXFP4 quantized format** to minimize storage and bandwidth.

References:  
- Model card: https://huggingface.co/openai/gpt-oss-20b  
- Architecture overview: https://cdn.openai.com/pdf/419b6906-9da6-406c-a19d-1bb078ac7637/oai_gpt-oss_model_card.pdf

---

### Why MXFP4 and why dequantize on Strix Halo
MXFP4 is a 4.25-bit floating-point format designed for OpenAI’s MoE weights.  
On AMD ROCm (Strix Halo) hardware, **native MXFP4 kernels are not available**, so training must run on dequantized weights.

Setting:
```python
Mxfp4Config(dequantize=True)
````

forces the model loader to expand MXFP4 tensors into **bf16**, ensuring stable gradients at the expense of higher memory use.

---

### Notebook structure

1. **Setup and utilities** – imports, helper functions, shared hyperparameters.
2. **Non-reasoning SFT** – `Abirate/english_quotes` dataset (simple instruction/response).
3. **Reasoning SFT** – `HuggingFaceH4/Multilingual-Thinking` dataset (Harmony reasoning traces).
4. **Inference sanity check** – quick validation using the merged reasoning adapter.

All training runs use `attn_implementation="eager"` to avoid FlashAttention kernel incompatibilities on ROCm.


## Setup and utilities

We import PyTorch, Hugging Face Transformers, PEFT, and TRL.  
All memory-tracking utilities are defined here for reuse across sections.  
The model is loaded with `Mxfp4Config(dequantize=True)` to ensure stability during fine-tuning on ROCm.

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Mxfp4Config
from peft import LoraConfig, get_peft_model, PeftModel
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset

def reset_peak_mem():
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()

def report_peak_mem(tag: str = ""):
    if torch.cuda.is_available():
        print(f"Peak training memory{(' ' + tag) if tag else ''}: {torch.cuda.max_memory_allocated()/1e9:.2f} GB")


## Model selection and training parameters

We fine-tune the **OpenAI GPT-OSS-20B (MXFP4)** model using bf16 dequantization for compatibility with Strix Halo.

| Parameter | Meaning | Value |
|:-----------|:---------|:------|
| **MODEL** | Base pretrained model | `openai/gpt-oss-20b` |
| **LR** | Learning rate | `2e-4` |
| **EPOCHS** | Number of full dataset passes | `1` |
| **BATCH_SIZE** | Samples per device per step | `4` |
| **MAX_LEN** | Token limit per example | `2048` |

Hyperparameters follow the ranges used in TRL’s SFTTrainer examples and OpenAI’s fine-tuning guide.  
Longer sequences or higher batch sizes will increase memory quadratically under eager attention.


In [None]:
MODEL = "openai/gpt-oss-20b"
model_name = MODEL.split("/")[-1]

LR = 2e-4
EPOCHS = 1
BATCH_SIZE = 4
MAX_LEN = 2048

## Section 1 – Non-reasoning SFT

Dataset: [`Abirate/english_quotes`](https://huggingface.co/datasets/Abirate/english_quotes)  
This dataset produces short, direct text completions without explicit reasoning or chain-of-thought.

Purpose:
- Validate the SFT training pipeline.  
- Confirm LoRA integration on dequantized weights.  
- Observe stable fine-tuning behavior before moving to reasoning datasets.

Each section loads its own dataset to prevent memory overlap and to keep lifetime of large tensors scoped.


In [None]:
# Load and prepare the non-reasoning dataset locally to this section
quotes_ds = load_dataset("Abirate/english_quotes", split="train").shuffle(seed=42).select(range(1000))

def quotes_to_messages(ex):
    return {
        "messages": [
            {"role": "user", "content": f"Give me a quote about: {ex['tags']}"},
            {"role": "assistant", "content": f"{ex['quote']} - {ex['author']}"}
        ]
    }

quotes_ds = quotes_ds.map(quotes_to_messages, remove_columns=quotes_ds.column_names).train_test_split(test_size=0.2)
print(f"Quotes train: {len(quotes_ds['train'])}, test: {len(quotes_ds['test'])}")

### LoRA setup for non-reasoning SFT

The model is loaded with bf16 dequantized weights (`Mxfp4Config(dequantize=True)`).  
LoRA adapters are attached to selected MLP expert projections.  
Only ~0.07 % of total parameters are trainable, drastically reducing memory footprint while maintaining adaptation ability.


In [None]:
# Model and LoRA setup for non-reasoning SFT
quant_quotes = Mxfp4Config(dequantize=True)
model_quotes = AutoModelForCausalLM.from_pretrained(
    MODEL,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="eager",
    use_cache=False,
    quantization_config=quant_quotes,
)
tokenizer_quotes = AutoTokenizer.from_pretrained(MODEL)

lora_config_quotes = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules="all-linear",
    target_parameters=[
        "7.mlp.experts.gate_up_proj",
        "7.mlp.experts.down_proj",
        "15.mlp.experts.gate_up_proj",
        "15.mlp.experts.down_proj",
        "23.mlp.experts.gate_up_proj",
        "23.mlp.experts.down_proj"
    ]
)

model_quotes = get_peft_model(model_quotes, lora_config_quotes)
model_quotes.print_trainable_parameters()

In [None]:
# Train on quotes
args_quotes = SFTConfig(
    output_dir=f"out-{model_name}-lora",
    max_length=MAX_LEN,
    packing=False,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    warmup_ratio=0.03,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    learning_rate=LR,
    lr_scheduler_type="cosine_with_min_lr",
    lr_scheduler_kwargs={"min_lr_rate": 0.1},
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    bf16=True,
    report_to="none",
    save_safetensors=True,
    save_total_limit=1
)

trainer_quotes = SFTTrainer(
    model=model_quotes,
    args=args_quotes,
    train_dataset=quotes_ds['train'],
    eval_dataset=quotes_ds['test'],
    processing_class=tokenizer_quotes
)

reset_peak_mem()
trainer_quotes.train()
report_peak_mem("lora")
trainer_quotes.save_model()

In [None]:
# Cleanup non-reasoning objects to free memory
del model_quotes, trainer_quotes
torch.cuda.empty_cache()

## Section 2 – Reasoning SFT

Dataset: [`HuggingFaceH4/Multilingual-Thinking`](https://huggingface.co/datasets/HuggingFaceH4/Multilingual-Thinking)  
This dataset provides structured *reasoning traces* — messages with explicit “thinking” content and final answers formatted in the **Harmony** schema.

Goals:
- Fine-tune GPT-OSS-20B to strengthen reasoning behavior and channel separation.  
- Use separate LoRA adapters (`out-lora-reasoning`) to isolate reasoning behavior from standard instruction following.


In [None]:
# Load the reasoning dataset only in this section
reason_ds = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train")
print(f"Reasoning samples: {len(reason_ds)}")
print(reason_ds[1]["messages"])

In [None]:
from datasets import load_dataset

reason_ds = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train")

# Filter only examples that actually include 'thinking'
reasoning_examples = [
    sample for sample in reason_ds
    if any(m.get("thinking") not in (None, "") for m in sample["messages"])
]

print(f"Total reasoning examples: {len(reasoning_examples)}")
print(reasoning_examples[0]["messages"])


### Harmony chat format

GPT-OSS models use the **Harmony** message template for chat and reasoning tasks.  
Messages are encoded using tags such as:

```
<|start|>system<|message|>...<|end|>
<|start|>user<|message|>...<|end|>
<|start|>assistant<|channel|>analysis<|message|>...<|end|>
<|start|>assistant<|channel|>final<|message|>...<|return|>

```

The tokenizer’s `apply_chat_template()` method converts message lists into this structure automatically.  
Harmony defines *channels* (`analysis`, `commentary`, `final`) that allow explicit reasoning steps before producing an answer.  
Fine-tuning must preserve this formatting to keep reasoning and output generation aligned.

ref: https://github.com/openai/harmony


In [None]:
tokenizer_reason = AutoTokenizer.from_pretrained(MODEL)

sample = reasoning_examples[0]
formatted = tokenizer_reason.apply_chat_template(
    sample["messages"],
    tokenize=False,
    add_generation_prompt=False
)
print(formatted)


### LoRA setup for reasoning SFT

- We again dequantize MXFP4 weights to bf16 and attach adapters to the same MoE MLP projections.  
- This ensures comparable capacity between the non-reasoning and reasoning adapters.  
- Peak memory is higher here due to longer sequences and reasoning traces.


In [None]:
# Model and LoRA setup for reasoning SFT
quant_reason = Mxfp4Config(dequantize=True)
model_reason = AutoModelForCausalLM.from_pretrained(
    MODEL,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="eager",
    use_cache=False,
    quantization_config=quant_reason
)
tokenizer_reason = AutoTokenizer.from_pretrained(MODEL)

lora_config_reason = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules="all-linear",
    target_parameters=[
        "7.mlp.experts.gate_up_proj",
        "7.mlp.experts.down_proj",
        "15.mlp.experts.gate_up_proj",
        "15.mlp.experts.down_proj",
        "23.mlp.experts.gate_up_proj",
        "23.mlp.experts.down_proj"
    ]
)

model_reason = get_peft_model(model_reason, lora_config_reason)
model_reason.print_trainable_parameters()

In [None]:
# Train on reasoning dataset
args_reason = SFTConfig(
    output_dir=f"out-{model_name}-reasoning-lora",
    max_length=MAX_LEN,
    packing=False,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    warmup_ratio=0.03,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    learning_rate=LR,
    lr_scheduler_type="cosine_with_min_lr",
    lr_scheduler_kwargs={"min_lr_rate": 0.1},
    logging_steps=10,
    eval_strategy="no",
    save_strategy="epoch",
    bf16=True,
    report_to="none",
    save_safetensors=True,
    save_total_limit=1
)

trainer_reason = SFTTrainer(
    model=model_reason,
    args=args_reason,
    train_dataset=reason_ds,
    processing_class=tokenizer_reason
)

reset_peak_mem()
trainer_reason.train()
report_peak_mem("reasoning-lora")
trainer_reason.save_model()

In [None]:
# Cleanup reasoning objects to free memory
del model_reason, trainer_reason
torch.cuda.empty_cache()

## Inference sanity check

We load the merged reasoning adapter (`out-lora-reasoning`) into the base model for validation.  
Generation runs with `attn_implementation="eager"` and bf16 weights.  
The prompt uses the Harmony chat template with an explicit reasoning language and user query to verify that the model:
1. Generates coherent reasoning in the `analysis` channel.  
2. Produces a correct, formatted final answer in the `final` channel.


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

base = AutoModelForCausalLM.from_pretrained(
    MODEL,
    attn_implementation="eager",
    torch_dtype="auto",
    use_cache=True,
    device_map="auto"
)
peft_reason = PeftModel.from_pretrained(base, "out-gpt-oss-20b-lora-reasoning")
merged = peft_reason.merge_and_unload()
tok = AutoTokenizer.from_pretrained(MODEL)

REASONING_LANGUAGE = "German"
SYSTEM_PROMPT = f"reasoning language: {REASONING_LANGUAGE}"
USER_PROMPT = "¿Cuál es el capital de Australia?"

messages = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": USER_PROMPT}
]
input_ids = tok.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(merged.device)
out = merged.generate(input_ids, max_new_tokens=256, do_sample=True, temperature=0.6)
print(tok.batch_decode(out)[0])