# Fine-Tuning Llama-3.2-1B-Instruct for Customer Support Triage

This notebook fine-tunes `mlx-community/Llama-3.2-1B-Instruct-bf16` to transform customer support tickets into structured internal bug reports.

## Overview
- **Model**: Llama-3.2-1B-Instruct (~1B parameters)
- **Task**: Convert customer tickets → internal bug reports with severity, owner, and investigation steps
- **Method**: LoRA (Low-Rank Adaptation)
- **Data Split**: 10 train / 1 valid / 1 test per category (stratified)

## Hardware
- MacBook Pro M4 Max
- 40-core GPU
- 64GB unified memory

## 1. Setup and Configuration

In [1]:
import json
import random
from pathlib import Path
from collections import defaultdict

# Set random seed for reproducibility
random.seed(42)

print("Setup complete!")

Setup complete!


In [2]:
# Configuration
MODEL_NAME = "mlx-community/Llama-3.2-1B-Instruct-bf16"
DATA_DIR = Path("data")
ADAPTER_PATH = Path("adapters")
SOURCE_FILE = "source_data.jsonl"

# Create directories
DATA_DIR.mkdir(exist_ok=True)
ADAPTER_PATH.mkdir(exist_ok=True)

print(f"Model: {MODEL_NAME}")
print(f"Data directory: {DATA_DIR}")
print(f"Adapter path: {ADAPTER_PATH}")

Model: mlx-community/Llama-3.2-1B-Instruct-bf16
Data directory: data
Adapter path: adapters


## 2. Load and Preview Data

In [4]:
# Load source data
data = []
with open(SOURCE_FILE, "r") as f:
    for line in f:
        line = line.strip()
        if line:  # Skip empty lines
            data.append(json.loads(line))

print(f"Loaded {len(data)} records")

# Show category distribution
category_counts = defaultdict(int)
for entry in data:
    category_counts[entry["category"]] += 1

print("\nCategory distribution:")
for category, count in sorted(category_counts.items()):
    print(f"  {category}: {count}")

Loaded 72 records

Category distribution:
  auth_p1: 12
  auth_p2: 12
  auth_p3: 12
  data_platform_p1: 12
  data_platform_p2: 12
  data_platform_p3: 12


In [5]:
# Preview sample entries
print("Sample entries:")
for entry in random.sample(data, 2):
    print("=" * 60)
    print(f"Category: {entry['category']}")
    for msg in entry["messages"]:
        content_preview = msg["content"][:200] + "..." if len(msg["content"]) > 200 else msg["content"]
        print(f"{msg['role'].upper()}: {content_preview}")

Sample entries:
Category: data_platform_p2
USER: Ticket #8045
Customer: jennifer.hartley@novanexus.com
Plan: Enterprise
Issue: The filtering on our analytics dashboard is extremely laggy. Every time I change a filter dropdown it takes 20-30 seconds ...
ASSISTANT: ## Internal Bug Report

Severity: P2 (service degraded)

Owner: Data Platform Team

Summary: Dashboard filter interactions experiencing 20-30s latency since Monday

Probable Cause: Filter queries hitt...
Category: data_platform_p1
USER: Ticket #8156
Customer: tom.wright@pegasus-logistics.com
Plan: Personal
Issue: I can't export any data at all. Every export just spins forever then times out. I have a shipment manifest due to my bigge...
ASSISTANT: ## Internal Bug Report

Severity: P1 (business-critical)

Owner: Data Platform Team

Summary: Data exports timing out, blocking time-critical client deliverable

Probable Cause: Snowflake ANALYTICS_WH...


## 3. Stratified Data Split

Split data by category: 10 train / 1 validation / 1 test per category.

This ensures each split has balanced representation across:
- Teams (Data Platform, Auth)
- Severity levels (P1, P2, P3)

In [6]:
# Group by category
by_category = defaultdict(list)
for entry in data:
    by_category[entry["category"]].append(entry)

# Stratified split: 10 train, 1 valid, 1 test per category
train_data, valid_data, test_data = [], [], []

