In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
import os
# COMMENT THESE OUT IF YOU ARE NOT ME
sys.path.append("/home/mila/b/benjamin.lebrun/genparse")
os.environ["HF_HOME"] = os.path.join(os.environ["SCRATCH"], "hf_cache")

In [4]:
from hfppl import Model, CachedCausalLM, LMContext, smc_standard, smc_steer
from hfppl.distributions import TokenCategorical
from transformers import AutoTokenizer, AutoModel

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
MODEL_ID = "codellama/CodeLlama-7b-Instruct-hf"
hfppl_llm = CachedCausalLM.from_pretrained(MODEL_ID, load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID, 
    use_fast=True,
    prefix_token=None, 
    middle_token=None, 
    suffix_token=None, 
    eot_token=None, 
    fill_token=None
)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████| 2/2 [00:20<00:00, 10.47s/it]


In [6]:
import genparse
from genparse.cfglm import BoolMaskCFGLM
from genparse.util import LarkStuff
from genparse import EOS, Float
from arsenal.maths import sample_dict, logsumexp
from genparse.proposal import CharacterProposal
from genparse.lm import LM
from arsenal import timers

In [7]:
prompt = """
You have access to a political survey data table named "data", which includes the following columns:
- "age" (integer)
- "gender" ("male" or "female"),
- "year" (integer)
- "state_color" ("blue" or "red")
- "zipcode" (integer)
- "vote" ("democrat" or "republican") 
- "race_ethnicity" ("white", "black", or "latino").

Q: Write a SQL query that shows individuals' age and gender, for people over 50 years old.
A: SELECT age, gender FROM data WHERE age>50 </s>
Q: Write a SQL query that shows individuals' vote and zipcode, ordered from lowest to highest age.
A:""" 

# A: SELECT vote, zipcode, age FROM data ORDER BY age ASC </s>
# Q: Write a SQL query that returns white voters' average age for each state color.

guide = BoolMaskCFGLM(
    LarkStuff(
        r"""
            start: WS? "SELECT" WS select_expr WS "FROM" WS from_expr [WS "WHERE" WS bool_condition] [WS "GROUP BY" WS var_list] [WS "ORDER BY" WS orderby_expr] WS EOS
            EOS: "</s>"
            select_expr: STAR | select_list
            bool_condition: bool_expr | "(" bool_condition WS "AND" WS bool_condition ")" | "(" bool_condition WS "OR" WS bool_condition ")"
            bool_expr: var "=" value | var ">" value | var "<" value
            from_expr: "data"
            orderby_expr: var_list WS "ASC" | var_list WS "DESC"
            select_list: select_var ("," WS select_var)*
            var_list: var ("," WS var)*
            select_var: var | "AVG(" var ")" | "MEDIAN(" var ")" | "COUNT(" var ")"
            var: "age" | "gender" | "year" | "state_color" | "zipcode" | "vote" | "race_ethnicity"
            value: NUMBER | "'red'" | "'blue'" | "'white'" | "'black'" | "'latino'" | "'republican'" | "'democrat'" | "'male'" | "'female'"
            STAR: "*"
            NUMBER: /\d+/
            WS: /[ ]/

        """
    ).char_cfg(.99, ignore='[ ]?')
)

In [8]:
import numpy as np
import time
import asyncio
import nest_asyncio
nest_asyncio.apply()

In [10]:
from genparse import Float

class SteeringModel(Model):
    def __init__(self, llm, guide, proposal, prompt, max_tokens, compare_time=False):
        super().__init__()
        self.llm = llm # GreedilyTokenizedLM
        self.guide = guide # PCFGLM
        self.prompt = prompt
        self.context = ""
        self.proposal = proposal # CharacterProposal
        self.max_tokens = max_tokens
        self.compare_time = compare_time

    async def step(self):
        (token, llm_prob, guide_prob, proposal_prob) = await self.proposal.sample_next_token(
            context=self.context, prompt=self.prompt, compare_time=self.compare_time
        )
        self.context += token
        self.weight += np.log(llm_prob) + np.log(guide_prob) - np.log(proposal_prob)
        self.max_tokens -= 1

        print(f"Sampled token=`{token}`. Particle={self.context}")

        if token == self.llm.eos or self.max_tokens == 0 or token == genparse.EOS:
            self.finish()
            return

