# Experiment 2

The goal of this experiment is to test whether a model can recognize a legal action given a state. 

Input:  A state s and an action $x_t$. 50% of the actions are legal and 50% of them are not

Output: The model's judgment on whether the action is legal or not. 

In [21]:
!pip install openai
! pip install pandas



In [22]:
from openai import OpenAI
from pydantic import BaseModel
import json
import pandas as pd
from tqdm import tqdm

import os
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))

In [27]:
#utils
import re
import random

def parse_expression(expression):
    if expression[0] == "(":
        #remove the parentheses
        expression = expression[1:-1]
    parts = expression.split()
    predicate = parts[0]
    args = parts[1:] if len(parts) > 1 else []
    return predicate, args

def parse_state(state):
    pattern = r"\((.*?)\)"
    matches = re.findall(pattern, state)
    expressions = []
    for match in matches:
        predicate, args = parse_expression(match)
        expressions.append((predicate, args))
    return expressions



# Possible actions
# - (pick-up b1): Pick up block b1. --> For this, a block needs to be clear and the arm needs to be empty 
# - (put-down b1): Put down block b1. --> For this, a block needs to be in the arm 
# - (stack b1 b2): Stack block b1 on top of block b2. --> For this, a block needs to be in the arm and the arm needs to be empty
# - (unstack b1 b2): Unstack block b1 from block b2. --> For this, a block needs to be clear and the arm needs to be empty

def is_block_clear(state, block):
    clear = False
    for expression in state:
        predicate, args = expression
        if predicate == "clear" and block in args:
            # the block has been marked as clear, so return true
            return True 
        elif predicate == "on" and block == args[0]:
            clear = True
        elif predicate == "on" and block == args[1]:
            #Another block is currently on top of this block, immediately return false
            return False
    return clear

def is_arm_empty(state):
    for expression in state:
        predicate, args = expression
        if predicate == "arm-empty":
            return True
    return False

def is_block_on_other(state, block, other_block):
    for expression in state:
        predicate, args = expression
        if predicate == "on" and block == args[0] and other_block == args[1]:
            return True
    return False

def is_block_in_arm(state, block):
    for expression in state:
        predicate, args = expression
        if predicate == "holding" and block in args:
            return True
    return False

def is_action_legal(state, action):
    """Given a state and an action, return whether the action is legal or not."""
    
    state = parse_state(state)
    #Get the blocks in the action
    action_predicate, action_blocks = parse_expression(action)
    if action_predicate == "pick-up":
        block = action_blocks[0]
        #check that the arm is empty and the block is clear
        arm_empty = is_arm_empty(state)
        block_clear = is_block_clear(state, block)
        if arm_empty and block_clear:
            return True
        else:
            return False
    elif action_predicate == "put-down":
        block = action_blocks[0]
        #check that the arm is empty and the block is in the arm
        block_in_arm = is_block_in_arm(state, block)
        if block_in_arm:
            return True
        else:
            return False
    elif action_predicate == "stack":
        block = action_blocks[0]
        other_block = action_blocks[1]
        block_in_arm = is_block_in_arm(state, block)
        other_block_clear = is_block_clear(state, other_block)
        if block_in_arm and other_block_clear:
            return True
        else:
            return False 
    elif action_predicate == "unstack":
        block = action_blocks[0]
        other_block = action_blocks[1]
        block_clear = is_block_clear(state, block)
        block_on_other = is_block_on_other(state, block, other_block)
        arm_empty = is_arm_empty(state)
        if block_clear and block_on_other and arm_empty:
            return True
        else:
            return False

def generate_illegal_action(state):
    """Returns a legal action given a state"""
    parsed_state = parse_state(state)
    #Get the blocks in the state
    blocks = []
    for expression in parsed_state:
        predicate, args = expression
        blocks.extend(args)
    blocks = set(blocks)
    #Randomly select an action
    action = random.choice(["pick-up", "put-down", "stack", "unstack"])
    if action == "pick-up" or action == "put-down":
        block = random.choice(list(blocks))
        action = f"({action} {block})"
        if not is_action_legal(state, action):
            return state
        else:
            return generate_illegal_action(state)
    elif action == "stack" or action == "unstack":
        block = random.choice(list(blocks))
        other_block = random.choice(list(blocks - {block}))
        action = f"({action} {block} {other_block})"
        if not is_action_legal(state, action):
            return state
        else:
            return generate_illegal_action(state)


Dataset curration: We keepm 1000 data points (state, action). 500 are valid, 500 are not

