# Fine-tuning the latest Google Gemma model locally using MLX

## 1. Preparation


In [58]:
# !pip install -Uqq mlx mlx_lm transformers datasets

## 2. Using MLX to Run Inference with Gemma Model using MLX

In [13]:
from mlx_lm import generate, load

model, tokenizer = load("mlx-community/gemma-2-27b-it-4bit")
# https://huggingface.co/mlx-community/gemma-2-27b-4bit 


Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

In [15]:
messages = [{"role": "user", "content": "What are the minimum size requirements for the marks and the UN packaging symbol on pressure receptacles, and how do these requirements differ based on the diameter of the pressure receptacle?"}]
result = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
print(result)

<bos><start_of_turn>user
What are the minimum size requirements for the marks and the UN packaging symbol on pressure receptacles, and how do these requirements differ based on the diameter of the pressure receptacle?<end_of_turn>
<start_of_turn>model



In [17]:
# Generating without adding a prompt template manually
prompt = """
What are the minimum size requirements for the marks and the UN packaging symbol on pressure receptacles, and how do these requirements differ based on the diameter of the pressure receptacle per IMDG code?
""".strip()
response = generate(
    model,
    tokenizer,
    prompt=prompt,
    verbose=True,  # Set to True to see the prompt and response
    temp=0.0,
    max_tokens=1024,
)

Prompt: What are the minimum size requirements for the marks and the UN packaging symbol on pressure receptacles, and how do these requirements differ based on the diameter of the pressure receptacle per IMDG code?


The IMDG Code (International Maritime Dangerous Goods Code) outlines the marking and labeling requirements for pressure receptacles.

**Minimum Size Requirements for Marks and UN Packaging Symbol:**

The minimum size requirements for marks and the UN packaging symbol on pressure receptacles are dependent on the diameter of the receptacle.

