# Experiment 4

Task 4: Given a sequence of actions $x_1,...,x_{t-1}$, which $x_t$ is legal?

Going from an initial state, ensure that the model can follow through intermediary states, and output a final legal action on the second to last state.


Input:  An initial state and a list of actions that were taken.

Output: An action $x_t$. We test whether the action is legal

In [1]:
#install dependencies
!pip install openai
! pip install pandas



In [2]:
#imports
from openai import OpenAI
from pydantic import BaseModel
import json
import pandas as pd
from tqdm import tqdm

import os
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))

In [3]:
#utils
import re
import random

def parse_expression(expression):
    if expression[0] == "(":
        #remove the parentheses
        expression = expression[1:-1]
    parts = expression.split()
    predicate = parts[0]
    args = parts[1:] if len(parts) > 1 else []
    return predicate, args

def parse_state(state):
    pattern = r"\((.*?)\)"
    matches = re.findall(pattern, state)
    expressions = []
    for match in matches:
        predicate, args = parse_expression(match)
        expressions.append((predicate, args))
    return expressions



# Possible actions
# - (pick-up b1): Pick up block b1. --> For this, a block needs to be clear and the arm needs to be empty 
# - (put-down b1): Put down block b1. --> For this, a block needs to be in the arm 
# - (stack b1 b2): Stack block b1 on top of block b2. --> For this, a block needs to be in the arm and the arm needs to be empty
# - (unstack b1 b2): Unstack block b1 from block b2. --> For this, a block needs to be clear and the arm needs to be empty

def is_block_clear(state, block):
    clear = False
    for expression in state:
        predicate, args = expression
        if predicate == "clear" and block in args:
            # the block has been marked as clear, so return true
            return True 
        elif predicate == "on" and block == args[0]:
            clear = True
        elif predicate == "on" and block == args[1]:
            #Another block is currently on top of this block, immediately return false
            return False
    return clear

def is_arm_empty(state):
    for expression in state:
        predicate, args = expression
        if predicate == "arm-empty":
            return True
    return False

def is_block_on_other(state, block, other_block):
    for expression in state:
        predicate, args = expression
        if predicate == "on" and block == args[0] and other_block == args[1]:
            return True
    return False

def is_block_in_arm(state, block):
    for expression in state:
        predicate, args = expression
        if predicate == "holding" and block in args:
            return True
    return False

def is_action_legal(state, action):
    """Given a state and an action, return whether the action is legal or not."""
    
    state = parse_state(state)
    #Get the blocks in the action
    action_predicate, action_blocks = parse_expression(action)
    if action_predicate == "pick-up":
        block = action_blocks[0]
        #check that the arm is empty and the block is clear
        arm_empty = is_arm_empty(state)
        block_clear = is_block_clear(state, block)
        if arm_empty and block_clear:
            return True
        else:
            return False
    elif action_predicate == "put-down":
        block = action_blocks[0]
        #check that the arm is empty and the block is in the arm
        block_in_arm = is_block_in_arm(state, block)
        if block_in_arm:
            return True
        else:
            return False
    elif action_predicate == "stack":
        block = action_blocks[0]
        other_block = action_blocks[1]
        block_in_arm = is_block_in_arm(state, block)
        other_block_clear = is_block_clear(state, other_block)
        if block_in_arm and other_block_clear:
            return True
        else:
            return False 
    elif action_predicate == "unstack":
        block = action_blocks[0]
        other_block = action_blocks[1]
        block_clear = is_block_clear(state, block)
        block_on_other = is_block_on_other(state, block, other_block)
        arm_empty = is_arm_empty(state)
        if block_clear and block_on_other and arm_empty:
            return True
        else:
            return False

def generate_illegal_action(state):
    """Returns a legal action given a state"""
    parsed_state = parse_state(state)
    #Get the blocks in the state
    blocks = []
    for expression in parsed_state:
        predicate, args = expression
        blocks.extend(args)
    blocks = set(blocks)
    #Randomly select an action
    action = random.choice(["pick-up", "put-down", "stack", "unstack"])
    if action == "pick-up" or action == "put-down":
        block = random.choice(list(blocks))
        action = f"({action} {block})"
        if not is_action_legal(state, action):
            return state
        else:
            return generate_illegal_action(state)
    elif action == "stack" or action == "unstack":
        block = random.choice(list(blocks))
        other_block = random.choice(list(blocks - {block}))
        action = f"({action} {block} {other_block})"
        if not is_action_legal(state, action):
            return state
        else:
            return generate_illegal_action(state)


Dataset curation: We keep 1000 (initial_state, [actions])

