In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os
import getpass

if getpass.getuser() == "benjamin.lebrun":
    sys.path.append("/home/mila/b/benjamin.lebrun/genparse")
    os.environ["HF_HOME"] = os.path.join(os.environ["SCRATCH"], "hf_cache")
    print("HF cache set; path updated")

import numpy as np
import asyncio
import nest_asyncio
nest_asyncio.apply()

HF cache set; path updated


In [3]:
from hfppl import Model, CachedCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from genparse import EOS, add_EOS, Boolean, Float

In [14]:
from genparse.lm import LM

class EarleyBoolMaskCFGLM2(LM):
    "LM-like interface for Boolean-masking CFG models; uses Earley's algorithm for inference."

    def __init__(self, cfg):
        from genparse.experimental.earley import Earley
        if EOS not in cfg.V: cfg = add_EOS(cfg)
        cfg = cfg.nullaryremove(binarize=True).unarycycleremove().renumber()
        if cfg.R != Boolean: cfg = cfg.map_values(lambda x: Boolean(x>0), Boolean)
        self.model = Earley(cfg.prefix_grammar)
        super().__init__(eos = EOS, V = cfg.V)

    def p_next(self, context):
        p = self.model.p_next(context).trim()
        return Float.chart({w: 1 for w in p})

    def __call__(self, context):
        assert context[-1] == EOS
        return float(self.model(context) != Boolean.zero)

In [5]:
from genparse.lm import AsyncGreedilyTokenizedLLM
llm = AsyncGreedilyTokenizedLLM("codellama/CodeLlama-7b-Instruct-hf")

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████| 2/2 [00:25<00:00, 12.58s/it]


In [6]:
restricted_sql = r"""
start: query_expr "</s>"

query_expr: select [ "ORDER" "BY" (order_by_expr ",")*  order_by_expr] [ "LIMIT" integer_ ]

select: "SELECT" [(select_expr ",")*] select_expr "FROM" "data" [ "WHERE" bool_expression ] [ "GROUP" "BY" [(expression ",")*] expression ]

select_expr.0: expression_math [ [ "AS" ] alias ] -> select_expression

?expression_math: expression_product
               | expression_math PLUS expression_product -> expression_add
               | expression_math MINUS expression_product -> expression_sub
               | AGGREGATION expression_math /\)/ -> sql_aggregation

?expression: (name | STAR) -> column_name
            | literal

?expression_product: expression_parens
                  | expression_product STAR expression_parens -> expression_mul
                  | expression_product "/" expression_parens -> expression_div

?expression_parens: expression
                  | /\(/ expression_parens STAR expression /\)/ -> expression_mul
                  | /\(/  expression_parens "/" expression /\)/ -> expression_div
                  | /\(/  expression_parens PLUS expression /\)/ -> expression_add
                  | /\(/  expression_parens MINUS expression /\)/ -> expression_sub

bool_expression: bool_parentheses
                 | bool_expression "AND" bool_parentheses -> bool_and
                 | bool_expression "OR" bool_parentheses -> bool_or
bool_parentheses: comparison_type
                 | /\(/   bool_expression "AND" comparison_type /\)/ -> bool_and
                 | /\(/  bool_expression "OR" comparison_type /\)/ -> bool_or
comparison_type: equals | not_equals | greater_than | less_than | greater_than_or_equal
| less_than_or_equal | is_null | is_not_null
equals: expression_math "=" expression_math
not_equals: expression_math ("<>" | "!=") expression_math
greater_than: expression_math ">" expression_math
less_than: expression_math "<" expression_math
greater_than_or_equal: expression_math ">=" expression_math
less_than_or_equal: expression_math "<=" expression_math
is_null: expression_math "is" "null"
is_not_null: expression_math "is" "not" "null"

alias: /[A-Za-z]/
name: /[A-Za-z]/
PLUS: /\+/
MINUS: /[\-]/

order_by_expr: expression_math ["ASC"] -> order_asc
        | expression_math "DESC" -> order_desc

AGGREGATION.8: ("sum(" | "avg(" | "min(" | "max(" | "count(" "distinct" | "count(")
STAR: /\*/
integer_: /[1-9][0-9]*/
?literal: boolean -> bool
       | integer_ -> number
       | /'([^']|\s)+'|''/ -> string

boolean: "true" -> true
       | "false" -> false

%import common.WS
%ignore WS
"""

