In [1]:
%load_ext autoreload
%autoreload all

In [2]:
import nest_asyncio

nest_asyncio.apply()

prompt = """
    You have access to a political survey data table named "data", which includes the following columns:
    - "age" (integer)
    - "gender" ("male" or "female"),
    - "year" (integer)
    - "state_color" ("blue" or "red")
    - "zipcode" (integer)
    - "vote" ("democrat" or "republican") 
    - "race_ethnicity" ("white", "black", or "latino").

    Q: Write a SQL query that shows individuals' age and gender, for people over 50 years old.
    A: SELECT age, gender FROM data WHERE age>50 </s>
    Q: Write a SQL query that shows individuals' vote and zipcode, ordered from lowest to highest age.
    A: SELECT vote, zipcode, age FROM data ORDER BY age ASC </s>
    Q: Write a SQL query that returns white voters' average age for each state color. 
    A:"""

# an ambiguous prompt
prompt = """
    You have access to a political survey data table named "data", which includes the following columns:
    - "age" (integer)
    - "gender" ("male" or "female"),
    - "year" (integer)
    - "state_color" ("blue" or "red")
    - "zipcode" (integer)
    - "vote" ("democrat" or "republican") 
    - "race_ethnicity" ("white", "black", or "latino").

    Q: Write a SQL query that shows individuals' age and gender, for people over 50 years old.
    A: SELECT age, gender FROM data WHERE age>50 </s>
    Q: Write a SQL query that shows individuals' vote and zipcode, ordered from lowest to highest age.
    A: SELECT vote, zipcode, age FROM data ORDER BY age ASC </s>
    Q: Write a SQL query that returns the average proportion of votes by party for young versus old voters. 
    A:"""

very_restricted_sql = r"""
    start: WS? "SELECT" WS select_expr WS "FROM" WS from_expr [WS "WHERE" WS bool_condition] [WS "GROUP BY" WS var_list] [WS "ORDER BY" WS orderby_expr] WS EOS
    EOS: "</s>"
    select_expr: STAR | select_list
    bool_condition: bool_expr | "(" bool_condition WS "AND" WS bool_condition ")" | "(" bool_condition WS "OR" WS bool_condition ")"
    bool_expr: var "=" value | var ">" value | var "<" value
    from_expr: "data"
    orderby_expr: var_list WS "ASC" | var_list WS "DESC"
    select_list: select_var ("," WS select_var)*
    var_list: var ("," WS var)*
    select_var: var | "AVG(" var ")" | "MEDIAN(" var ")" | "COUNT(" var ")"
    var: "age" | "gender" | "year" | "state_color" | "zipcode" | "vote" | "race_ethnicity"
    value: NUMBER | "'red'" | "'blue'" | "'white'" | "'black'" | "'latino'" | "'republican'" | "'democrat'" | "'male'" | "'female'"
    STAR: "*"
    NUMBER: /\d+/
    WS: /[ ]/
"""

In [3]:
from genparse.util import load_model_by_name, lark_guide
from torch.cuda import is_available as is_cuda_available

if is_cuda_available():
    genparse_llm = load_model_by_name('codellama', batch_size=40)

else:
    genparse_llm = load_model_by_name('gpt2', batch_size=5)

MAX_TOKENS = 100

In [5]:
from genparse.steer import HFPPLSampler
from genparse.proposal import CharacterProposal

guide = lark_guide(very_restricted_sql)

sampler = HFPPLSampler(llm=genparse_llm, guide=guide)

proposal = CharacterProposal(llm=genparse_llm, guide=guide)

In [6]:
particle_approx = sampler.run_inference(
    prompt=prompt,
    proposal=proposal,
    method='smc-standard',
    return_record=True,  # use version of smc that keeps a record
    n_particles=12,
    max_tokens=60,
    verbosity=0,
)
record = particle_approx.record

In [7]:
particle_approx.posterior

0,1
key,value
"SELECT age, age, gender FROM data WHERE age>50 ORDER BY age ASC </s>▪",0.9999486132024705
"SELECT age, age, gender FROM data WHERE age>50 ORDER BY age DESC </s>▪",5.1386797529405986e-05


## Plotting the particle beam

In [11]:
record.plotly(
    # xrange=[12,17], #edit x range like [0,10] to see a subset
    show_est_logprob=True,
    height=600,
    untangle=True,  # reorder the particles before plotting to avoid crossing lines in resampling
)

In [10]:
# Save an example
import json
import numpy as np

WRITE_EXAMPLE = 0
READ_EXAMPLE = 0


class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NumpyEncoder, self).default(obj)


if WRITE_EXAMPLE:
    with open('smc_viz/MyExample.json', 'w') as f:
        json.dump(particle_approx.record, f, cls=NumpyEncoder)
if READ_EXAMPLE:
    from genparse.record import SMCRecord

    with open('smc_viz/MyExample.json', 'r') as f:
        record = SMCRecord(**json.load(f))

Write out example figures

In [21]:
import os


def write_images_scrollby(windowsize=10, outdir='figs', height=800, width=1000):
    record.plotly(height=height, width=width).write_image(
        os.path.join(outdir, 'EXAMPLE.png')
    )
    for x_ in range(len(record['step'])):
        record.plotly(xrange=[0, x_], height=height, width=width).write_image(
            os.path.join(outdir, f'TMP{x_:02d}.png')
        )


if 0:
    write_images_scrollby()