In [11]:
import json

import duckdb
import pandas as pd
from IPython.display import display
from tqdm.auto import tqdm

In [12]:
df = pd.read_parquet("../output/validated_dataset.parquet")

In [None]:
df.head()

In [14]:
# Utilities
duckdb.execute("""
INSTALL sqlite;
LOAD sqlite;               
""")

def query_sqlite(query:str, db_path:str) -> pd.DataFrame:
    conn = duckdb.connect(db_path)
    return conn.execute(query).fetch_df()

In [None]:
path = df.db_path.iloc[0]
print(path)

In [None]:
tables = query_sqlite('SHOW TABLES', path)
tables

In [None]:
for table in tables.name:
    print(table)
    display(t:=query_sqlite(f'DESCRIBE TABLE {table}', path).dropna(how='all', axis=1))

In [18]:
def table_metadata(db_path:str) -> dict[str, pd.DataFrame]:
    """
    Get metadata for all tables in the database.
    """
    tables = query_sqlite('SHOW TABLES', db_path)
    metadata = {}
    for table in tables.name:
        metadata[table] = query_sqlite(f'DESCRIBE TABLE {table}', db_path).dropna(how='all', axis=1)
    return metadata

In [21]:
def format_schema_for_chat(metadata: dict[str, pd.DataFrame]) -> str:
    """
    Format table metadata into a readable string for chat training.
    """
    schema_lines = []
    
    for table_name, df in metadata.items():
        schema_lines.append(f"Table: {table_name}")
        
        for _, row in df.iterrows():
            col_info = f"  - {row['column_name']} ({row['column_type']}"
            if row['key'] == 'PRI':
                col_info += ", PRIMARY KEY"
            elif pd.notna(row['key']) and row['key'] != 'None':
                col_info += f", {row['key']}"
            if row['null'] == 'NO':
                col_info += ", NOT NULL"
            col_info += ")"
            schema_lines.append(col_info)
        
        schema_lines.append("")  # Add blank line between tables
    
    return "\n".join(schema_lines).strip()

def create_chatml_dataset(df: pd.DataFrame) -> list[dict]:
    """
    Create ChatML conversations format dataset directly from DataFrame.
    Includes system prompt with SQLite constraints.
    """
    # SQLite-focused system prompt based on constraints from prompts.py
    system_prompt = """You are an expert SQL developer specializing in SQLite. Generate accurate SQL queries that follow these requirements:

SQLITE REQUIREMENTS:
1. Use only standard SQLite-compatible SQL syntax
2. Avoid PostgreSQL-specific functions like INTERVAL - use date arithmetic instead
3. Use SQLite date functions: date(), datetime(), julianday() for date calculations
4. For "recent month" use: WHERE date_column >= date((SELECT MAX(date_column) FROM table_name), '-1 month')
5. For "past 30 days" use: WHERE date_column >= date((SELECT MAX(date_column) FROM table_name), '-30 days')
6. For rolling averages, use window functions or subqueries
7. Use MAX(appropriate_date_column) to find the most recent date
8. Use proper SQLite column types: INTEGER, TEXT, REAL, BLOB
9. Handle foreign key relationships correctly with proper JOIN syntax

QUERY STANDARDS:
- Write clear, efficient queries
- Use appropriate aggregation functions (COUNT, SUM, AVG, MAX, MIN)
- Include proper GROUP BY and ORDER BY clauses when needed
- Use table aliases for readability in complex queries"""

    chatml_data = []
    
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Creating ChatML dataset"):
        # Skip invalid entries
        if not row['is_valid']:
            continue
            
        # Get table metadata for this database
        metadata = table_metadata(row['db_path'])
        schema_text = format_schema_for_chat(metadata)
        
        # Create user message with schema and question
        user_content = f"Database Schema:\n{schema_text}\n\nQuestion: {row['question']}"
        
        # Create ChatML conversation with system prompt
        chatml_entry = {
            "conversations": [
                {
                    "role": "system",
                    "content": system_prompt
                },
                {
                    "role": "user",
                    "content": user_content
                },
                {
                    "role": "assistant", 
                    "content": row['query']
                }
            ]
        }
        
        chatml_data.append(chatml_entry)
    
    return chatml_data

In [None]:
# Create the ChatML format dataset for Axolotl
print("Creating ChatML format dataset for Axolotl...")
chatml_dataset = create_chatml_dataset(df)

print(f"Created {len(chatml_dataset)} ChatML training examples")

# Save to JSONL file for Axolotl
chatml_output_file = "chatml_training_data.jsonl"
with open(chatml_output_file, 'w') as f:
    for entry in chatml_dataset:
        f.write(json.dumps(entry) + '\n')

print(f"Saved ChatML training data to {chatml_output_file}")

# Display a sample ChatML entry
print("\nSample ChatML training entry:")
print(json.dumps(chatml_dataset[0], indent=2))

In [None]:
%%writefile axolotl.yaml
# axolotl_qwen3_0.6b.yaml
# Full fine-tune configuration for Qwen3-0.6B on SQL generation dataset
# Adapted from https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3/32b-qlora.yaml

base_model: Qwen/Qwen3-0.6B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

plugins:
  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false

chat_template: qwen3
datasets:
  - path: chatml_training_data.jsonl
    type: chat_template
    field_messages: conversations
    message_property_mappings:
      role: role
      content: content
val_set_size: 0.2
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared

sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true

# Since we're training a small model, we will just load full precision and not use LoRA.
# load_in_4bit: true
# adapter: qlora
# lora_r: 16
# lora_alpha: 32
# lora_target_modules:
#   - q_proj
#   - k_proj
#   - v_proj
#   - o_proj
#   - down_proj
#   - up_proj
# lora_mlp_kernel: true
# lora_qkv_kernel: true
# lora_o_kernel: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 2
micro_batch_size: 4
num_epochs: 4
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002

bf16: auto
tf32: true

gradient_checkpointing: offload
gradient_checkpointing_kwargs:
  use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true

warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens: