In [2]:
# One-time setup
!mkdir -p "$HOME/spider_eval"
!cp "$HOME/evaluation.py" "$HOME/spider_eval/"
!cp "$HOME/process_sql.py" "$HOME/spider_eval/"
!ls -lh "$HOME/spider_eval"

!"$HOME/t2sql-env/bin/pip" install -q nltk
!"$HOME/t2sql-env/bin/python" - <<'PY'
import nltk; nltk.download('punkt')
print("NLTK punkt downloaded.")

total 52K
drwxr-xr-x 2 sagemaker-user users  41 Oct 17 01:20 __pycache__
-rw-r--r-- 1 sagemaker-user users 30K Oct 17 04:12 evaluation.py
-rw-r--r-- 1 sagemaker-user users 17K Oct 17 04:12 process_sql.py
NLTK punkt downloaded.


[nltk_data] Downloading package punkt to /home/sagemaker-
[nltk_data]     user/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
!mkdir -p "$HOME/"
!unzip -q -o "$HOME/spider_data.zip" -d "$HOME/"

In [4]:
import os, nltk
os.environ["NLTK_DATA"] = os.path.expanduser("~/nltk_data")
os.makedirs(os.environ["NLTK_DATA"], exist_ok=True)

nltk.download("punkt")       # classic sentence/word tokenizer
nltk.download("punkt_tab")   # extra tables required by NLTK>=3.9

print("NLTK data dir:", os.environ["NLTK_DATA"])

NLTK data dir: /home/sagemaker-user/nltk_data


[nltk_data] Downloading package punkt to /home/sagemaker-
[nltk_data]     user/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /home/sagemaker-
[nltk_data]     user/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [6]:
# 0) make sure we're using the venv's pip
!"$HOME/t2sql-env/bin/python" -m pip install -q --upgrade pip

# 1) if you previously installed a conflicting multiprocess, remove it
!"$HOME/t2sql-env/bin/pip" uninstall -y multiprocess || true

# 2) install a known-good, minimal stack (no multiprocess pin; datasets pulls a compatible one)
!"$HOME/t2sql-env/bin/pip" install -q \
  "transformers==4.44.2" \
  "datasets==2.19.0" \
  "peft==0.12.0" \
  "accelerate==0.33.0" \
  sentencepiece \
  "pyarrow==19.0.0"

# 3) confirm versions actually installed inside the venv
!"$HOME/t2sql-env/bin/python" - <<'PY'
import transformers, datasets, peft, accelerate, pyarrow, sys
import multiprocess
print("Python:", sys.version)
print("transformers:", transformers.__version__)
print("datasets:", datasets.__version__)
print("peft:", peft.__version__)
print("accelerate:", accelerate.__version__)
print("pyarrow:", pyarrow.__version__)
print("multiprocess:", multiprocess.__version__)

Found existing installation: multiprocess 0.70.16
Uninstalling multiprocess-0.70.16:
  Successfully uninstalled multiprocess-0.70.16


  from .autonotebook import tqdm as notebook_tqdm


Python: 3.12.9 | packaged by conda-forge | (main, Feb 14 2025, 08:00:06) [GCC 13.3.0]
transformers: 4.44.2
datasets: 2.19.0
peft: 0.12.0
accelerate: 0.33.0
pyarrow: 19.0.0
multiprocess: 0.70.16


In [17]:
import json, os

HOME = os.path.expanduser("~")
DATA_DIR = os.path.join(HOME, "text2sql_data")
DEV_JSON = os.path.join(DATA_DIR, "dev.json")
EVAL_DIR = os.path.join(HOME, "text2sql_outputs", "spider_eval")
os.makedirs(EVAL_DIR, exist_ok=True)

with open(DEV_JSON, "r", encoding="utf-8") as f:
    dev = json.load(f)

gold_path = os.path.join(EVAL_DIR, "gold.tsv")
with open(gold_path, "w", encoding="utf-8") as g:
    for ex in dev:
        g.write(f"{ex['query']}\t{ex['db_id']}\n")

print("Wrote gold.tsv:", gold_path, "rows:", len(dev))

Wrote gold.tsv: /home/sagemaker-user/text2sql_outputs/spider_eval/gold.tsv rows: 1034


In [18]:
import os, json, torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel

HOME = os.path.expanduser("~")
OUT_DIR = os.path.join(HOME, "text2sql_outputs")
DATA_DIR = os.path.join(HOME, "text2sql_data")
DEV_JSON = os.path.join(DATA_DIR, "dev.json")
TABLES_JSON = os.path.join(DATA_DIR, "tables.json")
EVAL_DIR = os.path.join(OUT_DIR, "spider_eval")
adapters_dir = os.path.join(OUT_DIR, "lora_adapters")

with open(DEV_JSON, "r", encoding="utf-8") as f: dev = json.load(f)
with open(TABLES_JSON, "r", encoding="utf-8") as f: tables = json.load(f)

