# Fine-Tune Gemma using Hugging Face Transformers and QloRA

### Imports

In [1]:
# Install Pytorch & other libraries
%pip install "torch>=2.4.0" tensorboard

# Install Gemma release branch from Hugging Face
%pip install "transformers>=4.51.3"

# Install Hugging Face libraries
%pip install  --upgrade \
  "datasets>=3.3.2" \
  "accelerate==1.4.0" \
  "evaluate==0.4.3" \
  "bitsandbytes==0.45.3" \
  "trl==0.21.0" \
  "peft==0.14.0" \
  "protobuf==5.29.1" \
  "fsspec==2025.3.0" \
  python-Levenshtein \
  sentencepiece
  # flash_attn \




In [2]:
import os
os.environ["HF_HUB_DISABLE_XET"] = "1"

import jax
print("JAX backend:", jax.default_backend())
print("JAX devices:", jax.devices())

JAX backend: gpu
JAX devices: [CudaDevice(id=0)]


In [None]:
from google.colab import drive
from google.colab import userdata
from huggingface_hub import login

import os
from pathlib import Path
import json
import random
import re
import sqlite3

import torch
from datasets import load_dataset, Dataset
import pandas as pd
import Levenshtein

from huggingface_hub import login
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForImageTextToText,
    BitsAndBytesConfig,
)
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer

# Login into Hugging Face Hub
hf_token = userdata.get('HF_TOKEN') # If you are running inside a Google Colab
login(hf_token)

### Setup Repo Structure

In [None]:
drive.mount('/content/drive')

ROOT_DRIVE_DIR = "/content/drive/MyDrive/gemma_lora_ft"

WIKISQL_DIR = Path("/content/WikiSQL")
DATA_DIR = Path("/content/data")


Mounted at /content/drive


### System/User Prompt with Prompt Builder

In [None]:
# User prompt template (used in both training & inference)

def generate_raw_prompt(question: str, schema_text: str) -> str:
    return f"""<INSTRUCTIONS>
You are a precise text-to-SQL generator. Using the known schema of the sql database you must output only a valid SQL query and nothing else.
</INSTRUCTIONS>

<SCHEMA>
{schema_text}
</SCHEMA>

<QUESTION>
{question}
</QUESTION>

<SQL_Query>
""".strip()

# WikiSQL Dataset

In [None]:
if not os.path.exists(WIKISQL_DIR):
  !git clone https://github.com/salesforce/WikiSQL

if not os.path.exists(DATA_DIR):
  !tar xvjf /content/WikiSQL/data.tar.bz2

