In [1]:
%load_ext autoreload
%autoreload all

In [8]:
import nest_asyncio
nest_asyncio.apply()

prompt = """
    You have access to a political survey data table named "data", which includes the following columns:
    - "age" (integer)
    - "gender" ("male" or "female"),
    - "year" (integer)
    - "state_color" ("blue" or "red")
    - "zipcode" (integer)
    - "vote" ("democrat" or "republican") 
    - "race_ethnicity" ("white", "black", or "latino").

    Q: Write a SQL query that shows individuals' age and gender, for people over 50 years old.
    A: SELECT age, gender FROM data WHERE age>50 </s>
    Q: Write a SQL query that shows individuals' vote and zipcode, ordered from lowest to highest age.
    A: SELECT vote, zipcode, age FROM data ORDER BY age ASC </s>
    Q: Write a SQL query that returns white voters' average age for each state color. 
    A:"""

# an ambiguous prompt
prompt = """
    You have access to a political survey data table named "data", which includes the following columns:
    - "age" (integer)
    - "gender" ("male" or "female"),
    - "year" (integer)
    - "state_color" ("blue" or "red")
    - "zipcode" (integer)
    - "vote" ("democrat" or "republican") 
    - "race_ethnicity" ("white", "black", or "latino").

    Q: Write a SQL query that shows individuals' age and gender, for people over 50 years old.
    A: SELECT age, gender FROM data WHERE age>50 </s>
    Q: Write a SQL query that shows individuals' vote and zipcode, ordered from lowest to highest age.
    A: SELECT vote, zipcode, age FROM data ORDER BY age ASC </s>
    Q: Write a SQL query that returns the average proportion of votes by party for young versus old voters. 
    A:"""

very_restricted_sql = r"""
    start: WS? "SELECT" WS select_expr WS "FROM" WS from_expr [WS "WHERE" WS bool_condition] [WS "GROUP BY" WS var_list] [WS "ORDER BY" WS orderby_expr] WS EOS
    EOS: "</s>"
    select_expr: STAR | select_list
    bool_condition: bool_expr | "(" bool_condition WS "AND" WS bool_condition ")" | "(" bool_condition WS "OR" WS bool_condition ")"
    bool_expr: var "=" value | var ">" value | var "<" value
    from_expr: "data"
    orderby_expr: var_list WS "ASC" | var_list WS "DESC"
    select_list: select_var ("," WS select_var)*
    var_list: var ("," WS var)*
    select_var: var | "AVG(" var ")" | "MEDIAN(" var ")" | "COUNT(" var ")"
    var: "age" | "gender" | "year" | "state_color" | "zipcode" | "vote" | "race_ethnicity"
    value: NUMBER | "'red'" | "'blue'" | "'white'" | "'black'" | "'latino'" | "'republican'" | "'democrat'" | "'male'" | "'female'"
    STAR: "*"
    NUMBER: /\d+/
    WS: /[ ]/
"""

In [18]:
from genparse.lm import AsyncGreedilyTokenizedLLM
from torch.cuda import is_available as is_cuda_available

if is_cuda_available():
    genparse_llm = AsyncGreedilyTokenizedLLM.from_name("codellama/CodeLlama-7b-Instruct-hf", batch_size=40)

else:
    import transformers
    from genparse.lm import LLM
    MODEL_ID = "gpt2"

    MAX_TOKENS = 100
    BATCH_SIZE = 80
    
    hfppl_llm = LLM(transformers.AutoModelForCausalLM.from_pretrained(MODEL_ID))
    tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_ID)

    hfppl_llm.batch_size = BATCH_SIZE
    genparse_llm = AsyncGreedilyTokenizedLLM(model=hfppl_llm, tokenizer=tokenizer, batch_size = BATCH_SIZE)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [19]:
from genparse.cfglm import EarleyBoolMaskCFGLM
from genparse.util import LarkStuff
guide = EarleyBoolMaskCFGLM(LarkStuff(very_restricted_sql).char_cfg(.99, ignore='[ ]?'))

