# Fine-tuning the latest Google Gemma model locally using MLX

## 1. Preparation


In [58]:
# !pip install -Uqq mlx mlx_lm transformers datasets

## 2. Using MLX to Run Inference with Gemma Model using MLX

In [1]:
from mlx_lm import generate, load

model, tokenizer = load("mlx-community/gemma-2-27b-it-4bit")
# https://huggingface.co/mlx-community/gemma-2-27b-4bit 


Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

In [2]:
messages = [{"role": "user", "content": "What are the normal working hours for customs at Dar es Salaam Port?"}]
result = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
print(result)

<bos><start_of_turn>user
What are the normal working hours for customs at Dar es Salaam Port?<end_of_turn>
<start_of_turn>model



In [3]:
# Generating without adding a prompt template manually
prompt = """
What are the normal working hours for customs at Dar es Salaam Port?
""".strip()
response = generate(
    model,
    tokenizer,
    prompt=prompt,
    verbose=True,  # Set to True to see the prompt and response
    temp=0.0,
    max_tokens=1024,
)

Prompt: What are the normal working hours for customs at Dar es Salaam Port?


The normal working hours for customs at Dar es Salaam Port are:

**Monday to Friday:** 8:00 AM to 5:00 PM

**Saturday:** 8:00 AM to 12:00 PM

**Sunday:** Closed

**Please note:**

* These are general working hours and may vary depending on the specific customs office and workload.
* It is always best to confirm the working hours with the relevant customs office before visiting.
* Customs offices may also be open outside of normal working hours for urgent matters.

You can find contact information for the Dar es Salaam Port customs office on the Tanzania Revenue Authority (TRA) website.
<end_of_turn>

Prompt: 15 tokens, 31.024 tokens-per-sec
Generation: 144 tokens, 16.006 tokens-per-sec
Peak memory: 14.614 GB


## 3. Generating training dataset from PDF file

In [None]:
import PyPDF2
import csv, re
import random
from tqdm import tqdm
import time
from transformers import AutoTokenizer
from mlx_lm import load, generate
import subprocess

def run_command_with_live_output(command: list[str]) -> None:
    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
    while True:
        output = process.stdout.readline()
        if output == '' and process.poll() is not None:
            break
        if output:
            print(output.strip())
    err_output = process.stderr.read()
    if err_output:
        print(err_output)

def extract_text_from_pdf(pdf_path):
    print("Extracting text from PDF...")
    with open(pdf_path, 'rb') as file:
        reader = PyPDF2.PdfReader(file)
        text = ""
        for page in tqdm(reader.pages, desc="Processing pages"):
            text += page.extract_text()
    print("Text extraction completed")
    return text

def create_chunks(text, min_chunk_size=200, max_chunk_size=1000):
    print("Splitting text into chunks...")
    chunks = []
    current_chunk = ""
    sentences = text.split('.')
    for sentence in tqdm(sentences, desc="Creating chunks"):
        if len(current_chunk) + len(sentence) > max_chunk_size and len(current_chunk) >= min_chunk_size:
            chunks.append(current_chunk.strip())
            current_chunk = sentence
        else:
            current_chunk += sentence + '.'
    if current_chunk:
        chunks.append(current_chunk.strip())
    print(f"Number of chunks created: {len(chunks)}")
    return chunks

def load_model_and_tokenizer(model_path):
    print("Loading model and tokenizer...")
    model, tokenizer = load(model_path)
    return model, tokenizer

