In [1]:
%reload_ext autoreload
%autoreload 2

import argparse
import re
from pathlib import Path
import pandas as pd
import os
import json
from build_stimulus_table import build_stimulus_table
from sample_experiment import generate_profile, generate_debug_profile, save_profile

In [2]:
# Paths
WORK = Path(os.environ.get("WORK", Path.cwd()))
DATA = WORK / "color-concept-entanglement" / "data"
OBJ_IMG_ROOT = DATA / "color_images" / "gpt-4o"

## Create stimulus table for each condition

In [3]:
percentage_buckets = [0, 5, 10, 20, 30, 40, 50, 55, 60, 70, 80, 90, 100]

### 1. Objects with color priors

In [4]:
df_priors = build_stimulus_table(
    dataset_root=OBJ_IMG_ROOT / "image_priors",
    stimulus_type="correct_prior",
    data_root=DATA,
)
# Drop images colored with black
df_priors = df_priors[df_priors["manipulation_color"] != "black"]

# Only keep images where percent_colored is in percentage buckets
df_priors = df_priors[df_priors["percent_colored"].isin(percentage_buckets)]

df_priors.to_csv(DATA / "prolific_stimuli" / f"stimulus_table_image_priors.csv", index=False)
display(df_priors.head())
display(df_priors.value_counts("target_color"))

Unnamed: 0,object,stimulus_type,manipulation_color,target_color,variant_region,percent_colored,mode,variant_label,image_path
0,Band Aid,correct_prior,pink,white,BG,0,seq,BG 0% (seq),color_images/gpt-4o/image_priors/Band_Aid_1_e7...
5,Band Aid,correct_prior,pink,white,BG,5,seq,BG 5% (seq),color_images/gpt-4o/image_priors/Band_Aid_1_e7...
10,Band Aid,correct_prior,pink,white,BG,10,seq,BG 10% (seq),color_images/gpt-4o/image_priors/Band_Aid_1_e7...
11,Band Aid,correct_prior,pink,white,BG,20,seq,BG 20% (seq),color_images/gpt-4o/image_priors/Band_Aid_1_e7...
12,Band Aid,correct_prior,pink,white,BG,30,seq,BG 30% (seq),color_images/gpt-4o/image_priors/Band_Aid_1_e7...


target_color
white     2786
brown      852
grey       336
green      336
red        324
blue       168
orange     156
yellow     132
pink        48
purple      36
Name: count, dtype: int64

### 2. Objects with counterfact color

In [5]:
df_cf = build_stimulus_table(
    dataset_root=OBJ_IMG_ROOT / "counterfact",
    stimulus_type="counterfact",
    data_root=DATA,
)
# Drop images colored with black
df_cf = df_cf[df_cf["manipulation_color"] != "black"]

# Only keep images where percent_colored is in percentage buckets
df_cf = df_cf[df_cf["percent_colored"].isin(percentage_buckets)]

df_cf.to_csv(DATA / "prolific_stimuli" / f"stimulus_table_counterfact.csv", index=False)
display(df_cf.head())
display(df_cf.value_counts("target_color"))

Unnamed: 0,object,stimulus_type,manipulation_color,target_color,variant_region,percent_colored,mode,variant_label,image_path
0,Band Aid,counterfact,purple,white,BG,0,seq,BG 0% (seq),color_images/gpt-4o/counterfact/Band_Aid_1_e73...
5,Band Aid,counterfact,purple,white,BG,5,seq,BG 5% (seq),color_images/gpt-4o/counterfact/Band_Aid_1_e73...
10,Band Aid,counterfact,purple,white,BG,10,seq,BG 10% (seq),color_images/gpt-4o/counterfact/Band_Aid_1_e73...
11,Band Aid,counterfact,purple,white,BG,20,seq,BG 20% (seq),color_images/gpt-4o/counterfact/Band_Aid_1_e73...
12,Band Aid,counterfact,purple,white,BG,30,seq,BG 30% (seq),color_images/gpt-4o/counterfact/Band_Aid_1_e73...


target_color
white     3010
purple     480
pink       480
blue       384
green      348
yellow     312
orange     252
red        240
brown       84
Name: count, dtype: int64

### 3. Shapes

In [6]:
df_shapes = build_stimulus_table(
    dataset_root= DATA / "shapes" / "shape_colored",
    stimulus_type="shape",
    data_root=DATA,
)
# Drop images colored with black
df_shapes = df_shapes[df_shapes["manipulation_color"] != "black"]

# Only keep images where percent_colored is in percentage buckets
df_shapes = df_shapes[df_shapes["percent_colored"].isin(percentage_buckets)]

df_shapes.to_csv(DATA / "prolific_stimuli" / f"stimulus_table_shapes.csv", index=False)
display(df_shapes.head())
display(df_shapes.value_counts("target_color"))