from genparse.steer import HFPPLSampler
sampler = HFPPLSampler(llm=genparse_llm, guide=guide)

from genparse.proposal import TokenProposal, CharacterProposal
proposal = CharacterProposal(llm=genparse_llm, guide=guide)

testing a modified version

In [119]:
particle_approx, record = sampler.run_inference(
    prompt = prompt,
    proposal = proposal,
    method = "smc-standard", 
    return_record = True, # use version of smc that keeps a record
    n_particles = 5,
    max_tokens = 10,
    verbosity = 1
)

├ Particle   0 (weight -0.2022). ` SELECT` :  SELECT
├ Particle   1 (weight -0.2022). ` SELECT` :  SELECT
├ Particle   2 (weight -0.2022). ` SELECT` :  SELECT
├ Particle   3 (weight -0.2022). ` SELECT` :  SELECT
├ Particle   4 (weight -0.2022). ` SELECT` :  SELECT
│ Step   1 average weight: -0.2022
└╼
├ Particle   0 (weight -0.4999). ` AV` :  SELECT AV
├ Particle   1 (weight -0.1050). ` ` :  SELECT 
├ Particle   2 (weight -0.4999). ` AV` :  SELECT AV
├ Particle   3 (weight -0.4999). ` AV` :  SELECT AV
├ Particle   4 (weight -0.4999). ` AV` :  SELECT AV
│ Step   2 average weight: -0.4074
└╼
├ Particle   0 (weight -0.5002). `G` :  SELECT AVG
├ Particle   1 (weight -2.5384). ` AV` :  SELECT  AV
├ Particle   2 (weight -0.5002). `G` :  SELECT AVG
├ Particle   3 (weight -0.5002). `G` :  SELECT AVG
├ Particle   4 (weight -0.5002). `G` :  SELECT AVG
│ Step   3 average weight: -0.6913
└╼
├ Particle   0 (weight -0.5815). `(` :  SELECT AVG(
├ Particle   1 (weight -2.5390). `G` :  SELECT  AVG
├ Pa

In [101]:
import pandas as pd
rec = pd.DataFrame(record)
rec['token'] = rec['context'].apply(lambda r: [ri[-1] if len(ri)>0 else [] for ri in r ])
rec['surprisal estimate'] = -1 * rec['average weight'].diff() # incremental diff in avg weight (biased (over)estimator of marginalizing constant)
rec

Unnamed: 0,step,context,weight,resample?,resampled as,average weight,token,surprisal estimate
0,0,"[[], [], [], [], []]","[0.0, 0.0, 0.0, 0.0, 0.0]",False,"[0, 1, 2, 3, 4]",0.0,"[[], [], [], [], []]",
1,1,"[[ SELECT], [ SELECT], [ SELECT], [ ], [ SELECT]]","[-0.20220203720224034, -0.20220203720224034, -...",False,"[0, 1, 2, 3, 4]",-0.202202,"[ SELECT, SELECT, SELECT, , SELECT]",0.202202
2,2,"[[ SELECT, ], [ SELECT, ], [ SELECT, AV], [...","[1.4068422106477152, -0.4998541155349871, -0.4...",True,"[0, 0, 0, 0, 1]",0.213856,"[ , , AV, SELECT, ]",-0.416058
3,3,"[[ SELECT, , vot], [ SELECT, , AV], [ SELE...","[-2.3212290152103714, -2.2195380989104763, -2....",False,"[0, 1, 2, 3, 4]",-2.258982,"[ vot, AV, vote, AV, AV]",2.472838
4,4,"[[ SELECT, , vot, e], [ SELECT, , AV, G], ...","[-7.4244995108957275, -2.2201792331383787, -2....",False,"[0, 1, 2, 3, 4]",-2.487693,"[e, G, ,, G, G]",0.228711
5,5,"[[ SELECT, , vot, e, ,], [ SELECT, , AV, G...","[-7.52079830473581, -2.3313144240434345, 0.182...",True,"[2, 2, 2, 2, 2]",-1.209402,"[,, (, , (, (]",-1.278291
6,6,"[[ SELECT, , vote, ,, , AV], [ SELECT, , ...","[-2.268713568316568, -2.325205749108936, -2.32...",False,"[0, 1, 2, 3, 4]",-2.302225,"[ AV, age, age, age, AV]",1.092822
7,7,"[[ SELECT, , vote, ,, , AV, G], [ SELECT, ...","[-2.2692935108304537, -2.11343842976186, -2.36...",False,"[0, 1, 2, 3, 4]",-2.270836,"[G, ,, , , G]",-0.031388
8,8,"[[ SELECT, , vote, ,, , AV, G, (], [ SELEC...","[-2.355483244898019, -2.140329013848588, -3.14...",False,"[0, 1, 2, 3, 4]",-2.541401,"[(, , FROM, FROM, (]",0.270564
9,9,"[[ SELECT, , vote, ,, , AV, G, (, age], [...","[-3.8550337856248773, -3.2130347271145263, -3....",False,"[0, 1, 2, 3, 4]",-3.250231,"[ age, COUNT, , , vote]",0.70883


In [105]:
recs = rec.explode(['context','weight','resampled as','token']).reset_index(drop=True)
N_PARTS = len(rec['context'][0])
recs['particle'] = recs.groupby(recs.index // N_PARTS).cumcount()
recs['context_string'] = recs['context'].apply(lambda r: ''.join(r[:-1]))
recs['exp_average weight'] = recs['average weight'].apply(lambda r: np.exp(r))
recs['exp_weight'] = recs['weight'].apply(lambda r: np.exp(r))
recs['prop_exp_weight'] = recs['exp_weight'] / recs['exp_average weight']
recs

Unnamed: 0,step,context,weight,resample?,resampled as,average weight,token,surprisal estimate,particle,context_string,exp_average weight,exp_weight,prop_exp_weight
0,0,[],0.0,False,0,0.0,[],,0,,1.0,1.0,1.0
1,0,[],0.0,False,1,0.0,[],,1,,1.0,1.0,1.0
2,0,[],0.0,False,2,0.0,[],,2,,1.0,1.0,1.0
3,0,[],0.0,False,3,0.0,[],,3,,1.0,1.0,1.0
4,0,[],0.0,False,4,0.0,[],,4,,1.0,1.0,1.0
5,1,[ SELECT],-0.202202,False,0,-0.202202,SELECT,0.202202,0,,0.81693,0.81693,1.0
6,1,[ SELECT],-0.202202,False,1,-0.202202,SELECT,0.202202,1,,0.81693,0.81693,1.0
7,1,[ SELECT],-0.202202,False,2,-0.202202,SELECT,0.202202,2,,0.81693,0.81693,1.0
8,1,[ ],-0.202202,False,3,-0.202202,,0.202202,3,,0.81693,0.81693,1.0
9,1,[ SELECT],-0.202202,False,4,-0.202202,SELECT,0.202202,4,,0.81693,0.81693,1.0


In [103]:
record.df_for_plotting()

Unnamed: 0,step,context,weight,resample?,resampled as,average weight,token,surprisal estimate,particle,context_string,exp_average weight,exp_weight,prop_exp_weight
0,0,[],0.0,False,0,0.0,[],,,,1.0,1.0,1.0
1,0,[],0.0,False,1,0.0,[],,0.0,,1.0,1.0,1.0
2,0,[],0.0,False,2,0.0,[],,1.0,,1.0,1.0,1.0
3,0,[],0.0,False,3,0.0,[],,2.0,,1.0,1.0,1.0
4,0,[],0.0,False,4,0.0,[],,3.0,,1.0,1.0,1.0
5,1,[ SELECT],-0.202202,False,0,-0.202202,SELECT,0.202202,4.0,,0.81693,0.81693,1.0
6,1,[ SELECT],-0.202202,False,1,-0.202202,SELECT,0.202202,5.0,,0.81693,0.81693,1.0
7,1,[ SELECT],-0.202202,False,2,-0.202202,SELECT,0.202202,6.0,,0.81693,0.81693,1.0
8,1,[ ],-0.202202,False,3,-0.202202,,0.202202,7.0,,0.81693,0.81693,1.0
9,1,[ SELECT],-0.202202,False,4,-0.202202,SELECT,0.202202,8.0,,0.81693,0.81693,1.0


In [98]:
particle_approx.posterior


0,1
key,value
"SELECT vote, AVG( age)",0.11071171922499545
"SELECT vote, age , COUNT(",0.15052855263047052
"SELECT vote, age FROM data",0.5245225156847417
"SELECT vote, AVG(vote)",0.2142372124597923


-----------

## Plotting the particle beam

In [120]:
record.plot_history()

In [107]:
import plotly.express as px
import plotly.graph_objects as go

# Ensure color consistency by creating a color map
pallette = px.colors.n_colors('rgb(255, 0, 0)', 'rgb(0, 125, 255)', len(recs['particle'].unique()), colortype='rgb')
# pallette = px.colors.cyclical.IceFire
color_map = {p: pallette[i % len(pallette)] for i, p in enumerate(recs['particle'].unique())}

fig = px.scatter(
    recs, x="step", y="particle", size='prop_exp_weight', hover_name="token", color='resampled as', text='token',
    hover_data=["context_string","weight", "resample?"], opacity=0.2,
    # animation_frame="step",# animation_group='resampled as', 
    range_x=None,
    width=1500, height=600,
    color_discrete_map=color_map
)
steps = recs['step'].unique()
resample_steps = recs[recs['resample?'] == True]['step'].unique()
no_resample_steps = recs[recs['resample?'] == False]['step'].unique()
for step in resample_steps:
    fig.add_vline(x=step, line_width=4, opacity=0.15, line_color="gray")

resample_data = recs[recs['resample?'] == True]

# Add lines to represent resampling
for resampled_as in recs['resampled as'].unique():
    resampled_as_data = recs[recs['resampled as'] == resampled_as]
    
    for step in resample_steps:
        for _, row in resampled_as_data[resampled_as_data['step'] == step].iterrows():
            fig.add_trace(go.Scatter(
                x=[row['step'], row['step']+1], y=[row['resampled as'], row['particle']],
                mode='lines', line=dict(color=color_map[resampled_as]), opacity=0.3, name=resampled_as, showlegend=False))

    for step in no_resample_steps:
        for _, row in resampled_as_data[resampled_as_data['step'] == step].iterrows():
            fig.add_trace(go.Scatter(
                x=[row['step'], row['step']+1], y=[row['particle'], row['particle']], 
                mode='lines', line=dict(color='gray'), opacity=0.2, name=resampled_as, showlegend=False))

    # for step in steps[:-1]:
    #     for _, row in resampled_as_data[resampled_as_data['step'] == step].iterrows():
    #         fig.add_trace(go.Scatter(
    #             x=[row['step']+1/2, row['step']+1], y=[row['particle'], row['particle']],
    #             mode='lines', line=dict(color='gray'), opacity=0.2, name=resampled_as, showlegend=False))
fig.update_traces(textposition='middle right')
fig.update_layout(plot_bgcolor="#fff", showlegend=False, )
# fig.update_yaxes(showticklabels=False, visible=False)
fig.show()

In [45]:
# # two frames ata time
# for f in fig.frames:
#     f.layout.update(xaxis_range = [int(f.name)-2,int(f.name)+2])
# fig.show()

In [12]:
import plotly.express as px
import plotly.graph_objects as go

# Sample data (recs) should be loaded or defined here
# Assuming recs is your DataFrame

# Define color palette
pallette = px.colors.n_colors('rgb(255, 0, 0)', 'rgb(0, 125, 255)', 12, colortype='rgb')
color_map = {p: pallette[i] for i, p in enumerate(recs['particle'].unique())}

# Initialize figure with all scatter points
scatter = go.Scatter(
    x=recs["step"],
    y=recs["particle"],
    mode='markers',
    marker=dict(size=recs['prop_exp_weight'], color=recs['resampled as'].map(color_map), opacity=0.6),
    text=recs['token'],
    hoverinfo='text',
    hovertext=recs.apply(lambda row: f"Token: {row['token']}<br>Context: {row['context_string']}<br>Weight: {row['weight']}<br>Resample?: {row['resample?']}", axis=1),
    textposition='top center'
)

fig = go.Figure(data=[scatter])

# Define frames
frames = []
steps = sorted(recs['step'].unique())
for i, step in enumerate(steps):
    # Filter data up to the current step
    data_up_to_step = recs[recs['step'] <= step]
    
    scatter = go.Scatter(
        x=data_up_to_step["step"],
        y=data_up_to_step["particle"],
        mode='markers+text',
        marker=dict(size=data_up_to_step['prop_exp_weight'], color=data_up_to_step['resampled as'].map(color_map), opacity=0.6),
        text=data_up_to_step['token'],
        hoverinfo='text',
        hovertext=data_up_to_step.apply(lambda row: f"Token: {row['token']}<br>Context: {row['context_string']}<br>Weight: {row['weight']}<br>Resample?: {row['resample?']}", axis=1),
        textposition='top center'
    )
    
    resampling_lines = []
    for resampled_as in recs['resampled as'].unique():
        resampled_as_data = recs[recs['resampled as'] == resampled_as]
        
        # Add resampling lines for the current step
        for _, row in resampled_as_data[resampled_as_data['step'] == step].iterrows():
            resampling_lines.append(go.Scatter(
                x=[row['step'], row['step'] + 1],
                y=[row['resampled as'], row['particle']],
                mode='lines',
                line=dict(color=color_map[resampled_as]),
                opacity=0.3,
                name=resampled_as,
                showlegend=False
            ))
            
        for s in steps[:i+1]:
            for _, row in resampled_as_data[resampled_as_data['step'] == s].iterrows():
                resampling_lines.append(go.Scatter(
                    x=[row['step'], row['step'] + 1],
                    y=[row['particle'], row['particle']],
                    mode='lines',
                    line=dict(color='gray'),
                    opacity=0.2,
                    name=resampled_as,
                    showlegend=False
                ))

    # Add vertical lines for the current step if it is a resampling step
    vertical_lines = []
    if step in resample_steps:
        vertical_lines.append(go.layout.Shape(
            type="line",
            x0=step, x1=step,
            y0=0, y1=1,
            line=dict(width=4, color="gray"),
            opacity=0.15
        ))

    frame = go.Frame(data=[scatter] + resampling_lines, name=str(step), layout=go.Layout(shapes=vertical_lines))
    frames.append(frame)

# Update layout with animation controls
fig.update_layout(
    width=1200, height=500,
    xaxis=dict(range=[-1, 36]),
    yaxis=dict(type='category'),  # Ensure the y-axis is categorical
    plot_bgcolor="#fff",
    showlegend=False,
    updatemenus=[{
        "buttons": [
            {
                "args": [None, {"frame": {"duration": 500, "redraw": True}, "fromcurrent": True}],
                "label": "Play",
                "method": "animate"
            },
            {
                "args": [[None], {"frame": {"duration": 0, "redraw": True}, "mode": "immediate", "transition": {"duration": 0}}],
                "label": "Pause",
                "method": "animate"
            }
        ],
        "direction": "left",
        "pad": {"r": 10, "t": 87},
        "showactive": False,
        "type": "buttons",
        "x": 0.1,
        "xanchor": "right",
        "y": 0,
        "yanchor": "top"
    }],
    sliders=[{
        "steps": [
            {
                "args": [
                    [str(step)],
                    {"frame": {"duration": 300, "redraw": True}, "mode": "immediate", "transition": {"duration": 300}}
                ],
                "label": str(step),
                "method": "animate"
            } for step in steps
        ],
        "x": 0.1,
        "xanchor": "left",
        "y": 0,
        "yanchor": "top",
        "currentvalue": {
            "font": {"size": 20},
            "prefix": "Step:",
            "visible": True,
            "xanchor": "right"
        },
        "transition": {"duration": 300, "easing": "cubic-in-out"}
    }]
)

# Set frames to the figure
fig.frames = frames

# Show the figure
fig.show()
