In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
import sys
import os
import getpass

if getpass.getuser() == "benjamin.lebrun":
    sys.path.append("/home/mila/b/benjamin.lebrun/genparse")
    os.environ["HF_HOME"] = os.path.join(os.environ["SCRATCH"], "hf_cache")
    print("HF cache set; path updated")

import numpy as np
import asyncio
import nest_asyncio
nest_asyncio.apply()

from random import seed
from torch import manual_seed
from transformers import set_seed

RANDOM_SEED = 80808
set_seed(RANDOM_SEED)
seed(RANDOM_SEED)
manual_seed(RANDOM_SEED)

In [None]:
from hfppl import Model, CachedCausalLM, LMContext
from transformers import AutoTokenizer

In [None]:
MODEL_ID = "codellama/CodeLlama-7b-Instruct-hf"
hfppl_llm = CachedCausalLM.from_pretrained(MODEL_ID, load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID, 
    use_fast=True,
    prefix_token=None, 
    middle_token=None, 
    suffix_token=None, 
    eot_token=None, 
    fill_token=None
)

In [None]:
import genparse
from genparse.cfglm import EarleyBoolMaskCFGLM, BoolMaskCFGLM
from genparse.util import LarkStuff
from genparse import EOS, Float
from arsenal.maths import sample_dict, logsumexp
from genparse.proposal import CharacterProposal
from genparse.lm import AsyncGreedilyTokenizedLLM
from genparse.inference import smc_standard, smc_steer

In [None]:
prompt = f"""
You have access to a political survey data table named "data", which includes the following columns:
- "age" (integer)
- "gender" ("male" or "female"),
- "year" (integer)
- "state_color" ("blue" or "red")
- "zipcode" (integer)
- "vote" ("democrat" or "republican") 
- "race_ethnicity" ("white", "black", or "latino").

Q: Write a SQL query that shows individuals' age and gender, for people over 50 years old.
A: SELECT age, gender FROM data WHERE age>50 {EOS}
Q: Write a SQL query that shows individuals' vote and zipcode, ordered from lowest to highest age.
A: SELECT vote, zipcode, age FROM data ORDER BY age ASC {EOS}

Q: Write a SQL query that returns white voters' average age for each state color and sort the results.
A:"""

character_cfg = LarkStuff(
    r"""
        start: WS? "SELECT" WS select_expr WS "FROM" WS from_expr [WS "WHERE" WS bool_condition] [WS "GROUP BY" WS var_list] [WS "ORDER BY" WS orderby_expr] WS EOS
        EOS: "▪"
        select_expr: STAR | select_list
        bool_condition: bool_expr | "(" bool_condition WS "AND" WS bool_condition ")" | "(" bool_condition WS "OR" WS bool_condition ")"
        bool_expr: var "=" value | var ">" value | var "<" value
        from_expr: "data"
        orderby_expr: var_list WS "ASC" | var_list WS "DESC"
        select_list: select_var ("," WS select_var)*
        var_list: var ("," WS var)*
        select_var: var | "AVG(" var ")" | "MEDIAN(" var ")" | "COUNT(" var ")"
        var: "age" | "gender" | "year" | "state_color" | "zipcode" | "vote" | "race_ethnicity"
        value: NUMBER | "'red'" | "'blue'" | "'white'" | "'black'" | "'latino'" | "'republican'" | "'democrat'" | "'male'" | "'female'"
        STAR: "*"
        NUMBER: /\d+/
        WS: /[ ]/

    """
).char_cfg(.99, ignore='[ ]?')

guide = EarleyBoolMaskCFGLM(character_cfg)

In [None]:
class PureModel(Model):
    def __init__(self, llm, prompt, max_tokens):
        super().__init__()
        self.LLM = llm
        self.context = LMContext(self.LLM, prompt)
        self.max_tokens = max_tokens
        
    async def step(self):
        token = await self.sample(self.context.next_token())

        self.max_tokens -= 1

        print(f"{token} : {str(self.context)}")
        
        # Check if done
        if token == self.LLM.tokenizer.eos_token_id or self.max_tokens == 0:
            self.finish()
            return

class ChracterProposalSteeringModel(Model):
    def __init__(self, llm, guide, proposal, prompt, max_tokens, compare_time=False):
        super().__init__()
        self.llm = llm # AsyncGreedilyTokenizedLM
        self.guide = guide # PCFGLM
        self.prompt = prompt
        self.context = []
        self.proposal = proposal # CharacterProposal
        self.max_tokens = max_tokens
        self.compare_time = compare_time

    async def step(self):
        (token, llm_prob, guide_prob, proposal_prob) = await self.proposal.sample_next_token(
            prompt=self.prompt, context=''.join(self.context), compare_time=self.compare_time
        )
        self.context.append(token)
        self.weight += np.log(llm_prob) + np.log(guide_prob) - np.log(proposal_prob)
        self.max_tokens -= 1

        print(f"`{token}` : {''.join(self.context)} : {self.weight}")

        if token == self.llm.eos or self.max_tokens == 0 or token == genparse.EOS:
            self.finish()
            return
        
    def immutable_properties(self):
        return ['llm', 'prompt', 'guide', 'compare_token']
    
    def __repr__(self):
        return f"`{'' if not self.context else self.context[-1]}` : {''.join(self.context)} : {self.weight}"

In [None]:
MAX_TOKENS = 100
BATCH_SIZE = 80

hfppl_llm.batch_size = BATCH_SIZE
genparse_llm = AsyncGreedilyTokenizedLLM(hfppl_llm, tokenizer)
proposal = CharacterProposal(llm=genparse_llm, guide=guide)
steering_model = ChracterProposalSteeringModel(
    genparse_llm, guide, proposal, prompt, MAX_TOKENS, compare_time=False
)

In [None]:
particles = asyncio.run(smc_standard(steering_model, n_particles=20))

In [None]:
posterior = Float.chart()
for p in particles:
    posterior[''.join(p.context)] += np.exp(p.weight)
posterior.normalize()