for category, items in sorted(by_category.items()):
    random.shuffle(items)
    train_data.extend(items[:10])
    valid_data.append(items[10])
    test_data.append(items[11])
    print(f"{category}: {len(items[:10])} train, 1 valid, 1 test")

print(f"\nTotal: {len(train_data)} train, {len(valid_data)} valid, {len(test_data)} test")

auth_p1: 10 train, 1 valid, 1 test
auth_p2: 10 train, 1 valid, 1 test
auth_p3: 10 train, 1 valid, 1 test
data_platform_p1: 10 train, 1 valid, 1 test
data_platform_p2: 10 train, 1 valid, 1 test
data_platform_p3: 10 train, 1 valid, 1 test

Total: 60 train, 6 valid, 6 test


In [11]:
# Summarize the train_data, valid_data, test_data, grouped by category
def count_by_category(entries):
    counts = defaultdict(int)
    for entry in entries:
        counts[entry["category"]] += 1
    return counts

train_counts = count_by_category(train_data)
valid_counts = count_by_category(valid_data)
test_counts = count_by_category(test_data)

print("Data split summary by category:\n")
print(f"{'Category':<20} {'Train':>6} {'Valid':>6} {'Test':>6}")
print("-" * 42)
for category in sorted(set(train_counts) | set(valid_counts) | set(test_counts)):
    print(f"{category:<20} {train_counts[category]:>6} {valid_counts[category]:>6} {test_counts[category]:>6}")
print("-" * 42)
print(f"{'Total':<20} {len(train_data):>6} {len(valid_data):>6} {len(test_data):>6}")


Data split summary by category:

Category              Train  Valid   Test
------------------------------------------
auth_p1                  10      1      1
auth_p2                  10      1      1
auth_p3                  10      1      1
data_platform_p1         10      1      1
data_platform_p2         10      1      1
data_platform_p3         10      1      1
------------------------------------------
Total                    60      6      6


In [7]:
# Format for MLX-LM: remove category field, keep only messages
def format_for_mlx(entries):
    return [{"messages": entry["messages"]} for entry in entries]

train_formatted = format_for_mlx(train_data)
valid_formatted = format_for_mlx(valid_data)
test_formatted = format_for_mlx(test_data)

print("Formatted data for MLX-LM")
print("\nExample entry:")
print(json.dumps(train_formatted[0], indent=2)[:500] + "...")

Formatted data for MLX-LM