very_restricted_sql = r"""
        start: WS? "SELECT" WS select_expr WS "FROM" WS from_expr [WS "WHERE" WS bool_condition] [WS "GROUP BY" WS var_list] [WS "ORDER BY" WS orderby_expr] WS EOS
        EOS: "</s>"
        select_expr: STAR | select_list
        bool_condition: bool_expr | "(" bool_condition WS "AND" WS bool_condition ")" | "(" bool_condition WS "OR" WS bool_condition ")"
        bool_expr: var "=" value | var ">" value | var "<" value
        from_expr: "data"
        orderby_expr: var_list WS "ASC" | var_list WS "DESC"
        select_list: select_var ("," WS select_var)*
        var_list: var ("," WS var)*
        select_var: var | "AVG(" var ")" | "MEDIAN(" var ")" | "COUNT(" var ")"
        var: "age" | "gender" | "year" | "state_color" | "zipcode" | "vote" | "race_ethnicity"
        value: NUMBER | "'red'" | "'blue'" | "'white'" | "'black'" | "'latino'" | "'republican'" | "'democrat'" | "'male'" | "'female'"
        STAR: "*"
        NUMBER: /\d+/
        WS: /[ ]/
    """

In [7]:
from genparse.cfglm import EarleyBoolMaskCFGLM
from genparse.util import LarkStuff
guide = EarleyBoolMaskCFGLM(LarkStuff(very_restricted_sql).char_cfg(.99, ignore='[ ]?'))

In [17]:
from genparse.steer import SteeredSampler

In [18]:
sampler = SteeredSampler(llm, guide)

In [19]:
from genparse.proposal import TokenProposal, CharacterProposal
proposal = TokenProposal(llm=llm, guide=guide)

In [20]:
prompt = "Write an SQL query: "
sampler.run_inference(
    prompt = prompt,
    proposal = proposal,
    method = "smc-standard",
    n_particles = 2
)

In [21]:
sampler.posterior

0,1
key,value
"[' SELECT', ' *', ' FROM', ' ', ' dat', 'a', ' GROUP', ' BY', ' ', ' year', ',', ' year', ',', ' year', ',', ' year', ',', ' year', ',', ' year', ' ', ' ORDER', ' BY', ' ', ' year', ' DESC', ' ', ' <', '/', 's', '>', ' ', '▪']",1.8920063240579024e-41
"[' SE', 'L', 'ECT', ' *', ' FROM', ' ', ' data', ' ', ' WHERE', ' ', ' (', '((', 's', 't', 'at', 'e', '_', 'co', 'lor', '=', '0', ' OR', ' state', '_', 'color', '=', '2', ')', ' AND', ' (', ' (', 'state', '_', 'color', ' =', ' ', '0', ' OR', ' state', '_', 'color', ' =', ' ', '3', ')', ' AND', ' state', '_', 'color', ' =', ' ', '5', ')', ' ', ')', ' AND', ' (', ' (', 'state', '_', 'color', ' =', ' ', '6', ' AND', ' state', '_', 'color', ' =', ' ', '9', ')', ' OR', ' state', '_', 'color', ' =', ' ', '5', ')', ' )', ' ORDER', ' BY', ' ', ' state', '_', 'color', ' DESC', ' ', ' <', '/', 's', '>', ' ', '▪']",3.3962177866286767e-76


In [38]:
p_llm = await llm.p_next(prompt)

In [39]:
proposal._update_trie(p_llm)

In [40]:
paths = proposal._enumerate_paths("")

In [41]:
paths

[]

In [27]:
p_next = await proposal.sample_next_token(
    prompt = "Write an SQL query: ", context = ""
)

IndexError: index -1 is out of bounds for axis 0 with size 0

In [26]:
prompt = "Write an SQL query: "
sampler.run_inference(
    prompt = prompt,
    proposal = proposal,
    method = "smc-standard",
    n_particles = 2
)

IndexError: index -1 is out of bounds for axis 0 with size 0

In [4]:
from hfppl import Model, CachedCausalLM, LMContext
from transformers import AutoTokenizer

