# SQL-R1: Text-to-SQL RL Training on Kaggle

**Requirements**: Kaggle GPU Runtime (T4 16GB)

## Overview
- **Paper**: SQL-R1: Training Natural Language to SQL Reasoning Model By Reinforcement Learning
- **Algorithm**: GRPO (Group Relative Policy Optimization)
- **Model**: Qwen2.5-Coder-3B-Instruct

## 1. Environment Setup

⚠️ **IMPORTANT**: After running the installation cell, you MUST restart the kernel before continuing!

In [None]:
!nvidia-smi

In [None]:
# Step 1: Install dependencies
# After this cell completes, RESTART THE KERNEL (Runtime -> Restart runtime)

!pip install vllm==0.6.3 ray transformers accelerate --quiet
!pip install wandb sqlparse func_timeout nltk ijson --quiet
!pip install hydra-core omegaconf --quiet
!pip install flash-attn --no-build-isolation --quiet

# Clone SQL-R1
import os
if not os.path.exists('SellWizr-Assignment'):
    !git clone https://github.com/dancinglightning/SellWizr-Assignment.git

# Install verl
%cd SellWizr-Assignment/SQL-R1
!pip install -e . --quiet

print("\n" + "="*60)
print("INSTALLATION COMPLETE!")
print("Please RESTART THE KERNEL now: Runtime -> Restart runtime")
print("Then skip this cell and run the next one.")
print("="*60)

In [None]:
# Step 2: Run this cell AFTER kernel restart
# This cell sets up the working directory after restart

import os
os.chdir('/kaggle/working/SellWizr-Assignment/SQL-R1')
print(f"Working directory: {os.getcwd()}")

import torch
import pandas as pd
import numpy as np

print(f"PyTorch: {torch.__version__}")
print(f"NumPy: {np.__version__}")
print(f"Pandas: {pd.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

## 2. Download Databases

In [None]:
import os
import sqlite3
import shutil

# Create directories
os.makedirs('data/NL2SQL/SynSQL-2.5M/databases', exist_ok=True)
os.makedirs('data/spider/database', exist_ok=True)

synsql_db_path = 'data/NL2SQL/SynSQL-2.5M/databases'
spider_db_path = 'data/spider/database'

# Create test databases for reward function
test_schemas = {
    'concert_singer': [
        'CREATE TABLE singer (singer_id INT PRIMARY KEY, name TEXT, country TEXT, age INT)',
        'CREATE TABLE concert (concert_id INT PRIMARY KEY, concert_name TEXT, year INT)'
    ],
    'pets_1': [
        'CREATE TABLE pets (pet_id INT PRIMARY KEY, pet_type TEXT, pet_age INT)',
        'CREATE TABLE owners (owner_id INT PRIMARY KEY, name TEXT, age INT)'
    ],
    'employee_hire_evaluation': [
        'CREATE TABLE employees (employee_id INT PRIMARY KEY, name TEXT, department TEXT, salary INT)',
        'CREATE TABLE evaluations (eval_id INT PRIMARY KEY, employee_id INT, score INT)'
    ],
    'world_1': [
        'CREATE TABLE country (code TEXT PRIMARY KEY, name TEXT, continent TEXT, population INT)',
        'CREATE TABLE city (id INT PRIMARY KEY, name TEXT, country_code TEXT, population INT)'
    ]
}

for db_name, schemas in test_schemas.items():
    for base_path in [spider_db_path, synsql_db_path]:
        db_dir = f'{base_path}/{db_name}'
        os.makedirs(db_dir, exist_ok=True)
        db_file = f'{db_dir}/{db_name}.sqlite'
        
        conn = sqlite3.connect(db_file)
        for schema in schemas:
            conn.execute(schema)
        conn.commit()
        conn.close()

print(f"Created {len(test_schemas)} databases!")

## 3. Download Model

In [None]:
from huggingface_hub import snapshot_download
import os

MODEL_NAME = "Qwen/Qwen2.5-Coder-3B-Instruct"
MODEL_PATH = "models/Qwen2.5-Coder-3B-Instruct"

if not os.path.exists(MODEL_PATH):
    print(f"Downloading {MODEL_NAME}...")
    snapshot_download(repo_id=MODEL_NAME, local_dir=MODEL_PATH, local_dir_use_symlinks=False)
    print("Done!")
else:
    print("Model exists!")

## 4. Check Training Data

In [None]:
import pandas as pd

train_df = pd.read_parquet('example_data/train.parquet')
print(f"Training samples: {len(train_df)}")
print(f"Columns: {train_df.columns.tolist()}")

## 5. RL Training Config

In [None]:
import os
os.environ['VLLM_ATTENTION_BACKEND'] = 'XFORMERS'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

TRAIN_CONFIG = {
    'data.train_files': 'example_data/train.parquet',
    'data.val_files': 'example_data/test.parquet',
    'data.train_batch_size': 2,
    'data.val_batch_size': 2,
    'data.max_prompt_length': 1024,
    'data.max_response_length': 512,
    'actor_rollout_ref.model.path': 'models/Qwen2.5-Coder-3B-Instruct',
    'actor_rollout_ref.model.enable_gradient_checkpointing': True,
    'actor_rollout_ref.actor.ppo_mini_batch_size': 2,
    'actor_rollout_ref.actor.ppo_micro_batch_size': 1,
    'actor_rollout_ref.actor.fsdp_config.param_offload': True,
    'actor_rollout_ref.actor.fsdp_config.grad_offload': True,
    'actor_rollout_ref.actor.fsdp_config.optimizer_offload': True,
    'actor_rollout_ref.actor.optim.lr': '1e-6',
    'actor_rollout_ref.actor.use_kl_loss': True,
    'actor_rollout_ref.actor.kl_loss_coef': 0.001,
    'actor_rollout_ref.actor.kl_loss_type': 'low_var_kl',
    'actor_rollout_ref.rollout.name': 'vllm',
    'actor_rollout_ref.rollout.tensor_model_parallel_size': 1,
    'actor_rollout_ref.rollout.gpu_memory_utilization': 0.3,
    'actor_rollout_ref.rollout.n': 4,
    'actor_rollout_ref.rollout.temperature': 1.0,
    'actor_rollout_ref.rollout.log_prob_micro_batch_size': 8,
    'actor_rollout_ref.ref.fsdp_config.param_offload': True,
    'actor_rollout_ref.ref.log_prob_micro_batch_size': 8,
    'algorithm.adv_estimator': 'grpo',
    'algorithm.kl_ctrl.kl_coef': 0.001,
    'trainer.n_gpus_per_node': 1,
    'trainer.nnodes': 1,
    'trainer.total_epochs': 1,
    'trainer.save_freq': 50,
    'trainer.test_freq': 25,
    'trainer.critic_warmup': 0,
    'trainer.logger': "['console']",
    'trainer.project_name': 'SQL-R1-Kaggle',
    'trainer.experiment_name': '3B-T4-GRPO',
    'trainer.default_local_dir': 'logs/kaggle_run',
}

cmd_args = ' '.join([f"{k}={v}" for k, v in TRAIN_CONFIG.items()])
print("Config ready!")

In [None]:
# Run training
!python -m verl.trainer.main_ppo {cmd_args}

## 6. Test Reward Function

In [None]:
from verl.utils.reward_score.synsql import extract_solution

test = """<|im_start|>assistant
<think>Query analysis</think>
<answer>```sql
SELECT * FROM employees
```</answer>"""

answer, think, _ = extract_solution(test)
print(f"Answer: {answer}")
print(f"Think: {think}")

In [None]:
!nvidia-smi