<a href="https://colab.research.google.com/github/mgfrantz/CalTech-CTME-AramCo-2025/blob/main/notebooks/02_fine_tune_lora.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Part 1: Dataset Preparation

## What We're Doing

Previously, we created a dataset of question/SQL query pairs and validated them against actual SQLite databases. Now we'll **fine-tune a language model** to accept a SQL schema and a natural language question, then output a correct SQLite query.

## Why This Approach Works

Fine-tuning allows us to:
- Teach the model our specific database schemas and patterns
- Improve accuracy on domain-specific SQL queries
- Reduce hallucinations by training on validated examples

## Dataset Format Strategy

The first step is preparing our dataset in the **correct format**. We'll use ChatML format because:

1. **Consistency**: Use the same format your base model was trained on (most modern models use chat format)
2. **Portability**: This format works across multiple platforms (Axolotl, OpenAI fine-tuning, etc.)
3. **Structure**: Clear separation of system instructions, user input, and expected output

Our dataset structure:
```json
{
    "conversations": [
        {"role": "system", "content": "Expert SQL developer instructions..."},
        {"role": "user", "content": "Schema: [tables] Question: [user query]"},
        {"role": "assistant", "content": "SELECT * FROM table WHERE..."}
    ]
}
```

This 3-message pattern teaches the model:
- **System**: How to behave (SQL expert with SQLite constraints)
- **User**: What information it receives (schema + question)  
- **Assistant**: What we want it to output (correct SQL query)

In [None]:
import json

import duckdb
import pandas as pd
from IPython.display import display
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm

In [None]:
# If running this on colab, make sure to upload the parquet file and update the path.
df = pd.read_parquet("../output/validated_dataset.parquet")

In [None]:
df.head()

## Step 1: Schema Extraction

To train our model effectively, we need to provide it with database schema information in a readable format.

**Why schemas matter**: The model needs to understand:
- What tables exist in the database
- What columns are in each table  
- Data types and constraints (PRIMARY KEY, NOT NULL, etc.)
- Relationships between tables

Let's create utilities to extract and format this schema information from our SQLite databases.

In [None]:
# Utilities
duckdb.execute("""
INSTALL sqlite;
LOAD sqlite;
""")


def query_sqlite(query: str, db_path: str) -> pd.DataFrame:
    conn = duckdb.connect(db_path)
    return conn.execute(query).fetch_df()

In [None]:
path = df.db_path.iloc[0]
print(path)

In [None]:
tables = query_sqlite("SHOW TABLES", path)
tables

In [None]:
for table in tables.name:
    print(table)
    display(
        t := query_sqlite(f"DESCRIBE TABLE {table}", path).dropna(how="all", axis=1)
    )

In [None]:
def table_metadata(db_path: str) -> dict[str, pd.DataFrame]:
    """
    Get metadata for all tables in the database.
    """
    tables = query_sqlite("SHOW TABLES", db_path)
    metadata = {}
    for table in tables.name:
        metadata[table] = query_sqlite(f"DESCRIBE TABLE {table}", db_path).dropna(
            how="all", axis=1
        )
    return metadata

In [None]:
def format_schema_for_chat(metadata: dict[str, pd.DataFrame]) -> str:
    """
    Format table metadata into a readable string for chat training.
    """
    schema_lines = []

    for table_name, df in metadata.items():
        schema_lines.append(f"Table: {table_name}")

        for _, row in df.iterrows():
            col_info = f"  - {row['column_name']} ({row['column_type']}"
            if row["key"] == "PRI":
                col_info += ", PRIMARY KEY"
            elif pd.notna(row["key"]) and row["key"] != "None":
                col_info += f", {row['key']}"
            if row["null"] == "NO":
                col_info += ", NOT NULL"
            col_info += ")"
            schema_lines.append(col_info)

        schema_lines.append("")  # Add blank line between tables

    return "\n".join(schema_lines).strip()


