In [1]:
from lora import LORA, FUSE
from load_data import LoadData
from mlx_lm import generate, utils

None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [None]:
# Configuration

root_folder = "model_a_1"

data_folder = f"./{root_folder}/data"
dataset_name = "n4jiDX/Math-Problems"
n = 1000
test_split_ratio = 0.2
valid_split_ratio = 0.2

model_path = "mistralai/Mistral-7B-Instruct-v0.2"
adapter_file = f"./{root_folder}/adapters.npz"
save_model_path = f"./{root_folder}/model"

In [None]:
# Create Train, Test and Validation Data from the Dataset

system_message = """
You are a math problem solver. Given a math problem, you will provide a step-by-step solution.
Use the following format:
Problem: <The math problem>
Solution: <Step-by-step solution>
"""

def create_conversation(input: dict) -> dict:
    return {
        "messages": [
            {"role": "system", "content": system_message.strip()},
            {"role": "user", "content": input["Problem"]},
            {"role": "assistant", "content": input["Solution"]}
        ]
    }

data_loader = LoadData(folder=data_folder, dataset_name=dataset_name)
data_loader.save(function=create_conversation, n=n, test_split_ratio=test_split_ratio, valid_split_ratio=valid_split_ratio, write_files=True)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

{'train': './model_a_1/data/train.jsonl',
 'test': './model_a_1/data/test.jsonl',
 'valid': './model_a_1/data/valid.jsonl'}

In [None]:
# Fine-Tuning with LoRA

lora = LORA(config={"train": True, "batch_size": 1, "lora_layers": 4, "adapter_file": adapter_file})
lora.invoke(model_path=model_path, data=data_folder)

Fetching 11 files:   0%|          | 0/11 [00:00<?, ?it/s]

Total parameters 7242.158M
Trainable parameters 0.426M
Loading datasets
Training
Iter 1: Val loss 0.998, Val took 37.477s
Iter 10: Train loss 0.960, It/sec 0.471, Tokens/sec 443.263
Iter 20: Train loss 0.818, It/sec 0.512, Tokens/sec 389.278
Iter 30: Train loss 0.850, It/sec 0.314, Tokens/sec 149.099
Iter 40: Train loss 0.823, It/sec 0.226, Tokens/sec 136.215
Iter 50: Train loss 0.881, It/sec 0.195, Tokens/sec 88.031
Iter 60: Train loss 0.709, It/sec 0.163, Tokens/sec 83.023
Iter 70: Train loss 0.802, It/sec 0.218, Tokens/sec 124.225
Iter 80: Train loss 0.725, It/sec 0.217, Tokens/sec 154.478
Iter 90: Train loss 0.771, It/sec 0.196, Tokens/sec 195.234
Iter 100: Train loss 0.741, It/sec 0.287, Tokens/sec 180.082
Iter 100: Saved adapter weights to ./model_a_1/adapters.npz.
Iter 110: Train loss 0.716, It/sec 0.270, Tokens/sec 131.337
Iter 120: Train loss 0.876, It/sec 0.322, Tokens/sec 182.692
Iter 130: Train loss 0.814, It/sec 0.389, Tokens/sec 233.204
Iter 140: Train loss 0.858, It/sec 

In [None]:
# Fuse the LoRA adapters with the base model and save the fine-tuned model

fuse = FUSE(config={"adapter_file": adapter_file})
fuse.invoke(model_path=model_path, save_path=save_model_path)

Fetching 11 files:   0%|          | 0/11 [00:00<?, ?it/s]

In [None]:
# Load the fine-tuned model and generate a response

model, tokenizer = utils.load(save_model_path)
generate(model=model, tokenizer=tokenizer, prompt="You are a math problem solver. Given a math problem, you will provide a step-by-step solution.\nUse the following format:\nProblem: <The math problem>\nSolution: <Step-by-step solution>\nUser:Find the product of the solutions of the equation: $|y| = 3(|y| - 2)$.\nAssistant:")

'To solve this problem, we first need to find the solutions of the equation $|y| = 3(|y| - 2)$.\n\nSolving the equation $|y| = 3(|y| - 2)$ means finding the values of $y$ that satisfy the equation. Since $|y|$ is the absolute value of $y$, the equation is satisfied when $y$ is either 3 or $-3$ (the positive and negative solutions).\n\nNow, we need to find the product of these solutions:\n\n$$\n\\text{Product of solutions} = y_1 \\times y_2 \\\\\n= 3 \\times (-3) \\\\\n= -3 \\times 3 \\\\\n= -9\n$$\n\nSo, the product of the solutions of the equation $|y| = 3(|y| - 2)$ is $\\boxed{-9}$.'