Make sure to run this in your `genparse` conda environment.

In [1]:
import sys
import os
import getpass

if getpass.getuser() == "benjamin.lebrun": # change to your user if you want to set these
    # @TIMO you may need to set this to your local genparse repo 
    sys.path.append("/home/mila/b/benjamin.lebrun/genparse")
    # @TIMO also set your cache IF you run into disk quota issues
    os.environ["HF_HOME"] = os.path.join(os.environ["SCRATCH"], "hf_cache")
    print("HF cache set; path updated")

HF cache set; path updated


In [2]:
from hfppl import Model, CachedCausalLM, LMContext #, smc_standard, smc_steer
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
MODEL_ID = "codellama/CodeLlama-7b-Instruct-hf"
hfppl_llm = CachedCausalLM.from_pretrained(MODEL_ID, load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID, 
    use_fast=True,
    prefix_token=None, 
    middle_token=None, 
    suffix_token=None, 
    eot_token=None, 
    fill_token=None
)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████| 2/2 [00:22<00:00, 11.28s/it]


In [4]:
import genparse
from genparse.cfglm import EarleyBoolMaskCFGLM 
from genparse.util import LarkStuff
from genparse import EOS, Float
from genparse.proposal import CharacterProposal
from genparse.inference import smc_standard, smc_steer

In [5]:
prompt = """
You have access to a political survey data table named "data", which includes the following columns:
- "age" (integer)
- "gender" ("male" or "female"),
- "year" (integer)
- "state_color" ("blue" or "red")
- "zipcode" (integer)
- "vote" ("democrat" or "republican") 
- "race_ethnicity" ("white", "black", or "latino").

Q: Write a SQL query that shows individuals' age and gender, for people over 50 years old.
A: SELECT age, gender FROM data WHERE age>50 </s>
Q: Write a SQL query that shows individuals' vote and zipcode, ordered from lowest to highest age.
A: SELECT vote, zipcode, age FROM data ORDER BY age ASC </s>

Q: Write a SQL query that returns white voters' average age for each state color.
A:"""

guide = EarleyBoolMaskCFGLM(
    LarkStuff(
        r"""
            start: WS? "SELECT" WS select_expr WS "FROM" WS from_expr [WS "WHERE" WS bool_condition] [WS "GROUP BY" WS var_list] [WS "ORDER BY" WS orderby_expr] WS EOS
            EOS: "</s>"
            select_expr: STAR | select_list
            bool_condition: bool_expr | "(" bool_condition WS "AND" WS bool_condition ")" | "(" bool_condition WS "OR" WS bool_condition ")"
            bool_expr: var "=" value | var ">" value | var "<" value
            from_expr: "data"
            orderby_expr: var_list WS "ASC" | var_list WS "DESC"
            select_list: select_var ("," WS select_var)*
            var_list: var ("," WS var)*
            select_var: var | "AVG(" var ")" | "MEDIAN(" var ")" | "COUNT(" var ")"
            var: "age" | "gender" | "year" | "state_color" | "zipcode" | "vote" | "race_ethnicity"
            value: NUMBER | "'red'" | "'blue'" | "'white'" | "'black'" | "'latino'" | "'republican'" | "'democrat'" | "'male'" | "'female'"
            STAR: "*"
            NUMBER: /\d+/
            WS: /[ ]/

        """
    ).char_cfg(.99, ignore='[ ]?')
)

In [6]:
import numpy as np
import asyncio
import nest_asyncio
nest_asyncio.apply()

In [10]:
from genparse import Float

class SteeringModel(Model):
    def __init__(self, llm, guide, proposal, prompt, max_tokens, compare_time=False):
        super().__init__()
        self.llm = llm # GreedilyTokenizedLM
        self.guide = guide # PCFGLM
        self.prompt = prompt
        self.context = []
        self.proposal = proposal # CharacterProposal
        self.max_tokens = max_tokens
        self.compare_time = compare_time

    async def step(self):
        (token, llm_prob, guide_prob, proposal_prob) = await self.proposal.sample_next_token(
            context=''.join(self.context), 
            prompt=self.prompt, 
            compare_time=self.compare_time
        )
        self.context.append(token)
        self.weight += np.log(llm_prob) + np.log(guide_prob) - np.log(proposal_prob)
        self.max_tokens -= 1

        print(f"`{token}` : {''.join(self.context)} : {self.weight}")

        if token == self.llm.eos or self.max_tokens == 0 or token == genparse.EOS:
            self.finish()
            return
        
    def immutable_properties(self):
        return ['llm', 'prompt', 'guide', 'compare_token']
    
    def __repr__(self):
        return f"`{'' if not self.context else self.context[-1]}` : {''.join(self.context)} : {self.weight}"

class GreedilyTokenizedLLM:
    def __init__(self, llm, tokenizer):
        self.tokenizer = tokenizer
        self._model = llm # hfppl Model
        self._decode = [self.tokenizer.decode([i]) for i in range(self.tokenizer.vocab_size)]
        self.V = set(self._decode)
        self.eos = self.tokenizer.eos_token

    def __call__(self, xs):
        return self.model(self.tokenizer.encode(xs))

    async def p_next(self, xs, top=None):
        return await self._p_next(xs, top=top)

    async def _p_next(self, xs, top=None):
        assert isinstance(xs, str)
        tokens = self.tokenizer.encode(xs)

        _logp = await self._model.next_token_logprobs(tokens)
        _p = np.exp(_logp)

        if top is None:
            top_p = _p.argsort()
        else:
            top_p = _p.argsort()[-top:]
        pp = Float.chart()
        for i in reversed(top_p):
            pp[self._decode[i]] = _p[i]
        if top is None:
            return pp
        else:
            return pp.normalize()

In [11]:
MAX_TOKENS = 100
BATCH_SIZE = 80

hfppl_llm.batch_size = BATCH_SIZE
genparse_llm = GreedilyTokenizedLLM(hfppl_llm, tokenizer)
proposal = CharacterProposal(llm=genparse_llm, guide=guide)
steering_model = SteeringModel(
    genparse_llm, guide, proposal, prompt, MAX_TOKENS, compare_time=False
)

In [12]:
particles = asyncio.run(
    smc_standard(steering_model, n_particles=10)
)

` ` :   : -1.5632851851974077
` ` :   : -1.5632851851974077
` ` :   : -1.5632851851974077
` ` :   : -1.5632851851974077
`  ` :    : 0.07733979931194535
` ` :   : -1.5632851851974077
` ` :   : -1.5632851851974077
`SELECT` : SELECT : -0.717184941230105
` ` :   : -1.5632851851974077
`SELECT` : SELECT : -0.717184941230105
`SELECT` :  SELECT : -3.3094512865053862
`SELECT` :  SELECT : -3.3094512865053862
`SELECT` :  SELECT : -3.3094512865053862
`SELECT` :  SELECT : -3.3094512865053862
`SELECT` :   SELECT : -2.805157440555118
`SELECT` :  SELECT : -3.3094512865053862
`SELECT` :  SELECT : -3.3094512865053862
` ` : SELECT  : -4.646338486946215
`SELECT` :  SELECT : -3.3094512865053862
` ` : SELECT  : -4.646338486946215
` ` :  SELECT  : -8.085529644287872
` ` :  SELECT  : -8.085529644287872
`  ` :  SELECT   : -5.8042791767004775
` ` :  SELECT  : -8.085529644287872
` ` :   SELECT  : -6.955299892692902
` ` :  SELECT  : -8.085529644287872
` ` :  SELECT  : -8.085529644287872
`race` : SELECT race : -5.