In [96]:
import os
from time import sleep
import json
import argparse
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import LLMChain
from langchain_google_genai import ChatGoogleGenerativeAI
import base64

In [152]:
parser = argparse.ArgumentParser(description="test LLM planning abilities")
parser.add_argument(
        "--dim",
        type=int,
        default=1,
        help="dimension of the problem 1 or 2 (for 1D or 2D)",
)
parser.add_argument(
        "--prompt_type",
        type=str,
        default='text',
        help="type of prompt: text, image, metrical",
    )
parser.add_argument(
        "--color",
        type=int,
        default=0,
        help="Use color images (1) or not (0)",
    )
parser.add_argument(
        "--gap",
        type=int,
        default=0,
        help="Use images with gap (1) or not (0)",
    )
parser.add_argument(
        "--model_name",
        type=str,
        default='all',
        help="num",
    )
parser.add_argument(
        "--openai_key",
        type=str,
        default='',
        help="num",
    )
parser.add_argument(
        "--gemini_key",
        type=str,
        default='',
        help="num",
    )

args_string = '--dim=2 --prompt_type=img'
args_list = args_string.split(' ')
args = parser.parse_args(args_list)

In [150]:
if args.model_name=='all':
    model_name = [
        'gpt-3.5-turbo',
        'gpt-4o-mini',
        'gpt-4o-2024-08-06'
    ]
else:
    model_name = args.model_name
# model_name='gpt-3.5-turbo'
# model_name='gpt-4o-mini'
model_name='gpt-4o-2024-08-06'
# model_name='gemini-1.5-pro'

if not args.openai_key:
    apikey_filepath = '../.openai_key.txt'
    with open(apikey_filepath, 'r') as f:
        openai_key = f.read()
else:
    openai_key = args.key

if not args.gemini_key:
    gemini_key_filepath = '../.gemini_api_key.txt'
    with open(gemini_key_filepath, 'r') as f:
        gemini_key = f.read()
else:
    gemini_key = args.gemini_key

if 'gpt' in model_name:
    print('gpt model')
    client = ChatOpenAI(api_key=openai_key, model_name=model_name)
elif 'gemini' in model_name:
    print('gemini model')
    client = ChatGoogleGenerativeAI(google_api_key=gemini_key, model=model_name)

gpt model


In [151]:
# Load Data
if args.dim == 1:
    examples_data_dir = '../data/brick_1D_50'
elif args.dim == 2:
    examples_data_dir = '../data/brick_2D_50'

if args.prompt_type=='metrical':
    filename='data_metrical_descriptions.json'
else:
    filename='data.json'

with open(os.path.join(examples_data_dir, filename)) as f:
    data = json.load(f)

zeroshot_prompt = 'Lets think step by step, and provide the answer in the format of a sequence of bricks by a comma in the last sentence.'
zeroshot_prompt = 'Lets think step by step, and provide the answer in the format of a sequence of bricks by a comma in the last sentence.'


In [142]:
if args.prompt_type.lower() in ['text', 'txt']:
    res_list = []

    for i, item in enumerate(data):
        print('step: ', i+1)
        response = client(messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": zeroshot_prompt + '\n\n' + 'Question:' + '\n' + item['data'] + '\nAnswer:\n'}
        ],
        max_tokens=2048,
        temperature=0)
        res = response.content

        dict_res = {'pred': res, 'label': item['label']}
        res_list.append(dict_res)


In [143]:
if args.prompt_type.lower() in ['metrical']:
    res_list = []

    for i, item in enumerate(data):
        print('step: ', i+1)
        response = client(messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": zeroshot_prompt + '\n\n' + 'Question:' + '\n' + item['metrical_description'] + '\nAnswer:\n'}
        ],
        max_tokens=2048,
        temperature=0)
        res = response.content

        dict_res = {'pred': res, 'label': item['label']}
        res_list.append(dict_res)


step:  1
step:  2
step:  3
step:  4
step:  5
step:  6
step:  7
step:  8
step:  9
step:  10
step:  11
step:  12
step:  13
step:  14
step:  15
step:  16
step:  17
step:  18
step:  19
step:  20
step:  21
step:  22
step:  23
step:  24
step:  25
step:  26
step:  27
step:  28
step:  29
step:  30
step:  31
step:  32
step:  33
step:  34
step:  35
step:  36
step:  37
step:  38
step:  39
step:  40
step:  41
step:  42
step:  43
step:  44
step:  45
step:  46
step:  47
step:  48
step:  49
step:  50


