In [5]:
import os
from time import sleep
import json
import argparse
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import LLMChain
from langchain_google_genai import ChatGoogleGenerativeAI
import base64

In [55]:
parser = argparse.ArgumentParser(description="test LLM planning abilities")
parser.add_argument(
        "--dim",
        type=int,
        default=1,
        help="dimension of the problem 1 or 2 (for 1D or 2D)",
)
parser.add_argument(
        "--prompt_type",
        type=str,
        default='text',
        help="type of prompt: text, image",
    )
parser.add_argument(
        "--color",
        type=int,
        default=0,
        help="Use color images (1) or not (0)",
    )
parser.add_argument(
        "--gap",
        type=int,
        default=0,
        help="Use images with gap (1) or not (0)",
    )
parser.add_argument(
        "--model_name",
        type=str,
        default='all',
        help="num",
    )
parser.add_argument(
        "--openai_key",
        type=str,
        default='',
        help="num",
    )
parser.add_argument(
        "--gemini_key",
        type=str,
        default='',
        help="num",
    )

args_string = '--dim=2 --prompt_type=both --color=0 --gap=1'
args_list = args_string.split(' ')
args = parser.parse_args(args_list)

In [56]:
if args.model_name=='all':
    model_name = [
        'gpt-3.5-turbo',
        'gpt-4o-mini',
        'gpt-4o-2024-08-06'
    ]
else:
    model_name = args.model_name
# model_name='gpt-3.5-turbo'
# model_name='gpt-4o-mini'
model_name='gpt-4o-2024-08-06'
# model_name='gemini-1.5-pro'

if not args.openai_key:
    apikey_filepath = '../.openai_key.txt'
    with open(apikey_filepath, 'r') as f:
        openai_key = f.read()
else:
    openai_key = args.key

if not args.gemini_key:
    gemini_key_filepath = '../.gemini_api_key.txt'
    with open(gemini_key_filepath, 'r') as f:
        gemini_key = f.read()
else:
    gemini_key = args.gemini_key

if 'gpt' in model_name:
    print('gpt model')
    client = ChatOpenAI(api_key=openai_key, model_name=model_name)
elif 'gemini' in model_name:
    print('gemini model')
    client = ChatGoogleGenerativeAI(google_api_key=gemini_key, model=model_name)

gpt model


In [57]:
# Load Data
if args.dim == 1:
    examples_data_dir = '../data/brick_1D_50'
elif args.dim == 2:
    examples_data_dir = '../data/brick_2D_50'

with open(os.path.join(examples_data_dir, 'data.json')) as f:
    data = json.load(f)

zeroshot_prompt = 'Lets think step by step, and provide the answer in the format of a sequence of bricks by a comma in the last sentence.'
zeroshot_prompt = 'Lets think step by step, and provide the answer in the format of a sequence of bricks by a comma in the last sentence.'


In [58]:
if args.prompt_type.lower() in ['text', 'txt']:
    res_list = []

    for i, item in enumerate(data):
        print('step: ', i+1)
        response = client(messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": zeroshot_prompt + '\n\n' + 'Question:' + '\n' + item['data'] + '\nAnswer:\n'}
        ],
        max_tokens=2048,
        temperature=0)
        res = response.content

        dict_res = {'pred': res, 'label': item['label']}
        res_list.append(dict_res)


In [59]:
data[0]['target']

'R'

