In [32]:
import os
from glob import glob
import json
import pandas as pd
from tqdm import tqdm

In [33]:
Q0_df = pd.read_csv("/Users/minsukchang/Research/PunchHole/initial_patch_csv/128_gpt_Q0.csv")
Q1_df = pd.read_csv("/Users/minsukchang/Research/PunchHole/initial_patch_csv/128_gpt_Q1.csv")
Q0_df["Q"] = "Q0"
Q1_df["Q"] = "Q1"
df = pd.concat([Q0_df, Q1_df])
with open("/Users/minsukchang/Research/PunchHole/SalChartQA/data/image_questions.json") as f:
    image_questions = json.load(f)

In [39]:
image_lists = glob("/Users/minsukchang/Research/ChartQA/raw_img/*.png")
image_lists = [os.path.basename(x) for x in image_lists]


def create_default_components():
    return {
        "welcome": {
            "type": "markdown",
            "path": "punch/assets/welcome.md",
            "response": []
        },
        "consent": {
            "type": "markdown",
            "path": "punch/assets/consent.md",
            "nextButtonText": "I agree",
            "response": []
        },
        "demographics": {
            "type": "questionnaire",
            "response": [
                {
                    "id": "age",
                    "type": "numerical",
                    "required": True,
                    "prompt": "What is your age?"
                },
                {
                    "id": "height",
                    "type": "numerical",
                    "required": True,
                    "prompt": "What is your height in inches?"
                }
            ]
        },
    }

def create_initial_components(limit=10000):
    allComponents = {}
    for image_id in tqdm(image_lists[:limit]):
        image_link = f"https://raw.githubusercontent.com/jangsus1/ChartQA/main/grid_img/{image_id}"
        for Q in ["Q0", "Q1"]:
            if Q not in image_questions[image_id]: continue
            rows = df[(df["image_id"] == image_id)&(df["Q"] == Q)]
            y = rows[rows['importance']>0]['i']
            x = rows[rows['importance']>0]['j']
            
            x_grids = rows['j'].unique().tolist()
            y_grids = rows['i'].unique().tolist()
            
            
            allComponents[f"trial_{image_id}_{Q}"] = {
                "baseComponent": "punch",
                "parameters": {
                    "image": image_link,
                    "question": image_questions[image_id][Q],
                    "holes": [{"x": x_, "y": y_} for x_, y_ in zip(x, y)],
                    "x_grids": x_grids,
                    "y_grids": y_grids,
                }
            }
    return allComponents


default_components = create_default_components()
main_components = create_initial_components(10)
components = default_components|main_components

100%|██████████| 10/10 [00:00<00:00, 45.14it/s]


In [40]:
def sequence_generator(components):
    sequence = {
        "order": "fixed",
        "components": [
            "welcome",
            "consent",
            {
                "order": "random",
                "numSamples": 10,
                "components": components
            },
            "demographics"
        ]
    }
    return sequence

sequence = sequence_generator(list(main_components.keys()))


In [41]:
with open("config.json", "r") as f:
    config = json.load(f)
config['components'] = components
config['sequence'] = sequence
with open("config.json", "w") as f:
    json.dump(config, f, indent=4)