In [None]:
import math
import os
import re
import types
from pathlib import Path

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np

from mlx_lm.tuner.datasets import load_dataset
from mlx_lm.tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from mlx_lm.tuner.utils import (
    build_schedule,
    linear_to_lora_layers,
    load_adapters,
    print_trainable_parameters,
)
from mlx_lm.utils import load, save_config


In [None]:
options = {
    "model": "./models/gemma3-1b-it/transformers",
    "train": False,
    "fine_tune_type": "lora",
    "optimizer": "adam",
    "optimizer_config": {
        "adam": {},
        "adamw": {},
    },
    "data": "./training/",
    "seed": 0,
    "num_layers": 16,
    "batch_size": 1,
    "iters": 1000,
    "val_batches": 25,
    "learning_rate": 1e-5,
    "steps_per_report": 10,
    "steps_per_eval": 200,
    "resume_adapter_file": None,
    "adapter_path": "adapters",
    "save_every": 100,
    "test": False,
    "test_batches": 500,
    "max_seq_length": 2048,
    "config": None,
    "grad_checkpoint": False,
    "lr_schedule": None,
    "lora_parameters": {"rank": 8, "dropout": 0.0, "scale": 10.0},
    "mask_prompt": False,
}

In [None]:
args = types.SimpleNamespace(**options)

In [None]:
print("Loading pretrained model")
model, tokenizer = load(args.model)

print("Loading datasets")
train_set, valid_set, test_set = load_dataset(args, tokenizer)

In [None]:
training_callback = None

mx.random.seed(args.seed)
model.freeze()
if args.num_layers > len(model.layers):
    raise ValueError(
        f"Requested to train {args.num_layers} layers "
        f"but the model only has {len(model.layers)} layers."
    )

if args.fine_tune_type == "full":
    for l in model.layers[-max(args.num_layers, 0) :]:
        l.unfreeze()
elif args.fine_tune_type in ["lora", "dora"]:
    # Convert linear layers to lora/dora layers and unfreeze in the process
    linear_to_lora_layers(
        model,
        args.num_layers,
        args.lora_parameters,
        use_dora=(args.fine_tune_type == "dora"),
    )
else:
    raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}")

# Resume from weights if provided
if args.resume_adapter_file is not None:
    print(f"Loading fine-tuned weights from {args.resume_adapter_file}")
    model.load_weights(args.resume_adapter_file, strict=False)

print_trainable_parameters(model)

adapter_path = Path(args.adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)

adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")

# init training args
training_args = TrainingArgs(
    batch_size=args.batch_size,
    iters=args.iters,
    val_batches=args.val_batches,
    steps_per_report=args.steps_per_report,
    steps_per_eval=args.steps_per_eval,
    steps_per_save=args.save_every,
    adapter_file=adapter_file,
    max_seq_length=args.max_seq_length,
    grad_checkpoint=args.grad_checkpoint,
)

# Initialize the selected optimizer
lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate

optimizer_name = args.optimizer.lower()
optimizer_config = args.optimizer_config.get(optimizer_name, {})

if optimizer_name == "adam":
    opt_class = optim.Adam
elif optimizer_name == "adamw":
    opt_class = optim.AdamW
else:
    raise ValueError(f"Unsupported optimizer: {optimizer_name}")

opt = opt_class(learning_rate=lr, **optimizer_config)

# Train model
train(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    optimizer=opt,
    train_dataset=train_set,
    val_dataset=valid_set,
    training_callback=training_callback,
)