Cloning into 'WikiSQL'...
remote: Enumerating objects: 389, done.[K
remote: Counting objects: 100% (195/195), done.[K
remote: Compressing objects: 100% (41/41), done.[K
remote: Total 389 (delta 186), reused 154 (delta 154), pack-reused 194 (from 1)[K
Receiving objects: 100% (389/389), 50.72 MiB | 12.60 MiB/s, done.
Resolving deltas: 100% (213/213), done.
data/
data/train.jsonl
data/test.tables.jsonl
data/test.db
data/dev.tables.jsonl
data/dev.db
data/test.jsonl
data/train.tables.jsonl
data/train.db
data/dev.jsonl


### SQLite Setup and Table Builder

In [None]:
def load_tables(split="dev"):
    tables_path = DATA_DIR / f"{split}.tables.jsonl"
    tables = {}
    with open(tables_path, "r", encoding="utf-8") as f:
        for line in f:
            obj = json.loads(line)
            tables[obj["id"]] = obj
    return tables


def iter_split(split="dev"):
    q_path = DATA_DIR / f"{split}.jsonl"
    tables = load_tables(split)
    with open(q_path, "r", encoding="utf-8") as f:
        for line in f:
            ex = json.loads(line)
            table = tables[ex["table_id"]]
            yield ex, table


def build_sqlite_from_table(table_obj):
    df = pd.DataFrame(table_obj["rows"], columns=table_obj["header"])
    conn = sqlite3.connect(":memory:")
    df.to_sql("data", conn, index=False, if_exists="replace")
    return conn


AGG_OPS = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
COND_OPS = ["=", ">", "<", "OP"]


def escape_identifier(name: str) -> str:
    # Escape internal quotes by doubling them
    cleaned = name.replace('"', '""')
    return f'"{cleaned}"'


def logical_to_sql(sql_obj, table_obj, table_name="data"):
    sel_idx = sql_obj["sel"]
    agg_idx = sql_obj["agg"]
    conds = sql_obj["conds"]

    columns = table_obj["header"]
    sel_col = escape_identifier(columns[sel_idx])
    agg = AGG_OPS[agg_idx]

    # FIX: Proper aggregation syntax
    if agg == "":
        select_expr = sel_col
    else:
        select_expr = f"{agg}({sel_col})"

    query = f"SELECT {select_expr} FROM {table_name}"

    where_clauses = []
    for col_idx, op_idx, val in conds:
        col_name = escape_identifier(columns[col_idx])
        op = COND_OPS[op_idx]

        if isinstance(val, str):
            v_str = "'" + val.replace("'", "''") + "'"
        else:
            v_str = str(val)

        if op == "OP":
            op = "="

        where_clauses.append(f"{col_name} {op} {v_str}")

    if where_clauses:
        query += " WHERE " + " AND ".join(where_clauses)

    return query

def execute_sql(conn, sql):
    try:
        cur = conn.cursor()
        cur.execute(sql)
        return cur.fetchall(), None
    except Exception as e:
        return None, str(e)


def make_schema_text(table_obj):
    """
    Convert WikiSQL table to the same schema format used for synthetic data,
    so the model sees consistent inputs.
    """
    cols = table_obj["header"]
    return "\n".join(f"- {c} (TEXT)" for c in cols)

In [None]:
ex, tbl = next(iter_split("train"))
example_gold_sql = logical_to_sql(ex["sql"], tbl)
# print("\nWikiSQL sanity check:")
# print("Question:", ex["question"])
# print("Gold SQL:", example_gold_sql)

print(ex)
print(tbl)

{'phase': 1, 'table_id': '1-1000181-1', 'question': 'Tell me what the notes are for South Australia ', 'sql': {'sel': 5, 'conds': [[3, 0, 'SOUTH AUSTRALIA']], 'agg': 0}}
{'id': '1-1000181-1', 'header': ['State/territory', 'Text/background colour', 'Format', 'Current slogan', 'Current series', 'Notes'], 'types': ['text', 'text', 'text', 'text', 'text', 'text'], 'rows': [['Australian Capital Territory', 'blue/white', 'YaaÂ·nna', 'ACT Â· CELEBRATION OF A CENTURY 2013', 'YILÂ·00A', 'Slogan screenprinted on plate'], ['New South Wales', 'black/yellow', 'aaÂ·nnÂ·aa', 'NEW SOUTH WALES', 'BXÂ·99Â·HI', 'No slogan on current series'], ['New South Wales', 'black/white', 'aaaÂ·nna', 'NSW', 'CPXÂ·12A', 'Optional white slimline series'], ['Northern Territory', 'ochre/white', 'CaÂ·nnÂ·aa', 'NT Â· OUTBACK AUSTRALIA', 'CBÂ·06Â·ZZ', 'New series began in June 2011'], ['Queensland', 'maroon/white', 'nnnÂ·aaa', 'QUEENSLAND Â· SUNSHINE STATE', '999Â·TLG', 'Slogan embossed on plate'], ['South Australia', 'bla

### Test SQL Helper Functions

In [None]:
import sqlite3
import re

def is_sql_syntax_valid(sql):
    """
    Checks SQL syntax by replacing table names with a dummy table
    so that missing tables do NOT cause an error.
    """
    sql_clean = sql.strip()

    # Replace any token after FROM or JOIN with the dummy table name
    sql_clean = re.sub(r"(FROM|JOIN)\s+[\w\.\-]+", r"\1 dummy", sql_clean, flags=re.IGNORECASE)

    conn = sqlite3.connect(":memory:")
    conn.execute("CREATE TABLE dummy(x TEXT);")

    try:
        conn.execute(sql_clean)
        return True
    except Exception as e:
        if "no such table" in str(e).lower():
            return True   # table does not exist â†’ ignore
        return False


In [None]:
wikisql_valid = 0
wikisql_total = 0
wikisql_errors = []

for ex, tbl in iter_split("dev"):
    wikisql_total += 1

    conn = build_sqlite_from_table(tbl)
    sql = logical_to_sql(ex["sql"], tbl)

    result, err = execute_sql(conn, sql)

    if err is None:
        wikisql_valid += 1
    else:
        wikisql_errors.append((sql, err))

print(f"WikiSQL valid queries: {wikisql_valid}/{wikisql_total} "
      f"({wikisql_valid / wikisql_total:.2%})")

print("\nExample WikiSQL errors:")
if not wikisql_errors:
    print("No errors â€” all WikiSQL queries are valid!")
else:
    for i, (sql, err) in enumerate(wikisql_errors[:5]):
        print("SQL:", sql)
        print("ERROR:", err, "\n")


WikiSQL valid queries: 8421/8421 (100.00%)

Example WikiSQL errors:
No errors â€” all WikiSQL queries are valid!


### Final Data Preprocessing

In [None]:
from datasets import Dataset

def normalize_sql(sql: str) -> str:
    sql = sql.strip().rstrip(";")
    sql = sql.replace("`", '"')
    return sql

def build_completion(sql: str) -> str:
    sql = normalize_sql(sql)
    return sql.lstrip() + "\n</SQL_Query>"

def build_training_row_from_wikisql(ex, tbl):
    schema = make_schema_text(tbl)
    prompt = generate_raw_prompt(ex["question"], schema)
    gold_sql = logical_to_sql(ex["sql"], tbl)
    completion = build_completion(gold_sql)

    return {
        "prompt": prompt,
        "completion": completion
    }


In [None]:
train_rows = []
for ex, tbl in iter_split("train"):
    train_rows.append(build_training_row_from_wikisql(ex, tbl))

test_rows = []
for ex, tbl in iter_split("test"):
    test_rows.append(build_training_row_from_wikisql(ex, tbl))

val_rows = []
for ex, tbl in iter_split("dev"):
    val_rows.append(build_training_row_from_wikisql(ex, tbl))

train_wikisql = Dataset.from_list(train_rows)
test_wikisql = Dataset.from_list(test_rows)
val_wikisql = Dataset.from_list(val_rows)

print(train_wikisql)
print(test_wikisql)
print(val_wikisql)

print("\n")

for i in range(3):
    print("PROMPT:\n", train_wikisql[i]["prompt"])
    print("COMPLETION:\n", train_wikisql[i]["completion"])
    print("="*80)


Dataset({
    features: ['prompt', 'completion'],
    num_rows: 56355
})
Dataset({
    features: ['prompt', 'completion'],
    num_rows: 15878
})
Dataset({
    features: ['prompt', 'completion'],
    num_rows: 8421
})


PROMPT:
 <INSTRUCTIONS>
You are a precise text-to-SQL generator. Using the known schema of the sql database you must output only a valid SQL query and nothing else.
</INSTRUCTIONS>

<SCHEMA>
- State/territory (TEXT)
- Text/background colour (TEXT)
- Format (TEXT)
- Current slogan (TEXT)
- Current series (TEXT)
- Notes (TEXT)
</SCHEMA>

<QUESTION>
Tell me what the notes are for South Australia 
</QUESTION>

<SQL_Query>
COMPLETION:
 SELECT "Notes" FROM data WHERE "Current slogan" = 'SOUTH AUSTRALIA'
</SQL_Query>
PROMPT:
 <INSTRUCTIONS>
You are a precise text-to-SQL generator. Using the known schema of the sql database you must output only a valid SQL query and nothing else.
</INSTRUCTIONS>

<SCHEMA>
- State/territory (TEXT)
- Text/background colour (TEXT)
- Format (TEXT)


# Gemma 3 4B Model

### Load Pretrained Model

In [None]:
# Hugging Face model id
model_id = "google/gemma-3-4b-pt"  # pre-trained (not instruction-tuned)
# For tokenizer we use the instruction-tuned tokenizer
tokenizer_id = "google/gemma-3-4b-it"

# Select model class based on id
if model_id == "google/gemma-3-4b-pt":
    model_class = AutoModelForCausalLM
else:
    model_class = AutoModelForImageTextToText

# Choose dtype based on GPU capability
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
    torch_dtype = torch.bfloat16
else:
    torch_dtype = torch.float16

# Define model init arguments
model_kwargs = dict(
    attn_implementation="sdpa", # Use "flash_attention_2" when running on Ampere or newer GPU
    torch_dtype=torch_dtype, # What torch dtype to use, defaults to auto
    device_map="auto", # Let torch decide how to load the model
)

model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=model_kwargs['torch_dtype'],
    bnb_4bit_quant_storage=model_kwargs['torch_dtype'],
)


print("ðŸ”„ Loading model...")
model = model_class.from_pretrained(model_id, **model_kwargs)

print("ðŸ”„ Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)

# Ensure tokenizer has EOS & PAD set correctly for generation & SFT
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"


ðŸ”„ Loading model...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/815 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

ðŸ”„ Loading tokenizer...


tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

### Generate Cleaned SQL Query from Gemma Model with Prompt

In [None]:
import re

def clean_sql_output(text: str) -> str:
    """
    Extract a clean SQL statement from LLM output under the new tag-based format:

        <SQL_Query>
        SELECT ...
        </SQL_Query>

    Handles:
      - missing or extra whitespace
      - missing closing tags
      - trailing commentary after </SQL_Query>
      - selects the FIRST valid SQL statement
    """

    if not text or not isinstance(text, str):
        return ""

    raw = text.strip()

    # -------------------------------
    # 1. Extract content inside <SQL_Query> ... </SQL_Query>
    # -------------------------------
    m = re.search(
        r"<SQL_Query>(.*?)(</SQL_Query>|$)",
        raw,
        flags=re.IGNORECASE | re.DOTALL,
    )

    if m:
        candidate = m.group(1).strip()
    else:
        # fallback if tag missing
        candidate = raw

    # Remove any accidental tag echoes
    candidate = re.sub(r"</?SQL_Query>", "", candidate, flags=re.IGNORECASE).strip()

    # Strip markdown fences if any
    candidate = re.sub(r"```sql", "", candidate, flags=re.IGNORECASE)
    candidate = candidate.replace("```", "").strip()

    # Normalize whitespace
    candidate = re.sub(r"[ \t]+", " ", candidate)

    # -------------------------------
    # 2. Grab the first SQL keyword (fallback)
    # -------------------------------
    sql_start = re.compile(
        r"\b(SELECT|INSERT\s+INTO|UPDATE|DELETE\s+FROM|CREATE\s+TABLE)\b",
        flags=re.IGNORECASE,
    )

    match = sql_start.search(candidate)
    if not match:
        return candidate  # return raw candidate (probably empty)

    sql = candidate[match.start():].strip()

    # -------------------------------
    # 3. Remove trailing commentary or extra content
    # -------------------------------
    stop_tokens = [
        "</SQL_Query>",
        "<INSTRUCTIONS>",
        "<QUESTION>",
        "<SCHEMA>",
        "Explanation:",
        "Answer:",
        "Result:",
        "Note:",
        "\n#",
        "```",
    ]

    end_positions = []
    for tok in stop_tokens:
        pos = sql.find(tok)
        if pos > 0:
            end_positions.append(pos)

    if end_positions:
        sql = sql[:min(end_positions)].strip()

    # Remove trailing punctuation
    sql = sql.rstrip(";` ")

    return sql.strip()



def generate_sql_from_llm(question: str, schema_text: str, model) -> str:
    """
    Generate SQL from a raw text prompt using continuation-style generation.
    """
    # Build the same raw prompt used during training
    prompt = generate_raw_prompt(question, schema_text)

    print(prompt)

    # Tokenize
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    print
    # Compute input length to slice off the prompt from model output
    input_len = inputs["input_ids"].shape[-1]

    # Generate continuation
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=150,
            do_sample=False,  # deterministic for evaluation
            pad_token_id=tokenizer.eos_token_id
        )

    # Slice off the prompt to isolate model-generated SQL
    gen_tokens = outputs[0][input_len:]
    decoded = tokenizer.decode(gen_tokens, skip_special_tokens=True)

    return clean_sql_output(decoded)


# SQL similarity score evaluator
def tokenize_sql(sql: str):
    sql = sql.lower()
    sql = re.sub(r"[^a-z0-9_*]", " ", sql)
    tokens = sql.split()
    return tokens

### Test Example Prompt on Pretrained Model

In [None]:
print("\nðŸ”Ž Single WikiSQL example BEFORE training:")
ex, tbl = next(iter_split("train"))
conn = build_sqlite_from_table(tbl)
schema = make_schema_text(tbl)

# print(ex["question"])
# print(schema)

print(generate_sql_from_llm(ex["question"], schema, model))


ðŸ”Ž Single WikiSQL example BEFORE training:


The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


<INSTRUCTIONS>
You are a precise text-to-SQL generator. Using the known schema of the sql database you must output only a valid SQL query and nothing else.
</INSTRUCTIONS>

<SCHEMA>
- State/territory (TEXT)
- Text/background colour (TEXT)
- Format (TEXT)
- Current slogan (TEXT)
- Current series (TEXT)
- Notes (TEXT)
</SCHEMA>

<QUESTION>
Tell me what the notes are for South Australia 
</QUESTION>

<SQL_Query>
SELECT notes FROM state_territory WHERE state_territory = 'South Australia'


In [None]:
print("\nðŸ”Ž Single WikiSQL example BEFORE training:")
ex, tbl = next(iter_split("train"))
conn = build_sqlite_from_table(tbl)
schema = make_schema_text(tbl)

gold_sql = logical_to_sql(ex["sql"], tbl)
pred_sql = generate_sql_from_llm(ex["question"], schema, model)

gold_res, ge = execute_sql(conn, gold_sql)
pred_res, pe = execute_sql(conn, pred_sql)

# print("Q:", ex["question"])
print("Gold:", gold_sql)
print("Pred:", pred_sql)
print("Gold result:", gold_res)
print("Pred result:", pred_res if pe is None else f"ERROR: {pe}")


ðŸ”Ž Single WikiSQL example BEFORE training:
<INSTRUCTIONS>
You are a precise text-to-SQL generator. Using the known schema of the sql database you must output only a valid SQL query and nothing else.
</INSTRUCTIONS>

<SCHEMA>
- State/territory (TEXT)
- Text/background colour (TEXT)
- Format (TEXT)
- Current slogan (TEXT)
- Current series (TEXT)
- Notes (TEXT)
</SCHEMA>

<QUESTION>
Tell me what the notes are for South Australia 
</QUESTION>

<SQL_Query>
Gold: SELECT "Notes" FROM data WHERE "Current slogan" = 'SOUTH AUSTRALIA'
Pred: SELECT notes FROM state_territory WHERE state_territory = 'South Australia'
Gold result: [('No slogan on current series',)]
Pred result: ERROR: no such table: state_territory


# Training Pipeline

### LoRA Config

In [None]:
peft_config = LoraConfig(
    r=32,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    bias="none",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"]
)

### SFT Config

In [None]:
from trl import SFTConfig

training_args  = SFTConfig(
    output_dir="gemma-text-to-sql-train-rank-32",         # directory to save and repository id
    max_length=512,                         # max sequence length for model and packing of the dataset
    packing=True,                           # Groups multiple samples in the dataset into a single sequence
    num_train_epochs=3,                     # number of training epochs
    per_device_train_batch_size=1,          # batch size per device during training
    gradient_accumulation_steps=4,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=10,                       # log every 10 steps
    save_strategy="epoch",                  # save checkpoint every epoch
    learning_rate=5e-5,                     # learning rate, based on QLoRA paper
    fp16=True if torch_dtype == torch.float16 else False,   # use float16 precision
    bf16=True if torch_dtype == torch.bfloat16 else False,   # use bfloat16 precision
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    push_to_hub=True,                       # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
)

### SFT Trainer

In [None]:
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_wikisql,
    eval_dataset=val_wikisql,
    processing_class=tokenizer,
    peft_config=peft_config,
)



Adding EOS to train dataset:   0%|          | 0/56355 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/56355 [00:00<?, ? examples/s]

Packing train dataset:   0%|          | 0/56355 [00:00<?, ? examples/s]

Adding EOS to eval dataset:   0%|          | 0/8421 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/8421 [00:00<?, ? examples/s]

Packing eval dataset:   0%|          | 0/8421 [00:00<?, ? examples/s]

### Train Loop & Model Saving


In [None]:
import datetime
import json
import os

# Create a timestamped directory for the current run
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
model_save_dir = f"{ROOT_DRIVE_DIR}/gemma_text_to_sql_run_train_rank_32_{timestamp}"

# Create the directory if it doesn't exist
Path(model_save_dir).mkdir(parents=True, exist_ok=True)

# Update the output_dir in training_args
training_args.output_dir = model_save_dir

print(f"Saving model artifacts to: {model_save_dir}")

# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()

# Save the final model again to the specified directory
trainer.save_model(model_save_dir)
tokenizer.save_pretrained(model_save_dir)


log_file_path = os.path.join(model_save_dir, "training_log.json")
log_history = trainer.state.log_history

with open(log_file_path, "w") as f:
    json.dump(log_history, f, indent=4)

print(f"Training log saved to: {log_file_path}")


The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1, 'bos_token_id': 2, 'pad_token_id': 0}.


Saving model artifacts to: /content/drive/MyDrive/gemma_lora_ft/gemma_text_to_sql_run_train_rank_32_20251208_075911


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
10,0.3518
20,0.1267
30,0.104
40,0.099
50,0.0817
60,0.0749
70,0.0689
80,0.0775
90,0.0637
100,0.0745


'(MaxRetryError("HTTPSConnectionPool(host='hf-hub-lfs-us-east-1.s3-accelerate.amazonaws.com', port=443): Max retries exceeded with url: /repos/95/2f/952f14e4865b89e8d7112976dfcb8e995fbb9817d7f1980f68c8065d96c5be57/6f6c34200386e0ede70eb59f709f9e4287b997be1d84b468052ab76744e8a200?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=AKIA2JU7TKAQLC2QXPN7%2F20251208%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20251208T203106Z&X-Amz-Expires=86400&X-Amz-Signature=c246f30658248c1028b6ee0cc86d384f9627750d223c5a307b66fab8fc5a451c&X-Amz-SignedHeaders=host&partNumber=147&uploadId=ybQCyIWoEAbTveGpacUUyzaeqqU.MKajNjm5YqYAFJk.ih5sB9e.NHudoPC7OpK91i9d8z7XT_fmjLuk.7lTYWcZZ62KIvX0W2FLRAdZKxHzwoEYhgfAY1b8Eqd_Fgeo&x-id=UploadPart (Caused by SSLError(SSLEOFError(8, 'EOF occurred in violation of protocol (_ssl.c:2427)')))"), '(Request ID: 41e3034a-6e45-45ef-8dd3-c85fe957b19d)')' thrown while requesting PUT https://hf-hub-lfs-us-east-1.s3-accelerate.amazonaws.com/repos/95/2f

Upload 0 LFS files: 0it [00:00, ?it/s]

No files have been modified since last commit. Skipping to prevent empty commit.


Training log saved to: /content/drive/MyDrive/gemma_lora_ft/gemma_text_to_sql_run_train_rank_32_20251208_075911/training_log.json


### Clean Cache

# Trained Model

### Load Trained Model from Parameters

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

BASE_MODEL = "google/gemma-3-4b-pt"

print("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    device_map="auto",
    torch_dtype=torch.bfloat16
)

print("Loading LoRA adapter...")
ft_model = PeftModel.from_pretrained(base_model, model_save_dir)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-4b-it")


Loading base model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading LoRA adapter...


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.25 GiB. GPU 0 has a total capacity of 22.16 GiB of which 1.25 GiB is free. Process 6707 has 20.91 GiB memory in use. Of the allocated memory 20.62 GiB is allocated by PyTorch, and 44.51 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
ft_model.print_trainable_parameters()

### Review Example Output of Trained Model

In [None]:
print("\nðŸ”Ž Single WikiSQL example BEFORE training:")
ex, tbl = next(iter_split("train"))
conn = build_sqlite_from_table(tbl)
schema = make_schema_text(tbl)

gold_sql = logical_to_sql(ex["sql"], tbl)
pred_sql = generate_sql_from_llm(ex["question"], schema, ft_model)

gold_res, ge = execute_sql(conn, gold_sql)
pred_res, pe = execute_sql(conn, pred_sql)

# print("Q:", ex["question"])
print("Gold:", gold_sql)
print("Pred:", pred_sql)
print("Gold result:", gold_res)
print("Pred result:", pred_res if pe is None else f"ERROR: {pe}")