In [None]:
import asyncio
from typing import List

from genlm.backend import load_model_by_name
from genlm.bytes import BeamParams
from genlm.control import ByteLLM, AWRS, direct_token_sampler
from genlm.eval.core import run_evaluation
from genlm.eval.core.model import ModelOutput, ModelResponse
from genlm.eval.domains.spider.spider import (
    SpiderDataset,
    SpiderEvaluator,
    default_prompt_formatter,
)
from genlm.eval.domains.spider.table_column_potential import SpiderTableColumnVerifier
from pprint import pprint as pp

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# HF_TOKEN should be set via environment variable or huggingface-cli login
# "/teamspace/studios/this_studio/spider_data"
SPIDER_SAMPLE_DIR = "/teamspace/studios/this_studio/genlm-eval/assets/spider/spider_sample"
SPIDER_DATA_DIR = "/teamspace/studios/this_studio/spider_data"
SPIDER_GRAMMARS = "/teamspace/studios/this_studio/genlm-eval/assets/spider/grammars.json"


  from .autonotebook import tqdm as notebook_tqdm


INFO 09-24 19:05:07 [__init__.py:235] Automatically detected platform cuda.


In [21]:
def build_bytelm():
    llm = load_model_by_name("meta-llama/Llama-3.2-1B-Instruct", backend="hf")
    # llm = load_model_by_name("gpt2", backend="hf")
    model_eos_token = llm.byte_vocab[llm.tokenizer.eos_token_id]
    beam_params = BeamParams(
        K=1,
        prune_threshold=0.0,
        eos_tokens=[b"\n", b"\n\n", b"<|eot_id|>"],
        heal=True,
    )
    return ByteLLM(llm, beam_params)

In [22]:
BYTE_LLM = build_bytelm()
dataset = SpiderDataset.from_spider_dir(
        SPIDER_DATA_DIR,
        grammar_json_path=SPIDER_GRAMMARS,
        few_shot_example_ids=[],
    )
sampleset = SpiderDataset.from_spider_dir(
        SPIDER_SAMPLE_DIR,
        grammar_json_path=SPIDER_GRAMMARS,
        few_shot_example_ids=[],
    )


In [23]:
from itertools import islice

def show_first_two(ds, label):
    items = list(islice(ds, 2))
    for i, inst in enumerate(items):
        print(f"[{label}] #{i} schema={inst.schema_name}")
        print(f"utterance: {inst.utterance}")
        print(f"gold: {inst.gold}\n")

show_first_two(dataset, "full")
show_first_two(sampleset, "sample")

[full] #0 schema=concert_singer
utterance: How many singers do we have?
gold: SELECT count(*) FROM singer

[full] #1 schema=concert_singer
utterance: What is the total number of singers?
gold: SELECT count(*) FROM singer

[sample] #0 schema=concert_singer
utterance: How many singers do we have?
gold: SELECT count(*) FROM singer

[sample] #1 schema=concert_singer
utterance: What is the total number of singers?
gold: SELECT count(*) FROM singer



In [24]:
first_full = next(iter(dataset))
first_sample = next(iter(sampleset))

print(first_full.schema_name, first_full.utterance, first_full.gold)
print(first_sample.schema_name, first_sample.utterance, first_sample.gold)

full_prompt_ids = default_prompt_formatter(
        BYTE_LLM.llm.tokenizer,
        first_full,
        use_chat_format=False,
    )
sample_prompt_ids = default_prompt_formatter(
        BYTE_LLM.llm.tokenizer,
        first_sample,
        use_chat_format=False,
    )
print("full prompt: ", BYTE_LLM.llm.tokenizer.decode(full_prompt_ids))
print("*"*80)
print("sample prompt: ", BYTE_LLM.llm.tokenizer.decode(sample_prompt_ids))


concert_singer How many singers do we have? SELECT count(*) FROM singer
concert_singer How many singers do we have? SELECT count(*) FROM singer
full prompt:  <|begin_of_text|>You are a coding assistant helping an analyst answer questions over business data in SQL. More specifically, the analyst provides you a database schema (tables in the database along with their column names and types) and asks a question about the data that can be solved by issuing a SQL query to the database. In response, you write the SQL statement that answers the question. You do not provide any commentary or explanation of what the code does, just the SQL statement ending in a semicolon.