Example entry:
{
  "messages": [
    {
      "role": "user",
      "content": "Ticket #6643\nCustomer: support@quickship.logistics\nPlan: Enterprise\nIssue: Warehouse team reporting authentication failures across all 5 distribution centres. Drivers can't get manifests. We have 200+ shipments stuck. Error: \"session store unavailable\". CRITICAL."
    },
    {
      "role": "assistant",
      "content": "## Internal Bug Report\n\nSeverity: P1 (business-critical)\n\nOwner: Auth Team\n\nSummary: Multi-site authen...


In [8]:
# Save splits to JSONL files
splits = {
    "train.jsonl": train_formatted,
    "valid.jsonl": valid_formatted,
    "test.jsonl": test_formatted,
}

for filename, split_data in splits.items():
    filepath = DATA_DIR / filename
    with open(filepath, "w") as f:
        for entry in split_data:
            f.write(json.dumps(entry) + "\n")
    print(f"Saved {filepath} ({len(split_data)} records)")

print("\nData preprocessing complete!")

Saved data/train.jsonl (60 records)
Saved data/valid.jsonl (6 records)
Saved data/test.jsonl (6 records)

Data preprocessing complete!


## 4. Fine-Tuning with LoRA

### Hyperparameters:
- **batch_size**: 6 — Number of samples processed per training iteration (matches validation set size).
- **learning_rate**: 2e-5 — Step size for gradient descent optimization.
- **lr_schedule**: cosine_decay with 10% warmup — Learning rate warms up then decays following a cosine curve.
- **iters**: 250 — Total number of training iterations (~25 epochs over 60 samples).
- **num_layers**: 16 — Number of transformer layers to apply LoRA adapters to.
- **lora_rank**: 16 — Rank of the low-rank decomposition matrices.
- **lora_scale**: 2.0 — Scaling factor applied to LoRA outputs (alpha/rank).
- **lora_dropout**: 0.1 — Dropout probability for regularization.
- **val_batches**: 1 — Number of batches used for validation evaluation.
- **steps_per_eval**: 50 — Validation every 50 iterations.
- **save_every**: 250 — Checkpoint saving frequency in iterations.

In [4]:
from pathlib import Path

# Training configuration is defined in lora_config.yaml
CONFIG_PATH = Path("lora_config.yaml")

print("Training configuration (from lora_config.yaml):")
print("-" * 50)
print(CONFIG_PATH.read_text())

Training configuration (from lora_config.yaml):
--------------------------------------------------
# MLX-LM LoRA Fine-Tuning Configuration

model: mlx-community/Llama-3.2-1B-Instruct-bf16
data: data
train: true
fine_tune_type: lora
adapter_path: adapters

# Training hyperparameters
batch_size: 6
num_layers: 16
iters: 250
learning_rate: 2.0e-05
val_batches: 1
save_every: 250
steps_per_eval: 50

# Learning rate schedule with warmup
lr_schedule:
  name: cosine_decay
  warmup: 25
  arguments: [2.0e-05, 225, 0.0]

# LoRA parameters
lora_parameters:
  rank: 16
  scale: 2.0
  dropout: 0.1



In [5]:
# Run LoRA fine-tuning
import subprocess

print("Starting LoRA fine-tuning...\n")

cmd = [
    "python", "-m", "mlx_lm", "lora",
    "--config", str(CONFIG_PATH),
]

print(f"Command: {' '.join(cmd)}\n")

result = subprocess.run(cmd, capture_output=False, text=True)

if result.returncode == 0:
    print("\nFine-tuning completed successfully!")
else:
    print(f"\nFine-tuning failed with return code {result.returncode}")

Starting LoRA fine-tuning...

Command: python -m mlx_lm lora --config lora_config.yaml

Loading configuration file lora_config.yaml
Loading pretrained model


Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 16090.68it/s]


Loading datasets
Training
Trainable parameters: 0.912% (11.272M/1235.814M)
Starting training..., iters: 250


Calculating loss...: 100%|██████████| 1/1 [00:00<00:00,  1.53it/s]


Iter 1: Val loss 4.389, Val took 0.656s
Iter 10: Train loss 4.273, Learning Rate 7.200e-06, It/sec 1.074, Tokens/sec 1392.033, Trained Tokens 12964, Peak mem 10.412 GB
Iter 20: Train loss 3.222, Learning Rate 1.520e-05, It/sec 1.424, Tokens/sec 1846.481, Trained Tokens 25928, Peak mem 10.441 GB
Iter 30: Train loss 2.272, Learning Rate 1.999e-05, It/sec 1.347, Tokens/sec 1745.815, Trained Tokens 38892, Peak mem 10.441 GB
Iter 40: Train loss 1.832, Learning Rate 1.984e-05, It/sec 1.325, Tokens/sec 1717.778, Trained Tokens 51856, Peak mem 10.441 GB


Calculating loss...: 100%|██████████| 1/1 [00:00<00:00,  2.71it/s]


Iter 50: Val loss 1.707, Val took 0.372s
Iter 50: Train loss 1.596, Learning Rate 1.949e-05, It/sec 1.317, Tokens/sec 1706.947, Trained Tokens 64820, Peak mem 10.441 GB
Iter 60: Train loss 1.418, Learning Rate 1.896e-05, It/sec 1.303, Tokens/sec 1688.755, Trained Tokens 77784, Peak mem 10.441 GB
Iter 70: Train loss 1.258, Learning Rate 1.825e-05, It/sec 1.297, Tokens/sec 1681.050, Trained Tokens 90748, Peak mem 10.441 GB
Iter 80: Train loss 1.110, Learning Rate 1.738e-05, It/sec 1.284, Tokens/sec 1665.097, Trained Tokens 103712, Peak mem 10.441 GB
Iter 90: Train loss 0.959, Learning Rate 1.637e-05, It/sec 1.313, Tokens/sec 1702.081, Trained Tokens 116676, Peak mem 10.441 GB


Calculating loss...: 100%|██████████| 1/1 [00:00<00:00,  2.46it/s]


Iter 100: Val loss 1.779, Val took 0.409s
Iter 100: Train loss 0.811, Learning Rate 1.524e-05, It/sec 1.291, Tokens/sec 1674.018, Trained Tokens 129640, Peak mem 10.441 GB
Iter 110: Train loss 0.675, Learning Rate 1.400e-05, It/sec 1.277, Tokens/sec 1655.673, Trained Tokens 142604, Peak mem 10.441 GB
Iter 120: Train loss 0.557, Learning Rate 1.269e-05, It/sec 1.253, Tokens/sec 1624.440, Trained Tokens 155568, Peak mem 10.441 GB
Iter 130: Train loss 0.442, Learning Rate 1.132e-05, It/sec 1.293, Tokens/sec 1675.905, Trained Tokens 168532, Peak mem 10.441 GB
Iter 140: Train loss 0.356, Learning Rate 9.930e-06, It/sec 1.298, Tokens/sec 1682.644, Trained Tokens 181496, Peak mem 10.441 GB


Calculating loss...: 100%|██████████| 1/1 [00:00<00:00,  2.61it/s]


Iter 150: Val loss 2.251, Val took 0.385s
Iter 150: Train loss 0.297, Learning Rate 8.539e-06, It/sec 1.291, Tokens/sec 1673.205, Trained Tokens 194460, Peak mem 10.441 GB
Iter 160: Train loss 0.250, Learning Rate 7.177e-06, It/sec 1.223, Tokens/sec 1585.141, Trained Tokens 207424, Peak mem 10.441 GB
Iter 170: Train loss 0.221, Learning Rate 5.869e-06, It/sec 1.278, Tokens/sec 1657.095, Trained Tokens 220388, Peak mem 10.441 GB
Iter 180: Train loss 0.197, Learning Rate 4.642e-06, It/sec 1.281, Tokens/sec 1660.720, Trained Tokens 233352, Peak mem 10.441 GB
Iter 190: Train loss 0.182, Learning Rate 3.519e-06, It/sec 1.250, Tokens/sec 1621.011, Trained Tokens 246316, Peak mem 10.441 GB


Calculating loss...: 100%|██████████| 1/1 [00:00<00:00,  2.46it/s]


Iter 200: Val loss 2.634, Val took 0.408s
Iter 200: Train loss 0.173, Learning Rate 2.522e-06, It/sec 1.235, Tokens/sec 1601.067, Trained Tokens 259280, Peak mem 10.441 GB
Iter 210: Train loss 0.168, Learning Rate 1.671e-06, It/sec 1.234, Tokens/sec 1599.375, Trained Tokens 272244, Peak mem 10.441 GB
Iter 220: Train loss 0.163, Learning Rate 9.817e-07, It/sec 1.207, Tokens/sec 1564.180, Trained Tokens 285208, Peak mem 10.441 GB
Iter 230: Train loss 0.162, Learning Rate 4.681e-07, It/sec 1.164, Tokens/sec 1509.533, Trained Tokens 298172, Peak mem 10.441 GB
Iter 240: Train loss 0.161, Learning Rate 1.400e-07, It/sec 1.130, Tokens/sec 1465.276, Trained Tokens 311136, Peak mem 10.441 GB


Calculating loss...: 100%|██████████| 1/1 [00:00<00:00,  2.23it/s]


Iter 250: Val loss 2.695, Val took 0.451s
Iter 250: Train loss 0.160, Learning Rate 3.899e-09, It/sec 1.105, Tokens/sec 1432.721, Trained Tokens 324100, Peak mem 10.441 GB
Iter 250: Saved adapter weights to adapters/adapters.safetensors and adapters/0000250_adapters.safetensors.
Saved final weights to adapters/adapters.safetensors.

Fine-tuning completed successfully!
