In [2]:
%reload_ext autoreload
%autoreload 2

import argparse
import re
from pathlib import Path
import pandas as pd
import os
import json
from build_stimulus_table import build_stimulus_table
from sample_experiment import generate_profile, generate_debug_profile, save_profile

In [3]:
# Paths
WORK = Path(os.environ.get("WORK", Path.cwd()))
DATA = WORK / "color-concept-entanglement" / "data"
OBJ_IMG_ROOT = DATA / "color_images" / "gpt-4o"

## Create stimulus table for each condition

### 1. Objects with color priors

In [4]:
df_priors = build_stimulus_table(
    dataset_root=OBJ_IMG_ROOT / "image_priors",
    stimulus_type="correct_prior",
    data_root=DATA,
)
# Drop images colored with black
df_priors = df_priors[df_priors["target_color"] != "black"]

df_priors.to_csv(DATA / "prolific_stimuli" / f"stimulus_table_image_priors.csv", index=False)
display(df_priors)
display(df_priors.value_counts("target_color"))

Unnamed: 0,object,stimulus_type,target_color,variant_region,percent_colored,mode,variant_label,image_path
0,Band Aid,correct_prior,pink,BG,0,seq,BG 0% (seq),color_images/gpt-4o/image_priors/Band_Aid_1_e7...
1,Band Aid,correct_prior,pink,BG,1,seq,BG 1% (seq),color_images/gpt-4o/image_priors/Band_Aid_1_e7...
2,Band Aid,correct_prior,pink,BG,2,seq,BG 2% (seq),color_images/gpt-4o/image_priors/Band_Aid_1_e7...
3,Band Aid,correct_prior,pink,BG,3,seq,BG 3% (seq),color_images/gpt-4o/image_priors/Band_Aid_1_e7...
4,Band Aid,correct_prior,pink,BG,4,seq,BG 4% (seq),color_images/gpt-4o/image_priors/Band_Aid_1_e7...
...,...,...,...,...,...,...,...,...
8327,worm,correct_prior,brown,FG,60,seq,FG 60% (seq),color_images/gpt-4o/image_priors/worm_1_69c913...
8328,worm,correct_prior,brown,FG,70,seq,FG 70% (seq),color_images/gpt-4o/image_priors/worm_1_69c913...
8329,worm,correct_prior,brown,FG,80,seq,FG 80% (seq),color_images/gpt-4o/image_priors/worm_1_69c913...
8330,worm,correct_prior,brown,FG,90,seq,FG 90% (seq),color_images/gpt-4o/image_priors/worm_1_69c913...


target_color
brown     2974
green     1176
grey      1168
red       1128
blue       584
orange     546
yellow     462
pink       168
purple     126
Name: count, dtype: int64

### 2. Objects with counterfact color

In [5]:
df_cf = build_stimulus_table(
    dataset_root=OBJ_IMG_ROOT / "counterfact",
    stimulus_type="counterfact",
    data_root=DATA,
)
# Drop images colored with black
df_cf = df_cf[df_cf["target_color"] != "black"]

df_cf.to_csv(DATA / "prolific_stimuli" / f"stimulus_table_counterfact.csv", index=False)
display(df_cf)
display(df_cf.value_counts("target_color"))

Unnamed: 0,object,stimulus_type,target_color,variant_region,percent_colored,mode,variant_label,image_path
0,Band Aid,counterfact,purple,BG,0,seq,BG 0% (seq),color_images/gpt-4o/counterfact/Band_Aid_1_e73...
1,Band Aid,counterfact,purple,BG,1,seq,BG 1% (seq),color_images/gpt-4o/counterfact/Band_Aid_1_e73...
2,Band Aid,counterfact,purple,BG,2,seq,BG 2% (seq),color_images/gpt-4o/counterfact/Band_Aid_1_e73...
3,Band Aid,counterfact,purple,BG,3,seq,BG 3% (seq),color_images/gpt-4o/counterfact/Band_Aid_1_e73...
4,Band Aid,counterfact,purple,BG,4,seq,BG 4% (seq),color_images/gpt-4o/counterfact/Band_Aid_1_e73...
...,...,...,...,...,...,...,...,...
8985,worm,counterfact,purple,FG,60,seq,FG 60% (seq),color_images/gpt-4o/counterfact/worm_1_69c913c...
8986,worm,counterfact,purple,FG,70,seq,FG 70% (seq),color_images/gpt-4o/counterfact/worm_1_69c913c...
8987,worm,counterfact,purple,FG,80,seq,FG 80% (seq),color_images/gpt-4o/counterfact/worm_1_69c913c...
8988,worm,counterfact,purple,FG,90,seq,FG 90% (seq),color_images/gpt-4o/counterfact/worm_1_69c913c...


