In [1]:
import json
import os
from typing import Dict, List, Optional, Tuple, Union

import mlx.optimizers as optim
from mlx.utils import tree_flatten
from mlx_lm import generate, load
from mlx_lm.tuner import TrainingArgs, datasets, linear_to_lora_layers, train
from transformers import PreTrainedTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_path = "mlx-community/Phi-3.5-mini-instruct-4bit"
model, tokenizer = load(model_path)

Fetching 11 files: 100%|██████████| 11/11 [00:00<00:00, 196329.12it/s]


In [3]:
adapter_path = "adapters"
os.makedirs(adapter_path, exist_ok=True)
adapter_config_path = os.path.join(adapter_path, "adapter_config.json")
adapter_file_path = os.path.join(adapter_path, "adapters.safetensors")

# Lora config

In [4]:
lora_config = {
    "num_layers": 8,
    "lora_parameters": {
        "rank": 8,
        "scale": 20.0,
        "dropout": 0.0,
    },
}

In [5]:
with open(adapter_config_path, "w") as f:
    json.dump(lora_config, f, indent=4)

In [19]:
training_args = TrainingArgs(
    adapter_file=adapter_file_path,
    iters=200,
    steps_per_eval=50,
)

In [7]:
model.freeze()
linear_to_lora_layers(model, lora_config["num_layers"], lora_config["lora_parameters"])
num_train_params = sum(v.size for _, v in tree_flatten(model.trainable_parameters()))
print(f"Number of trainable parameters: {num_train_params}")
model.train()

Number of trainable parameters: 786432