Here is a database schema:
stadium
* Stadium_ID (number): stadium id
* Location (text): location
* Name (text): name
* Capacity (number): capacity
* Highest (number): highest
* Lowest (number): lowest
* Average (number): average

singer
* Singer_ID (number): singer id
* Name (text): name
* Country (text): country
* Song_Nam

In [25]:
BYTE_LLM.set_prompt_from_str(BYTE_LLM.llm.tokenizer.decode(full_prompt_ids))
condition = SpiderTableColumnVerifier(grammar=first_full.lark_grammar, tables=first_full.tables).coerce(BYTE_LLM, f=b"".join)
sampler = AWRS(BYTE_LLM, condition)
sequences = await sampler.smc(
        n_particles=5,
        ess_threshold=0.9,
        max_tokens=70,
        verbosity=1,    
        )

print("Decoded posterior:", sequences.decoded_posterior)



Task was destroyed but it is pending!
task: <Task cancelling name='Task-6762' coro=<AsyncTokenByteTrie._background_loop() running at /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/genlm/bytes/trie.py:570> wait_for=<Future cancelled>>
Task was destroyed but it is pending!
task: <Task cancelling name='Task-6763' coro=<AsyncTokenByteTrie._background_loop() running at /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/genlm/bytes/trie.py:570> wait_for=<Future cancelled>>
Task was destroyed but it is pending!
task: <Task cancelling name='Task-6764' coro=<AsyncTokenByteTrie._background_loop() running at /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/genlm/bytes/trie.py:570> wait_for=<Future cancelled>>
Task was destroyed but it is pending!
task: <Task cancelling name='Task-6765' coro=<AsyncTokenByteTrie._background_loop() running at /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/genlm/bytes/trie.py:570> wait_for=<Fut

