# MXFP4 Training Example

## Check GPU Availability

In [None]:
!nvidia-smi

In [None]:
# Set CUDA Device Number
DEVICE_NUM = 0
ADDITIONAL_GPU = 0

from os import environ
environ["CUDA_VISIBLE_DEVICES"] = ",".join([f"{i+DEVICE_NUM}" for i in range(0, ADDITIONAL_GPU+1)])
environ["CUDA_VISIBLE_DEVICES"]

## Imports

In [None]:
# Apply amazon patch
import sys
sys.path.insert(0, "./mxfp4_llm/patch_override_scripts/te1.5")

In [None]:
from os import path

import torch
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
from accelerate import Accelerator, notebook_launcher

import mx
from mx import mx_mapping
import transformer_engine.pytorch as te

import wandb
from tqdm.auto import tqdm

In [None]:
if torch.cuda.is_available():
    if ADDITIONAL_GPU:
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda")  # torch.device(f"cuda:{DEVICE_NUM}")
else:
    device = torch.device("cpu")
    DEVICE_NUM = -1

print(f"INFO: Using device - {device}" + (f":{DEVICE_NUM}" if ADDITIONAL_GPU else ""))

In [None]:
PROJECT_NAME = "MXFP4_Example"
RUN_NAME = "Qwen3_8B_From_Scratch"

# WandB Initialization
wandb.init(project=PROJECT_NAME, name=RUN_NAME)

### MXFP4 Config

In [None]:
mx_specs = dict(
    # Forward Pass
    # weights: FP4 / activations: FP6
    w_elem_format="fp4",
    a_elem_format="fp6",
    
    # Backward Pass
    # weights: FP4 / activations: FP6
    w_elem_format_bp="fp4",
    a_elem_format_bp="fp6",
    
    # Optimizer
    quantize_backprop=True,  # backprop quantization
    round="dither_scale",  # stochastic rounding
    scale_bits=8,
    block_size=32,  # mx block size
    shared_exp_method="max"  # shared exponent method
)

In [None]:
mx_mapping.inject_pyt_ops(mx_specs)

## Define Dataset

In [None]:
dataset_id = "Trelis/tiny-shakespeare"

In [None]:
from datasets import load_dataset

dataset = load_dataset(dataset_id)

In [None]:
dataset['train'][0]

In [None]:
dataset['train'][0].keys()

## Load Model

In [None]:
from transformers import Qwen3ForCausalLM, AutoTokenizer

In [None]:
reference_model_id = "Qwen/Qwen3-0.6B"

In [None]:
reference_tokenizer = AutoTokenizer.from_pretrained(reference_model_id, use_fast=True)
tokenized_datasets = dataset.map(lambda data: reference_tokenizer(data["Text"], truncation=True, padding=True), batched=True)

### MXFP4 Check

In [None]:
# TODO: 여기에 MXFP4 변환 여부 확인하는 코드 넣기

## Training

In [None]:
BATCH_SIZE = 4, 4, 4, 4

In [None]:
training_args = TrainingArguments(
    
)

In [None]:
trainer = Trainer(
    
)

trainer.train()