Setup:

```###### create env ######
conda create -n genparse python=3.10 # python3.10(.14) seems to be the sweet spot
conda activate genparse

###### install genparse ######
cd genparse
pip install --user -e . # build via setup.py
conda install python-graphviz # install graphviz executable, which is different from the graphviz package
sh run-tests # all tests should pass (with a few warnings)
conda install nb_conda_kernels # for Jupyter notebook support

###### install hfppl ######
git clone https://github.com/probcomp/hfppl
cd hfppl
pip install . # i had trouble with poetry
```


In [1]:
import sys
import os
import getpass

if getpass.getuser() == 'benjamin.lebrun':  # change to your user if you want to set these
    # @TIMO you may need to set this to your local genparse repo
    sys.path.append('/home/mila/b/benjamin.lebrun/genparse')
    # @TIMO also set your cache IF you run into disk quota issues
    os.environ['HF_HOME'] = os.path.join(os.environ['SCRATCH'], 'hf_cache')
    print('HF cache set; path updated')

HF cache set; path updated


In [2]:
from hfppl import Model, CachedCausalLM  # , smc_standard, smc_steer
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
MODEL_ID = 'codellama/CodeLlama-7b-Instruct-hf'
hfppl_llm = CachedCausalLM.from_pretrained(MODEL_ID, load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    use_fast=True,
    prefix_token=None,
    middle_token=None,
    suffix_token=None,
    eot_token=None,
    fill_token=None,
)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████| 2/2 [00:24<00:00, 12.39s/it]


In [3]:
import genparse
from genparse.cfglm import BoolMaskCFGLM
from genparse.util import LarkStuff
from genparse.proposal import CharacterProposal
from genparse.inference import smc_standard
from genparse.lm import AsyncGreedilyTokenizedLLM

In [5]:
prompt = """
You have access to a political survey data table named "data", which includes the following columns:
- "age" (integer)
- "gender" ("male" or "female"),
- "year" (integer)
- "state_color" ("blue" or "red")
- "zipcode" (integer)
- "vote" ("democrat" or "republican") 
- "race_ethnicity" ("white", "black", or "latino").

Q: Write a SQL query that shows individuals' age and gender, for people over 50 years old.
A: SELECT age, gender FROM data WHERE age>50 </s>
Q: Write a SQL query that shows individuals' vote and zipcode, ordered from lowest to highest age.
A: SELECT vote, zipcode, age FROM data ORDER BY age ASC </s>

Q: Write a SQL query that returns white voters' average age for each state color.
A:"""

guide = BoolMaskCFGLM(
    LarkStuff(
        r"""
            start: WS? "SELECT" WS select_expr WS "FROM" WS from_expr [WS "WHERE" WS bool_condition] [WS "GROUP BY" WS var_list] [WS "ORDER BY" WS orderby_expr] WS EOS
            EOS: "</s>"
            select_expr: STAR | select_list
            bool_condition: bool_expr | "(" bool_condition WS "AND" WS bool_condition ")" | "(" bool_condition WS "OR" WS bool_condition ")"
            bool_expr: var "=" value | var ">" value | var "<" value
            from_expr: "data"
            orderby_expr: var_list WS "ASC" | var_list WS "DESC"
            select_list: select_var ("," WS select_var)*
            var_list: var ("," WS var)*
            select_var: var | "AVG(" var ")" | "MEDIAN(" var ")" | "COUNT(" var ")"
            var: "age" | "gender" | "year" | "state_color" | "zipcode" | "vote" | "race_ethnicity"
            value: NUMBER | "'red'" | "'blue'" | "'white'" | "'black'" | "'latino'" | "'republican'" | "'democrat'" | "'male'" | "'female'"
            STAR: "*"
            NUMBER: /\d+/
            WS: /[ ]/

        """
    ).char_cfg(0.99, ignore='[ ]?')
)

In [9]:
import numpy as np
import asyncio
import nest_asyncio

nest_asyncio.apply()