In [144]:
if args.prompt_type.lower() in ['image','img']:

    def encode_image(image_path):
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
    
    image_subdir = 'images'
    if args.color:
        color='color'
    else:
        color='bw'

    if args.gap:
        gap='_gap'
    else:
        gap=''

    img_prompt = 'The image shows a set of bricks that can be placed on top of each other. Now we have to get a specific brick. The bricks must now be grabbed from top to bottom, and if the lower brick is to be grabbed, the upper brick must be removed first. How to get brick {t}?'
    res_list = []
    i = 0

    for i, item in enumerate(data):
        print('step: ', i+1)
        image_path = os.path.join(examples_data_dir, image_subdir, f'img_{i}_{color}{gap}.png')
        image_base64 = encode_image(image_path)

        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", zeroshot_prompt + img_prompt.format(t=item['target'])),
                (
                    "user",
                    [
                        {
                            "type": "image_url",
                            "image_url": {"url": "data:image/jpeg;base64,{image_data}"},
                        }
                    ],
                ),
            ]
        )
        chain = prompt | client
        res = chain.invoke({'image_data':image_base64}).content
        dict_res = {'pred': res, 'label': item['label']}
        res_list.append(dict_res)

In [145]:
res_list

[{'pred': 'To get brick R, we need to follow this sequence: B, K, F, O, J, R.',
  'label': 'JR'},
 {'pred': 'To get brick C, we need to remove bricks U, I, N, K, E, R, W, F, T, Z, O, and D in that order. U, I, N, K, E, R, W, F, T, Z, O, D, C',
  'label': 'FWRC'},
 {'pred': 'To get brick Z, we need to follow this sequence: Y, N, K, C, Q, X, J, G, F, T, R, Z.',
  'label': 'RFJCZ'},
 {'pred': 'To get brick F, we need to remove bricks R, D, X, H, G, and I in that order. So, the sequence of bricks to be removed is R, D, X, H, G, I.',
  'label': 'IGF'},
 {'pred': 'To get brick L, we need to remove bricks Q, C, Z, U, S, G, N, E, M, Y, V, X, W, F, T, A, I, and O in that order. So, the sequence of bricks to be removed is Q, C, Z, U, S, G, N, E, M, Y, V, X, W, F, T, A, I, O.',
  'label': 'SUZCQL'},
 {'pred': 'To get brick Z, we need to follow this sequence: Brick Y, Brick V, Brick T, Brick B, Brick C, Brick U, Brick Q, Brick N, Brick H, Brick R, Brick M, Brick S, Brick J, Brick W, Brick O, Brick

In [146]:
if args.prompt_type.lower() in ['both','all','text+img']:

    def encode_image(image_path):
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
    

    image_subdir = 'images'
    if args.color:
        color='color'
    else:
        color='bw'
    
    if args.gap:
        gap='_gap'
    else:
        gap=''

    txt_img_prompt_suffix = ' The described brick layout is visualized in the attached image. You can use both image and textual description to solve the task.'
    res_list = []
    i = 0

    for i, item in enumerate(data):
        print('step: ', i+1)
        image_path = os.path.join(examples_data_dir, image_subdir, f'img_{i}_{color}{gap}.png')
        image_base64 = encode_image(image_path)

        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", zeroshot_prompt),
                (
                    "user",
                    [
                        {
                            "type": "text",
                            "text": item['data'] + txt_img_prompt_suffix
                        },
                        {
                            "type": "image_url",
                            "image_url": {"url": "data:image/jpeg;base64,{image_data}"},
                        }
                    ],
                ),
            ]
        )
        chain = prompt | client
        res = chain.invoke({'image_data':image_base64}).content
        dict_res = {'pred': res, 'label': item['label']}
        res_list.append(dict_res)

In [147]:
res_list