def generate_question_response_pair(chunk, model, tokenizer, max_tokens=512):
    instruction = """
    Based on the compliance requirements by CHERRY shipping line, 
    create a concise question and provide a brief, informative answer. 
    All the questions should be based on a country if it is required per provided documentation.
    Do not include any special markers like '**' or '<end_of_turn>'
    Format your response as follows:
    Question: [Your generated question]
    Answer: [Your generated answer]
"""

    prompt = f'''<s>[INST] {instruction}\n\nText: {chunk} [/INST]\n'''
    
    # generated_text = generate(model, tokenizer, prompt=prompt, max_tokens=max_tokens, verbose=False)

    generated_text = generate(
        model, 
        tokenizer, 
        prompt=prompt, 
        max_tokens=max_tokens, 
        temp=0.8,  # Value between 0.0 and 1.0, higher values produce more creative output
        # top_k=50,         # Number of top tokens to consider
        top_p=0.95,       # Select tokens until cumulative probability exceeds this value
        verbose=False
    )

    generated_text = generated_text.replace(prompt, "").strip()
    generated_text = re.sub(r'\*\*|<end_of_turn>', '', generated_text)
    generated_text = re.sub(r'```', '', generated_text)
    
    question_start = generated_text.find("Question:")
    answer_start = generated_text.find("Answer:")
    
    if question_start != -1 and answer_start != -1:
        question = generated_text[question_start+9:answer_start].strip()
        answer = generated_text[answer_start+7:].strip()
    else:
        question = "Unable to generate a question."
        answer = "Unable to generate a response."
    
    return question, answer

def create_fine_tuning_dataset(pdf_path, output_file, num_samples=10000):
    start_time = time.time()
    print(f"Starting fine-tuning dataset creation (Target samples: {num_samples})")

    model_path = "mlx-community/gemma-2-27b-it-4bit"
    model, tokenizer = load_model_and_tokenizer(model_path)
    print("Model and tokenizer loaded")

    text = extract_text_from_pdf(pdf_path)
    chunks = create_chunks(text)
    
    print("Generating question-response pairs and writing to CSV...")
    with open(output_file, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['question', 'response'])
        
        for _ in tqdm(range(num_samples), desc="Generating samples"):
            chunk = random.choice(chunks)
            question, response = generate_question_response_pair(chunk, model, tokenizer)
            if question != "Unable to generate a question." and response != "Unable to generate a response.":
                writer.writerow([question, response])
            else:
                print(f"Skipping invalid sample: Q: {question}, A: {response}")
            if (_ + 1) % 100 == 0:
                print(f"Generated {_ + 1} samples...")
    
    end_time = time.time()
    print(f"Number of dataset samples generated: {num_samples}")
    print(f"Total time taken: {end_time - start_time:.2f} seconds")

# Usage example
pdf_path = "./data/CHERRY_compliance.pdf"
output_file = "./data/CHERRY_compliance.csv"
create_fine_tuning_dataset(pdf_path, output_file, num_samples=10000)

## 4. Building dataframe

In [8]:
import csv
import random
import json
import pandas as pd

# load the dataset
dataset = pd.read_csv("./data/CHERRY_compliance.csv")

system_prompt = """
You are CHERRY_Compliance, a virtual expert for CHERRY Shipping Company. Provide accurate information on:

- Documentation Requirements
- Operational Requirements
- Restrictions and Prohibited Items 
- Additional Requirements
- Special Handling Requirements
- Contact Information

Key behaviors:

Offer current information on CHERRY's policies and maritime regulations
Use clear language, adapting to the user's expertise level
Emphasize safety, compliance, and environmental protection
Provide practical advice and clarify misunderstandings
Reference the CHERRY Compliance document for detailed information
For country-specific queries, use the relevant country's section in the document

When responding:

Be concise yet thorough
If information is unavailable or unknown, clearly state "No Information Available"
Highlight recent regulatory changes when relevant
End each response with '–CHERRY_Compliance'
"""

dataset["system_prompt"] = system_prompt
dataset["system_prompt"] = dataset["system_prompt"].str.strip()
dataset = dataset[["system_prompt", "question", "response"]]
dataset

