<a href="https://colab.research.google.com/github/mgfrantz/CalTech-CTME-AramCo-2025/blob/main/notebooks/02_fine_tune_lora.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Part 1: Dataset Preparation

Previously, we created a dataset of question/sql query pairs and validated them against actual SQLite databases.
Now that we have this data, our goal is to fine-tune a model to accept a sql schema and a quesiton and output a sqlite query.

The first step in this process is preparing our dataset.
While there are many [dataset formats](https://docs.axolotl.ai/docs/dataset-formats/), it's important to use the **same dataset format your base model was trained on.**
So for example if you have a chat model, you should make sure your dataset is formatted for chat.
This ensures you're not learning a new dataset format along with your task.
A good way to do that is to look at the axolotl [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples) - there are samples for most any model family you'd want to fine tune with sensible defaults.

In our case, we will output a `.jsonl` file where each line contains the folowing information:

```json
{
    "conversations": [
        {
            "role": <system, user, assistant>,
            "content": "some content"
        },
        ...
    ]
}
```

We will follow this general format:

- Message 1: our system prompt with general instructions
- Message 2: our user message with database schema and user query
- Message 3: our validated sql output (what we want to tune the model to output)

This format is useful for almost any chat model, even proprietary for use in proprietarty fine-tuning ([openai docs](https://platform.openai.com/docs/guides/supervised-fine-tuning?formatting=jsonl)).
This means you can prepare your data once, and fine-tune and compare many different models.

Let's start by loading and observing our dataset.

In [None]:
import json

import duckdb
import pandas as pd
from IPython.display import display
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm

In [None]:
# If running this on colab, make sure to upload the parquet file and update the path.
df = pd.read_parquet("../output/validated_dataset.parquet")

In [None]:
df.head()

Now, let's create some utilities to get the schema of the database as a string so we can pass it to the model.

In [None]:
# Utilities
duckdb.execute("""
INSTALL sqlite;
LOAD sqlite;
""")


def query_sqlite(query: str, db_path: str) -> pd.DataFrame:
    conn = duckdb.connect(db_path)
    return conn.execute(query).fetch_df()

In [None]:
path = df.db_path.iloc[0]
print(path)

In [None]:
tables = query_sqlite("SHOW TABLES", path)
tables

In [None]:
for table in tables.name:
    print(table)
    display(
        t := query_sqlite(f"DESCRIBE TABLE {table}", path).dropna(how="all", axis=1)
    )

In [None]:
def table_metadata(db_path: str) -> dict[str, pd.DataFrame]:
    """
    Get metadata for all tables in the database.
    """
    tables = query_sqlite("SHOW TABLES", db_path)
    metadata = {}
    for table in tables.name:
        metadata[table] = query_sqlite(f"DESCRIBE TABLE {table}", db_path).dropna(
            how="all", axis=1
        )
    return metadata

In [None]:
def format_schema_for_chat(metadata: dict[str, pd.DataFrame]) -> str:
    """
    Format table metadata into a readable string for chat training.
    """
    schema_lines = []

    for table_name, df in metadata.items():
        schema_lines.append(f"Table: {table_name}")

        for _, row in df.iterrows():
            col_info = f"  - {row['column_name']} ({row['column_type']}"
            if row["key"] == "PRI":
                col_info += ", PRIMARY KEY"
            elif pd.notna(row["key"]) and row["key"] != "None":
                col_info += f", {row['key']}"
            if row["null"] == "NO":
                col_info += ", NOT NULL"
            col_info += ")"
            schema_lines.append(col_info)

        schema_lines.append("")  # Add blank line between tables

    return "\n".join(schema_lines).strip()


def create_chatml_dataset(df: pd.DataFrame) -> list[dict]:
    """
    Create ChatML conversations format dataset directly from DataFrame.
    Includes system prompt with SQLite constraints.
    """
    # SQLite-focused system prompt based on constraints from prompts.py
    system_prompt = """You are an expert SQL developer specializing in SQLite. Generate accurate SQL queries that follow these requirements:

SQLITE REQUIREMENTS:
1. Use only standard SQLite-compatible SQL syntax
2. Avoid PostgreSQL-specific functions like INTERVAL - use date arithmetic instead
3. Use SQLite date functions: date(), datetime(), julianday() for date calculations
4. For "recent month" use: WHERE date_column >= date((SELECT MAX(date_column) FROM table_name), '-1 month')
5. For "past 30 days" use: WHERE date_column >= date((SELECT MAX(date_column) FROM table_name), '-30 days')
6. For rolling averages, use window functions or subqueries
7. Use MAX(appropriate_date_column) to find the most recent date
8. Use proper SQLite column types: INTEGER, TEXT, REAL, BLOB
9. Handle foreign key relationships correctly with proper JOIN syntax

QUERY STANDARDS:
- Write clear, efficient queries
- Use appropriate aggregation functions (COUNT, SUM, AVG, MAX, MIN)
- Include proper GROUP BY and ORDER BY clauses when needed
- Use table aliases for readability in complex queries"""

    chatml_data = []

    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Creating ChatML dataset"):
        # Skip invalid entries
        if not row["is_valid"]:
            continue

        # Get table metadata for this database
        metadata = table_metadata(row["db_path"])
        schema_text = format_schema_for_chat(metadata)

        # Create user message with schema and question
        user_content = f"Database Schema:\n{schema_text}\n\nQuestion: {row['question']}"

        # Create ChatML conversation with system prompt
        chatml_entry = {
            "conversations": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_content},
                {"role": "assistant", "content": row["query"]},
            ]
        }

        chatml_data.append(chatml_entry)

    return chatml_data

Now that we've created our utilities, let's complete our chatml dataset.

In [None]:
df_filtered = df[df.is_valid]
train_df, test_df = train_test_split(
    df_filtered, test_size=0.1, random_state=42, stratify=df_filtered.db_path
)

In [None]:
# Create the ChatML format dataset for Axolotl
print("Creating ChatML format dataset for Axolotl...")
train_chatml_dataset = create_chatml_dataset(train_df)
eval_chatml_dataset = create_chatml_dataset(test_df)

print(f"Created {len(train_chatml_dataset)} ChatML training examples")
print(f"Created {len(eval_chatml_dataset)} ChatML evaluation examples")

# Save to JSONL file for Axolotl
train_chatml_output_file = "chatml_training_data.jsonl"
eval_chatml_output_file = "chatml_evaluation_data.jsonl"

with open(train_chatml_output_file, "w") as f:
    for entry in train_chatml_dataset:
        f.write(json.dumps(entry) + "\n")

print(f"Saved ChatML training data to {train_chatml_output_file}")

with open(eval_chatml_output_file, "w") as f:
    for entry in eval_chatml_dataset:
        f.write(json.dumps(entry) + "\n")


print(f"Saved ChatML evaluation data to {eval_chatml_output_file}")

# Display a sample ChatML entry
print("\nSample ChatML training entry:")
print(json.dumps(train_chatml_dataset[0], indent=2))

In [1]:
# Note: must be connected to at least an L4 gpu (colab or other) to use flash-attn
!uv pip install -Uqqq packaging setuptools wheel ninja
!uv pip install -qqq --no-build-isolation "axolotl[flash-attn,deepspeed]"

In [2]:
import os

from google.colab import userdata

os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")

In [3]:
import os

# Define file URLs and local filenames
urls = [
    "https://raw.githubusercontent.com/mgfrantz/CalTech-CTME-AramCo-2025/refs/heads/main/notebooks/chatml_evaluation_data.jsonl",
    "https://raw.githubusercontent.com/mgfrantz/CalTech-CTME-AramCo-2025/refs/heads/main/notebooks/chatml_training_data.jsonl",
]
filenames = [
    "chatml_evaluation_data.jsonl",
    "chatml_training_data.jsonl",
]

# Download files if they don't exist
for url, filename in zip(urls, filenames):
    if not os.path.exists(filename):
        print(f"Downloading {filename}...")
        !wget {url} -O {filename}
    else:
        print(f"{filename} already exists.")

Downloading chatml_evaluation_data.jsonl...
--2025-06-25 21:39:08--  https://raw.githubusercontent.com/mgfrantz/CalTech-CTME-AramCo-2025/refs/heads/main/notebooks/chatml_evaluation_data.jsonl
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 34558 (34K) [text/plain]
Saving to: ‘chatml_evaluation_data.jsonl’


2025-06-25 21:39:09 (34.7 MB/s) - ‘chatml_evaluation_data.jsonl’ saved [34558/34558]

Downloading chatml_training_data.jsonl...
--2025-06-25 21:39:09--  https://raw.githubusercontent.com/mgfrantz/CalTech-CTME-AramCo-2025/refs/heads/main/notebooks/chatml_training_data.jsonl
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.

In [9]:
%%writefile axolotl.yaml
base_model: NousResearch/Llama-3.2-1B

load_in_8bit: true
load_in_4bit: false
strict: false
adapter: lora
# Added for bitsandbytes configuration
# bnb_4bit_use_double_quant: true


# Data config
chat_template: llama3
datasets:
  - path: chatml_training_data.jsonl
    type: chat_template
    field_messages: conversations
dataset_prepared_path: last_run_prepared

test_datasets:
  - path: chatml_evaluation_data.jsonl
    type: chat_template
    field_messages: conversations
    split: train

output_dir: ./outputs/lora-out

sequence_len: 2048
sample_packing: true
eval_sample_packing: false # with a larger eval dataset, we would do this, but we don't have a large enough one today.
pad_to_sequence_len: true

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
  - gate_proj
  - down_proj
  - up_proj
  - q_proj
  - v_proj
  - k_proj
  - o_proj

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  pad_token: <|finetune_right_pad_id|>
  eos_token: <|eot_id|>

# Added for saving best checkpoint and pushing to Hugging Face Hub
save_only_k_checkpoints: 1
save_total_limit: 1
load_best_model_at_end: true
metric_for_best_model: eval_loss # Or any other metric you want to track
greater_is_better: false # True if the metric should be maximized, False if minimized

# Push to Hugging Face Hub
# push_to_hub: false
# hub_model_id: your_huggingface_username/your_model_name # Replace with your desired repo ID
# hub_private_repo: false # Set to true if you want a private repo
# hub_always_push: false
# hub_strategy: every_save # "end" to push only at the end of training

Overwriting axolotl.yaml


In [10]:
# Prepare the datasets
!axolotl preprocess axolotl.yaml

2025-06-25 21:43:59.904258: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-06-25 21:43:59.921627: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750887839.942580    6415 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750887839.949010    6415 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-25 21:43:59.970356: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [11]:
!axolotl train axolotl.yaml

2025-06-25 21:46:38.548538: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-06-25 21:46:38.565938: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750887998.587018    8159 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750887998.593477    8159 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-25 21:46:38.614836: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [None]:
# [optional] merge weights
# !python -m axolotl.cli.merge_lora axolotl.yaml

In [12]:
# [optional] upload to hf hub
import os

username = "mgfrantz"
repo_name = "NousResearch-Llama-3.2-1B-ctme-sql-demo"
if not username or not repo_name:
    username = input("Username: ")
    repo_name = input("Repo name: ")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
!rm -r outputs/lora-out/checkpoint-*
!huggingface-cli upload {username}/{repo_name} ./outputs/lora-out/ .

Start hashing 8 files.
Finished hashing 8 files.
  0% 0/2 [00:00<?, ?it/s]
adapter_model.safetensors:   0% 0.00/90.2M [00:00<?, ?B/s][A
adapter_model.safetensors:  18% 16.0M/90.2M [00:00<00:03, 20.5MB/s][A
adapter_model.safetensors:  53% 48.0M/90.2M [00:00<00:00, 63.0MB/s][A
adapter_model.safetensors:  71% 64.0M/90.2M [00:01<00:00, 55.8MB/s][A
adapter_model.safetensors:  89% 80.0M/90.2M [00:01<00:00, 56.5MB/s][A
adapter_model.safetensors: 96.0MB [00:02, 39.5MB/s]
 50% 1/2 [00:02<00:02,  2.81s/it]
tokenizer.json:   0% 0.00/17.2M [00:00<?, ?B/s][A
tokenizer.json:  93% 16.0M/17.2M [00:00<00:00, 27.8MB/s][A
tokenizer.json: 32.0MB [00:02, 13.8MB/s]
100% 2/2 [00:05<00:00,  2.73s/it]
https://huggingface.co/mgfrantz/NousResearch-Llama-3.2-1B-ctme-sql-demo/tree/main/.
