<a href="https://colab.research.google.com/github/mgfrantz/CalTech-CTME-AramCo-2025/blob/main/notebooks/02_fine_tune_lora.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Part 1: Dataset Preparation

Previously, we created a dataset of question/sql query pairs and validated them against actual SQLite databases.
Now that we have this data, our goal is to fine-tune a model to accept a sql schema and a quesiton and output a sqlite query.

The first step in this process is preparing our dataset.
While there are many [dataset formats](https://docs.axolotl.ai/docs/dataset-formats/), it's important to use the **same dataset format your base model was trained on.**
So for example if you have a chat model, you should make sure your dataset is formatted for chat.
This ensures you're not learning a new dataset format along with your task.
A good way to do that is to look at the axolotl [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples) - there are samples for most any model family you'd want to fine tune with sensible defaults.

In our case, we will output a `.jsonl` file where each line contains the folowing information:

```json
{
    "conversations": [
        {
            "role": <system, user, assistant>,
            "content": "some content"
        },
        ...
    ]
}
```

We will follow this general format:

- Message 1: our system prompt with general instructions
- Message 2: our user message with database schema and user query
- Message 3: our validated sql output (what we want to tune the model to output)

This format is useful for almost any chat model, even proprietary for use in proprietarty fine-tuning ([openai docs](https://platform.openai.com/docs/guides/supervised-fine-tuning?formatting=jsonl)).
This means you can prepare your data once, and fine-tune and compare many different models.

Let's start by loading and observing our dataset.

In [None]:
import json

import duckdb
import pandas as pd
from IPython.display import display
from tqdm.auto import tqdm

In [None]:
# If running this on colab, make sure to upload the parquet file and update the path.
df = pd.read_parquet("../output/validated_dataset.parquet")

In [None]:
df.head()

Unnamed: 0,db_path,question,query,answer,is_valid,error
0,/Users/michaelfrantz/dev/CalTech-CTME-AramCo-2...,How many orders were placed in the most recent...,SELECT COUNT(*) FROM orders WHERE order_date >...,[[20]],True,
1,/Users/michaelfrantz/dev/CalTech-CTME-AramCo-2...,What is the total amount spent by customers in...,SELECT SUM(total_amount) FROM orders WHERE ord...,[[875.9686307710034]],True,
2,/Users/michaelfrantz/dev/CalTech-CTME-AramCo-2...,Who are the top 5 customers by the number of o...,"SELECT c.first_name, c.last_name, COUNT(o.id) ...","[[""Michael"", ""Brown"", 3], [""Timothy"", ""Wilson""...",True,
3,/Users/michaelfrantz/dev/CalTech-CTME-AramCo-2...,Which restaurant received the highest number o...,"SELECT r.name, COUNT(o.id) as order_count FROM...","[[""Baker-Martin"", 2]]",True,
4,/Users/michaelfrantz/dev/CalTech-CTME-AramCo-2...,What is the average order amount for completed...,SELECT AVG(total_amount) FROM orders WHERE sta...,[[51.165547081963425]],True,


Now, let's create some utilities to get the schema of the database as a string so we can pass it to the model.

In [None]:
# Utilities
duckdb.execute("""
INSTALL sqlite;
LOAD sqlite;
""")

def query_sqlite(query:str, db_path:str) -> pd.DataFrame:
    conn = duckdb.connect(db_path)
    return conn.execute(query).fetch_df()

In [None]:
path = df.db_path.iloc[0]
print(path)

/Users/michaelfrantz/dev/CalTech-CTME-AramCo-2025/output/db/database_0.db


In [None]:
tables = query_sqlite('SHOW TABLES', path)
tables

Unnamed: 0,name
0,customers
1,drivers
2,orders
3,restaurants


In [None]:
for table in tables.name:
    print(table)
    display(t:=query_sqlite(f'DESCRIBE TABLE {table}', path).dropna(how='all', axis=1))

customers


Unnamed: 0,column_name,column_type,null,key
0,id,BIGINT,NO,PRI
1,first_name,VARCHAR,NO,
2,last_name,VARCHAR,NO,
3,email,VARCHAR,NO,
4,phone_number,VARCHAR,NO,


drivers


Unnamed: 0,column_name,column_type,null,key
0,id,BIGINT,NO,PRI
1,first_name,VARCHAR,NO,
2,last_name,VARCHAR,NO,
3,phone_number,VARCHAR,NO,
4,license_number,VARCHAR,NO,


orders


Unnamed: 0,column_name,column_type,null,key
0,id,BIGINT,NO,PRI
1,customer_id,BIGINT,NO,
2,restaurant_id,BIGINT,NO,
3,order_date,TIMESTAMP,NO,
4,total_amount,DOUBLE,NO,
5,status,VARCHAR,NO,


restaurants


Unnamed: 0,column_name,column_type,null,key
0,id,BIGINT,NO,PRI
1,name,VARCHAR,NO,
2,address,VARCHAR,NO,
3,phone_number,VARCHAR,NO,


In [None]:
def table_metadata(db_path:str) -> dict[str, pd.DataFrame]:
    """
    Get metadata for all tables in the database.
    """
    tables = query_sqlite('SHOW TABLES', db_path)
    metadata = {}
    for table in tables.name:
        metadata[table] = query_sqlite(f'DESCRIBE TABLE {table}', db_path).dropna(how='all', axis=1)
    return metadata

In [None]:
def format_schema_for_chat(metadata: dict[str, pd.DataFrame]) -> str:
    """
    Format table metadata into a readable string for chat training.
    """
    schema_lines = []

    for table_name, df in metadata.items():
        schema_lines.append(f"Table: {table_name}")

        for _, row in df.iterrows():
            col_info = f"  - {row['column_name']} ({row['column_type']}"
            if row['key'] == 'PRI':
                col_info += ", PRIMARY KEY"
            elif pd.notna(row['key']) and row['key'] != 'None':
                col_info += f", {row['key']}"
            if row['null'] == 'NO':
                col_info += ", NOT NULL"
            col_info += ")"
            schema_lines.append(col_info)

        schema_lines.append("")  # Add blank line between tables

    return "\n".join(schema_lines).strip()

def create_chatml_dataset(df: pd.DataFrame) -> list[dict]:
    """
    Create ChatML conversations format dataset directly from DataFrame.
    Includes system prompt with SQLite constraints.
    """
    # SQLite-focused system prompt based on constraints from prompts.py
    system_prompt = """You are an expert SQL developer specializing in SQLite. Generate accurate SQL queries that follow these requirements:

SQLITE REQUIREMENTS:
1. Use only standard SQLite-compatible SQL syntax
2. Avoid PostgreSQL-specific functions like INTERVAL - use date arithmetic instead
3. Use SQLite date functions: date(), datetime(), julianday() for date calculations
4. For "recent month" use: WHERE date_column >= date((SELECT MAX(date_column) FROM table_name), '-1 month')
5. For "past 30 days" use: WHERE date_column >= date((SELECT MAX(date_column) FROM table_name), '-30 days')
6. For rolling averages, use window functions or subqueries
7. Use MAX(appropriate_date_column) to find the most recent date
8. Use proper SQLite column types: INTEGER, TEXT, REAL, BLOB
9. Handle foreign key relationships correctly with proper JOIN syntax

QUERY STANDARDS:
- Write clear, efficient queries
- Use appropriate aggregation functions (COUNT, SUM, AVG, MAX, MIN)
- Include proper GROUP BY and ORDER BY clauses when needed
- Use table aliases for readability in complex queries"""

    chatml_data = []

    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Creating ChatML dataset"):
        # Skip invalid entries
        if not row['is_valid']:
            continue

        # Get table metadata for this database
        metadata = table_metadata(row['db_path'])
        schema_text = format_schema_for_chat(metadata)

        # Create user message with schema and question
        user_content = f"Database Schema:\n{schema_text}\n\nQuestion: {row['question']}"

        # Create ChatML conversation with system prompt
        chatml_entry = {
            "conversations": [
                {
                    "role": "system",
                    "content": system_prompt
                },
                {
                    "role": "user",
                    "content": user_content
                },
                {
                    "role": "assistant",
                    "content": row['query']
                }
            ]
        }

        chatml_data.append(chatml_entry)

    return chatml_data

Now that we've created our utilities, let's complete our chatml dataset.

In [None]:
# Create the ChatML format dataset for Axolotl
print("Creating ChatML format dataset for Axolotl...")
chatml_dataset = create_chatml_dataset(df)

print(f"Created {len(chatml_dataset)} ChatML training examples")

# Save to JSONL file for Axolotl
chatml_output_file = "chatml_training_data.jsonl"
with open(chatml_output_file, 'w') as f:
    for entry in chatml_dataset:
        f.write(json.dumps(entry) + '\n')

print(f"Saved ChatML training data to {chatml_output_file}")

# Display a sample ChatML entry
print("\nSample ChatML training entry:")
print(json.dumps(chatml_dataset[0], indent=2))

Creating ChatML format dataset for Axolotl...


Creating ChatML dataset:   0%|          | 0/151 [00:00<?, ?it/s]

Created 150 ChatML training examples
Saved ChatML training data to chatml_training_data.jsonl

Sample ChatML training entry:
{
  "conversations": [
    {
      "role": "system",
      "content": "You are an expert SQL developer specializing in SQLite. Generate accurate SQL queries that follow these requirements:\n\nSQLITE REQUIREMENTS:\n1. Use only standard SQLite-compatible SQL syntax\n2. Avoid PostgreSQL-specific functions like INTERVAL - use date arithmetic instead\n3. Use SQLite date functions: date(), datetime(), julianday() for date calculations\n4. For \"recent month\" use: WHERE date_column >= date((SELECT MAX(date_column) FROM table_name), '-1 month')\n5. For \"past 30 days\" use: WHERE date_column >= date((SELECT MAX(date_column) FROM table_name), '-30 days')\n6. For rolling averages, use window functions or subqueries\n7. Use MAX(appropriate_date_column) to find the most recent date\n8. Use proper SQLite column types: INTEGER, TEXT, REAL, BLOB\n9. Handle foreign key relation

In [1]:
# Note: must be connected to at least an L4 gpu (colab or other) to use flash-attn
!uv pip install -Uqqq packaging setuptools wheel ninja
!uv pip install -qqq --no-build-isolation "axolotl[flash-attn,deepspeed]"
# !uv pip install -qqq "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"
# !uv pip install -qqq "git+https://github.com/linkedin/Liger-Kernel.git"

In [2]:
import os
from google.colab import userdata
os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')

In [7]:
%%writefile axolotl.yaml
base_model: meta-llama/Llama-3.2-3B

load_in_8bit: true
load_in_4bit: false
strict: false
adapter: lora
# Added for bitsandbytes configuration
# bnb_4bit_use_double_quant: true


# Data config
chat_template: llama3
datasets:
  - path: chatml_training_data.jsonl
    type: chat_template
    field_messages: conversations
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out

sequence_len: 2048
sample_packing: true
eval_sample_packing: false # with a larger eval dataset, we would do this, but we don't have a large enough one today.
pad_to_sequence_len: true

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
  - gate_proj
  - down_proj
  - up_proj
  - q_proj
  - v_proj
  - k_proj
  - o_proj

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  pad_token: <|finetune_right_pad_id|>
  eos_token: <|eot_id|>

# Added for saving best checkpoint and pushing to Hugging Face Hub
save_only_k_checkpoints: 1
save_total_limit: 1
load_best_model_at_end: true
metric_for_best_model: eval_loss # Or any other metric you want to track
greater_is_better: false # True if the metric should be maximized, False if minimized

# Push to Hugging Face Hub
# push_to_hub: false
# hub_model_id: your_huggingface_username/your_model_name # Replace with your desired repo ID
# hub_private_repo: false # Set to true if you want a private repo
# hub_always_push: false
# hub_strategy: every_save # "end" to push only at the end of training

Overwriting axolotl.yaml


In [8]:
!axolotl train axolotl.yaml

2025-06-25 18:57:17.341228: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-06-25 18:57:17.358318: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750877837.380312    2302 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750877837.387172    2302 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-25 18:57:17.408995: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [20]:
# [optional] merge weights
# !python -m axolotl.cli.merge_lora axolotl.yaml

In [9]:
!rm -r outputs/lora-out/checkpoint-*

In [10]:
# [optional] upload to hf hub
import os
username = 'mgfrantz'
repo_name = 'cmte-demo'
if not username or not repo_name:
    username = input("Username: ")
    repo_name = input("Repo name: ")
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
!huggingface-cli upload {username}/{repo_name} ./outputs/lora-out/ .

Start hashing 8 files.
Finished hashing 8 files.
  0% 0/2 [00:00<?, ?it/s]
adapter_model.safetensors:   0% 0.00/195M [00:00<?, ?B/s][A
adapter_model.safetensors:   8% 16.0M/195M [00:00<00:08, 19.9MB/s][A
adapter_model.safetensors:  16% 32.0M/195M [00:01<00:04, 34.9MB/s][A
adapter_model.safetensors:  25% 48.0M/195M [00:01<00:03, 42.9MB/s][A
adapter_model.safetensors:  33% 64.0M/195M [00:01<00:02, 59.5MB/s][A
adapter_model.safetensors:  41% 80.0M/195M [00:01<00:01, 67.1MB/s][A
adapter_model.safetensors:  58% 112M/195M [00:01<00:00, 108MB/s]  [A
adapter_model.safetensors:  66% 128M/195M [00:02<00:00, 77.5MB/s][A
adapter_model.safetensors:  74% 144M/195M [00:02<00:00, 84.1MB/s][A
adapter_model.safetensors:  90% 176M/195M [00:02<00:00, 119MB/s] [A
adapter_model.safetensors:  99% 192M/195M [00:02<00:00, 62.9MB/s][A
adapter_model.safetensors: 208MB [00:03, 63.2MB/s]
 50% 1/2 [00:03<00:03,  3.66s/it]
tokenizer.json:   0% 0.00/17.2M [00:00<?, ?B/s][A
tokenizer.json:  93% 16.0M/17.2M