Unnamed: 0,system_prompt,question,response
0,"You are CHERRY_Compliance, a virtual expert fo...",What is the maximum cargo weight allowed for a...,The maximum cargo weight allowed for a 40-foot...
1,"You are CHERRY_Compliance, a virtual expert fo...",What are the maximum gross weight limitations ...,The maximum gross weight for a 20' container i...
2,"You are CHERRY_Compliance, a virtual expert fo...",What type of certificates are required for agr...,Agricultural products may require phytosanitar...
3,"You are CHERRY_Compliance, a virtual expert fo...",What Indonesian law governs shipments handled ...,Shipments must comply with Indonesian Customs ...
4,"You are CHERRY_Compliance, a virtual expert fo...",What contact information is mandatory for both...,"For both the shipper and consignee, an email a..."
...,...,...,...
5953,"You are CHERRY_Compliance, a virtual expert fo...",What are the required packaging details for sh...,Both outer and inner packaging details must be...
5954,"You are CHERRY_Compliance, a virtual expert fo...",What is the required advance notice for bookin...,At least 7 days prior to vessel arrival.
5955,"You are CHERRY_Compliance, a virtual expert fo...",What import taxes are applicable to shipments ...,Shipments to Chile are subject to Import VAT a...
5956,"You are CHERRY_Compliance, a virtual expert fo...",What are the required documentation for shippi...,A Dangerous Goods Declaration in Chinese and S...


In [9]:
# Count records for Unable to generate a question & Unable to generate a response
filtered_df = dataset[
    (dataset["question"] == "Unable to generate a question.") & 
    (dataset["response"] == "Unable to generate a response.")
]

count = filtered_df.shape[0]
print(f"Number of rows with the specified question and response: {count}")

Number of rows with the specified question and response: 0


In [10]:
# Count records for duplicated question and response
filtered_df = dataset[
    dataset.duplicated(subset=["question", "response"], keep=False)
]

count = filtered_df.shape[0]
print(f"Number of rows with the same question and response: {count}")

Number of rows with the same question and response: 1167


In [19]:
# Drop duplicates
df = dataset.drop_duplicates(subset=["question", "response"])
df

Unnamed: 0,system_prompt,question,response
0,"You are CHERRY_Compliance, a virtual expert fo...",What is the maximum cargo weight allowed for a...,The maximum cargo weight allowed for a 40-foot...
1,"You are CHERRY_Compliance, a virtual expert fo...",What are the maximum gross weight limitations ...,The maximum gross weight for a 20' container i...
2,"You are CHERRY_Compliance, a virtual expert fo...",What type of certificates are required for agr...,Agricultural products may require phytosanitar...
3,"You are CHERRY_Compliance, a virtual expert fo...",What Indonesian law governs shipments handled ...,Shipments must comply with Indonesian Customs ...
4,"You are CHERRY_Compliance, a virtual expert fo...",What contact information is mandatory for both...,"For both the shipper and consignee, an email a..."
...,...,...,...
5952,"You are CHERRY_Compliance, a virtual expert fo...",What contact information should be used for qu...,You can reach CHERRY Shipping Line's India off...
5954,"You are CHERRY_Compliance, a virtual expert fo...",What is the required advance notice for bookin...,At least 7 days prior to vessel arrival.
5955,"You are CHERRY_Compliance, a virtual expert fo...",What import taxes are applicable to shipments ...,Shipments to Chile are subject to Import VAT a...
5956,"You are CHERRY_Compliance, a virtual expert fo...",What are the required documentation for shippi...,A Dangerous Goods Declaration in Chinese and S...


In [12]:
# Transform Gemma prompt template (https://ai.google.dev/gemma/docs/formatting)
# {"text": "<bos><start_of_turn>user\nWhat is the capital of France?<end_of_turn>\n<start_of_turn>model\nParis is the capital of France.<end_of_turn><eos>"}

def generate_prompt(row: pd.Series) -> str:
    "Format to Gemma's chat template"
    return """<bos><start_of_turn>user
## Instructions
{}
## User
{}<end_of_turn>
<start_of_turn>model
{}<end_of_turn>""".format(row["system_prompt"], row["question"], row["response"])


df["text"] = df.apply(generate_prompt, axis=1)
print(df["text"].iloc[100])

<bos><start_of_turn>user
## Instructions
You are CHERRY_Compliance, a virtual expert for CHERRY Shipping Company. Provide accurate information on:

- Documentation Requirements
- Operational Requirements
- Restrictions and Prohibited Items 
- Additional Requirements
- Special Handling Requirements
- Contact Information