Unnamed: 0,object,stimulus_type,manipulation_color,target_color,variant_region,percent_colored,mode,variant_label,image_path
42,circle,shape,blue,white,BG,0,seq,BG 0% (seq),shapes/shape_colored/circle_v0_blue/BG_000_seq...
47,circle,shape,blue,white,BG,5,seq,BG 5% (seq),shapes/shape_colored/circle_v0_blue/BG_005_seq...
52,circle,shape,blue,white,BG,10,seq,BG 10% (seq),shapes/shape_colored/circle_v0_blue/BG_010_seq...
53,circle,shape,blue,white,BG,20,seq,BG 20% (seq),shapes/shape_colored/circle_v0_blue/BG_020_seq...
54,circle,shape,blue,white,BG,30,seq,BG 30% (seq),shapes/shape_colored/circle_v0_blue/BG_030_seq...


target_color
white     3150
blue       300
brown      300
green      300
orange     300
grey       300
pink       300
purple     300
red        300
yellow     300
Name: count, dtype: int64

## Check datasets

In [7]:
# Image and object counts for all datasets
print("Object and image counts:")
print(f"Image priors: {df_priors['object'].nunique()} objects, {df_priors.shape[0]} images")
print(f"Counterfactuals: {df_cf['object'].nunique()} objects, {df_cf.shape[0]} images")
print(f"Shapes: {df_shapes['object'].nunique() * 5 * 9} objects, {df_shapes.shape[0]} images") # 5 basic shapes, with each 5 augementation variants and 9 colors 

Object and image counts:
Image priors: 199 objects, 5174 images
Counterfactuals: 215 objects, 5590 images
Shapes: 225 objects, 5850 images


In [8]:
# Check that all percentage buckets are represented for each object in priors
def missing_by_object(df):
    return df.groupby("object")["percent_colored"].apply(lambda x: sorted(set(percentage_buckets) - set(x)))

missing_priors = missing_by_object(df_priors)
missing_priors = missing_priors[missing_priors.map(len) > 0]

print("Missing percentage buckets in image priors by object:")
print(missing_priors)

print("\nObjects with missing buckets (image priors):")
print(missing_priors.index.tolist())


missing_cf = missing_by_object(df_cf)
missing_cf = missing_cf[missing_cf.map(len) > 0]

print("\nMissing percentage buckets in counterfactuals by object:")
print(missing_cf)

print("\nObjects with missing buckets (counterfactuals):")
print(missing_cf.index.tolist())


missing_shapes = missing_by_object(df_shapes)
missing_shapes = missing_shapes[missing_shapes.map(len) > 0]

print("\nMissing percentage buckets in shapes by object:")
print(missing_shapes)

print("\nObjects with missing buckets (shapes):")
print(missing_shapes.index.tolist())

Missing percentage buckets in image priors by object:
Series([], Name: percent_colored, dtype: object)

Objects with missing buckets (image priors):
[]

Missing percentage buckets in counterfactuals by object:
Series([], Name: percent_colored, dtype: object)

Objects with missing buckets (counterfactuals):
[]

Missing percentage buckets in shapes by object:
Series([], Name: percent_colored, dtype: object)

Objects with missing buckets (shapes):
[]


## Create survey profiles

In [9]:
df_priors["percent_colored"].value_counts().sort_index()

percent_colored
0      398
5      398
10     398
20     398
30     398
40     398
50     398
55     398
60     398
70     398
80     398
90     398
100    398
Name: count, dtype: int64

In [10]:
N_BASE_PROFILES = 37

profiles = []

for base_id in range(N_BASE_PROFILES):
    for introspection_pos in ["first", "last"]:

        profile = generate_profile(
            df_priors=df_priors,
            df_cf=df_cf,
            df_shapes=df_shapes,
            seed=base_id,
            introspection_position=introspection_pos,
        )

        profiles.append({
            "profile_id": f"{base_id}_{introspection_pos}",
            "base_id": base_id,
            "introspection_position": introspection_pos,
            "questions": profile,
        })

In [11]:
out_dir = DATA / "prolific_stimuli" / "profiles"
out_dir.mkdir(exist_ok=True)

for p in profiles:
    out_path = out_dir / f"profile_{p['profile_id']}.json"
    with open(out_path, "w") as f:
        json.dump(p, f, indent=2)

In [12]:
save_profile(
    generate_debug_profile(df_priors, df_shapes),
    out_dir / "debug_profile.json",
)
# test server.py with http://127.0.0.1:5000/?PROLIFIC_PID=DEBUG

In [13]:
# Recover exact stimulus subsets from profile JSONs

import json
from pathlib import Path
import pandas as pd

PROFILE_DIR = Path(DATA / "prolific_stimuli" / "profiles") 
prior_rows = []
counterfact_rows = []
shape_rows = []