In [28]:
import random
import json
with open("../blockworld_dataset.json", "r") as f:
    full_dataset = json.load(f)

experiment_dataset = []
N = 500 
#Get valid actions  
seen = set()
valid_data = []
while len(valid_data) < N:
    group = random.choice(full_dataset)
    if group["goal"] not in seen:
        for i in range(len(group["actions"])):
            action = group["actions"][i]
            state = group["states"][i]
            valid_data.append((state, action, "legal"))
        seen.add(group["goal"])

experiment_dataset.extend(valid_data)
#Get invalid actions
invalid_data = []
while len(invalid_data) < N:
    group = random.choice(full_dataset)
    if group["goal"] not in seen:
        for i in range(len(group["actions"])):
            action = generate_illegal_action(group["states"][i])
            invalid_data.append((group["states"][i], action, "illegal"))
        seen.add(group["goal"])
experiment_dataset.extend(invalid_data)
random.shuffle(experiment_dataset)
with open("experiment2_dataset.json", "w") as f:
    json.dump(experiment_dataset, f)
print("The dataset has ", len(experiment_dataset), "pairs of action and state")

The dataset has  1002 pairs of action and state


Experiment code:

The goal of this experiment is to test whether a model can recognize a legal action given a state. 

Input:  A state s and an action $x_t$. 50% of the actions are legal and 50% of them are not

Output: The model's judgment on whether the action is legal or not. 

In [29]:
prompt = """
You are a blockworld planner. You are given a state, and an action. You must determine if the action is feasible given the current state. If it is, you must return True. If it is not, you must return False, and explain your reasoning.
The possible commands in the game are:
- (pick-up b1): Pick up block b1.
- (put-down b1): Put down block b1.
- (stack b1 b2): Stack block b1 on top of block b2.
- (unstack b1 b2): Unstack block b1 from block b2.
Here are a few examples of states and actions:
state: (on b2 b1) (arm-empty) (on b1 b3) (on-table b3) (clear b2); action: (unstack b2 b1). Return True. Explanation: The action is legal because the arm is empty, the block b2 is on b1, and the block b2 is clear.
state: (on b1 b3) (holding b2) (clear b1) (on-table b3); action: (put-down b2). Return True. Explanation: The action is legal because the arm is holding b2.
state: (on-table b2) (arm-empty) (on b1 b3) (clear b1) (on-table b3) (clear b2); action: (unstack b1 b3). Return True. Explanation: The action is legal because the arm is empty, the block b1 is on b3, and the block b1 is clear.
state: (on-table b1) (on-table b2) (clear b2) (holding b3) (clear b1); action: (stack b1 b3). Return False. Explanation: The action is not legal because the arm is holding b3 and b1 is on the table, so cannot be stacking b1 on b3.
state: (on-table b1) (on-table b2) (clear b2) (on b3 b1) (arm-empty); action: (pick-up b1). Return False. Explanation: The action is not legal because b3 is on b1, so b1 cannot be picked up.
state: (on b2 b1) (arm-empty) (clear b3) (on-table b3) (on-table b1) (clear b2); action: (stack b2 b1). Return False. Explanation: The action is not legal because b2 is alreadyon b1, so b2 cannot be stacked on b1.
"""

class Response(BaseModel):
    result: bool
    explanation: str

def call_model(state, action):
    response = client.beta.chat.completions.parse(
        model="gpt-4o",
        response_format= Response,
        messages=[
            {"role": "system", "content": prompt},
            {"role": "user", "content": f"state: {state}; action: {action}"}
        ],
        temperature=0.0,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0
    )
    return response.choices[0].message.parsed.result, response.choices[0].message.parsed.explanation


Parse dataset to get pairs of states and actions

In [30]:
with open("experiment2_dataset.json", "r") as f:
    dataset = json.load(f)
df = pd.DataFrame(columns=["state", "action", "result", "explanation", "correct"])
for datapoint in tqdm(dataset):
    state, action, true = datapoint
    result, explanation = call_model(state, action)
    if true == "legal" and result:
        correct = 1
    elif true == "illegal" and not result:
        correct = 1
    else:
        correct = 0
    df.loc[len(df)] = {"state": state, "action": action, "result": result, "explanation": explanation, "correct": correct}

df.to_csv("experiment2_results.csv", index=False)


100%|██████████| 1002/1002 [35:08<00:00,  2.10s/it] 


In [32]:
df.to_csv("experiment2_results.csv", index=False)
accuracy = df["correct"].sum() / len(df)
print(f"Accuracy: {accuracy}")

Accuracy: 0.9570858283433133