[{'pred': 'To get brick R, we need to follow this sequence: B, K, F, O, J, R.',
  'label': 'JR'},
 {'pred': 'To get brick C, we need to remove bricks U, I, N, K, E, R, W, F, T, Z, O, and D in that order. U, I, N, K, E, R, W, F, T, Z, O, D, C',
  'label': 'FWRC'},
 {'pred': 'To get brick Z, we need to follow this sequence: Y, N, K, C, Q, X, J, G, F, T, R, Z.',
  'label': 'RFJCZ'},
 {'pred': 'To get brick F, we need to remove bricks R, D, X, H, G, and I in that order. So, the sequence of bricks to be removed is R, D, X, H, G, I.',
  'label': 'IGF'},
 {'pred': 'To get brick L, we need to remove bricks Q, C, Z, U, S, G, N, E, M, Y, V, X, W, F, T, A, I, and O in that order. So, the sequence of bricks to be removed is Q, C, Z, U, S, G, N, E, M, Y, V, X, W, F, T, A, I, O.',
  'label': 'SUZCQL'},
 {'pred': 'To get brick Z, we need to follow this sequence: Brick Y, Brick V, Brick T, Brick B, Brick C, Brick U, Brick Q, Brick N, Brick H, Brick R, Brick M, Brick S, Brick J, Brick W, Brick O, Brick

In [148]:
if args.prompt_type.lower() in ['img','image','both','all','text+img']:
    colorname = '_color' if args.color else '_bw'
else:
    colorname = ''

if args.gap:
    gap = '_gap'
else:
    gap = ''

output_path = os.path.join('../data/results/', f'{args.dim}D_{args.prompt_type}{colorname}{gap}_{model_name}.json')
os.makedirs(os.path.dirname(output_path), exist_ok=True)

with open(output_path, 'w') as outfile:
    json.dump(res_list, outfile)

print(f'Results saved to {output_path}')

Results saved to ../data/results/2D_metrical_gpt-3.5-turbo.json


## Ablation studies of horizontal and vertical gaps

In [157]:
name = 'img_0_0.1.png'
i = name.split('_')[1]
gap = name.split('_')[-1].strip('.png')
print(i,gap)

0 0.1


In [159]:
data[int(i)]

{'brick_layout': ['B,K', 'F,O', 'R,J'],
 'brick_colors': ['white,blue', 'yellow,white', 'yellow,blue'],
 'image': 'data/brick_2D_50/images/img_0_color.png',
 'target': 'R',
 'data': 'There is a set of bricks. The brick F is to the right of the brick B. The brick O is on top of the brick F . For the brick B, the color is white. The brick K is on top of the brick B . The brick R is to the right of the brick F. The brick J is on top of the brick R . Now we have to get a specific brick. The bricks must now be grabbed from top to bottom, and if the lower brick is to be grabbed, the upper brick must be removed first. How to get brick R?',
 'label': 'JR',
 'metrical_description': 'There is a set of bricks. Brick B is at position (0, 0). Brick K is at position (0, 1). Brick F is at position (1, 0). Brick O is at position (1, 1). Brick R is at position (2, 0). Brick J is at position (2, 1). Now we have to get a specific brick. The bricks must now be grabbed from top to bottom, and if the lower 

In [171]:
def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

ablation_images_subdir = 'images/ablation'

img_prompt = 'The image shows a set of bricks that can be placed on top of each other. Now we have to get a specific brick. The bricks must now be grabbed from top to bottom, and if the lower brick is to be grabbed, the upper brick must be removed first. How to get brick {t}?'
res_list = []
counter = 0

ablation_subdirs = ['both']
# , 'horizontal', 'both']

# iterate through all images in the ablation subdirectories and generate predictions
for dir in ablation_subdirs:
    print(dir)
    ablation_data_dir = os.path.join(examples_data_dir, ablation_images_subdir, dir)
    for name in os.listdir(ablation_data_dir):
        print(counter)
        counter += 1
        print(name)
        image_path = os.path.join(ablation_data_dir, name)
        image_base64 = encode_image(image_path)

        i = name.split('_')[1]
        gap = name.split('_')[-1].strip('.png')
        print(i, gap)

        item = data[int(i)]

        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", zeroshot_prompt + img_prompt.format(t=item['target'])),
                (
                    "user",
                    [
                        {
                            "type": "image_url",
                            "image_url": {"url": "data:image/jpeg;base64,{image_data}"},
                        }
                    ],
                ),
            ]
        )
        chain = prompt | client
        res = chain.invoke({'image_data':image_base64}).content
        dict_res = {'pred': res, 'label': item['label'], 'gap_type': dir, 'gap': gap}
        res_list.append(dict_res)

both
0
img_23_0.1.png
23 0.1
1
img_10_0.4.png
10 0.4
2
img_20_0.5.png
20 0.5
3
img_15_0.9.png
15 0.9
4
img_35_0.5.png
35 0.5
5
img_25_0.2.png
25 0.2
6
img_43_0.7.png
43 0.7
7
img_21_0.4.png
21 0.4
8
img_30_0.6.png
30 0.6
9
img_0_0.4.png
0 0.4
10
img_22_0.4.png
22 0.4
11
img_39_0.3.png
39 0.3
12
img_24_0.8.png
24 0.8
13
img_28_1.0.png
28 1.0
14
img_24_0.3.png
24 0.3
15
img_42_0.2.png
42 0.2
16
img_6_0.4.png
6 0.4
17
img_25_1.0.png
25 1.0
18
img_36_0.3.png
36 0.3
19
img_44_0.1.png
44 0.1
20
img_12_0.5.png
12 0.5
21
img_26_0.1.png
26 0.1
22
img_37_0.1.png
37 0.1
23
img_22_0.5.png
22 0.5
24
img_0_0.9.png
0 0.9
25
img_2_0.5.png
2 0.5
26
img_35_0.7.png
35 0.7
27
img_6_0.8.png
6 0.8
28
img_39_1.0.png
39 1.0
29
img_16_0.4.png
16 0.4
30
img_21_0.8.png
21 0.8
31
img_7_0.1.png
7 0.1
32
img_28_0.9.png
28 0.9
33
img_13_1.0.png
13 1.0
34
img_14_0.1.png
14 0.1
35
img_8_0.8.png
8 0.8
36
img_4_0.2.png
4 0.2
37
img_37_1.0.png
37 1.0
38
img_37_0.4.png
37 0.4
39
img_24_0.1.png
24 0.1
40
img_25_0.3.png
25 

In [177]:
def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

ablation_images_subdir = 'images/ablation_2'

img_prompt = 'The image shows a set of bricks that can be placed on top of each other. Now we have to get a specific brick. The bricks must now be grabbed from top to bottom, and if the lower brick is to be grabbed, the upper brick must be removed first. How to get brick {t}?'
res_list = []
counter = 0

# ablation_subdirs = ['vs_plus_03']
# ablation_subdirs = ['vs_times_2']
ablation_subdirs = ['vs_times_4']
# , 'horizontal', 'both']

# iterate through all images in the ablation subdirectories and generate predictions
for dir in ablation_subdirs:
    print(dir)
    ablation_data_dir = os.path.join(examples_data_dir, ablation_images_subdir, dir)
    for name in os.listdir(ablation_data_dir):
        print(counter)
        counter += 1
        print(name)
        image_path = os.path.join(ablation_data_dir, name)
        image_base64 = encode_image(image_path)

        i = name.split('_')[1]
        gap = name.split('_')[-1].strip('.png')
        print(i, gap)

        item = data[int(i)]

        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", zeroshot_prompt + img_prompt.format(t=item['target'])),
                (
                    "user",
                    [
                        {
                            "type": "image_url",
                            "image_url": {"url": "data:image/jpeg;base64,{image_data}"},
                        }
                    ],
                ),
            ]
        )
        chain = prompt | client
        res = chain.invoke({'image_data':image_base64}).content
        dict_res = {'pred': res, 'label': item['label'], 'gap_type': dir, 'gap': gap}
        res_list.append(dict_res)