for profile_path in PROFILE_DIR.glob("*.json"):

    # Ignore debug profile
    if profile_path.name.lower().startswith("debug"):
        continue
    
 
    with open(profile_path, "r") as f:
        profile = json.load(f)

    questions = profile.get("questions", [])

    for q in questions:
        # Only stimulus questions have image_path
        image_path = q.get("image_path")
        stimulus_type = q.get("stimulus_type")

        if image_path is None or stimulus_type is None:
            continue

        base_row = {
            "image_path": image_path,
            "object": q.get("object"),
            "stimulus_type": stimulus_type,
            "manipulation_color": q.get("manipulation_color"),
            "target_color": q.get("target_color"),
            "variant_region": q.get("variant_region"),
            "percent_colored": q.get("percent_colored"),
            "variant_label": q.get("variant_label"),
            "mode": q.get("mode"),
        }

        if stimulus_type == "correct_prior":
            prior_rows.append(base_row)

        elif stimulus_type == "counterfact":
            counterfact_rows.append(base_row)

        elif stimulus_type == "shape":
            shape_rows.append(base_row)

# Convert to DataFrames and drop duplicates
prior_df_prolific = pd.DataFrame(prior_rows).drop_duplicates(subset="image_path")
counterfact_df_prolific = pd.DataFrame(counterfact_rows).drop_duplicates(subset="image_path")
shape_df_prolific = pd.DataFrame(shape_rows).drop_duplicates(subset="image_path")

print("Recovered stimulus counts:")
print(f"  Correct-prior objects: {len(prior_df_prolific)}")
print(f"  Counterfactual objects: {len(counterfact_df_prolific)}")
print(f"  Shapes: {len(shape_df_prolific)}")

prior_df_prolific.to_csv(DATA / "prolific_stimuli" / f"stimulus_table_image_priors_prolific.csv", index=False)
counterfact_df_prolific.to_csv(DATA / "prolific_stimuli" / f"stimulus_table_counterfact_prolific.csv", index=False)
shape_df_prolific.to_csv(DATA / "prolific_stimuli" / f"stimulus_table_shapes_prolific.csv", index=False)

display(prior_df_prolific.head(), counterfact_df_prolific.head(), shape_df_prolific.head())

Recovered stimulus counts:
  Correct-prior objects: 1260
  Counterfactual objects: 412
  Shapes: 1331


Unnamed: 0,image_path,object,stimulus_type,manipulation_color,target_color,variant_region,percent_colored,variant_label,mode
0,color_images/gpt-4o/image_priors/cheese_1_78f6...,cheese,correct_prior,yellow,white,BG,80,BG 80% (seq),seq
1,color_images/gpt-4o/image_priors/espresso_make...,espresso maker,correct_prior,red,red,FG,5,FG 5% (seq),seq
2,color_images/gpt-4o/image_priors/tile_roof_2_f...,tile roof,correct_prior,red,red,FG,100,FG 100% (seq),seq
3,color_images/gpt-4o/image_priors/cloud_3_29898...,cloud,correct_prior,grey,grey,FG,55,FG 55% (seq),seq
4,color_images/gpt-4o/image_priors/frilled_lizar...,frilled lizard,correct_prior,brown,brown,FG,100,FG 100% (seq),seq


Unnamed: 0,image_path,object,stimulus_type,manipulation_color,target_color,variant_region,percent_colored,variant_label,mode
0,color_images/gpt-4o/counterfact/rose_3_6471302...,rose,counterfact,blue,blue,FG,100,FG 100% (seq),seq
1,color_images/gpt-4o/counterfact/Sealyham_terri...,sealyham terrier,counterfact,purple,purple,FG,60,FG 60% (seq),seq
2,color_images/gpt-4o/counterfact/iguana_2_a2663...,iguana,counterfact,orange,orange,FG,10,FG 10% (seq),seq
3,color_images/gpt-4o/counterfact/hartebeest_3_5...,hartebeest,counterfact,red,red,FG,55,FG 55% (seq),seq
4,color_images/gpt-4o/counterfact/mouse_2_cf4ddb...,mouse,counterfact,red,red,FG,20,FG 20% (seq),seq


Unnamed: 0,image_path,object,stimulus_type,manipulation_color,target_color,variant_region,percent_colored,variant_label,mode
0,shapes/shape_colored/hexagon_v3_yellow/FG_060_...,hexagon,shape,yellow,yellow,FG,60,FG 60% (seq),seq
1,shapes/shape_colored/pentagon_v0_purple/FG_055...,pentagon,shape,purple,purple,FG,55,FG 55% (seq),seq
2,shapes/shape_colored/square_v3_blue/FG_005_seq...,square,shape,blue,blue,FG,5,FG 5% (seq),seq
3,shapes/shape_colored/triangle_v0_brown/FG_050_...,triangle,shape,brown,brown,FG,50,FG 50% (seq),seq
4,shapes/shape_colored/pentagon_v2_orange/FG_090...,pentagon,shape,orange,orange,FG,90,FG 90% (seq),seq