In [5]:
MODEL_ID = "codellama/CodeLlama-7b-Instruct-hf"
hfppl_llm = CachedCausalLM.from_pretrained(MODEL_ID, load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID, 
    use_fast=True,
    prefix_token=None, 
    middle_token=None, 
    suffix_token=None, 
    eot_token=None, 
    fill_token=None
)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████| 2/2 [01:27<00:00, 43.91s/it]


In [43]:
import genparse
from genparse.cfglm import EarleyBoolMaskCFGLM, BoolMaskCFGLM
from genparse.util import LarkStuff
from genparse import EOS, Float
from arsenal.maths import sample_dict, logsumexp
from genparse.proposal import CharacterProposal, TokenProposal
from genparse.lm import AsyncGreedilyTokenizedLLM
from genparse.inference import smc_standard, smc_steer

In [15]:
from genparse.experimental.earley import Earley
from genparse.lm import LM
from genparse.semiring import Float, Boolean

In [40]:
class EarleyBoolMaskCFGLM2(LM):
    "LM-like interface for Boolean-masking CFG models; uses Earley's algorithm for inference."

    def __init__(self, cfg):
        from genparse.experimental.earley import Earley
        if EOS not in cfg.V: cfg = add_EOS(cfg)
        cfg = cfg.nullaryremove(binarize=True).unarycycleremove().renumber()
        if cfg.R != Boolean: cfg = cfg.map_values(lambda x: Boolean(x>0), Boolean)
        self.model = Earley(cfg.prefix_grammar)
        super().__init__(eos = EOS, V = cfg.V)

    def p_next(self, context):
        p = self.model.p_next(context).trim()
        return Float.chart({w: 1 for w in p})

    def __call__(self, context):
        assert context[-1] == EOS
        return float(self.model(context) != Boolean.zero)

In [41]:
# fails
cfg = LarkStuff(open("../assets/sql_grammar_case_sensitive.lark").read()).char_cfg(.99, ignore='[ ]?')
guide = EarleyBoolMaskCFGLM(cfg)

In [42]:
class PureModel(Model):
    def __init__(self, llm, prompt, max_tokens):
        super().__init__()
        self.LLM = llm
        self.context = LMContext(self.LLM, prompt)
        self.max_tokens = max_tokens
        
    async def step(self):
        token = await self.sample(self.context.next_token())

        self.max_tokens -= 1

        print(f"{token} : {str(self.context)}")
        
        # Check if done
        if token == self.LLM.tokenizer.eos_token_id or self.max_tokens == 0:
            self.finish()
            return

class SteeringModel(Model):
    def __init__(self, llm, guide, proposal, prompt, max_tokens, compare_time=False):
        super().__init__()
        self.llm = llm # AsyncGreedilyTokenizedLM
        self.guide = guide # PCFGLM
        self.prompt = prompt
        self.context = []
        self.proposal = proposal 
        self.max_tokens = max_tokens
        self.compare_time = compare_time

    async def step(self):
        (token, llm_prob, guide_prob, proposal_prob) = await self.proposal.sample_next_token(
            prompt=self.prompt, context=''.join(self.context), compare_time=self.compare_time
        )
        self.context.append(token)
        self.weight += np.log(llm_prob) + np.log(guide_prob) - np.log(proposal_prob)
        self.max_tokens -= 1

        print(f"`{token}` : {''.join(self.context)} : {self.weight}")

        if token == self.llm.eos or self.max_tokens == 0 or token == genparse.EOS:
            self.finish()
            return
        
    def immutable_properties(self):
        return ['llm', 'prompt', 'guide', 'compare_token']
    
    def __repr__(self):
        return f"`{'' if not self.context else self.context[-1]}` : {''.join(self.context)} : {self.weight}"

In [44]:
MAX_TOKENS = 100
BATCH_SIZE = 80

hfppl_llm.batch_size = BATCH_SIZE
genparse_llm = AsyncGreedilyTokenizedLLM(hfppl_llm, tokenizer)
proposal = TokenProposal(llm=genparse_llm, guide=guide)
steering_model = SteeringModel(
    genparse_llm, guide, proposal, prompt, MAX_TOKENS, compare_time=False
)

NameError: name 'prompt' is not defined

In [None]:
particles = asyncio.run(smc_standard(steering_model, n_particles=20))

In [None]:
posterior = Float.chart()
for p in particles:
    posterior[''.join(p.context)] += np.exp(p.weight)
posterior.normalize()