# build compact schema text like training (same caps you used)
def build_schema_texts(tables_json, keep_db_ids=None, cap_cols=6, cap_fks=6):
    out={}
    for db in tables_json:
        dbid=db["db_id"]
        if keep_db_ids and dbid not in keep_db_ids: continue
        names=db["table_names_original"]; cols=db["column_names_original"]
        pks=set(db["primary_keys"]); fks=db["foreign_keys"]
        per={i:[] for i in range(len(names))}
        for idx,(t_idx,c_name) in enumerate(cols):
            if t_idx==-1: continue
            tag=" PK" if idx in pks else ""
            per[t_idx].append(f"{c_name}{tag}")
        fk_lines=[]
        for a,b in fks:
            ct,cc=cols[a]; pt,pc=cols[b]
            if ct==-1 or pt==-1: continue
            fk_lines.append(f"{names[ct]}.{cc}->{names[pt]}.{pc}")
        lines=[f"DB: {dbid}","Tables:"]
        for i,t in enumerate(names):
            c=", ".join(per[i][:cap_cols])
            if len(per[i])>cap_cols: c+=", ..."
            lines.append(f"- {t}({c})")
        if fk_lines:
            fk_show="; ".join(fk_lines[:cap_fks])+(";" if len(fk_lines)>cap_fks else "")
            lines.append("FKs: "+fk_show)
        out[dbid]="\n".join(lines)
    return out

schema_texts = build_schema_texts(tables, set(ex["db_id"] for ex in dev), cap_cols=6, cap_fks=6)

tok = AutoTokenizer.from_pretrained("google/flan-t5-base")
base = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
model = PeftModel.from_pretrained(base, adapters_dir).eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

pred_path = os.path.join(EVAL_DIR, "pred.txt")
with open(pred_path, "w", encoding="utf-8") as p:
    for ex in dev:
        q, dbid = ex["question"], ex["db_id"]
        schema = schema_texts.get(dbid, f"DB: {dbid}\nTables:")
        prompt = f"translate to sql: {q}\n{schema}"
        enc = tok(prompt, return_tensors="pt", max_length=512, truncation=True).to(device)
        with torch.no_grad():
            out = model.generate(
                **enc, max_new_tokens=160, num_beams=5,
                no_repeat_ngram_size=3, length_penalty=0.8,
                early_stopping=True, do_sample=False,
                eos_token_id=tok.eos_token_id
            )
        sql = tok.decode(out[0], skip_special_tokens=True).strip()
        p.write(sql + "\n")

print("Wrote pred.txt:", pred_path, "rows:", len(dev))

Wrote pred.txt: /home/sagemaker-user/text2sql_outputs/spider_eval/pred.txt rows: 1034


In [19]:
import os, subprocess, shlex
HOME = os.path.expanduser("~")
EVAL_PY = os.path.join(HOME, "spider_eval", "evaluation.py")
TABLES_JSON = os.path.join(HOME, "text2sql_data", "tables.json")
SPIDER_DB_DIR = os.path.join(HOME, "spider_data", "database")  # your DB path
EVAL_DIR = os.path.join(HOME, "text2sql_outputs", "spider_eval")
gold_path = os.path.join(EVAL_DIR, "gold.tsv")
pred_path = os.path.join(EVAL_DIR, "pred.txt")

use_exec = os.path.isdir(SPIDER_DB_DIR)
etype = "all" if use_exec else "match"

cmd = (
    f'python "{EVAL_PY}" '
    f'--gold "{gold_path}" '
    f'--pred "{pred_path}" '
    f'--table "{TABLES_JSON}" '
    f'--etype {etype} ' + (f'--db "{SPIDER_DB_DIR}"' if use_exec else "")
)
print("Running:\n", cmd, "\n")
res = subprocess.run(shlex.split(cmd), capture_output=True, text=True)
print(res.stdout)
if res.stderr: print("STDERR:\n", res.stderr)

Running:
 python "/home/sagemaker-user/spider_eval/evaluation.py" --gold "/home/sagemaker-user/text2sql_outputs/spider_eval/gold.tsv" --pred "/home/sagemaker-user/text2sql_outputs/spider_eval/pred.txt" --table "/home/sagemaker-user/text2sql_data/tables.json" --etype all --db "/home/sagemaker-user/spider_data/database" 

medium pred: SELECT name, country, age FROM singer ORDER BY age ASC
medium gold: SELECT name ,  country ,  age FROM singer ORDER BY age DESC

eval_err_num:1
medium pred: SELECT avg(Singer_ID), min(Serge_ID), max(Age) FROM singer WHERE Country = "France"
medium gold: SELECT avg(age) ,  min(age) ,  max(age) FROM singer WHERE country  =  'France'

eval_err_num:2
medium pred: SELECT avg(Singer_ID), min(Serge_ID), max(Age) FROM singer WHERE country = 'France'
medium gold: SELECT avg(age) ,  min(age) ,  max(age) FROM singer WHERE country  =  'France'

medium pred: SELECT Song_Name, Song_release_year FROM singer ORDER BY Age DESC LIMIT 1
medium gold: SELECT song_name ,  song_r