In [41]:
import openai
import os
import pandas as pd
import re
import datetime
import pickle

from tqdm import tqdm

In [2]:
# Configure OpenAI API

openai.api_key = os.environ["OPENAI_API_KEY"]

ENGINE = "code-davinci-002"
MAX_TOKENS = 512
N_SAMPLES = 1
TEMPERATURE = 0.0
STOP = "/**"

In [3]:
GENERATIVE_MODEL_FILE = "data/generative_model.js"

with open(GENERATIVE_MODEL_FILE, "r") as f:
    generative_model_text = f.read()

In [4]:
print(f"{generative_model_text}")

/**
 * WebPPL generative model of a blockworld.
 */
 var makeBlockWorld = function () {

    //// Distributions and parameters ////
    var truncGeom = function (p, m, n) {
        if (m > n) {
            return uniformDraw(_.range(1, n + 1));
        } else {
            return flip(p) ? truncGeom(p, m + 1, n) : m;
        }
    }

    var dim = 10;
    var tableSize = 100;
    var color = function () { return flip() ? 'red' : 'yellow' };
    var monoColor = flip();
    var stackHeight = function () { return truncGeom(0.5, 1, 8) };
    var numStacks = truncGeom(0.5, 1, 8);
    var xpositions = _.range(worldWidth / 2 - tableSize, worldWidth / 2 + tableSize + 20, 20);

    //// Object definitions ////
    var ground = {
        shape: 'rect',
        static: true,
        dims: [100000 * worldWidth, 10],
        x: worldWidth / 2,
        y: worldHeight
    }

    var table = {
        shape: 'rect',
        static: false,
        dims: [tableSize, tableSize],
        x: worldWidth / 2

In [5]:
INPUT_FILE = "data/phys_lang_examples.csv"

df = pd.read_csv(INPUT_FILE, index_col="task_id", keep_default_na=False)
df

Unnamed: 0_level_0,language_full,language_phrase_1,language_phrase_2,language_phrase_3,code_phrase_1,code_phrase_2,code_phrase_3
task_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
1,There is a stack of yellow blocks on the left ...,There is a stack of yellow blocks on the left ...,There are a few red blocks on the middle of th...,,"condition(filter(isYellow, filter(isOnLeft, wo...","condition(filter(isOnMiddle, filter(isRed, wor...",
2,There is a tall stack of yellow blocks on the ...,There is a tall stack of yellow blocks on the ...,There are some red blocks near the yellow stack.,,"condition(filter(isTall, filter(isOnMiddle, fi...","condition(filter(isRed, filter(isNear(filter(i...",
3,"Half of the blocks are red, and half of the bl...",Half of the blocks are red.,Half of the blocks are yellow.,,"condition(filter(isRed, world.blocks).length =...","condition(filter(isYellow, world.blocks).lengt...",
4,"There are several stacks of yellow blocks, and...",There are several stacks of yellow blocks.,There is one stack of red blocks.,,"condition(filter(isYellow, world.stacks).lengt...","condition(filter(isRed, world.stacks).length =...",
5,There are two stacks of yellow blocks on the l...,There are two stacks of yellow blocks on the l...,There are also two stacks of red blocks on the...,,"condition(filter(isOnLeft, (filter(isYellow, w...","condition(filter(isOnRight, (filter(isRed, wor...",
6,There are two stacks of yellow blocks on the l...,There are two stacks of yellow blocks on the l...,There are also two stacks of red blocks on the...,The table is bumped from the left.,"condition(filter(isOnLeft, (filter(isYellow, w...","condition(filter(isOnRight, (filter(isRed, wor...",condition(isOnLeft(world.force));
7,"There is a short stack of red blocks, and ther...",There is a short stack of red blocks.,There is a tall stack of yellow blocks near th...,,"condition(filter(isShort, filter(isRed, world....","condition(filter(isTall, filter(isYellow, filt...",
8,There are many stacks of blocks on the table. ...,There are many stacks of blocks on the table.,All of the blocks are the right side are red.,Most of the blocks on the left side are yellow.,condition(world.stacks.length > 4);,"condition(all(isRed, filter(isOnRight, world.b...","condition(filter(isYellow, filter(isOnLeft, wo..."
9,There are more yellow blocks than red blocks o...,There are more yellow blocks than red blocks o...,There are more red blocks than yellow blocks o...,,"condition(filter(isRed, world.blocks).length >...","condition(filter(isOnEdge, filter(isRed, world...",
10,There are a short stack of red blocks on the l...,There are a short stack of red blocks on the l...,There is a tall stack of yellow blocks on the ...,The table is bumped from the left.,"condition(filter(isShort, filter(isOnLeft, (fi...","condition(filter(isTall, filter(isOnRight, (fi...",condition(isOnLeft(world.force));


In [6]:
TEMPLATE_EXAMPLE_HEADER = """
/**
 * Example:
 * {language_full}
 */
"""

def construct_prompt(df, task_id, global_header):
    prompt = global_header + "\n"
    
    # hold one out
    df_examples = df.drop(df.loc[[task_id]].index)
    
    for _, row in df_examples.iterrows():
        example_header = TEMPLATE_EXAMPLE_HEADER.format(language_full=row["language_full"])
        prompt += example_header
        
        for i in range(1, 4):
            if row[f"language_phrase_{i}"]:
                prompt += "\n"
                prompt += "// " + row[f"language_phrase_{i}"] + "\n"
                prompt += row[f"code_phrase_{i}"] + "\n"
                
    prompt += TEMPLATE_EXAMPLE_HEADER.format(language_full=df.loc[[task_id], "language_full"].item())
        
    return prompt

prompt = construct_prompt(df, 1, generative_model_text)

In [7]:
print(f"{prompt}")

/**
 * WebPPL generative model of a blockworld.
 */
 var makeBlockWorld = function () {

    //// Distributions and parameters ////
    var truncGeom = function (p, m, n) {
        if (m > n) {
            return uniformDraw(_.range(1, n + 1));
        } else {
            return flip(p) ? truncGeom(p, m + 1, n) : m;
        }
    }

    var dim = 10;
    var tableSize = 100;
    var color = function () { return flip() ? 'red' : 'yellow' };
    var monoColor = flip();
    var stackHeight = function () { return truncGeom(0.5, 1, 8) };
    var numStacks = truncGeom(0.5, 1, 8);
    var xpositions = _.range(worldWidth / 2 - tableSize, worldWidth / 2 + tableSize + 20, 20);

    //// Object definitions ////
    var ground = {
        shape: 'rect',
        static: true,
        dims: [100000 * worldWidth, 10],
        x: worldWidth / 2,
        y: worldHeight
    }

    var table = {
        shape: 'rect',
        static: false,
        dims: [tableSize, tableSize],
        x: worldWidth / 2

In [17]:
def query_codex(prompt):
    completion = openai.Completion.create(
        engine=ENGINE,
        prompt=prompt,
        temperature=TEMPERATURE,
        n=N_SAMPLES,
        stop=STOP,
        max_tokens=MAX_TOKENS,
        logprobs=None,
    )
    
    return completion

In [118]:
def extract_conditions(text):
    start, end = "condition", ";"
    return [start + x + end for x in re.findall(str(re.escape(start)) + "(.*)" + str(re.escape(end)), text)]

def extract_language(text):
    start, end = "// ", "\n"
    return re.findall(str(re.escape(start)) + "(.*)" + str(re.escape(end)), text)

def parse_choice(choice):
    if choice.finish_reason != "stop":
        print(f"WARNING: Completion choice {choice.index} encountered non-terminal finish reason: {choice.finish_reason}")

    data = {
        "choice_index": choice.index,
        "finish_reason": choice.finish_reason,
        "text": choice.text,
    }
    
    for i, (language, code) in enumerate(zip(extract_language(choice.text), extract_conditions(choice.text))):
        data.update({
            f"codex_language_phrase_{i+1}": language,
            f"codex_code_phrase_{i+1}": code,
        })
        
    return data

In [119]:
def run_experiment(df: pd.DataFrame, restore_ckpt: str = None):
    completions = []
    
    if restore_ckpt is None:
        # Make new checkpoint
        ckpt_dir = os.path.join("experiments", datetime.datetime.now().strftime('run-%Y-%m-%d-%H-%M-%S'))
        os.makedirs(ckpt_dir, exist_ok=True)
    
        # Query OpenAI for completions
        for task_id in tqdm(df.index):
            prompt = construct_prompt(df, task_id, generative_model_text)
            completion = query_codex(prompt)
            completions.append(completion)

            with open(os.path.join(ckpt_dir, f"completion_task_{task_id:03d}.pkl"), "wb") as f:
                pickle.dump(completion, f)
    else:
        # Load all completions from checkpoint
        for task_id in df.index:
            with open(os.path.join("experiments", restore_ckpt, f"completion_task_{task_id:03d}.pkl"), "rb") as f:
                completions.append(pickle.load(f))
        print(f"Restored completions from {restore_ckpt}")
            
    results = []
    for task_id, completion in zip(df.index, completions):
        for choice in completion.choices:
            d = {"task_id": task_id}
            d.update(parse_choice(choice))
            results.append(d)

    return {"results": results}

In [120]:
results_json = run_experiment(df, restore_ckpt="run-2023-01-27-11-52-34")

Restored completions from run-2023-01-27-11-52-34


In [121]:
results_json

{'results': [{'task_id': 1,
   'choice_index': 0,
   'finish_reason': 'stop',
   'text': '\n// There is a stack of yellow blocks on the left side of the table.\ncondition(filter(isOnLeft, (filter(isYellow, world.stacks))).length == 1);\n\n// There are a few red blocks on the middle of the table.\ncondition(filter(isOnMiddle, (filter(isRed, world.blocks))).length > 0 && filter(isOnMiddle, (filter(isRed, world.blocks))).length < 4);\n\n',
   'codex_language_phrase_1': 'There is a stack of yellow blocks on the left side of the table.',
   'codex_code_phrase_1': 'condition(filter(isOnLeft, (filter(isYellow, world.stacks))).length == 1);',
   'codex_language_phrase_2': 'There are a few red blocks on the middle of the table.',
   'codex_code_phrase_2': 'condition(filter(isOnMiddle, (filter(isRed, world.blocks))).length > 0 && filter(isOnMiddle, (filter(isRed, world.blocks))).length < 4);'},
  {'task_id': 2,
   'choice_index': 0,
   'finish_reason': 'stop',
   'text': '\n// There is a tall st

In [131]:
df_results = pd.DataFrame(results_json["results"])
df_results = df_results.set_index("task_id")
df_results

Unnamed: 0_level_0,choice_index,finish_reason,text,codex_language_phrase_1,codex_code_phrase_1,codex_language_phrase_2,codex_code_phrase_2,codex_language_phrase_3,codex_code_phrase_3
task_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
1,0,stop,\n// There is a stack of yellow blocks on the ...,There is a stack of yellow blocks on the left ...,"condition(filter(isOnLeft, (filter(isYellow, w...",There are a few red blocks on the middle of th...,"condition(filter(isOnMiddle, (filter(isRed, wo...",,
2,0,stop,\n// There is a tall stack of yellow blocks on...,There is a tall stack of yellow blocks on the ...,"condition(filter(isTall, filter(isOnMiddle, (f...",There are some red blocks near the yellow stack.,"condition(filter(isNear(filter(isTall, filter(...",,
3,0,stop,\n// Half of the blocks are red.\ncondition(fi...,Half of the blocks are red.,"condition(filter(isRed, world.blocks).length =...",Half of the blocks are yellow.,"condition(filter(isYellow, world.blocks).lengt...",,
4,0,stop,\n// There are several stacks of yellow blocks...,There are several stacks of yellow blocks.,"condition(filter(isYellow, world.stacks).lengt...",There is one stack of red blocks.,"condition(filter(isRed, world.stacks).length =...",,
5,0,stop,\n// There are two stacks of yellow blocks on ...,There are two stacks of yellow blocks on the l...,"condition(filter(isOnLeft, (filter(isYellow, w...",There are also two stacks of red blocks on the...,"condition(filter(isOnRight, (filter(isRed, wor...",,
6,0,stop,\n// There are two stacks of yellow blocks on ...,There are two stacks of yellow blocks on the l...,"condition(filter(isOnLeft, (filter(isYellow, w...",There are also two stacks of red blocks on the...,"condition(filter(isOnRight, (filter(isRed, wor...",The table is bumped from the left.,condition(isOnLeft(world.force));
7,0,stop,\n// There is a short stack of red blocks.\nco...,There is a short stack of red blocks.,"condition(filter(isShort, filter(isRed, world....",There is a tall stack of yellow blocks near th...,"condition(filter(isTall, filter(isYellow, filt...",,
8,0,stop,\n// There are many stacks of blocks on the ta...,There are many stacks of blocks on the table.,condition(world.stacks.length > 5);,All of the blocks are the right side are red.,"condition(filter(isOnRight, world.blocks).leng...",Most of the blocks on the left side are yellow.,"condition(filter(isOnLeft, world.blocks).lengt..."
9,0,stop,\n// There are more yellow blocks than red blo...,There are more yellow blocks than red blocks o...,"condition(filter(isYellow, world.blocks).lengt...",There are more red blocks than yellow blocks o...,"condition(filter(isRed, filter(isOnEdge, world...",,
10,0,stop,\n// There are a short stack of red blocks on ...,There are a short stack of red blocks on the l...,"condition(filter(isShort, filter(isRed, filter...",There is a tall stack of yellow blocks on the ...,"condition(filter(isTall, filter(isYellow, filt...",The table is bumped from the left.,condition(isOnLeft(world.force));


In [140]:
df.join(df_results).to_csv(os.path.join("experiments", "run-2023-01-27-11-52-34", "results.csv"))
df.join(df_results)

Unnamed: 0_level_0,language_full,language_phrase_1,language_phrase_2,language_phrase_3,code_phrase_1,code_phrase_2,code_phrase_3,choice_index,finish_reason,text,codex_language_phrase_1,codex_code_phrase_1,codex_language_phrase_2,codex_code_phrase_2,codex_language_phrase_3,codex_code_phrase_3
task_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
1,There is a stack of yellow blocks on the left ...,There is a stack of yellow blocks on the left ...,There are a few red blocks on the middle of th...,,"condition(filter(isYellow, filter(isOnLeft, wo...","condition(filter(isOnMiddle, filter(isRed, wor...",,0,stop,\n// There is a stack of yellow blocks on the ...,There is a stack of yellow blocks on the left ...,"condition(filter(isOnLeft, (filter(isYellow, w...",There are a few red blocks on the middle of th...,"condition(filter(isOnMiddle, (filter(isRed, wo...",,
2,There is a tall stack of yellow blocks on the ...,There is a tall stack of yellow blocks on the ...,There are some red blocks near the yellow stack.,,"condition(filter(isTall, filter(isOnMiddle, fi...","condition(filter(isRed, filter(isNear(filter(i...",,0,stop,\n// There is a tall stack of yellow blocks on...,There is a tall stack of yellow blocks on the ...,"condition(filter(isTall, filter(isOnMiddle, (f...",There are some red blocks near the yellow stack.,"condition(filter(isNear(filter(isTall, filter(...",,
3,"Half of the blocks are red, and half of the bl...",Half of the blocks are red.,Half of the blocks are yellow.,,"condition(filter(isRed, world.blocks).length =...","condition(filter(isYellow, world.blocks).lengt...",,0,stop,\n// Half of the blocks are red.\ncondition(fi...,Half of the blocks are red.,"condition(filter(isRed, world.blocks).length =...",Half of the blocks are yellow.,"condition(filter(isYellow, world.blocks).lengt...",,
4,"There are several stacks of yellow blocks, and...",There are several stacks of yellow blocks.,There is one stack of red blocks.,,"condition(filter(isYellow, world.stacks).lengt...","condition(filter(isRed, world.stacks).length =...",,0,stop,\n// There are several stacks of yellow blocks...,There are several stacks of yellow blocks.,"condition(filter(isYellow, world.stacks).lengt...",There is one stack of red blocks.,"condition(filter(isRed, world.stacks).length =...",,
5,There are two stacks of yellow blocks on the l...,There are two stacks of yellow blocks on the l...,There are also two stacks of red blocks on the...,,"condition(filter(isOnLeft, (filter(isYellow, w...","condition(filter(isOnRight, (filter(isRed, wor...",,0,stop,\n// There are two stacks of yellow blocks on ...,There are two stacks of yellow blocks on the l...,"condition(filter(isOnLeft, (filter(isYellow, w...",There are also two stacks of red blocks on the...,"condition(filter(isOnRight, (filter(isRed, wor...",,
6,There are two stacks of yellow blocks on the l...,There are two stacks of yellow blocks on the l...,There are also two stacks of red blocks on the...,The table is bumped from the left.,"condition(filter(isOnLeft, (filter(isYellow, w...","condition(filter(isOnRight, (filter(isRed, wor...",condition(isOnLeft(world.force));,0,stop,\n// There are two stacks of yellow blocks on ...,There are two stacks of yellow blocks on the l...,"condition(filter(isOnLeft, (filter(isYellow, w...",There are also two stacks of red blocks on the...,"condition(filter(isOnRight, (filter(isRed, wor...",The table is bumped from the left.,condition(isOnLeft(world.force));
7,"There is a short stack of red blocks, and ther...",There is a short stack of red blocks.,There is a tall stack of yellow blocks near th...,,"condition(filter(isShort, filter(isRed, world....","condition(filter(isTall, filter(isYellow, filt...",,0,stop,\n// There is a short stack of red blocks.\nco...,There is a short stack of red blocks.,"condition(filter(isShort, filter(isRed, world....",There is a tall stack of yellow blocks near th...,"condition(filter(isTall, filter(isYellow, filt...",,
8,There are many stacks of blocks on the table. ...,There are many stacks of blocks on the table.,All of the blocks are the right side are red.,Most of the blocks on the left side are yellow.,condition(world.stacks.length > 4);,"condition(all(isRed, filter(isOnRight, world.b...","condition(filter(isYellow, filter(isOnLeft, wo...",0,stop,\n// There are many stacks of blocks on the ta...,There are many stacks of blocks on the table.,condition(world.stacks.length > 5);,All of the blocks are the right side are red.,"condition(filter(isOnRight, world.blocks).leng...",Most of the blocks on the left side are yellow.,"condition(filter(isOnLeft, world.blocks).lengt..."
9,There are more yellow blocks than red blocks o...,There are more yellow blocks than red blocks o...,There are more red blocks than yellow blocks o...,,"condition(filter(isRed, world.blocks).length >...","condition(filter(isOnEdge, filter(isRed, world...",,0,stop,\n// There are more yellow blocks than red blo...,There are more yellow blocks than red blocks o...,"condition(filter(isYellow, world.blocks).lengt...",There are more red blocks than yellow blocks o...,"condition(filter(isRed, filter(isOnEdge, world...",,
10,There are a short stack of red blocks on the l...,There are a short stack of red blocks on the l...,There is a tall stack of yellow blocks on the ...,The table is bumped from the left.,"condition(filter(isShort, filter(isOnLeft, (fi...","condition(filter(isTall, filter(isOnRight, (fi...",condition(isOnLeft(world.force));,0,stop,\n// There are a short stack of red blocks on ...,There are a short stack of red blocks on the l...,"condition(filter(isShort, filter(isRed, filter...",There is a tall stack of yellow blocks on the ...,"condition(filter(isTall, filter(isYellow, filt...",The table is bumped from the left.,condition(isOnLeft(world.force));


In [139]:
for task_id, row in df.join(df_results).iterrows():
    print("---------")
    print(f"task_id {task_id}: {row['language_full']}")
    print()
    
    for i in range(1, 4):
        print(f"Language {i}")
        print(f"TRUTH: {row[f'language_phrase_{i}']}")
        print(f"CODEX: {row[f'codex_language_phrase_{i}']}")
        print()
        
        print(f"Code {i}")
        print(f"TRUTH: {row[f'code_phrase_{i}']}")
        print(f"CODEX: {row[f'codex_code_phrase_{i}']}")
        print()

---------
task_id 1: There is a stack of yellow blocks on the left side of the table, and there are a few red blocks on the middle of the table.

Language 1
TRUTH: There is a stack of yellow blocks on the left side of the table.
CODEX: There is a stack of yellow blocks on the left side of the table.

Code 1
TRUTH: condition(filter(isYellow, filter(isOnLeft, world.stacks)).length == 1);
CODEX: condition(filter(isOnLeft, (filter(isYellow, world.stacks))).length == 1);

Language 2
TRUTH: There are a few red blocks on the middle of the table.
CODEX: There are a few red blocks on the middle of the table.

Code 2
TRUTH: condition(filter(isOnMiddle, filter(isRed, world.blocks)).length > 0 && filter(isOnMiddle, filter(isRed, world.blocks)).length <= 3);
CODEX: condition(filter(isOnMiddle, (filter(isRed, world.blocks))).length > 0 && filter(isOnMiddle, (filter(isRed, world.blocks))).length < 4);

Language 3
TRUTH: 
CODEX: nan

Code 3
TRUTH: 
CODEX: nan

---------
task_id 2: There is a tall stac