class GreedilyTokenizedLLM:
    def __init__(self, llm, tokenizer):
        self.tokenizer = tokenizer
        self._model = llm # hfppl Model
        self._decode = [self.tokenizer.decode([i]) for i in range(self.tokenizer.vocab_size)]
        self.V = set(self._decode)
        self.eos = self.tokenizer.eos_token

    def __call__(self, xs):
        return self.model(self.tokenizer.encode(xs))

    async def p_next(self, xs, top=None):
        return await self._p_next(xs, top=top)

    async def _p_next(self, xs, top=None):
        assert isinstance(xs, str)
        tokens = self.tokenizer.encode(xs)

        _logp = await self._model.next_token_logprobs(tokens)
        _p = np.exp(_logp)

        if top is None:
            top_p = _p.argsort()
        else:
            top_p = _p.argsort()[-top:]
        pp = Float.chart()
        for i in reversed(top_p):
            pp[self._decode[i]] = _p[i]
        if top is None:
            return pp
        else:
            return pp.normalize()


In [None]:
# SELECT vote, zipcode, age FROM data ORDER BY age ASC </s>

In [11]:
genparse_llm = GreedilyTokenizedLLM(hfppl_llm, tokenizer)
proposal = CharacterProposal(llm=genparse_llm, guide=guide)
steering_model = SteeringModel(genparse_llm, guide, proposal, prompt, 100, compare_time=False)
particles = asyncio.run(smc_standard(steering_model, n_particles=10))

SCRUN MAGIC
Sampled token=` `. Particle= 
Sampled token=` `. Particle= 
Sampled token=` `. Particle= 
Sampled token=`  `. Particle=  
Sampled token=` `. Particle= 
Sampled token=` `. Particle= 
Sampled token=` `. Particle= 
Sampled token=` `. Particle= 
Sampled token=` `. Particle= 
Sampled token=` `. Particle= 
Sampled token=`SELECT`. Particle=  SELECT
Sampled token=`SELECT`. Particle= SELECT
Sampled token=`SELECT`. Particle= SELECT
Sampled token=`SELECT`. Particle=  SELECT
Sampled token=`SELECT`. Particle=  SELECT
Sampled token=`SELECT`. Particle= SELECT
Sampled token=`SELECT`. Particle= SELECT
Sampled token=`SELECT`. Particle= SELECT
Sampled token=`SELECT`. Particle= SELECT
Sampled token=`SELECT`. Particle=  SELECT
Sampled token=` `. Particle=  SELECT 
Sampled token=` `. Particle= SELECT 
Sampled token=` `. Particle= SELECT 
Sampled token=`  `. Particle=  SELECT  
Sampled token=` `. Particle=  SELECT 
Sampled token=` `. Particle= SELECT 
Sampled token=` `. Particle= SELECT 
Sampled 

Scalene: An exception of type UnicodeDecodeError occurred. Arguments:
('ascii', b'/* PrismJS 1.26.0\nhttps://prismjs.com/download.html#themes=prism&languages=markup+css+clike+javascript+python&plugins=normalize-whitespace */\n/// <reference lib="WebWorker"/>\n\nvar _self =\n  typeof window !== "undefined"\n    ? window // if in browser\n    : typeof WorkerGlobalScope !== "undefined" &&\n      self instanceof WorkerGlobalScope\n    ? self // if in worker\n    : {}; // if in node js\n\n/**\n * Prism: Lightweight, robust, elegant syntax highlighting\n *\n * @license MIT <https://opensource.org/licenses/MIT>\n * @author Lea Verou <https://lea.verou.me>\n * @namespace\n * @public\n */\nvar Prism = (function (_self) {\n  // Private helper vars\n  var lang = /(?:^|\\s)lang(?:uage)?-([\\w-]+)(?=\\s|$)/i;\n  var uniqueId = 0;\n\n  // The grammar object for plaintext\n  var plainTextGrammar = {};\n\n  var _ = {\n    /**\n     * By default, Prism will attempt to highlight all code elements (by ca

In [22]:
print(f'{sum(100 - p.max_tokens for p in particles)/(18*60)} tokens/sec')

0.42407407407407405 tokens/sec