Model(
  (model): Phi3Model(
    (embed_tokens): QuantizedEmbedding(32064, 3072, group_size=64, bits=4)
    (layers.0): TransformerBlock(
      (self_attn): Attention(
        (qkv_proj): QuantizedLinear(input_dims=3072, output_dims=9216, bias=False, group_size=64, bits=4)
        (o_proj): QuantizedLinear(input_dims=3072, output_dims=3072, bias=False, group_size=64, bits=4)
        (rope): SuScaledRoPE()
      )
      (mlp): MLP(
        (gate_up_proj): QuantizedLinear(input_dims=3072, output_dims=16384, bias=False, group_size=64, bits=4)
        (down_proj): QuantizedLinear(input_dims=8192, output_dims=3072, bias=False, group_size=64, bits=4)
      )
      (input_layernorm): RMSNorm(3072, eps=1e-05)
      (post_attention_layernorm): RMSNorm(3072, eps=1e-05)
    )
    (layers.1): TransformerBlock(
      (self_attn): Attention(
        (qkv_proj): QuantizedLinear(input_dims=3072, output_dims=9216, bias=False, group_size=64, bits=4)
        (o_proj): QuantizedLinear(input_dims=3072, out

In [8]:
class Metrics:
    def __init__(self) -> None:
        self.train_losses: List[Tuple[int, float]] = []
        self.val_losses: List[Tuple[int, float]] = []

    def on_train_loss_report(self, info: Dict[str, Union[float, int]]) -> None:
        self.train_losses.append((info["iteration"], info["train_loss"]))

    def on_val_loss_report(self, info: Dict[str, Union[float, int]]) -> None:
        self.val_losses.append((info["iteration"], info["val_loss"]))

In [9]:
metrics = Metrics()

# load data

In [22]:
from mlx_lm.tuner.datasets import load_hf_dataset
#config = {
#        'prompt_feature': 'instruction',
#        'completion_feature': 'response',
#}
config = { }
train_set, val_set, test_set = load_hf_dataset(
    data_id="mlx-community/wikisql",
    tokenizer=tokenizer,
    config=config,
)

Generating train split: 100%|██████████| 1000/1000 [00:00<00:00, 15145.56 examples/s]
Generating valid split: 100%|██████████| 100/100 [00:00<00:00, 21597.86 examples/s]
Generating test split: 100%|██████████| 100/100 [00:00<00:00, 73908.44 examples/s]


In [23]:
print(f"Test set size: {len(test_set)}")
print(f"Validation set size: {len(val_set)}")
print(f"Training set size: {len(train_set)}")
print(f"test set: {test_set[:2]}")

Test set size: 100
Validation set size: 100
Training set size: 1000
test set: {'text': ["table: 1-10015132-16\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What is terrence ross' nationality\nA: SELECT Nationality FROM 1-10015132-16 WHERE Player = 'Terrence Ross'", "table: 1-10015132-16\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What clu was in toronto 1995-96\nA: SELECT School/Club Team FROM 1-10015132-16 WHERE Years in Toronto = '1995-96'"]}


In [None]:
from mlx_lm.tuner.datasets import CacheDataset

train_dataset = CacheDataset(train_set)
val_dataset = CacheDataset(val_set)

train(
    model,
    optim.Adam(learning_rate=1e-5),
    train_dataset,
    val_dataset,
    args=training_args,
    training_callback=metrics
)


Starting training..., iters: 200


Calculating loss...: 100%|██████████| 25/25 [00:54<00:00,  2.18s/it]


Iter 1: Val loss 3.253, Val took 54.847s
Iter 10: Train loss 3.341, Learning Rate 1.000e-05, It/sec 0.047, Tokens/sec 17.872, Trained Tokens 3771, Peak mem 3.362 GB
Iter 20: Train loss 2.551, Learning Rate 1.000e-05, It/sec 0.439, Tokens/sec 166.693, Trained Tokens 7566, Peak mem 3.362 GB
Iter 30: Train loss 2.096, Learning Rate 1.000e-05, It/sec 0.227, Tokens/sec 94.214, Trained Tokens 11721, Peak mem 3.548 GB
Iter 40: Train loss 1.886, Learning Rate 1.000e-05, It/sec 0.167, Tokens/sec 67.639, Trained Tokens 15765, Peak mem 3.548 GB


Calculating loss...: 100%|██████████| 25/25 [01:52<00:00,  4.52s/it]

Iter 50: Val loss 1.675, Val took 113.331s





Iter 50: Train loss 1.686, Learning Rate 1.000e-05, It/sec 0.190, Tokens/sec 74.876, Trained Tokens 19701, Peak mem 3.548 GB
Iter 60: Train loss 1.609, Learning Rate 1.000e-05, It/sec 0.242, Tokens/sec 86.037, Trained Tokens 23258, Peak mem 3.548 GB
Iter 70: Train loss 1.637, Learning Rate 1.000e-05, It/sec 0.216, Tokens/sec 83.225, Trained Tokens 27106, Peak mem 3.548 GB
Iter 80: Train loss 1.447, Learning Rate 1.000e-05, It/sec 0.228, Tokens/sec 88.971, Trained Tokens 31007, Peak mem 3.548 GB
Iter 90: Train loss 1.350, Learning Rate 1.000e-05, It/sec 0.226, Tokens/sec 87.549, Trained Tokens 34877, Peak mem 3.548 GB


Calculating loss...: 100%|██████████| 25/25 [01:51<00:00,  4.46s/it]


Iter 100: Val loss 1.430, Val took 112.322s
Iter 100: Train loss 1.548, Learning Rate 1.000e-05, It/sec 0.209, Tokens/sec 81.556, Trained Tokens 38773, Peak mem 3.548 GB
Iter 100: Saved adapter weights to adapters/adapters.safetensors and adapters/0000100_adapters.safetensors.
Iter 110: Train loss 1.460, Learning Rate 1.000e-05, It/sec 0.154, Tokens/sec 67.827, Trained Tokens 43174, Peak mem 4.034 GB
Iter 120: Train loss 1.462, Learning Rate 1.000e-05, It/sec 0.196, Tokens/sec 73.759, Trained Tokens 46939, Peak mem 4.034 GB
Iter 130: Train loss 1.293, Learning Rate 1.000e-05, It/sec 0.178, Tokens/sec 71.431, Trained Tokens 50959, Peak mem 4.034 GB
Iter 140: Train loss 1.378, Learning Rate 1.000e-05, It/sec 0.191, Tokens/sec 72.459, Trained Tokens 54749, Peak mem 4.034 GB


Calculating loss...: 100%|██████████| 25/25 [03:05<00:00,  7.43s/it]


Iter 150: Val loss 1.352, Val took 186.999s
Iter 150: Train loss 1.187, Learning Rate 1.000e-05, It/sec 0.111, Tokens/sec 41.553, Trained Tokens 58493, Peak mem 4.034 GB