In [60]:
if args.prompt_type.lower() in ['image','img']:

    def encode_image(image_path):
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
    

    image_subdir = 'images'
    if args.color:
        color='color'
    else:
        color='bw'

    if args.gap:
        gap='_gap'
    else:
        gap=''

    img_prompt = 'The image shows a set of bricks that can be placed on top of each other. Now we have to get a specific brick. The bricks must now be grabbed from top to bottom, and if the lower brick is to be grabbed, the upper brick must be removed first. How to get brick {t}?'
    res_list = []
    i = 0

    for i, item in enumerate(data):
        print('step: ', i+1)
        image_path = os.path.join(examples_data_dir, image_subdir, f'img_{i}_{color}{gap}.png')
        image_base64 = encode_image(image_path)

        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", zeroshot_prompt + img_prompt.format(t=item['target'])),
                (
                    "user",
                    [
                        {
                            "type": "image_url",
                            "image_url": {"url": "data:image/jpeg;base64,{image_data}"},
                        }
                    ],
                ),
            ]
        )
        chain = prompt | client
        res = chain.invoke({'image_data':image_base64}).content
        dict_res = {'pred': res, 'label': item['label']}
        res_list.append(dict_res)

In [61]:
res_list

[{'pred': 'To get brick R, first remove the brick J. So the sequence is: J, R.',
  'label': 'JR'},
 {'pred': 'To get brick C, remove the bricks in the following order: F, W, R.',
  'label': 'FWRC'},
 {'pred': 'To get brick Z, you need to remove the bricks on top of it in the following order: R, F, J, C.',
  'label': 'RFJCZ'},
 {'pred': 'To get brick F, remove the bricks in this order: I, G.',
  'label': 'IGF'},
 {'pred': 'To get brick L, you need to remove the bricks above it in the same stack. The sequence is: S, U, Z, C, Q.',
  'label': 'SUZCQL'},
 {'pred': 'To get brick Z, remove the bricks above it in the following order: S, M, R, H, N, Q.',
  'label': 'SMRHNQZ'},
 {'pred': 'To get brick T, remove the bricks in this order: L, W, G.',
  'label': 'LWGT'},
 {'pred': 'To get brick F, you can grab it directly as it is not blocked by any other brick.',
  'label': 'F'},
 {'pred': 'To get brick S, remove the bricks in this order: Q, E, H, F, P, V.',
  'label': 'QEHFPVS'},
 {'pred': 'To get

In [63]:
if args.prompt_type.lower() in ['both','all','text+img']:

    def encode_image(image_path):
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
    

    image_subdir = 'images'
    if args.color:
        color='color'
    else:
        color='bw'
    
    if args.gap:
        gap='_gap'
    else:
        gap=''

    txt_img_prompt_suffix = ' The described brick layout is visualized in the attached image. You can use both image and textual description to solve the task.'
    res_list = []
    i = 0

    for i, item in enumerate(data):
        print('step: ', i+1)
        image_path = os.path.join(examples_data_dir, image_subdir, f'img_{i}_{color}{gap}.png')
        image_base64 = encode_image(image_path)

        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", zeroshot_prompt),
                (
                    "user",
                    [
                        {
                            "type": "text",
                            "text": item['data'] + txt_img_prompt_suffix
                        },
                        {
                            "type": "image_url",
                            "image_url": {"url": "data:image/jpeg;base64,{image_data}"},
                        }
                    ],
                ),
            ]
        )
        chain = prompt | client
        res = chain.invoke({'image_data':image_base64}).content
        dict_res = {'pred': res, 'label': item['label']}
        res_list.append(dict_res)

step:  1


In [64]:
res_list

[{'pred': 'To get brick R, you need to follow these steps:\n\n1. Remove brick J (which is on top of brick R).\n2. Now you can grab brick R.\n\nSo the sequence is: J, R.',
  'label': 'JR'}]

In [54]:
if args.prompt_type.lower() in ['img','image','both','all','text+img']:
    colorname = '_color' if args.color else '_bw'
else:
    colorname = ''

if args.gap:
    gap = '_gap'
else:
    gap = ''

output_path = os.path.join('../data/results/', f'{args.dim}D_{args.prompt_type}{colorname}{gap}_{model_name}.json')
os.makedirs(os.path.dirname(output_path), exist_ok=True)

with open(output_path, 'w') as outfile:
    json.dump(res_list, outfile)

print(f'Results saved to {output_path}')

Results saved to ../data/results/2D_img_bw_gap_gpt-4o-2024-08-06.json
