In [2]:
import os
import sys
from pathlib import Path

from genlm_control import InferenceEngine, PromptedLLM, BoolCFG, eager_token_sampler

genlm_control_path = Path.cwd().parent.parent
if str(genlm_control_path) not in sys.path:
    sys.path.insert(0, str(genlm_control_path))

from benchmark.text_to_sql.run_inference import spider_setup  # noqa: E402

In [3]:
raw_spider_dir = "data/spider_data"
model_name = "meta-llama/Llama-3.2-1B-Instruct"
lm_backend = "hf"
grammar_dir = "data/grammars"

In [4]:
dev_data, _, prompt_formatter = spider_setup(raw_spider_dir)

In [None]:
llm = PromptedLLM.from_name(model_name, backend=lm_backend)



In [6]:
datum = dev_data[0]

llm.prompt_ids = llm.model.tokenizer.apply_chat_template(
    prompt_formatter.format_openai(datum),
    add_generation_prompt=True,
    tokenize=True,
)

In [7]:
grammar = open(os.path.join(grammar_dir, f"{datum.schema_name}.lark"), "r").read()
bool_cfg = BoolCFG.from_lark(grammar)

In [8]:
sampler = eager_token_sampler(llm, bool_cfg)

  ).to_sparse_csr()


In [10]:
from genlm_control.experimental.vegas import GumbelMaxAdaptiveRejectionSampler

sampler = GumbelMaxAdaptiveRejectionSampler(llm, bool_cfg.coerce(llm, f=b"".join))

In [None]:
datum = dev_data[0]

llm.prompt_ids = llm.model.tokenizer.apply_chat_template(
    prompt_formatter.format_openai(datum),
    add_generation_prompt=True,
    tokenize=True,
)

In [12]:
sequences = await InferenceEngine(sampler)(
    n_particles=3,
    max_tokens=100,
    ess_threshold=0.9,
    verbosity=1,
    # json_path=os.path.join('results/test', f'0.json')
)

-0.00:	[0;35m[[0mb'SELECT'[0;35m][0m
-0.00:	[0;35m[[0mb'SELECT'[0;35m][0m
-0.00:	[0;35m[[0mb'SELECT'[0;35m][0m
-0.00:	[0;35m[[0mb'SELECT'[0;35m|[0mb' COUNT'[0;35m][0m
-0.00:	[0;35m[[0mb'SELECT'[0;35m|[0mb' COUNT'[0;35m][0m
-0.00:	[0;35m[[0mb'SELECT'[0;35m|[0mb' COUNT'[0;35m][0m
-0.00:	[0;35m[[0mb'SELECT'[0;35m|[0mb' COUNT'[0;35m|[0mb'(S'[0;35m][0m
-0.00:	[0;35m[[0mb'SELECT'[0;35m|[0mb' COUNT'[0;35m|[0mb'(S'[0;35m][0m
-0.00:	[0;35m[[0mb'SELECT'[0;35m|[0mb' COUNT'[0;35m|[0mb'(S'[0;35m][0m
-0.00:	[0;35m[[0mb'SELECT'[0;35m|[0mb' COUNT'[0;35m|[0mb'(S'[0;35m|[0mb'inger'[0;35m][0m
-0.00:	[0;35m[[0mb'SELECT'[0;35m|[0mb' COUNT'[0;35m|[0mb'(S'[0;35m|[0mb'inger'[0;35m][0m
-0.00:	[0;35m[[0mb'SELECT'[0;35m|[0mb' COUNT'[0;35m|[0mb'(S'[0;35m|[0mb'inger'[0;35m][0m
-0.00:	[0;35m[[0mb'SELECT'[0;35m|[0mb' COUNT'[0;35m|[0mb'(S'[0;35m|[0mb'inger'[0;35m|[0mb'_ID'[0;35m][0m
-0.00:	[0;35m[[0mb'SELECT'[0;35m|[0m