In [10]:
class SteeringModel(Model):
    def __init__(self, llm, guide, proposal, prompt, max_tokens, compare_time=False):
        super().__init__()
        self.llm = llm  # AsyncGreedilyTokenizedLM
        self.guide = guide  # PCFGLM
        self.prompt = prompt
        self.context = []
        self.proposal = proposal  # CharacterProposal
        self.max_tokens = max_tokens
        self.compare_time = compare_time

    async def step(self):
        (
            token,
            llm_prob,
            guide_prob,
            proposal_prob,
        ) = await self.proposal.sample_next_token(
            context=''.join(self.context),
            prompt=self.prompt,
            compare_time=self.compare_time,
        )
        self.context.append(token)
        self.weight += np.log(llm_prob) + np.log(guide_prob) - np.log(proposal_prob)
        self.max_tokens -= 1

        print(f"`{token}` : {''.join(self.context)} : {self.weight}")

        if token == self.llm.eos or self.max_tokens == 0 or token == genparse.EOS:
            self.finish()
            return

    def immutable_properties(self):
        return ['llm', 'prompt', 'guide', 'compare_token']

    def __repr__(self):
        return f"`{'' if not self.context else self.context[-1]}` : {''.join(self.context)} : {self.weight}"

In [13]:
MAX_TOKENS = 100
BATCH_SIZE = 80

hfppl_llm.batch_size = BATCH_SIZE
genparse_llm = AsyncGreedilyTokenizedLLM(hfppl_llm, tokenizer)
proposal = CharacterProposal(llm=genparse_llm, guide=guide)
steering_model = SteeringModel(
    genparse_llm, guide, proposal, prompt, MAX_TOKENS, compare_time=False
)

In [14]:
particles = asyncio.run(smc_standard(steering_model, n_particles=10))

` SELECT` :  SELECT : -0.19492441859449025
` SELECT` :  SELECT : -0.19492441859449025
` SELECT` :  SELECT : -0.19492441859449025
` SELECT` :  SELECT : -0.19492441859449025
` SELECT` :  SELECT : -0.19492441859449025
` SELECT` :  SELECT : -0.19492441859449025
` SELECT` :  SELECT : -0.19492441859449025
` SELECT` :  SELECT : -0.19492441859449025
` SELECT` :  SELECT : -0.19492441859449025
` SELECT` :  SELECT : -0.19492441859449025
` state` :  SELECT state : -0.2687663107216889
` state` :  SELECT state : -0.2687663107216889
` state` :  SELECT state : -0.2687663107216889
` state` :  SELECT state : -0.2687663107216889
` state` :  SELECT state : -0.2687663107216889
` state` :  SELECT state : -0.2687663107216889
` state` :  SELECT state : -0.2687663107216889
` state` :  SELECT state : -0.2687663107216889
` state` :  SELECT state : -0.2687663107216889
` state` :  SELECT state : -0.2687663107216889
`_` :  SELECT state_ : -24.32555761626368
`_` :  SELECT state_ : -24.32555761626368
`_` :  SELECT st

  self.weight += np.log(llm_prob) + np.log(guide_prob) - np.log(proposal_prob)


` ` :  SELECT state_color , state_color FROM data WHERE vote = 'republican' GROUP BY vote , zipcode , age ORDER BY vote , zipcode , age , age ASC   </s>  : -199.49593366515407
` ` :  SELECT state_color , state_color FROM data WHERE vote = 'republican' GROUP BY vote , zipcode , age ORDER BY vote , zipcode , age , age ASC   </s>  : -199.49593366515407
`▪` :  SELECT state_color , state_color FROM data WHERE vote = 'republican' GROUP BY vote , zipcode , age ORDER BY vote , zipcode , age , age ASC   </s>▪ : -inf
` ` :  SELECT state_color , state_color FROM data WHERE vote = 'republican' GROUP BY vote , zipcode , age ORDER BY vote , zipcode , age , age ASC   </s>  : -199.49593366515407
` ` :  SELECT state_color , state_color FROM data WHERE vote = 'republican' GROUP BY vote , zipcode , age ORDER BY vote , zipcode , age , age ASC   </s>  : -199.49593366515407
`▪` :  SELECT state_color , state_color FROM data WHERE vote = 'republican' GROUP BY vote , zipcode , age ORDER BY vote , zipcode , age

  arr -= vmax