vs_times_2
0
img_23_0.1.png
23 0.1
1
img_10_0.4.png
10 0.4
2
img_20_0.5.png
20 0.5
3
img_15_0.9.png
15 0.9
4
img_35_0.5.png
35 0.5
5
img_25_0.2.png
25 0.2
6
img_43_0.7.png
43 0.7
7
img_21_0.4.png
21 0.4
8
img_30_0.6.png
30 0.6
9
img_0_0.4.png
0 0.4
10
img_22_0.4.png
22 0.4
11
img_39_0.3.png
39 0.3
12
img_24_0.8.png
24 0.8
13
img_39_0.0.png
39 0.0
14
img_24_0.3.png
24 0.3
15
img_42_0.2.png
42 0.2
16
img_6_0.4.png
6 0.4
17
img_36_0.3.png
36 0.3
18
img_17_0.0.png
17 0.0
19
img_44_0.1.png
44 0.1
20
img_12_0.5.png
12 0.5
21
img_34_0.0.png
34 0.0
22
img_26_0.1.png
26 0.1
23
img_37_0.1.png
37 0.1
24
img_22_0.5.png
22 0.5
25
img_0_0.9.png
0 0.9
26
img_2_0.5.png
2 0.5
27
img_35_0.7.png
35 0.7
28
img_6_0.8.png
6 0.8
29
img_16_0.4.png
16 0.4
30
img_21_0.8.png
21 0.8
31
img_29_0.0.png
29 0.0
32
img_7_0.1.png
7 0.1
33
img_28_0.9.png
28 0.9
34
img_14_0.1.png
14 0.1
35
img_8_0.8.png
8 0.8
36
img_4_0.2.png
4 0.2
37
img_37_0.4.png
37 0.4
38
img_24_0.1.png
24 0.1
39
img_25_0.3.png
25 0.3
40
img_9_0.1.pn

In [178]:
output_path = os.path.join('../data/results/ablation', f'{dir}.json')
os.makedirs(os.path.dirname(output_path), exist_ok=True)

with open(output_path, 'w') as outfile:
    json.dump(res_list, outfile)

print(f'Results saved to {output_path}')

Results saved to ../data/results/ablation/vs_times_2.json