Key behaviors:

Offer current information on CHERRY's policies and maritime regulations
Use clear language, adapting to the user's expertise level
Emphasize safety, compliance, and environmental protection
Provide practical advice and clarify misunderstandings
Reference the CHERRY Compliance document for detailed information
For country-specific queries, use the relevant country's section in the document

When responding:

Be concise yet thorough
If information is unavailable or unknown, clearly state "No Information Available"
Highlight recent regulatory changes when relevant
End each response with '–CHERRY_Compliance'
## User
What is the required timeframe for electro

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["text"] = df.apply(generate_prompt, axis=1)


In [6]:
# Split dataset to train and valid 

from pathlib import Path

Path("data").mkdir(exist_ok=True)

split_ix = int(len(df) * 0.9)
# shuffle data
data = df.sample(frac=1, random_state=42)
train, valid = data[:split_ix], data[split_ix:]

# Save train and valid dataset as jsonl files
train[["text"]].to_json("data/train.jsonl", orient="records", lines=True, force_ascii=False)
valid[["text"]].to_json("data/valid.jsonl", orient="records", lines=True, force_ascii=False)

!head -n 1 data/train.jsonl

{"text":"<bos><start_of_turn>user\n## Instructions\nYou are CHERRY_Compliance, a virtual expert for CHERRY Shipping Company. Provide accurate information on:\n\n- Documentation Requirements\n- Operational Requirements\n- Restrictions and Prohibited Items \n- Additional Requirements\n- Special Handling Requirements\n- Contact Information\n\nKey behaviors:\n\nOffer current information on CHERRY's policies and maritime regulations\nUse clear language, adapting to the user's expertise level\nEmphasize safety, compliance, and environmental protection\nProvide practical advice and clarify misunderstandings\nReference the CHERRY Compliance document for detailed information\nFor country-specific queries, use the relevant country's section in the document\n\nWhen responding:\n\nBe concise yet thorough\nIf information is unavailable or unknown, clearly state \"No Information Available\"\nHighlight recent regulatory changes when relevant\nEnd each response with '–CHERRY_Compliance'\n## User\nWhat

## 5. LoRA fine-tuning

In [7]:
!python -m mlx_lm.lora --help

usage: lora.py [-h] [--model MODEL] [--train] [--data DATA]
               [--lora-layers LORA_LAYERS] [--batch-size BATCH_SIZE]
               [--iters ITERS] [--val-batches VAL_BATCHES]
               [--learning-rate LEARNING_RATE]
               [--steps-per-report STEPS_PER_REPORT]
               [--steps-per-eval STEPS_PER_EVAL]
               [--resume-adapter-file RESUME_ADAPTER_FILE]
               [--adapter-path ADAPTER_PATH] [--save-every SAVE_EVERY]
               [--test] [--test-batches TEST_BATCHES]
               [--max-seq-length MAX_SEQ_LENGTH] [-c CONFIG]
               [--grad-checkpoint] [--seed SEED] [--use-dora]

LoRA or QLoRA finetuning.

options:
  -h, --help            show this help message and exit
  --model MODEL         The path to the local model directory or Hugging Face
                        repo.
  --train               Do training
  --data DATA           Directory with {train, valid, test}.jsonl files
  --lora-layers LORA_LAYERS
                   

In [8]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [9]:
!python -m mlx_lm.lora \
    --model mlx-community/gemma-2-27b-it-4bit \
    --train \
    --data data \
    --iters 300 \
    --batch-size 4 \
    --learning-rate 1e-5 \
    --steps-per-report 10 \
    --steps-per-eval 10 \
    --adapter-path checkpoints/adapters \
    # --resume-adapter-file checkpoints/600_adapters.npz \
    --save-every 10 \
    --max-seq-length 2048 \
    --seed 42 \
    --lora-layers 16

Loading pretrained model
Fetching 9 files: 100%|███████████████████████| 9/9 [00:00<00:00, 108162.57it/s]
Loading datasets
Training
Trainable parameters: 0.007% (1.966M/27227.128M)
Starting training..., iters: 200
Iter 1: Val loss 3.598, Val took 206.026s
Iter 10: Val loss 2.665, Val took 238.179s
Iter 10: Train loss 3.075, Learning Rate 1.000e-05, It/sec 0.612, Tokens/sec 582.463, Trained Tokens 9513, Peak mem 31.318 GB
Iter 20: Val loss 1.794, Val took 245.594s
Iter 20: Train loss 2.168, Learning Rate 1.000e-05, It/sec 0.670, Tokens/sec 643.753, Trained Tokens 19116, Peak mem 32.516 GB
Iter 30: Val loss 0.849, Val took 215.245s
Iter 30: Train loss 1.217, Learning Rate 1.000e-05, It/sec 0.664, Tokens/sec 633.840, Trained Tokens 28657, Peak mem 35.784 GB
Iter 40: Val loss 0.504, Val took 221.943s
Iter 40: Train loss 0.600, Learning Rate 1.000e-05, It/sec 0.607, Tokens/sec 582.921, Trained Tokens 38253, Peak mem 35.784 GB
Iter 50: Val loss 0.374, Val took 200.716s
Iter 50: Train loss 0.

In [None]:
# !python -m mlx_lm.lora \
#     --model mlx-community/gemma-2-27b-it-4bit \
#     --train \
#     --data data \
#     --iters 600 \
#     --batch-size 4 \
#     --learning-rate 1e-5 \
#     --steps-per-report 10 \
#     --steps-per-eval 10 \
#     --adapter-path checkpoints/adapters \
#     --resume-adapter-file checkpoints/adapters/0000100_adapters.safetensors \
#     --save-every 10 \
#     --max-seq-length 2048 \
#     --seed 42 \
#     --lora-layers 16

## 6. Inference with fine-tuned model

In [13]:
# System prompt

system_prompt = df["system_prompt"].unique()[-1]
print(system_prompt)

You are CHERRY_Compliance, a virtual expert for CHERRY Shipping Company. Provide accurate information on:

- Documentation Requirements
- Operational Requirements
- Restrictions and Prohibited Items 
- Additional Requirements
- Special Handling Requirements
- Contact Information

Key behaviors:

Offer current information on CHERRY's policies and maritime regulations
Use clear language, adapting to the user's expertise level
Emphasize safety, compliance, and environmental protection
Provide practical advice and clarify misunderstandings
Reference the CHERRY Compliance document for detailed information
For country-specific queries, use the relevant country's section in the document

When responding:

Be concise yet thorough
If information is unavailable or unknown, clearly state "No Information Available"
Highlight recent regulatory changes when relevant
End each response with '–CHERRY_Compliance'


In [14]:
question = "What are the normal working hours for customs at Dar es Salaam Port? Let me know as detail as possible."


def format_prompt(system_prompt: str, question: str) -> str:
    "Format the question to the format of the dataset we fine-tuned to."
    return """<bos><start_of_turn>user
## Instructions
{}
## User
{}<end_of_turn>
<start_of_turn>model
""".format(
        system_prompt, question
    )


print(format_prompt(system_prompt, question))

<bos><start_of_turn>user
## Instructions
You are CHERRY_Compliance, a virtual expert for CHERRY Shipping Company. Provide accurate information on:

- Documentation Requirements
- Operational Requirements
- Restrictions and Prohibited Items 
- Additional Requirements
- Special Handling Requirements
- Contact Information

Key behaviors:

Offer current information on CHERRY's policies and maritime regulations
Use clear language, adapting to the user's expertise level
Emphasize safety, compliance, and environmental protection
Provide practical advice and clarify misunderstandings
Reference the CHERRY Compliance document for detailed information
For country-specific queries, use the relevant country's section in the document

When responding:

Be concise yet thorough
If information is unavailable or unknown, clearly state "No Information Available"
Highlight recent regulatory changes when relevant
End each response with '–CHERRY_Compliance'
## User
What are the normal working hours for cust

In [15]:
# Load the fine-tuned model with LoRA weights
model_lora, tokenizer = load(
    "mlx-community/gemma-2-27b-it-4bit",
    adapter_path="./checkpoints/adapters",
)

Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

In [16]:
response = generate(
    model_lora,
    tokenizer,
    prompt=format_prompt(system_prompt, question),
    verbose=True,
    temp=0.5,
    max_tokens=1024,
)

Prompt: <bos><start_of_turn>user
## Instructions
You are CHERRY_Compliance, a virtual expert for CHERRY Shipping Company. Provide accurate information on:

- Documentation Requirements
- Operational Requirements
- Restrictions and Prohibited Items 
- Additional Requirements
- Special Handling Requirements
- Contact Information

Key behaviors:

Offer current information on CHERRY's policies and maritime regulations
Use clear language, adapting to the user's expertise level
Emphasize safety, compliance, and environmental protection
Provide practical advice and clarify misunderstandings
Reference the CHERRY Compliance document for detailed information
For country-specific queries, use the relevant country's section in the document

When responding:

Be concise yet thorough
If information is unavailable or unknown, clearly state "No Information Available"
Highlight recent regulatory changes when relevant
End each response with '–CHERRY_Compliance'
## User
What are the normal working hours 

In [17]:
response = generate(
    model,
    tokenizer,
    prompt=format_prompt(system_prompt, question),
    verbose=True,
    temp=0.5,
    max_tokens=1024,
)

Prompt: <bos><start_of_turn>user
## Instructions
You are CHERRY_Compliance, a virtual expert for CHERRY Shipping Company. Provide accurate information on:

- Documentation Requirements
- Operational Requirements
- Restrictions and Prohibited Items 
- Additional Requirements
- Special Handling Requirements
- Contact Information

Key behaviors:

Offer current information on CHERRY's policies and maritime regulations
Use clear language, adapting to the user's expertise level
Emphasize safety, compliance, and environmental protection
Provide practical advice and clarify misunderstandings
Reference the CHERRY Compliance document for detailed information
For country-specific queries, use the relevant country's section in the document

When responding:

Be concise yet thorough
If information is unavailable or unknown, clearly state "No Information Available"
Highlight recent regulatory changes when relevant
End each response with '–CHERRY_Compliance'
## User
What are the normal working hours 

In [18]:
!python -m mlx_lm.generate \
    --model mlx-community/gemma-2-27b-it-4bit \
    --adapter-path checkpoints/adapters \
    --prompt "What are the normal working hours for customs at Dar es Salaam Port?" \
    --max-tokens 256 \
    --temp 0.5

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Fetching 9 files: 100%|████████████████████████| 9/9 [00:00<00:00, 26696.42it/s]
Prompt: <bos><start_of_turn>user
What are the normal working hours for customs at Dar es Salaam Port?<end_of_turn>
<start_of_turn>model

I cannot provide specific real-time information like working hours for government offices. Working hours are subject to change and are best obtained directly from the source.<end_of_turn>
Prompt: 24 tokens, 12.770 tokens-per-sec
Generation: 32 tokens, 13.584 tokens-per-sec
Peak memory: 14.642 GB


In [12]:
!python -m mlx_lm.generate --help

usage: generate.py [-h] [--model MODEL] [--adapter-path ADAPTER_PATH]
                   [--trust-remote-code] [--eos-token EOS_TOKEN]
                   [--prompt PROMPT] [--max-tokens MAX_TOKENS] [--temp TEMP]
                   [--top-p TOP_P] [--seed SEED] [--ignore-chat-template]
                   [--use-default-chat-template] [--colorize]
                   [--cache-limit-gb CACHE_LIMIT_GB]

LLM inference script

options:
  -h, --help            show this help message and exit
  --model MODEL         The path to the local model directory or Hugging Face
                        repo.
  --adapter-path ADAPTER_PATH
                        Optional path for the trained adapter weights and
                        config.
  --trust-remote-code   Enable trusting remote code for tokenizer
  --eos-token EOS_TOKEN
                        End of sequence token for tokenizer
  --prompt PROMPT       Message to be processed by the model
  --max-tokens MAX_TOKENS, -m MAX_TOKENS
               