def create_chatml_dataset(df: pd.DataFrame) -> list[dict]:
    """
    Create ChatML conversations format dataset directly from DataFrame.
    Includes system prompt with SQLite constraints.
    """
    # SQLite-focused system prompt based on constraints from prompts.py
    system_prompt = """You are an expert SQL developer specializing in SQLite. Generate accurate SQL queries that follow these requirements:

SQLITE REQUIREMENTS:
1. Use only standard SQLite-compatible SQL syntax
2. Avoid PostgreSQL-specific functions like INTERVAL - use date arithmetic instead
3. Use SQLite date functions: date(), datetime(), julianday() for date calculations
4. For "recent month" use: WHERE date_column >= date((SELECT MAX(date_column) FROM table_name), '-1 month')
5. For "past 30 days" use: WHERE date_column >= date((SELECT MAX(date_column) FROM table_name), '-30 days')
6. For rolling averages, use window functions or subqueries
7. Use MAX(appropriate_date_column) to find the most recent date
8. Use proper SQLite column types: INTEGER, TEXT, REAL, BLOB
9. Handle foreign key relationships correctly with proper JOIN syntax

QUERY STANDARDS:
- Write clear, efficient queries
- Use appropriate aggregation functions (COUNT, SUM, AVG, MAX, MIN)
- Include proper GROUP BY and ORDER BY clauses when needed
- Use table aliases for readability in complex queries"""

    chatml_data = []

    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Creating ChatML dataset"):
        # Skip invalid entries
        if not row["is_valid"]:
            continue

        # Get table metadata for this database
        metadata = table_metadata(row["db_path"])
        schema_text = format_schema_for_chat(metadata)

        # Create user message with schema and question
        user_content = f"Database Schema:\n{schema_text}\n\nQuestion: {row['question']}"

        # Create ChatML conversation with system prompt
        chatml_entry = {
            "conversations": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_content},
                {"role": "assistant", "content": row["query"]},
            ]
        }

        chatml_data.append(chatml_entry)

    return chatml_data

## Step 2: Creating the ChatML Dataset

Now we'll convert our validated question/query pairs into the ChatML format for fine-tuning.

**Key components of our dataset creation**:

1. **System Prompt**: Detailed instructions about SQLite syntax and constraints
2. **User Message**: Database schema + natural language question
3. **Assistant Response**: The validated SQL query

**Important notes**:
- We only include validated examples (`is_valid=True`)
- We stratify the train/test split by database to ensure both sets have representative samples
- Each conversation teaches the model one complete SQL generation task

In [None]:
df_filtered = df[df.is_valid]
train_df, test_df = train_test_split(
    df_filtered, test_size=0.1, random_state=42, stratify=df_filtered.db_path
)

In [None]:
# Create the ChatML format dataset for Axolotl
print("Creating ChatML format dataset for Axolotl...")
train_chatml_dataset = create_chatml_dataset(train_df)
eval_chatml_dataset = create_chatml_dataset(test_df)

print(f"Created {len(train_chatml_dataset)} ChatML training examples")
print(f"Created {len(eval_chatml_dataset)} ChatML evaluation examples")

# Save to JSONL file for Axolotl
train_chatml_output_file = "chatml_training_data.jsonl"
eval_chatml_output_file = "chatml_evaluation_data.jsonl"

with open(train_chatml_output_file, "w") as f:
    for entry in train_chatml_dataset:
        f.write(json.dumps(entry) + "\n")

print(f"Saved ChatML training data to {train_chatml_output_file}")

with open(eval_chatml_output_file, "w") as f:
    for entry in eval_chatml_dataset:
        f.write(json.dumps(entry) + "\n")


print(f"Saved ChatML evaluation data to {eval_chatml_output_file}")

# Display a sample ChatML entry
print("\nSample ChatML training entry:")
print(json.dumps(train_chatml_dataset[0], indent=2))

# Part 2: Fine-tuning with LoRA

## What is LoRA?

**LoRA (Low-Rank Adaptation)** is an efficient fine-tuning technique that:
- Only trains a small number of additional parameters (~1% of the original model)
- Maintains the original model weights unchanged
- Reduces memory requirements and training time
- Produces a small "adapter" that can be easily shared and applied

## Why Use Google Colab?

Fine-tuning requires significant computational resources:
- **GPU Memory**: Need at least 8GB+ GPU memory (L4, T4, V100, A100)
- **Flash Attention**: Speeds up training and reduces memory usage (L4 and A100 only)
- **Cost Effective**: Colab Pro gives access to powerful GPUs without infrastructure setup

## Training Strategy

We'll use **Axolotl**, a popular fine-tuning framework that:
- Handles data preprocessing automatically
- Supports various model architectures and adapters
- Provides sensible defaults for LoRA training
- Integrates with HuggingFace for easy model sharing