target_color
purple    1672
pink      1672
blue      1296
green     1214
yellow    1048
red        834
orange     798
brown      456
Name: count, dtype: int64

### 3. Shapes

In [6]:
df_shapes = build_stimulus_table(
    dataset_root= DATA / "shapes" / "shape_colored",
    stimulus_type="shape",
    data_root=DATA,
)
# Drop images colored with black
df_shapes = df_shapes[df_shapes["target_color"] != "black"]

df_shapes.to_csv(DATA / "prolific_stimuli" / f"stimulus_table_shapes.csv", index=False)
display(df_shapes)
display(df_shapes.value_counts("target_color"))

Unnamed: 0,object,stimulus_type,target_color,variant_region,percent_colored,mode,variant_label,image_path
42,circle,shape,blue,BG,0,seq,BG 0% (seq),shapes/shape_colored/circle_v0_blue/BG_000_seq...
43,circle,shape,blue,BG,1,seq,BG 1% (seq),shapes/shape_colored/circle_v0_blue/BG_001_seq...
44,circle,shape,blue,BG,2,seq,BG 2% (seq),shapes/shape_colored/circle_v0_blue/BG_002_seq...
45,circle,shape,blue,BG,3,seq,BG 3% (seq),shapes/shape_colored/circle_v0_blue/BG_003_seq...
46,circle,shape,blue,BG,4,seq,BG 4% (seq),shapes/shape_colored/circle_v0_blue/BG_004_seq...
...,...,...,...,...,...,...,...,...
10495,triangle,shape,yellow,FG,60,seq,FG 60% (seq),shapes/shape_colored/triangle_v4_yellow/FG_060...
10496,triangle,shape,yellow,FG,70,seq,FG 70% (seq),shapes/shape_colored/triangle_v4_yellow/FG_070...
10497,triangle,shape,yellow,FG,80,seq,FG 80% (seq),shapes/shape_colored/triangle_v4_yellow/FG_080...
10498,triangle,shape,yellow,FG,90,seq,FG 90% (seq),shapes/shape_colored/triangle_v4_yellow/FG_090...


target_color
blue      1050
brown     1050
green     1050
grey      1050
orange    1050
pink      1050
purple    1050
red       1050
yellow    1050
Name: count, dtype: int64

## Create survey profiles

In [32]:
N_BASE_PROFILES = 37

profiles = []

for base_id in range(N_BASE_PROFILES):
    for introspection_pos in ["first", "last"]:

        profile = generate_profile(
            df_priors=df_priors,
            df_cf=df_cf,
            df_shapes=df_shapes,
            seed=base_id,
            introspection_position=introspection_pos,
        )

        profiles.append({
            "profile_id": f"{base_id}_{introspection_pos}",
            "base_id": base_id,
            "introspection_position": introspection_pos,
            "questions": profile,
        })

In [33]:
out_dir = DATA / "prolific_stimuli" / "profiles"
out_dir.mkdir(exist_ok=True)

for p in profiles:
    out_path = out_dir / f"profile_{p['profile_id']}.json"
    with open(out_path, "w") as f:
        json.dump(p, f, indent=2)

In [9]:
save_profile(
    generate_debug_profile(df_priors, df_shapes),
    out_dir / "debug_profile.json",
)
# test server.py with http://127.0.0.1:5000/?PROLIFIC_PID=DEBUG