* **Receptacles with a diameter of less than 140 mm:**

    * The UN packaging symbol must have a minimum dimension of 6 mm.
    * Other marks (e.g., manufacturer's name, serial number, test pressure, etc.) must be clearly visible and legible.

* **Receptacles with a diameter of 140 mm or more:**

    * The UN packaging symbol must have a minimum dimension of 12 mm.
    * Other marks must be clearly visible and legible, with a minimum he

## 3. Generating training dataset from PDF file

In [4]:
import PyPDF2
import csv, re
import random
from tqdm import tqdm
import time
from transformers import AutoTokenizer
from mlx_lm import load, generate
import subprocess

def run_command_with_live_output(command: list[str]) -> None:
    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
    while True:
        output = process.stdout.readline()
        if output == '' and process.poll() is not None:
            break
        if output:
            print(output.strip())
    err_output = process.stderr.read()
    if err_output:
        print(err_output)

def extract_text_from_pdf(pdf_path):
    print("Extracting text from PDF...")
    with open(pdf_path, 'rb') as file:
        reader = PyPDF2.PdfReader(file)
        text = ""
        for page in tqdm(reader.pages, desc="Processing pages"):
            text += page.extract_text()
    print("Text extraction completed")
    return text

def create_chunks(text, min_chunk_size=200, max_chunk_size=1000):
    print("Splitting text into chunks...")
    chunks = []
    current_chunk = ""
    sentences = text.split('.')
    for sentence in tqdm(sentences, desc="Creating chunks"):
        if len(current_chunk) + len(sentence) > max_chunk_size and len(current_chunk) >= min_chunk_size:
            chunks.append(current_chunk.strip())
            current_chunk = sentence
        else:
            current_chunk += sentence + '.'
    if current_chunk:
        chunks.append(current_chunk.strip())
    print(f"Number of chunks created: {len(chunks)}")
    return chunks

def load_model_and_tokenizer(model_path):
    print("Loading model and tokenizer...")
    model, tokenizer = load(model_path)
    return model, tokenizer

def generate_question_response_pair(chunk, model, tokenizer, max_tokens=512):
    instruction = """
    Based on the IMDG code, create a practical question and provide a detailed, informative answer. 
    All the questions should be based on the IMDG code.
    Do not include any special markers like '**' or '<end_of_turn>'
    Format your response as follows:
    Question: [Your generated question]
    Answer: [Your generated answer]
"""

    prompt = f'''<s>[INST] {instruction}\n\nText: {chunk} [/INST]\n'''
    
    # generated_text = generate(model, tokenizer, prompt=prompt, max_tokens=max_tokens, verbose=False)

    generated_text = generate(
        model, 
        tokenizer, 
        prompt=prompt, 
        max_tokens=max_tokens, 
        temp=0.8,  # Value between 0.0 and 1.0, higher values produce more creative output
        # top_k=50,         # Number of top tokens to consider
        top_p=0.95,       # Select tokens until cumulative probability exceeds this value
        verbose=False
    )

    generated_text = generated_text.replace(prompt, "").strip()
    generated_text = re.sub(r'\*\*|<end_of_turn>', '', generated_text)
    generated_text = re.sub(r'```', '', generated_text)
    
    question_start = generated_text.find("Question:")
    answer_start = generated_text.find("Answer:")
    
    if question_start != -1 and answer_start != -1:
        question = generated_text[question_start+9:answer_start].strip()
        answer = generated_text[answer_start+7:].strip()
    else:
        question = "Unable to generate a question."
        answer = "Unable to generate a response."
    
    return question, answer

def create_fine_tuning_dataset(pdf_path, output_file, num_samples=10000):
    start_time = time.time()
    print(f"Starting fine-tuning dataset creation (Target samples: {num_samples})")

    model_path = "mlx-community/gemma-2-27b-it-4bit"
    model, tokenizer = load_model_and_tokenizer(model_path)
    print("Model and tokenizer loaded")

    text = extract_text_from_pdf(pdf_path)
    chunks = create_chunks(text)
    
    print("Generating question-response pairs and writing to CSV...")
    with open(output_file, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['question', 'response'])
        
        for _ in tqdm(range(num_samples), desc="Generating samples"):
            chunk = random.choice(chunks)
            question, response = generate_question_response_pair(chunk, model, tokenizer)
            if question != "Unable to generate a question." and response != "Unable to generate a response.":
                writer.writerow([question, response])
            else:
                print(f"Skipping invalid sample: Q: {question}, A: {response}")
            if (_ + 1) % 100 == 0:
                print(f"Generated {_ + 1} samples...")
    
    end_time = time.time()
    print(f"Number of dataset samples generated: {num_samples}")
    print(f"Total time taken: {end_time - start_time:.2f} seconds")

# Usage example
pdf_path = "./data/IMDG.pdf"
output_file = "./data/IMDG.csv"
create_fine_tuning_dataset(pdf_path, output_file, num_samples=10000)

Starting fine-tuning dataset creation (Target samples: 10000)
Loading model and tokenizer...


Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

Model and tokenizer loaded
Extracting text from PDF...


Processing pages: 100%|██████████| 1098/1098 [01:08<00:00, 15.99it/s]


Text extraction completed
Splitting text into chunks...


Creating chunks: 100%|██████████| 64387/64387 [00:00<00:00, 3784772.43it/s]


Number of chunks created: 3140
Generating question-response pairs and writing to CSV...


Generating samples:   0%|          | 3/10000 [00:59<52:24:33, 18.87s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   0%|          | 19/10000 [05:25<42:56:21, 15.49s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   0%|          | 31/10000 [08:53<37:44:59, 13.63s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   1%|          | 61/10000 [17:41<50:10:46, 18.18s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   1%|          | 66/10000 [18:37<28:43:51, 10.41s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   1%|          | 92/10000 [26:21<46:18:52, 16.83s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   1%|          | 100/10000 [28:49<56:26:57, 20.53s/it]

Generated 100 samples...


Generating samples:   1%|          | 106/10000 [30:16<39:23:03, 14.33s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   2%|▏         | 167/10000 [48:11<43:29:00, 15.92s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   2%|▏         | 200/10000 [57:58<47:18:09, 17.38s/it]

Generated 200 samples...


Generating samples:   2%|▏         | 222/10000 [1:04:36<46:24:10, 17.08s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   3%|▎         | 261/10000 [1:16:06<55:56:15, 20.68s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   3%|▎         | 267/10000 [1:17:29<36:17:56, 13.43s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   3%|▎         | 284/10000 [1:21:47<28:26:15, 10.54s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   3%|▎         | 300/10000 [1:26:17<40:56:13, 15.19s/it]

Generated 300 samples...


Generating samples:   3%|▎         | 307/10000 [1:27:59<32:49:55, 12.19s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   3%|▎         | 309/10000 [1:28:24<32:13:56, 11.97s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   3%|▎         | 312/10000 [1:29:26<42:23:24, 15.75s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   3%|▎         | 316/10000 [1:30:35<47:42:15, 17.73s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   3%|▎         | 329/10000 [1:34:31<53:46:24, 20.02s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   4%|▍         | 399/10000 [1:55:08<43:05:50, 16.16s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   4%|▍         | 400/10000 [1:55:44<58:47:15, 22.05s/it]

Generated 400 samples...


Generating samples:   4%|▍         | 430/10000 [2:03:12<33:52:53, 12.75s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   4%|▍         | 440/10000 [2:05:41<30:34:35, 11.51s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   4%|▍         | 450/10000 [2:09:06<54:56:13, 20.71s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   5%|▍         | 488/10000 [2:21:32<50:30:03, 19.11s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   5%|▌         | 500/10000 [2:25:08<51:20:31, 19.46s/it]

Generated 500 samples...


Generating samples:   5%|▌         | 530/10000 [2:33:15<38:12:17, 14.52s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   6%|▌         | 562/10000 [2:42:10<39:26:09, 15.04s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   6%|▌         | 593/10000 [2:49:43<38:03:09, 14.56s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   6%|▌         | 600/10000 [2:51:41<40:54:09, 15.66s/it]

Generated 600 samples...


Generating samples:   6%|▌         | 610/10000 [2:53:58<28:26:09, 10.90s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   6%|▌         | 612/10000 [2:54:25<32:03:38, 12.29s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   7%|▋         | 700/10000 [3:18:06<47:20:18, 18.32s/it]

Generated 700 samples...


Generating samples:   7%|▋         | 714/10000 [3:22:38<53:04:09, 20.57s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   7%|▋         | 728/10000 [3:26:45<51:03:20, 19.82s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   8%|▊         | 772/10000 [3:39:45<47:34:50, 18.56s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   8%|▊         | 800/10000 [3:47:43<46:48:04, 18.31s/it]

Generated 800 samples...


Generating samples:   9%|▉         | 892/10000 [4:13:51<36:46:47, 14.54s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:   9%|▉         | 900/10000 [4:16:16<40:00:50, 15.83s/it]

Generated 900 samples...


Generating samples:   9%|▉         | 937/10000 [4:27:25<54:01:53, 21.46s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  10%|▉         | 952/10000 [4:31:02<30:57:28, 12.32s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  10%|▉         | 973/10000 [4:36:25<37:55:12, 15.12s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  10%|▉         | 987/10000 [4:40:08<39:11:20, 15.65s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  10%|▉         | 992/10000 [4:41:14<30:19:16, 12.12s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  10%|█         | 1000/10000 [4:43:49<45:19:00, 18.13s/it]

Generated 1000 samples...


Generating samples:  10%|█         | 1024/10000 [4:50:00<30:39:15, 12.29s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  10%|█         | 1043/10000 [4:55:21<34:12:32, 13.75s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  11%|█         | 1061/10000 [4:59:41<29:20:03, 11.81s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  11%|█         | 1070/10000 [5:01:55<29:47:19, 12.01s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  11%|█         | 1077/10000 [5:04:03<40:01:41, 16.15s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  11%|█         | 1100/10000 [5:10:20<34:57:40, 14.14s/it]

Generated 1100 samples...


Generating samples:  12%|█▏        | 1179/10000 [5:32:16<46:14:56, 18.88s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  12%|█▏        | 1190/10000 [5:35:34<44:00:29, 17.98s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  12%|█▏        | 1200/10000 [5:38:03<37:31:21, 15.35s/it]

Generated 1200 samples...


Generating samples:  12%|█▏        | 1217/10000 [5:42:53<38:06:18, 15.62s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  12%|█▏        | 1231/10000 [5:47:01<33:18:27, 13.67s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  12%|█▏        | 1232/10000 [5:47:05<25:48:08, 10.59s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  12%|█▏        | 1249/10000 [5:52:21<41:23:00, 17.02s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  13%|█▎        | 1292/10000 [6:04:05<31:08:47, 12.88s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  13%|█▎        | 1300/10000 [6:06:13<34:13:16, 14.16s/it]

Generated 1300 samples...


Generating samples:  13%|█▎        | 1310/10000 [6:09:17<53:07:52, 22.01s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  13%|█▎        | 1327/10000 [6:14:17<35:55:52, 14.91s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  13%|█▎        | 1331/10000 [6:15:19<35:39:38, 14.81s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  14%|█▎        | 1362/10000 [6:24:31<32:39:57, 13.61s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  14%|█▍        | 1387/10000 [6:31:30<35:39:42, 14.91s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  14%|█▍        | 1400/10000 [6:35:03<45:16:21, 18.95s/it]

Generated 1400 samples...


Generating samples:  14%|█▍        | 1403/10000 [6:35:42<35:57:08, 15.06s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  14%|█▍        | 1404/10000 [6:35:52<31:57:48, 13.39s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  14%|█▍        | 1409/10000 [6:37:29<43:15:57, 18.13s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  14%|█▍        | 1428/10000 [6:42:49<39:00:50, 16.38s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  14%|█▍        | 1439/10000 [6:45:22<32:04:51, 13.49s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  14%|█▍        | 1443/10000 [6:46:21<30:49:32, 12.97s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  15%|█▍        | 1464/10000 [6:52:14<45:53:35, 19.36s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  15%|█▍        | 1467/10000 [6:53:11<43:30:40, 18.36s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  15%|█▍        | 1496/10000 [7:00:53<45:30:23, 19.26s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  15%|█▌        | 1500/10000 [7:01:53<34:31:39, 14.62s/it]

Generated 1500 samples...


Generating samples:  16%|█▌        | 1600/10000 [7:30:52<31:50:09, 13.64s/it]

Generated 1600 samples...


Generating samples:  16%|█▌        | 1604/10000 [7:32:07<39:55:22, 17.12s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  16%|█▌        | 1606/10000 [7:32:22<27:26:12, 11.77s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  16%|█▌        | 1613/10000 [7:34:44<49:01:17, 21.04s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  16%|█▌        | 1624/10000 [7:37:38<36:52:15, 15.85s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  16%|█▋        | 1643/10000 [7:43:07<29:44:54, 12.81s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  16%|█▋        | 1644/10000 [7:43:15<26:27:29, 11.40s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  16%|█▋        | 1650/10000 [7:45:04<38:44:55, 16.71s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  17%|█▋        | 1668/10000 [7:49:49<36:42:31, 15.86s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  17%|█▋        | 1700/10000 [8:00:36<53:16:43, 23.11s/it]

Generated 1700 samples...


Generating samples:  18%|█▊        | 1763/10000 [8:17:31<30:13:35, 13.21s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  18%|█▊        | 1800/10000 [8:27:05<33:15:40, 14.60s/it]

Generated 1800 samples...


Generating samples:  18%|█▊        | 1811/10000 [8:30:27<39:48:23, 17.50s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  18%|█▊        | 1822/10000 [8:33:10<37:30:59, 16.52s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  18%|█▊        | 1823/10000 [8:33:12<27:51:50, 12.27s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  18%|█▊        | 1831/10000 [8:35:00<23:48:36, 10.49s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  18%|█▊        | 1836/10000 [8:36:20<36:32:15, 16.11s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  18%|█▊        | 1841/10000 [8:37:42<32:39:57, 14.41s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  18%|█▊        | 1842/10000 [8:38:16<46:22:19, 20.46s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  19%|█▊        | 1865/10000 [8:44:50<42:02:09, 18.60s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  19%|█▊        | 1874/10000 [8:47:40<42:10:57, 18.69s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  19%|█▉        | 1878/10000 [8:48:50<41:28:10, 18.38s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  19%|█▉        | 1889/10000 [8:51:33<32:11:00, 14.28s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  19%|█▉        | 1900/10000 [8:54:46<41:11:20, 18.31s/it]

Generated 1900 samples...


Generating samples:  19%|█▉        | 1935/10000 [9:04:46<33:49:42, 15.10s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  19%|█▉        | 1940/10000 [9:06:00<28:55:51, 12.92s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  19%|█▉        | 1941/10000 [9:06:14<30:03:16, 13.43s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  19%|█▉        | 1942/10000 [9:06:17<22:48:14, 10.19s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  20%|█▉        | 1999/10000 [9:22:16<27:57:36, 12.58s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  20%|██        | 2000/10000 [9:22:41<36:36:44, 16.48s/it]

Generated 2000 samples...


Generating samples:  21%|██        | 2100/10000 [9:50:27<32:03:35, 14.61s/it]

Generated 2100 samples...


Generating samples:  21%|██▏       | 2135/10000 [10:01:15<31:34:42, 14.45s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  21%|██▏       | 2138/10000 [10:01:54<29:32:34, 13.53s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  21%|██▏       | 2139/10000 [10:01:56<22:26:44, 10.28s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  22%|██▏       | 2154/10000 [10:05:54<33:17:28, 15.28s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  22%|██▏       | 2189/10000 [10:15:02<31:40:43, 14.60s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  22%|██▏       | 2191/10000 [10:15:28<30:43:31, 14.16s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  22%|██▏       | 2200/10000 [10:18:08<35:28:09, 16.37s/it]

Generated 2200 samples...


Generating samples:  23%|██▎       | 2280/10000 [10:41:04<31:38:25, 14.75s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  23%|██▎       | 2300/10000 [10:47:16<40:19:51, 18.86s/it]

Generated 2300 samples...


Generating samples:  24%|██▍       | 2376/10000 [11:09:49<28:57:45, 13.68s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  24%|██▍       | 2399/10000 [11:16:23<36:59:05, 17.52s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  24%|██▍       | 2400/10000 [11:16:58<48:26:03, 22.94s/it]

Generated 2400 samples...


Generating samples:  24%|██▍       | 2412/10000 [11:20:39<39:12:04, 18.60s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  25%|██▍       | 2458/10000 [11:33:27<19:41:13,  9.40s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  25%|██▍       | 2465/10000 [11:35:29<34:44:35, 16.60s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  25%|██▌       | 2500/10000 [11:45:33<31:12:53, 14.98s/it]

Generated 2500 samples...


Generating samples:  25%|██▌       | 2530/10000 [11:54:11<35:21:58, 17.04s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  25%|██▌       | 2536/10000 [11:55:42<27:13:13, 13.13s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  25%|██▌       | 2542/10000 [11:57:18<25:55:32, 12.51s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  26%|██▌       | 2553/10000 [12:00:23<29:22:29, 14.20s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  26%|██▌       | 2600/10000 [12:13:38<35:14:31, 17.14s/it]

Generated 2600 samples...


Generating samples:  26%|██▌       | 2620/10000 [12:18:42<24:39:10, 12.03s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  26%|██▋       | 2629/10000 [12:21:02<30:17:58, 14.80s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  27%|██▋       | 2677/10000 [12:34:35<29:57:01, 14.72s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  27%|██▋       | 2700/10000 [12:39:51<29:15:09, 14.43s/it]

Generated 2700 samples...


Generating samples:  27%|██▋       | 2727/10000 [12:46:45<32:41:31, 16.18s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  27%|██▋       | 2733/10000 [12:48:18<31:31:13, 15.61s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  27%|██▋       | 2737/10000 [12:49:37<38:23:45, 19.03s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  28%|██▊       | 2780/10000 [13:01:28<37:56:10, 18.92s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  28%|██▊       | 2791/10000 [13:04:15<32:12:20, 16.08s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  28%|██▊       | 2800/10000 [13:06:38<29:58:32, 14.99s/it]

Generated 2800 samples...


Generating samples:  29%|██▉       | 2888/10000 [13:30:25<40:46:04, 20.64s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  29%|██▉       | 2900/10000 [13:33:57<35:34:46, 18.04s/it]

Generated 2900 samples...


Generating samples:  29%|██▉       | 2926/10000 [13:41:41<30:08:03, 15.34s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  29%|██▉       | 2929/10000 [13:42:37<32:14:32, 16.42s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  29%|██▉       | 2943/10000 [13:46:14<27:42:09, 14.13s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  29%|██▉       | 2944/10000 [13:46:37<32:59:03, 16.83s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  30%|██▉       | 2972/10000 [13:54:33<33:30:41, 17.17s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  30%|██▉       | 2995/10000 [14:02:09<45:25:39, 23.35s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  30%|███       | 3000/10000 [14:03:32<33:36:43, 17.29s/it]

Generated 3000 samples...


Generating samples:  31%|███       | 3056/10000 [14:20:05<28:54:11, 14.98s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  31%|███       | 3070/10000 [14:23:29<20:33:47, 10.68s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  31%|███       | 3081/10000 [14:26:40<31:10:51, 16.22s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  31%|███       | 3099/10000 [14:31:28<31:43:04, 16.55s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  31%|███       | 3100/10000 [14:31:50<34:26:05, 17.97s/it]

Generated 3100 samples...


Generating samples:  31%|███       | 3123/10000 [14:38:34<40:21:14, 21.12s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  32%|███▏      | 3159/10000 [14:47:33<20:06:02, 10.58s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  32%|███▏      | 3178/10000 [14:53:46<26:54:15, 14.20s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  32%|███▏      | 3200/10000 [15:00:32<32:00:43, 16.95s/it]

Generated 3200 samples...


Generating samples:  33%|███▎      | 3269/10000 [15:19:23<32:15:16, 17.25s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  33%|███▎      | 3293/10000 [15:26:08<24:01:35, 12.90s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  33%|███▎      | 3300/10000 [15:27:59<31:48:08, 17.09s/it]

Generated 3300 samples...


Generating samples:  33%|███▎      | 3335/10000 [15:37:51<27:24:12, 14.80s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  33%|███▎      | 3346/10000 [15:40:47<40:03:20, 21.67s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  34%|███▎      | 3356/10000 [15:43:44<30:41:00, 16.63s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  34%|███▍      | 3400/10000 [15:57:09<33:37:07, 18.34s/it]

Generated 3400 samples...


Generating samples:  34%|███▍      | 3407/10000 [15:59:05<36:18:32, 19.83s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  34%|███▍      | 3446/10000 [16:10:24<26:14:36, 14.42s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  35%|███▍      | 3487/10000 [16:21:57<21:42:16, 12.00s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  35%|███▌      | 3500/10000 [16:25:45<30:32:23, 16.91s/it]

Generated 3500 samples...


Generating samples:  35%|███▌      | 3527/10000 [16:33:43<28:58:22, 16.11s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  35%|███▌      | 3531/10000 [16:34:27<19:24:49, 10.80s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  36%|███▌      | 3552/10000 [16:39:47<23:57:40, 13.38s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  36%|███▌      | 3600/10000 [16:52:56<22:35:08, 12.70s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.
Generated 3600 samples...


Generating samples:  36%|███▋      | 3640/10000 [17:02:52<21:17:54, 12.06s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  37%|███▋      | 3700/10000 [17:20:51<25:09:02, 14.37s/it]

Generated 3700 samples...


Generating samples:  38%|███▊      | 3751/10000 [17:36:32<30:49:53, 17.76s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  38%|███▊      | 3763/10000 [17:40:16<28:37:54, 16.53s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  38%|███▊      | 3800/10000 [17:50:55<28:12:21, 16.38s/it]

Generated 3800 samples...


Generating samples:  38%|███▊      | 3801/10000 [17:51:24<34:42:17, 20.15s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  38%|███▊      | 3805/10000 [17:52:18<25:16:13, 14.69s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  38%|███▊      | 3827/10000 [17:58:29<29:03:29, 16.95s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  38%|███▊      | 3829/10000 [17:59:01<27:58:11, 16.32s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  38%|███▊      | 3832/10000 [17:59:33<20:59:29, 12.25s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  39%|███▉      | 3885/10000 [18:14:03<28:26:46, 16.75s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  39%|███▉      | 3888/10000 [18:15:10<31:52:07, 18.77s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  39%|███▉      | 3889/10000 [18:15:13<23:41:43, 13.96s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  39%|███▉      | 3900/10000 [18:18:29<32:34:32, 19.22s/it]

Generated 3900 samples...


Generating samples:  39%|███▉      | 3945/10000 [18:31:28<21:56:21, 13.04s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  40%|███▉      | 3958/10000 [18:35:06<31:00:19, 18.47s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  40%|███▉      | 3964/10000 [18:36:36<22:42:20, 13.54s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  40%|███▉      | 3977/10000 [18:40:48<29:59:35, 17.93s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  40%|███▉      | 3987/10000 [18:43:22<23:48:47, 14.26s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  40%|████      | 4000/10000 [18:46:39<23:10:54, 13.91s/it]

Generated 4000 samples...


Generating samples:  40%|████      | 4043/10000 [18:59:11<23:29:03, 14.19s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  41%|████      | 4055/10000 [19:02:24<26:46:04, 16.21s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  41%|████      | 4098/10000 [19:14:08<27:59:38, 17.08s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  41%|████      | 4100/10000 [19:14:37<25:21:57, 15.48s/it]

Generated 4100 samples...


Generating samples:  41%|████      | 4105/10000 [19:15:40<17:05:06, 10.43s/it]

Skipping invalid sample: Q: Unable to generate a question., A: Unable to generate a response.


Generating samples:  42%|████▏     | 4152/10000 [19:30:03<26:14:07, 16.15s/it]

## 4. Building dataframe

In [1]:
import csv
import random
import json
import pandas as pd

# load the dataset
dataset = pd.read_csv("./data/IMDG.csv")

system_prompt = """
    You are an IMDG Code specialist with a deep understanding of the International Maritime Dangerous Goods (IMDG) Code. Your task is to provide detailed and accurate information in response to user questions about IMDG code.

    Guidelines for answering:
    1. Present information in a clear, concise, and easily understandable manner, using bullet points for organization.
    2. For questions about specific UN Numbers:
       - Provide details on: UN No., Proper Shipping Name, Class, Subsidiary hazard, Packing Group, Special Provisions, Limited Quantity, Excepted Quantity, Packing Instructions, Stowage and handling, Segregation, and Properties and observations.
       - If this information is not in the provided context, state that you need to refer to the official IMDG Code for accurate details.
    3. Respond in the language of the user's question. If unable to determine the language, default to English.
    4. If you don't know the answer or if the information is not in the provided context, clearly state "I don't have enough information to answer this question accurately. Please refer to the official IMDG Code or consult with an IMDG expert for the most up-to-date and accurate information."
    5. End each response with '–IMDGGenie'
"""

dataset["system_prompt"] = system_prompt
dataset["system_prompt"] = dataset["system_prompt"].str.strip()
dataset = dataset[["system_prompt", "question", "response"]]
dataset

Unnamed: 0,system_prompt,question,response
0,You are an IMDG Code specialist with a deep un...,"According to the IMDG code, how should a subst...",Category A toxins are assigned to UN number 34...
1,You are an IMDG Code specialist with a deep un...,What are the specific packaging requirements f...,Paragraph P600 of the IMDG Code outlines speci...
2,You are an IMDG Code specialist with a deep un...,A cargo transport unit contains both dangerous...,"Yes, you would need to placard the cargo trans..."
3,You are an IMDG Code specialist with a deep un...,Based on the IMDG Code information provided fo...,"According to the IMDG Code excerpt, Nickel Car..."
4,You are an IMDG Code specialist with a deep un...,"A shipment of silvery-white, ductile, soft met...",This material aligns with the descriptions pro...
...,...,...,...
3988,You are an IMDG Code specialist with a deep un...,"According to the IMDG Code, what are the regul...",The IMDG Code stipulates that placards indicat...
3989,You are an IMDG Code specialist with a deep un...,"A chemical shipment labeled ""SG49"" is declared...","Given the hazardous nature of the ""SG49"" corro..."
3990,You are an IMDG Code specialist with a deep un...,A mixture contains two toxic components: Subst...,"To determine the packing group, we need to cal..."
3991,You are an IMDG Code specialist with a deep un...,A package design incorporates a single waterti...,"No, we cannot assume that water will not leak ..."


In [2]:
# Count records for Unable to generate a question & Unable to generate a response
filtered_df = dataset[
    (dataset["question"] == "Unable to generate a question.") & 
    (dataset["response"] == "Unable to generate a response.")
]

count = filtered_df.shape[0]
print(f"Number of rows with the specified question and response: {count}")

Number of rows with the specified question and response: 0


In [3]:
# Count records for duplicated question and response
filtered_df = dataset[
    dataset.duplicated(subset=["question", "response"], keep=False)
]

count = filtered_df.shape[0]
print(f"Number of rows with the same question and response: {count}")

Number of rows with the same question and response: 0


In [4]:
# Drop duplicates
df = dataset.drop_duplicates(subset=["question", "response"])
df

Unnamed: 0,system_prompt,question,response
0,You are an IMDG Code specialist with a deep un...,"According to the IMDG code, how should a subst...",Category A toxins are assigned to UN number 34...
1,You are an IMDG Code specialist with a deep un...,What are the specific packaging requirements f...,Paragraph P600 of the IMDG Code outlines speci...
2,You are an IMDG Code specialist with a deep un...,A cargo transport unit contains both dangerous...,"Yes, you would need to placard the cargo trans..."
3,You are an IMDG Code specialist with a deep un...,Based on the IMDG Code information provided fo...,"According to the IMDG Code excerpt, Nickel Car..."
4,You are an IMDG Code specialist with a deep un...,"A shipment of silvery-white, ductile, soft met...",This material aligns with the descriptions pro...
...,...,...,...
3988,You are an IMDG Code specialist with a deep un...,"According to the IMDG Code, what are the regul...",The IMDG Code stipulates that placards indicat...
3989,You are an IMDG Code specialist with a deep un...,"A chemical shipment labeled ""SG49"" is declared...","Given the hazardous nature of the ""SG49"" corro..."
3990,You are an IMDG Code specialist with a deep un...,A mixture contains two toxic components: Subst...,"To determine the packing group, we need to cal..."
3991,You are an IMDG Code specialist with a deep un...,A package design incorporates a single waterti...,"No, we cannot assume that water will not leak ..."


In [5]:
# Transform Gemma prompt template (https://ai.google.dev/gemma/docs/formatting)
# {"text": "<bos><start_of_turn>user\nWhat is the capital of France?<end_of_turn>\n<start_of_turn>model\nParis is the capital of France.<end_of_turn><eos>"}

def generate_prompt(row: pd.Series) -> str:
    "Format to Gemma's chat template"
    return """<bos><start_of_turn>user
## Instructions
{}
## User
{}<end_of_turn>
<start_of_turn>model
{}<end_of_turn>""".format(row["system_prompt"], row["question"], row["response"])


df["text"] = df.apply(generate_prompt, axis=1)
print(df["text"].iloc[100])

<bos><start_of_turn>user
## Instructions
You are an IMDG Code specialist with a deep understanding of the International Maritime Dangerous Goods (IMDG) Code. Your task is to provide detailed and accurate information in response to user questions about IMDG code.

    Guidelines for answering:
    1. Present information in a clear, concise, and easily understandable manner, using bullet points for organization.
    2. For questions about specific UN Numbers:
       - Provide details on: UN No., Proper Shipping Name, Class, Subsidiary hazard, Packing Group, Special Provisions, Limited Quantity, Excepted Quantity, Packing Instructions, Stowage and handling, Segregation, and Properties and observations.
       - If this information is not in the provided context, state that you need to refer to the official IMDG Code for accurate details.
    3. Respond in the language of the user's question. If unable to determine the language, default to English.
    4. If you don't know the answer or if

In [6]:
# Split dataset to train and valid 

from pathlib import Path

Path("data").mkdir(exist_ok=True)

split_ix = int(len(df) * 0.9)
# shuffle data
data = df.sample(frac=1, random_state=42)
train, valid = data[:split_ix], data[split_ix:]

# Save train and valid dataset as jsonl files
train[["text"]].to_json("data/train.jsonl", orient="records", lines=True, force_ascii=False)
valid[["text"]].to_json("data/valid.jsonl", orient="records", lines=True, force_ascii=False)

!head -n 1 data/train.jsonl

{"text":"<bos><start_of_turn>user\n## Instructions\nYou are an IMDG Code specialist with a deep understanding of the International Maritime Dangerous Goods (IMDG) Code. Your task is to provide detailed and accurate information in response to user questions about IMDG code.\n\n    Guidelines for answering:\n    1. Present information in a clear, concise, and easily understandable manner, using bullet points for organization.\n    2. For questions about specific UN Numbers:\n       - Provide details on: UN No., Proper Shipping Name, Class, Subsidiary hazard, Packing Group, Special Provisions, Limited Quantity, Excepted Quantity, Packing Instructions, Stowage and handling, Segregation, and Properties and observations.\n       - If this information is not in the provided context, state that you need to refer to the official IMDG Code for accurate details.\n    3. Respond in the language of the user's question. If unable to determine the language, default to English.\n    4. If you don't kn

## 5. LoRA fine-tuning

In [7]:
!python -m mlx_lm.lora --help

usage: lora.py [-h] [--model MODEL] [--train] [--data DATA]
               [--lora-layers LORA_LAYERS] [--batch-size BATCH_SIZE]
               [--iters ITERS] [--val-batches VAL_BATCHES]
               [--learning-rate LEARNING_RATE]
               [--steps-per-report STEPS_PER_REPORT]
               [--steps-per-eval STEPS_PER_EVAL]
               [--resume-adapter-file RESUME_ADAPTER_FILE]
               [--adapter-path ADAPTER_PATH] [--save-every SAVE_EVERY]
               [--test] [--test-batches TEST_BATCHES]
               [--max-seq-length MAX_SEQ_LENGTH] [-c CONFIG]
               [--grad-checkpoint] [--seed SEED] [--use-dora]

LoRA or QLoRA finetuning.

options:
  -h, --help            show this help message and exit
  --model MODEL         The path to the local model directory or Hugging Face
                        repo.
  --train               Do training
  --data DATA           Directory with {train, valid, test}.jsonl files
  --lora-layers LORA_LAYERS
                   

In [8]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [9]:
!python -m mlx_lm.lora \
    --model mlx-community/gemma-2-27b-it-4bit \
    --train \
    --data data \
    --iters 500 \
    --batch-size 4 \
    --learning-rate 1e-5 \
    --steps-per-report 10 \
    --steps-per-eval 10 \
    --adapter-path checkpoints/adapters \
    # --resume-adapter-file checkpoints/600_adapters.npz \
    --save-every 10 \
    --max-seq-length 2048 \
    --seed 42 \
    --lora-layers 16

Loading pretrained model
Fetching 9 files: 100%|████████████████████████| 9/9 [00:00<00:00, 77672.30it/s]
Loading datasets
Training
Trainable parameters: 0.007% (1.966M/27227.128M)
Starting training..., iters: 500
Iter 1: Val loss 2.061, Val took 400.467s
Iter 10: Val loss 1.522, Val took 398.144s
Iter 10: Train loss 1.815, Learning Rate 1.000e-05, It/sec 0.486, Tokens/sec 885.188, Trained Tokens 18218, Peak mem 58.444 GB
Iter 20: Val loss 1.172, Val took 415.262s
Iter 20: Train loss 1.326, Learning Rate 1.000e-05, It/sec 0.232, Tokens/sec 491.767, Trained Tokens 39452, Peak mem 60.797 GB
Iter 30: Val loss 0.811, Val took 407.030s
Iter 30: Train loss 0.924, Learning Rate 1.000e-05, It/sec 0.429, Tokens/sec 821.017, Trained Tokens 58606, Peak mem 60.797 GB
Iter 40: Val loss 0.578, Val took 402.604s
Iter 40: Train loss 0.682, Learning Rate 1.000e-05, It/sec 0.460, Tokens/sec 918.351, Trained Tokens 78561, Peak mem 60.797 GB
Iter 50: Val loss 0.535, Val took 396.426s
Iter 50: Train loss 0

In [None]:
# !python -m mlx_lm.lora \
#     --model mlx-community/gemma-2-27b-it-4bit \
#     --train \
#     --data data \
#     --iters 600 \
#     --batch-size 4 \
#     --learning-rate 1e-5 \
#     --steps-per-report 10 \
#     --steps-per-eval 10 \
#     --adapter-path checkpoints/adapters \
#     --resume-adapter-file checkpoints/adapters/0000100_adapters.safetensors \
#     --save-every 10 \
#     --max-seq-length 2048 \
#     --seed 42 \
#     --lora-layers 16

## 6. Inference with fine-tuned model

In [18]:
# System prompt

system_prompt = df["system_prompt"].unique()[-1]
print(system_prompt)

You are an IMDG Code specialist with a deep understanding of the International Maritime Dangerous Goods (IMDG) Code. Your task is to provide detailed and accurate information in response to user questions about IMDG code.

    Guidelines for answering:
    1. Present information in a clear, concise, and easily understandable manner, using bullet points for organization.
    2. For questions about specific UN Numbers:
       - Provide details on: UN No., Proper Shipping Name, Class, Subsidiary hazard, Packing Group, Special Provisions, Limited Quantity, Excepted Quantity, Packing Instructions, Stowage and handling, Segregation, and Properties and observations.
       - If this information is not in the provided context, state that you need to refer to the official IMDG Code for accurate details.
    3. Respond in the language of the user's question. If unable to determine the language, default to English.
    4. If you don't know the answer or if the information is not in the provided c

In [30]:
question = "What are the minimum wall thickness requirements for IBCs used for transporting liquids according to the IMDG Code?"


def format_prompt(system_prompt: str, question: str) -> str:
    "Format the question to the format of the dataset we fine-tuned to."
    return """<bos><start_of_turn>user
## Instructions
{}
## User
{}<end_of_turn>
<start_of_turn>model
""".format(
        system_prompt, question
    )


print(format_prompt(system_prompt, question))

<bos><start_of_turn>user
## Instructions
You are an IMDG Code specialist with a deep understanding of the International Maritime Dangerous Goods (IMDG) Code. Your task is to provide detailed and accurate information in response to user questions about IMDG code.

    Guidelines for answering:
    1. Present information in a clear, concise, and easily understandable manner, using bullet points for organization.
    2. For questions about specific UN Numbers:
       - Provide details on: UN No., Proper Shipping Name, Class, Subsidiary hazard, Packing Group, Special Provisions, Limited Quantity, Excepted Quantity, Packing Instructions, Stowage and handling, Segregation, and Properties and observations.
       - If this information is not in the provided context, state that you need to refer to the official IMDG Code for accurate details.
    3. Respond in the language of the user's question. If unable to determine the language, default to English.
    4. If you don't know the answer or if

In [31]:
# Load the fine-tuned model with LoRA weights
model_lora, tokenizer = load(
    "mlx-community/gemma-2-27b-it-4bit",
    adapter_path="./checkpoints/adapters",
)

Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

In [32]:
response = generate(
    model_lora,
    tokenizer,
    prompt=format_prompt(system_prompt, question),
    verbose=True,
    temp=0.5,
    max_tokens=1024,
)

Prompt: <bos><start_of_turn>user
## Instructions
You are an IMDG Code specialist with a deep understanding of the International Maritime Dangerous Goods (IMDG) Code. Your task is to provide detailed and accurate information in response to user questions about IMDG code.

    Guidelines for answering:
    1. Present information in a clear, concise, and easily understandable manner, using bullet points for organization.
    2. For questions about specific UN Numbers:
       - Provide details on: UN No., Proper Shipping Name, Class, Subsidiary hazard, Packing Group, Special Provisions, Limited Quantity, Excepted Quantity, Packing Instructions, Stowage and handling, Segregation, and Properties and observations.
       - If this information is not in the provided context, state that you need to refer to the official IMDG Code for accurate details.
    3. Respond in the language of the user's question. If unable to determine the language, default to English.
    4. If you don't know the answ

In [33]:
response = generate(
    model,
    tokenizer,
    prompt=format_prompt(system_prompt, question),
    verbose=True,
    temp=0.5,
    max_tokens=1024,
)

Prompt: <bos><start_of_turn>user
## Instructions
You are an IMDG Code specialist with a deep understanding of the International Maritime Dangerous Goods (IMDG) Code. Your task is to provide detailed and accurate information in response to user questions about IMDG code.

    Guidelines for answering:
    1. Present information in a clear, concise, and easily understandable manner, using bullet points for organization.
    2. For questions about specific UN Numbers:
       - Provide details on: UN No., Proper Shipping Name, Class, Subsidiary hazard, Packing Group, Special Provisions, Limited Quantity, Excepted Quantity, Packing Instructions, Stowage and handling, Segregation, and Properties and observations.
       - If this information is not in the provided context, state that you need to refer to the official IMDG Code for accurate details.
    3. Respond in the language of the user's question. If unable to determine the language, default to English.
    4. If you don't know the answ

In [18]:
!python -m mlx_lm.generate \
    --model mlx-community/gemma-2-27b-it-4bit \
    --adapter-path checkpoints/adapters \
    --prompt "What are the normal working hours for customs at Dar es Salaam Port?" \
    --max-tokens 256 \
    --temp 0.5

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Fetching 9 files: 100%|████████████████████████| 9/9 [00:00<00:00, 26696.42it/s]
Prompt: <bos><start_of_turn>user
What are the normal working hours for customs at Dar es Salaam Port?<end_of_turn>
<start_of_turn>model

I cannot provide specific real-time information like working hours for government offices. Working hours are subject to change and are best obtained directly from the source.<end_of_turn>
Prompt: 24 tokens, 12.770 tokens-per-sec
Generation: 32 tokens, 13.584 tokens-per-sec
Peak memory: 14.642 GB


In [12]:
!python -m mlx_lm.generate --help

usage: generate.py [-h] [--model MODEL] [--adapter-path ADAPTER_PATH]
                   [--trust-remote-code] [--eos-token EOS_TOKEN]
                   [--prompt PROMPT] [--max-tokens MAX_TOKENS] [--temp TEMP]
                   [--top-p TOP_P] [--seed SEED] [--ignore-chat-template]
                   [--use-default-chat-template] [--colorize]
                   [--cache-limit-gb CACHE_LIMIT_GB]

LLM inference script

options:
  -h, --help            show this help message and exit
  --model MODEL         The path to the local model directory or Hugging Face
                        repo.
  --adapter-path ADAPTER_PATH
                        Optional path for the trained adapter weights and
                        config.
  --trust-remote-code   Enable trusting remote code for tokenizer
  --eos-token EOS_TOKEN
                        End of sequence token for tokenizer
  --prompt PROMPT       Message to be processed by the model
  --max-tokens MAX_TOKENS, -m MAX_TOKENS
               