## Step 1: Install Dependencies

We need to install Axolotl with Flash Attention support. This installation:
- **Requires GPU**: Flash Attention only works on CUDA GPUs
- **Takes time**: Compiling Flash Attention can take 5-10 minutes
- **Memory intensive**: Make sure you have sufficient GPU memory

In [1]:
# Note: must be connected to at least an L4 gpu (colab or other) to use flash-attn
!uv pip install -Uqqq packaging setuptools wheel ninja
!uv pip install -qqq --no-build-isolation "axolotl[flash-attn,deepspeed]"

## Step 2: Setup Environment

Configure HuggingFace token for model access and downloading.

In [2]:
import os

from google.colab import userdata

os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")

## Step 3: Download Training Data

Download our prepared ChatML datasets from the repository.

In [3]:
import os

# Define file URLs and local filenames
urls = [
    "https://raw.githubusercontent.com/mgfrantz/CalTech-CTME-AramCo-2025/refs/heads/main/notebooks/chatml_evaluation_data.jsonl",
    "https://raw.githubusercontent.com/mgfrantz/CalTech-CTME-AramCo-2025/refs/heads/main/notebooks/chatml_training_data.jsonl",
]
filenames = [
    "chatml_evaluation_data.jsonl",
    "chatml_training_data.jsonl",
]

# Download files if they don't exist
for url, filename in zip(urls, filenames):
    if not os.path.exists(filename):
        print(f"Downloading {filename}...")
        !wget {url} -O {filename}
    else:
        print(f"{filename} already exists.")

Downloading chatml_evaluation_data.jsonl...
--2025-06-26 03:38:23--  https://raw.githubusercontent.com/mgfrantz/CalTech-CTME-AramCo-2025/refs/heads/main/notebooks/chatml_evaluation_data.jsonl
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 34558 (34K) [text/plain]
Saving to: ‘chatml_evaluation_data.jsonl’


2025-06-26 03:38:23 (27.8 MB/s) - ‘chatml_evaluation_data.jsonl’ saved [34558/34558]

Downloading chatml_training_data.jsonl...
--2025-06-26 03:38:23--  https://raw.githubusercontent.com/mgfrantz/CalTech-CTME-AramCo-2025/refs/heads/main/notebooks/chatml_training_data.jsonl
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.

## Step 4: Configure Training

This YAML configuration defines our training setup:

**Model Configuration**:
- `base_model`: Starting with Llama-3.2-1B (fast to train, good performance)
- `load_in_8bit`: Reduces memory usage for training

**LoRA Configuration**:
- `lora_r`: 32 (rank of adaptation - higher = more parameters but better learning)
- `lora_alpha`: 16 (scaling factor)
- `lora_target_modules`: Which layers to adapt (all attention and MLP layers)

**Training Configuration**:
- `num_epochs`: 10 (how many times to see the full dataset)
- `learning_rate`: 0.0002 (conservative rate to avoid overfitting)
- `micro_batch_size`: 2 (small batches to fit in GPU memory)
- `gradient_accumulation_steps`: 4 (effective batch size = 2 × 4 = 8)

**Monitoring**:
- `evals_per_epoch`: 4 (check validation loss 4 times per epoch)
- `saves_per_epoch`: 1 (save checkpoints once per epoch)

In [4]:
%%writefile axolotl.yaml
# Configure the base model and output directory
base_model: NousResearch/Llama-3.2-1B
output_dir: ./outputs/lora-out

# Lora configuration
load_in_8bit: true
load_in_4bit: false
strict: false
adapter: lora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
  - gate_proj
  - down_proj
  - up_proj
  - q_proj
  - v_proj
  - k_proj
  - o_proj


# Data configuration
chat_template: llama3
datasets:
  - path: chatml_training_data.jsonl
    type: chat_template
    field_messages: conversations
dataset_prepared_path: last_run_prepared

test_datasets:
  - path: chatml_evaluation_data.jsonl
    type: chat_template
    field_messages: conversations
    split: train

sequence_len: 2048
sample_packing: true
eval_sample_packing: false # with a larger eval dataset, we would do this, but we don't have a large enough one today.
pad_to_sequence_len: true

# [optional] weights and biases configuration
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

# Training hyperparameters
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 10
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