In [6]:
#randomly choose 1000 states from the full dataset
import random
import json
with open("../blockworld_dataset.json", "r") as f:
    full_dataset = json.load(f)

experiment_dataset = []
N = 1000
#Get valid actions  
seen = set()
while len(experiment_dataset) < N:
    group = random.choice(full_dataset)
    if group["goal"] not in seen:
        actions = group["actions"][:-1]
        state = group["states"][0]
        datapoint = {"initial_state": state, "actions": actions, "implicit_state": group["states"][-2]}
        experiment_dataset.append(datapoint)
        seen.add(group["goal"])
random.shuffle(experiment_dataset)
with open("experiment4_dataset.json", "w") as f:
    json.dump(experiment_dataset, f)
print("The dataset has ", len(experiment_dataset), "states")

The dataset has  1000 states


Experiment code:

Going from an initial state, ensure that the model can follow through intermediary states, and output a final legal action on the second to last state.

Input:  An initial state and a list of actions that were taken.

Output: An action $x_t$. We test whether the action is legal

In [8]:
prompt = """
You are a blockworld planner. You are given an initial state and a list of actions. You must first implicitly generate the state S that we will reach after all the actions have been executed on the initial state. Finally, you must generate an action that is feasible in S. 
The possible commands in the game are (the blocks given are just examples, the model should be able to handle any block):
- (pick-up b1): Pick up block b1.
- (put-down b1): Put down block b1.
- (stack b1 b2): Stack block b1 on top of block b2.
- (unstack b1 b2): Unstack block b1 from block b2.
You must understand the state S, and generate an action that makes sense given the state. All actions are accomplished by the arm. The arm is either empty (arm-empty) or holding a block (holding b1).
Here are a few examples of states:
1. (on b2 b1) (arm-empty) (on b1 b3) (on-table b3) (clear b2)
2. (on b1 b3) (holding b2) (clear b1) (on-table b3)
3.(on-table b1) (on-table b2) (clear b2) (holding b3) (clear b1)
4. (on-table b1) (on-table b2) (clear b2) (on b3 b1) (arm-empty)
5.(on b2 b1) (arm-empty) (clear b3) (on-table b3) (on-table b1) (clear b2)
"""

class Response(BaseModel):
    implicit_state_S: str
    action: str
    explanation: str

def call_model(state, actions):
    response = client.beta.chat.completions.parse(
        model="gpt-4o",
        response_format= Response,
        messages=[
            {"role": "system", "content": prompt},
            {"role": "user", "content": f"state: {state}\n actions: {actions}"}
        ],
        temperature=0.0,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0
    )
    return response.choices[0].message.parsed.implicit_state_S, response.choices[0].message.parsed.action, response.choices[0].message.parsed.explanation

def validate_output(action):
    try:
        predicate, args = parse_expression(action)
        if (predicate == "pick-up" or predicate == "put-down") and len(args) == 1:
            return True
        elif (predicate == "stack" or predicate == "unstack") and len(args) == 2:
            return True
        return False
    except Exception as e:
        print("Validation Error:", e)
        return False


In [9]:
with open("experiment4_dataset.json", "r") as f:
    dataset = json.load(f)
df = pd.DataFrame(columns=["initial_state", "actions", "implicit_state", "action", "explanation", "is_valid"])
def check_and_save(initial_state, actions, implicit_state, action, explanation, df):
    is_valid = is_action_legal(implicit_state, action)
    df.loc[len(df)] = {"initial_state": initial_state, "actions": actions, "implicit_state": implicit_state, "action": action, "explanation": explanation, "is_valid": is_valid}
    
for datapoint in tqdm(dataset):
    tries = 0
    implicit_state, action, explanation = call_model(datapoint["initial_state"], datapoint["actions"])
    clean_output= validate_output(action)
    if clean_output:
        check_and_save(datapoint["initial_state"], datapoint["actions"], implicit_state, action, explanation, df)   
    else:
        tries += 1
        if tries > 3:
            continue
        else:
            implicit_state, action, explanation = call_model(datapoint["initial_state"], datapoint["actions"])
            clean_output= validate_output(action)
            if clean_output:
                check_and_save(datapoint["initial_state"], datapoint["actions"], implicit_state, action, explanation, df)
df.to_csv("experiment4_results.csv", index=False)

100%|██████████| 1000/1000 [1:15:10<00:00,  4.51s/it]   


In [10]:
df = pd.read_csv("experiment4_results.csv")
df["is_valid"] = df["is_valid"].apply(lambda x: 1 if x else 0)
df.to_csv("experiment4_results_binary.csv", index=False)


In [12]:
accuracy = df["is_valid"].sum() / len(df)
print(f"Accuracy: {accuracy}")

Accuracy: 0.96
