In [4]:
from datasets import Dataset
import pandas as pd 

In [52]:
df = pd.read_csv("merged.csv")
dataset = Dataset.from_pandas(df)

In [53]:
ds = dataset.train_test_split(test_size=0.4, shuffle=True)

In [54]:
train_data = ds["train"]

eval_test_data = ds["test"].train_test_split(test_size=0.3, shuffle=True)
eval_data = eval_test_data["train"]
test_data = eval_test_data["test"]

In [55]:
INIT_PROMPT = (
    "Du bist ein hochentwickelter KI-Schreibassistent, der professionelle, präzise und technische Texte in "
    "in Deutsch und Englisch verfasst. Falsl der Benutzer keine Sprache angibt, nutze die Sprache der Eingabe.\n\n"
)

train_data = train_data.map(lambda example: {"inputs": INIT_PROMPT + example["inputs"]})
eval_data = eval_data.map(lambda example: {"inputs": INIT_PROMPT + example["inputs"]})
test_data = test_data.map(lambda example: {"inputs": INIT_PROMPT + example["inputs"]})

Map: 100%|██████████| 2674/2674 [00:00<00:00, 36995.54 examples/s]
Map: 100%|██████████| 1248/1248 [00:00<00:00, 48436.12 examples/s]
Map: 100%|██████████| 536/536 [00:00<00:00, 40225.21 examples/s]


In [61]:
from pathlib import Path
import json
from mlx_lm.tuner import TrainingArgs

adapter_path = Path("adapter")
adapter_path.mkdir(exist_ok=True)

lora_config = {
    "num_layers": 8,
    "lora_parameters": {
        "rank": 8,
        "scale": 16,
        "dropout": 0.05
    }
}

with open(adapter_path / "adapter_config.json", "w") as f:
    json.dump(lora_config, f)


training_args = TrainingArgs(
    adapter_file=adapter_path / "adapters.safetensors",
    iters=50,
    steps_per_eval=5,
)


In [62]:
model_name = "mlx/mistral-7b-quantized"

In [63]:
from mlx_lm import load

model_name = "mlx-community/Mistral-7B-Instruct-v0.3-4bit"
model, tokenizer = load(model_name)

Fetching 7 files: 100%|██████████| 7/7 [00:00<00:00, 88434.12it/s]


In [64]:
model.freeze()

Model(
  (model): LlamaModel(
    (embed_tokens): QuantizedEmbedding(32768, 4096, group_size=64, bits=4)
    (layers.0): TransformerBlock(
      (self_attn): Attention(
        (q_proj): QuantizedLinear(input_dims=4096, output_dims=4096, bias=False, group_size=64, bits=4)
        (k_proj): QuantizedLinear(input_dims=4096, output_dims=1024, bias=False, group_size=64, bits=4)
        (v_proj): QuantizedLinear(input_dims=4096, output_dims=1024, bias=False, group_size=64, bits=4)
        (o_proj): QuantizedLinear(input_dims=4096, output_dims=4096, bias=False, group_size=64, bits=4)
        (rope): RoPE(128, traditional=False)
      )
      (mlp): MLP(
        (gate_proj): QuantizedLinear(input_dims=4096, output_dims=14336, bias=False, group_size=64, bits=4)
        (down_proj): QuantizedLinear(input_dims=14336, output_dims=4096, bias=False, group_size=64, bits=4)
        (up_proj): QuantizedLinear(input_dims=4096, output_dims=14336, bias=False, group_size=64, bits=4)
      )
      (input_l

In [65]:
from mlx_lm.tuner.datasets import CompletionsDataset


def make_dataset(dataset):
    return CompletionsDataset(
        dataset,
        tokenizer,
        prompt_key="inputs",
        completion_key="labels",
        mask_prompt=False
    )


train_dataset, eval_dataset, test_dataset = map(make_dataset, [train_data, eval_data, test_data])

In [66]:
from mlx_lm.tuner import train, evaluate, linear_to_lora_layers
from mlx_lm.utils import tree_flatten


linear_to_lora_layers(model, lora_config["num_layers"], lora_config["lora_parameters"])

num_train_params = (sum(v.size for _, v in tree_flatten(model.trainable_parameters())))

print(f"Num of trainable parameters: {num_train_params}")

Num of trainable parameters: 851968


In [67]:
class Metrics:
    train_losses = []
    val_losses = []
    def on_train_loss_report(self, info):
        self.train_losses.append((info["iteration"], info["train_loss"]))
    def on_val_loss_report(self, info):
        self.val_losses.append((info["iteration"], info["val_loss"]))

metrics = Metrics()

In [68]:
model.train()

Model(
  (model): LlamaModel(
    (embed_tokens): QuantizedEmbedding(32768, 4096, group_size=64, bits=4)
    (layers.0): TransformerBlock(
      (self_attn): Attention(
        (q_proj): QuantizedLinear(input_dims=4096, output_dims=4096, bias=False, group_size=64, bits=4)
        (k_proj): QuantizedLinear(input_dims=4096, output_dims=1024, bias=False, group_size=64, bits=4)
        (v_proj): QuantizedLinear(input_dims=4096, output_dims=1024, bias=False, group_size=64, bits=4)
        (o_proj): QuantizedLinear(input_dims=4096, output_dims=4096, bias=False, group_size=64, bits=4)
        (rope): RoPE(128, traditional=False)
      )
      (mlp): MLP(
        (gate_proj): QuantizedLinear(input_dims=4096, output_dims=14336, bias=False, group_size=64, bits=4)
        (down_proj): QuantizedLinear(input_dims=14336, output_dims=4096, bias=False, group_size=64, bits=4)
        (up_proj): QuantizedLinear(input_dims=4096, output_dims=14336, bias=False, group_size=64, bits=4)
      )
      (input_l

In [70]:
from mlx_lm.tuner import train, evaluate
from mlx.optimizers import Adam

optimizer = Adam(learning_rate=1e-4)
metrics = Metrics()

In [71]:
train(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    optimizer=optimizer,
    train_dataset=train_dataset,
    val_dataset=eval_dataset,
    training_callback=metrics
)

Starting training..., iters: 50
Iter 1: Val loss 2.989, Val took 78.536s
Iter 5: Val loss 1.936, Val took 87.504s
Iter 10: Val loss 1.660, Val took 79.379s
Iter 10: Train loss 2.187, Learning Rate 1.000e-04, It/sec 0.254, Tokens/sec 217.030, Trained Tokens 8558, Peak mem 11.071 GB
Iter 15: Val loss 1.747, Val took 94.007s
Iter 20: Val loss 1.570, Val took 76.138s
Iter 20: Train loss 1.532, Learning Rate 1.000e-04, It/sec 0.244, Tokens/sec 215.769, Trained Tokens 17386, Peak mem 11.287 GB
Iter 25: Val loss 1.624, Val took 85.701s
Iter 30: Val loss 1.580, Val took 80.137s
Iter 30: Train loss 1.483, Learning Rate 1.000e-04, It/sec 0.260, Tokens/sec 217.682, Trained Tokens 25757, Peak mem 11.764 GB
Iter 35: Val loss 1.606, Val took 88.758s
Iter 40: Val loss 1.554, Val took 81.440s
Iter 40: Train loss 1.540, Learning Rate 1.000e-04, It/sec 0.218, Tokens/sec 218.162, Trained Tokens 35772, Peak mem 11.961 GB
Iter 45: Val loss 1.555, Val took 82.588s
Iter 50: Val loss 1.572, Val took 78.597s
I