# Masking
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  pad_token: <|finetune_right_pad_id|>
  eos_token: <|eot_id|>

# Added for saving best checkpoint and pushing to Hugging Face Hub
save_only_k_checkpoints: 1
save_total_limit: 1
load_best_model_at_end: true
metric_for_best_model: eval_loss # Or any other metric you want to track
greater_is_better: false # True if the metric should be maximized, False if minimized

# Push to Hugging Face Hub
# push_to_hub: false
# hub_model_id: your_huggingface_username/your_model_name # Replace with your desired repo ID
# hub_private_repo: false # Set to true if you want a private repo
# hub_always_push: false
# hub_strategy: every_save # "end" to push only at the end of training

Writing axolotl.yaml


## Step 5: Preprocess Data

Axolotl preprocesses our JSONL files into its internal format for efficient training.

In [5]:
# Prepare the datasets
!axolotl preprocess axolotl.yaml

2025-06-26 03:40:00.602220: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-06-26 03:40:00.619150: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750909200.640372    3849 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750909200.646788    3849 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-26 03:40:00.668254: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

## Step 6: Start Training

This will begin the fine-tuning process. Expected behavior:

**Training Progress**:
- 10 epochs of training (may take 30-60 minutes on L4 GPU)
- Loss should decrease over time
- Validation evaluations every ~25% of each epoch

**What to Watch For**:
- **Training loss decreasing**: Model is learning
- **Validation loss stable/decreasing**: Not overfitting
- **GPU memory usage**: Should be stable throughout training

**Troubleshooting**:
- If OOM (Out of Memory): Reduce `micro_batch_size` to 1
- If loss not decreasing: Check data format and increase learning rate
- If overfitting: Reduce `num_epochs` or increase regularization

In [8]:
!axolotl train axolotl.yaml

2025-06-26 03:43:25.433874: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-06-26 03:43:25.450655: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750909405.471594    5929 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750909405.478031    5929 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-26 03:43:25.498684: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

## Step 7: Optional - Merge LoRA Weights

**What this does**: Combines the base model with LoRA adapter into a single model file

**When to use**:
- ✅ **For inference/deployment**: Single file is easier to use
- ✅ **For sharing**: Recipients don't need both base model + adapter
- ❌ **For further training**: Keep adapter separate for additional fine-tuning
- ❌ **With multiple adapters**: If you have multiple adapters you want to serve from the same base model, it's better to keep them separate.

**Trade-offs**:
- **Pros**: Simpler deployment, no adapter loading required
- **Cons**: Larger file size, can't easily swap adapters

In [None]:
# !python -m axolotl.cli.merge_lora axolotl.yaml

## Step 8: Upload to HuggingFace Hub

**Sharing your model**:
- Makes your fine-tuned model publicly available
- Others can use it with `transformers` library or other inference libraries like vLLM.
- Enables easy deployment and inference

**What gets uploaded**:
- LoRA adapter weights (`adapter_model.bin`)
- Configuration files (`adapter_config.json`)
- Training metadata and logs


In [9]:
# [optional] upload to hf hub
import os

username = "mgfrantz"
repo_name = "NousResearch-Llama-3.2-1B-ctme-sql-demo"
if not username:
    username = input("Username: ")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
!rm -r outputs/lora-out/checkpoint-* # we only want to push the final checkpoint since we loaded the best one back at the end
!huggingface-cli upload {username}/{repo_name} ./outputs/lora-out/ .

Start hashing 8 files.
Finished hashing 8 files.
  0% 0/1 [00:00<?, ?it/s]
adapter_model.safetensors:   0% 0.00/90.2M [00:00<?, ?B/s][A
adapter_model.safetensors:  18% 16.0M/90.2M [00:00<00:03, 21.0MB/s][A
adapter_model.safetensors:  53% 48.0M/90.2M [00:01<00:01, 27.3MB/s][A
adapter_model.safetensors:  71% 64.0M/90.2M [00:02<00:00, 34.2MB/s][A
adapter_model.safetensors: 96.0MB [00:03, 28.4MB/s]
100% 1/1 [00:03<00:00,  3.76s/it]
Removing 5 file(s) from commit that have not changed.
https://huggingface.co/mgfrantz/NousResearch-Llama-3.2-1B-ctme-sql-demo/tree/main/.