0.00:	[0;35m[[0m\n[0;35m][0m
0.00:	[0;35m[[0m\n[0;35m][0m
0.00:	[0;35m[[0m\n[0;35m][0m
0.00:	[0;35m[[0m␣[0;35m][0m
0.00:	[0;35m[[0m␣[0;35m][0m
0.00:	[0;35m[[0m\n[0;35m|[0m\n[0;35m][0m
0.00:	[0;35m[[0m\n[0;35m|[0m`[0;35m][0m
0.00:	[0;35m[[0m\n[0;35m|[0m\n[0;35m][0m
0.00:	[0;35m[[0m␣[0;35m|[0m␣[0;35m][0m
0.00:	[0;35m[[0m␣[0;35m|[0m\n[0;35m][0m
0.00:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m][0m
0.00:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m][0m
0.00:	[0;35m[[0m\n[0;35m|[0m`[0;35m|[0m`[0;35m][0m
0.00:	[0;35m[[0m␣[0;35m|[0m␣[0;35m|[0m;[0;35m][0m
0.00:	[0;35m[[0m␣[0;35m|[0m\n[0;35m|[0m1[0;35m][0m
0.00:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m|[0m␣[0;35m][0m
0.00:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m|[0m#[0;35m][0m
0.00:	[0;35m[[0m\n[0;35m|[0m`[0;35m|[0m`[0;35m|[0m`[0;35m][0m
0.00:	[0;35m[[0m␣[0;35m|[0m␣[0;35m|[0m;[0;35m|[0m␣[0;35m][0m
0.00:	[0;35m[[0m␣[

10


0.00:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m|[0m␣[0;35m|[0mC[0;35m|[0mr[0;35m|[0me[0;35m|[0ma[0;35m][0m
-3.69:	[0;35m[[0m\n[0;35m|[0m`[0;35m|[0m`[0;35m|[0m`[0;35m|[0ms[0;35m|[0mq[0;35m|[0ml[0;35m|[0m␣[0;35m][0m
0.00:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m|[0m#[0;35m|[0m#[0;35m|[0m␣[0;35m|[0mR[0;35m|[0ma[0;35m][0m
0.00:	[0;35m[[0m␣[0;35m|[0m␣[0;35m|[0m;[0;35m|[0m␣[0;35m|[0mS[0;35m|[0mE[0;35m|[0mL[0;35m|[0mE[0;35m][0m
-0.19:	[0;35m[[0m␣[0;35m|[0m\n[0;35m|[0m1[0;35m|[0m0[0;35m|[0m␣[0;35m|[0m7[0;35m|[0m␣[0;35m|[0mA[0;35m][0m
-0.26:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m|[0m#[0;35m|[0m#[0;35m|[0m␣[0;35m|[0mR[0;35m|[0ma[0;35m|[0mt[0;35m][0m
-0.26:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m|[0m␣[0;35m|[0mC[0;35m|[0mr[0;35m|[0me[0;35m|[0ma[0;35m|[0mt[0;35m][0m
-0.26:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m|[0m␣[0;35m|[0mC[0;35m|[0mr[0;35m|[0m

```sql


-0.26:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m|[0m␣[0;35m|[0mC[0;35m|[0mr[0;35m|[0me[0;35m|[0ma[0;35m|[0mt[0;35m|[0me[0;35m|[0m␣[0;35m|[0ma[0;35m][0m
-0.26:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m|[0m␣[0;35m|[0mC[0;35m|[0mr[0;35m|[0me[0;35m|[0ma[0;35m|[0mt[0;35m|[0me[0;35m|[0m␣[0;35m|[0mt[0;35m][0m
-0.26:	[0;35m[[0m\n[0;35m|[0m`[0;35m|[0m`[0;35m|[0m`[0;35m|[0ms[0;35m|[0mq[0;35m|[0ml[0;35m|[0m␣[0;35m|[0m\n[0;35m|[0mS[0;35m|[0mE[0;35m|[0mL[0;35m][0m
-0.26:	[0;35m[[0m\n[0;35m|[0m`[0;35m|[0m`[0;35m|[0m`[0;35m|[0ms[0;35m|[0mq[0;35m|[0ml[0;35m|[0m␣[0;35m|[0m\n[0;35m|[0mS[0;35m|[0mE[0;35m|[0mL[0;35m][0m
-0.26:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m|[0m#[0;35m|[0m#[0;35m|[0m␣[0;35m|[0mR[0;35m|[0ma[0;35m|[0mt[0;35m|[0mi[0;35m|[0mn[0;35m|[0mg[0;35m|[0m␣[0;35m][0m
-0.26:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m|[0m␣[0;35m|[0mC[0;35m|[0mr[0;35m|[


# Create table

### Rating scale


-0.26:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m|[0m␣[0;35m|[0mC[0;35m|[0mr[0;35m|[0me[0;35m|[0ma[0;35m|[0mt[0;35m|[0me[0;35m|[0m␣[0;35m|[0ma[0;35m|[0mn[0;35m|[0m␣[0;35m|[0ma[0;35m|[0mg[0;35m|[0mg[0;35m|[0mr[0;35m|[0me[0;35m][0m
-0.26:	[0;35m[[0m\n[0;35m|[0m`[0;35m|[0m`[0;35m|[0m`[0;35m|[0ms[0;35m|[0mq[0;35m|[0ml[0;35m|[0m␣[0;35m|[0m\n[0;35m|[0mS[0;35m|[0mE[0;35m|[0mL[0;35m|[0mE[0;35m|[0mC[0;35m|[0mT[0;35m|[0m␣[0;35m|[0mC[0;35m|[0mO[0;35m|[0mU[0;35m][0m
-0.26:	[0;35m[[0m\n[0;35m|[0m`[0;35m|[0m`[0;35m|[0m`[0;35m|[0ms[0;35m|[0mq[0;35m|[0ml[0;35m|[0m␣[0;35m|[0m\n[0;35m|[0mS[0;35m|[0mE[0;35m|[0mL[0;35m|[0mE[0;35m|[0mC[0;35m|[0mT[0;35m|[0m␣[0;35m|[0mC[0;35m|[0mO[0;35m|[0mU[0;35m][0m
-0.95:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m|[0m␣[0;35m|[0mC[0;35m|[0mr[0;35m|[0me[0;35m|[0ma[0;35m|[0mt[0;35m|[0me[0;35m|[0m␣[0;35m|[0mt[0;35m|[0ma[0;35m|[0mb[


# Create table aliases


-0.49:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m|[0m␣[0;35m|[0mC[0;35m|[0mr[0;35m|[0me[0;35m|[0ma[0;35m|[0mt[0;35m|[0me[0;35m|[0m␣[0;35m|[0mt[0;35m|[0ma[0;35m|[0mb[0;35m|[0ml[0;35m|[0me[0;35m|[0m␣[0;35m|[0ma[0;35m|[0ml[0;35m|[0mi[0;35m|[0ma[0;35m|[0ms[0;35m|[0me[0;35m|[0ms[0;35m|[0m␣[0;35m|[0mf[0;35m|[0mo[0;35m][0m
-0.49:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m|[0m␣[0;35m|[0mC[0;35m|[0mr[0;35m|[0me[0;35m|[0ma[0;35m|[0mt[0;35m|[0me[0;35m|[0m␣[0;35m|[0ma[0;35m|[0mn[0;35m|[0m␣[0;35m|[0ma[0;35m|[0mg[0;35m|[0mg[0;35m|[0mr[0;35m|[0me[0;35m|[0mg[0;35m|[0ma[0;35m|[0mt[0;35m|[0me[0;35m|[0md[0;35m|[0m␣[0;35m|[0ms[0;35m|[0mu[0;35m][0m
-0.49:	[0;35m[[0m\n[0;35m|[0m`[0;35m|[0m`[0;35m|[0m`[0;35m|[0ms[0;35m|[0mq[0;35m|[0ml[0;35m|[0m␣[0;35m|[0m\n[0;35m|[0mS[0;35m|[0mE[0;35m|[0mL[0;35m|[0mE[0;35m|[0mC[0;35m|[0mT[0;35m|[0m␣[0;35m|[0mC[0;35m|[0mO[0;35m|[0

```sql 
SELECT COUNT(*) FROM singer


-0.49:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m|[0m␣[0;35m|[0mC[0;35m|[0mr[0;35m|[0me[0;35m|[0ma[0;35m|[0mt[0;35m|[0me[0;35m|[0m␣[0;35m|[0ma[0;35m|[0mn[0;35m|[0m␣[0;35m|[0ma[0;35m|[0mg[0;35m|[0mg[0;35m|[0mr[0;35m|[0me[0;35m|[0mg[0;35m|[0ma[0;35m|[0mt[0;35m|[0me[0;35m|[0md[0;35m|[0m␣[0;35m|[0ms[0;35m|[0mu[0;35m|[0mm[0;35m|[0m␣[0;35m|[0mo[0;35m|[0mf[0;35m|[0m␣[0;35m|[0ms[0;35m|[0mi[0;35m|[0mn[0;35m|[0mg[0;35m|[0me[0;35m|[0mr[0;35m|[0ms[0;35m][0m
-0.49:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m|[0m␣[0;35m|[0mC[0;35m|[0mr[0;35m|[0me[0;35m|[0ma[0;35m|[0mt[0;35m|[0me[0;35m|[0m␣[0;35m|[0mt[0;35m|[0ma[0;35m|[0mb[0;35m|[0ml[0;35m|[0me[0;35m|[0m␣[0;35m|[0ma[0;35m|[0ml[0;35m|[0mi[0;35m|[0ma[0;35m|[0ms[0;35m|[0me[0;35m|[0ms[0;35m|[0m␣[0;35m|[0mf[0;35m|[0mo[0;35m|[0mr[0;35m|[0m␣[0;35m|[0ms[0;35m|[0mt[0;35m|[0ma[0;35m|[0md[0;35m|[0mi[0;35m|[0mu[0;35m

```sql 
SELECT COUNT(*) FROM singer_in_concert

# Create an aggregated sum of singers by location


-1.31:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m|[0m␣[0;35m|[0mC[0;35m|[0mr[0;35m|[0me[0;35m|[0ma[0;35m|[0mt[0;35m|[0me[0;35m|[0m␣[0;35m|[0ma[0;35m|[0mn[0;35m|[0m␣[0;35m|[0ma[0;35m|[0mg[0;35m|[0mg[0;35m|[0mr[0;35m|[0me[0;35m|[0mg[0;35m|[0ma[0;35m|[0mt[0;35m|[0me[0;35m|[0md[0;35m|[0m␣[0;35m|[0ms[0;35m|[0mu[0;35m|[0mm[0;35m|[0m␣[0;35m|[0mo[0;35m|[0mf[0;35m|[0m␣[0;35m|[0ms[0;35m|[0mi[0;35m|[0mn[0;35m|[0mg[0;35m|[0me[0;35m|[0mr[0;35m|[0ms[0;35m|[0m␣[0;35m|[0mb[0;35m|[0my[0;35m|[0m␣[0;35m|[0ml[0;35m|[0mo[0;35m|[0mc[0;35m|[0ma[0;35m|[0mt[0;35m|[0mi[0;35m|[0mo[0;35m|[0mn[0;35m|[0m␣[0;35m][0m
-0.49:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m|[0m␣[0;35m|[0mC[0;35m|[0mr[0;35m|[0me[0;35m|[0ma[0;35m|[0mt[0;35m|[0me[0;35m|[0m␣[0;35m|[0ma[0;35m|[0mn[0;35m|[0m␣[0;35m|[0ma[0;35m|[0mg[0;35m|[0mg[0;35m|[0mr[0;35m|[0me[0;35m|[0mg[0;35m|[0ma[0;35m|[0mt[0;35m


# Create table aliases for stadium, singer, concert_name


-1.28:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m|[0m␣[0;35m|[0mC[0;35m|[0mr[0;35m|[0me[0;35m|[0ma[0;35m|[0mt[0;35m|[0me[0;35m|[0m␣[0;35m|[0mt[0;35m|[0ma[0;35m|[0mb[0;35m|[0ml[0;35m|[0me[0;35m|[0m␣[0;35m|[0ma[0;35m|[0ml[0;35m|[0mi[0;35m|[0ma[0;35m|[0ms[0;35m|[0me[0;35m|[0ms[0;35m|[0m␣[0;35m|[0mf[0;35m|[0mo[0;35m|[0mr[0;35m|[0m␣[0;35m|[0ms[0;35m|[0mt[0;35m|[0ma[0;35m|[0md[0;35m|[0mi[0;35m|[0mu[0;35m|[0mm[0;35m|[0m,[0;35m|[0m␣[0;35m|[0ms[0;35m|[0mi[0;35m|[0mn[0;35m|[0mg[0;35m|[0me[0;35m|[0mr[0;35m|[0m,[0;35m|[0m␣[0;35m|[0mc[0;35m|[0mo[0;35m|[0mn[0;35m|[0mc[0;35m|[0me[0;35m|[0mr[0;35m|[0mt[0;35m|[0m_[0;35m|[0mn[0;35m|[0ma[0;35m|[0mm[0;35m|[0me[0;35m|[0m,[0;35m][0m
-0.79:	[0;35m[[0m\n[0;35m|[0m\n[0;35m|[0m#[0;35m|[0m␣[0;35m|[0mC[0;35m|[0mr[0;35m|[0me[0;35m|[0ma[0;35m|[0mt[0;35m|[0me[0;35m|[0m␣[0;35m|[0ma[0;35m|[0mn[0;35m|[0m␣[0;35m|[0ma[0;35m

In [None]:
async def spider_model_adaptor(instance, output_dir: str, replicate: int) -> ModelOutput:
    global BYTE_LLM

    prompt_ids = default_prompt_formatter(
        BYTE_LLM.llm.tokenizer,
        instance,
        use_chat_format=False,
    )
    prompt_text = BYTE_LLM.llm.tokenizer.decode(prompt_ids)
    BYTE_LLM.set_prompt_from_str(prompt_text)

    # Print few-shot examples being used
    print("\n--- Few-shot examples ---")
    if instance.few_shot_examples:
        for i, (inp, out) in enumerate(instance.few_shot_examples):
            preview = inp.replace("\n", " ")
            if len(preview) > 200:
                preview = preview[:200] + "..."
            print(f"[{i}] Input: {preview}")
            print(f"    Output: {out}")
    else:
        print("(none)")

    # Print prompt and gold SQL for inspection
    print("\n--- Prompt ---\n" + prompt_text)
    print("\n--- Gold SQL ---\n" + instance.gold)
    condition = SpiderTableColumnVerifier(grammar=instance.lark_grammar, tables=instance.tables).coerce(BYTE_LLM, f=b"".join)
    sampler = AWRS(BYTE_LLM, condition)

    sequences = await sampler.smc(
        n_particles=5,
        ess_threshold=0.9,
        max_tokens=70,
        verbosity=1,
    )

    print("Decoded posterior:", sequences.decoded_posterior)

    responses: List[ModelResponse] = [
        ModelResponse(response=seq, weight=float(prob))
        for seq, prob in sequences.decoded_posterior.items()
    ]

    await BYTE_LLM.cleanup()

    return ModelOutput(responses=responses, runtime_seconds=None)