In [1]:
import os
from glob import glob
import json
import pandas as pd
from tqdm import tqdm
from matplotlib import pyplot as plt
from IPython.display import Image, display
import time
import numpy as np
import cv2
from collections import defaultdict
from ortools.sat.python import cp_model

In [2]:
chartqa = {
    "text_bar": ["05411753006467.png"],
    "text_stacked_bar": ["9280.png"],
    "text_pie": ["43.png"],
    "text_line": ["19371621021871.png"],
    
    "notext_line": ["two_col_4524.png"],
    "notext_pie": ["two_col_61107.png"],
    "notext_bar": ["two_col_40186.png"],
    "notext_stacked_bar": ["multi_col_60949.png"]
}

with open("/Users/minsukchang/Research/ChartDataset/image_questions.json") as f:
    image_questions = json.load(f)

In [3]:
# splits = {'validation': 'val.parquet', 'test': 'test.parquet'}
# df = pd.read_parquet("hf://datasets/princeton-nlp/CharXiv/" + splits["validation"])
# df['figure_path'] = df['figure_path'].apply(lambda x: x.split("/")[-1])
# df = df.drop(columns=["image"])
# df.to_csv("/Users/minsukchang/Research/ChartDataset/charxiv_qna.csv", index=False)


In [4]:
df = pd.read_csv("/Users/minsukchang/Research/ChartDataset/charxiv_qna.csv")

In [5]:
df[df['figure_path'] == '1186.jpg']['reasoning_q'].values[0]

'What is the difference between the range of z axis and x1 axis?'

In [6]:
charxiv = {
    "scatterplot": ["17.jpg", "617.jpg"],
    "hist": ["20.jpg", '81.jpg'],
    "contour": ["954.jpg", "1248.jpg"],
    "heatmap": ["568.jpg", "446.jpg"],
    "geo": ['433.jpg', '550.jpg'], 
}

In [7]:
ocr_dict = {}
with open("/Users/minsukchang/Research/ChartDataset/OCR/chartqa.json") as f:
    ocr_dict = ocr_dict | json.load(f)
with open("/Users/minsukchang/Research/ChartDataset/OCR/charxiv.json") as f:
    ocr_dict = ocr_dict | json.load(f)

def ocr(path):
    basename = os.path.basename(path)
    img = cv2.imread(path)
    text_areas = []
    for bbox in ocr_dict[basename]:
        left, top = bbox[0]
        right, bottom = bbox[2]
        left, top, right, bottom = int(left), int(top), int(right), int(bottom)

        left = max(0, min(left, img.shape[1] - 1))
        right = max(0, min(right, img.shape[1] - 1))
        top = max(0, min(top, img.shape[0] - 1))
        bottom = max(0, min(bottom, img.shape[0] - 1))

        top_line = img[top, left:right, :].reshape(-1, 3)
        bottom_line = img[bottom, left:right, :].reshape(-1, 3)
        left_line = img[top:bottom, left, :].reshape(-1, 3)
        right_line = img[top:bottom, right, :].reshape(-1, 3)
        lines = np.vstack([top_line, bottom_line, left_line, right_line])
        R_color = np.argmax(np.bincount(lines[:, 0]))
        G_color = np.argmax(np.bincount(lines[:, 1]))
        B_color = np.argmax(np.bincount(lines[:, 2]))
        R_color, G_color, B_color = int(R_color), int(G_color), int(B_color)
        # filled rectangle
        cv2.rectangle(img, (left, top), (right, bottom),
                      (R_color, G_color, B_color), -1)
        ## y, x, h, w
        text_areas.append([top, left, bottom - top, right - left])
    return img, text_areas


def canny(img):
    channels = cv2.split(img)

    # Apply Canny to each channel
    edges = [cv2.Canny(channel, 50, 150) for channel in channels]

    # Combine edges using bitwise OR
    combined_edges = cv2.bitwise_or(edges[0], edges[1])
    combined_edges = cv2.bitwise_or(combined_edges, edges[2])
    return combined_edges


In [8]:
def rect_to_grid(rectangles, n_rows, n_cols):
    """Convert a list of rectangle coordinates to a binary grid."""
    area = np.zeros((n_rows, n_cols), dtype=np.int32)
    for r, c in rectangles:
        area[r, c] = 1
    return area

def compute_integral_image(area):
    """ Computes the integral image for quick sum queries. """
    return np.cumsum(np.cumsum(area, axis=0), axis=1)

def is_tile_valid(area, integral, w, h):
    """ Check if a w x h tile can fit into the area using the integral image. """
    rows, cols = area.shape
    for i in range(rows - h + 1):
        for j in range(cols - w + 1):
            # Compute the sum of the w x h submatrix using the integral image
            total = integral[i + h - 1, j + w - 1]
            if i > 0:
                total -= integral[i - 1, j + w - 1]
            if j > 0:
                total -= integral[i + h - 1, j - 1]
            if i > 0 and j > 0:
                total += integral[i - 1, j - 1]
            # If the entire subgrid is filled with 1s, the tile is valid
            if total == w * h:
                return True
    return False

def get_valid_tiles(area, tile_shape, max_tile_size):
    """ Generate all valid tiles that fit within the binary grid. """
    area = np.array(area)  # Ensure it's a NumPy array
    integral = compute_integral_image(area)
    tiles = []
    
    if tile_shape == "square":
       for s in range(1, max_tile_size + 1):
            if is_tile_valid(area, integral, s, s):
                tiles.append((s, s))
    else:
        for w in range(1, max_tile_size + 1):
            for h in range(1, max_tile_size + 1):
                if is_tile_valid(area, integral, w, h):
                    tiles.append((w, h))
    assert len(tiles) > 0, "No valid tiles found"
    return tiles

def optimal_tiling(rectangles, n_rows, n_cols, tile_shape="sqaure", max_tile=4):
    """Solve the minimum tile cover problem using ILP with OR-Tools."""
    model = cp_model.CpModel()
    
    # Convert rectangles to a binary grid
    grid = rect_to_grid(rectangles, n_rows, n_cols)
    
    # Adjust max tile size
    if max_tile == "max":
        max_tile = max(n_rows, n_cols)
    
    # Get valid tile sizes efficiently
    valid_tiles = get_valid_tiles(grid, tile_shape, max_tile)

    # Define ILP Decision Variables using a compact index
    tile_vars = {}
    for i in range(n_rows):
        for j in range(n_cols):
            if grid[i, j] == 1:  # Only consider occupied cells
                for w, h in valid_tiles:
                    if i + h <= n_rows and j + w <= n_cols:
                        tile_vars[(i, j, w, h)] = model.NewBoolVar(f't_{i}_{j}_{w}_{h}')

    # Constraint: Cover every occupied cell at least once
    for i in range(n_rows):
        for j in range(n_cols):
            if grid[i, j] == 1:
                model.Add(sum(tile_vars.get((ii, jj, w, h), 0)
                                for w, h in valid_tiles
                                for ii in range(max(0, i - h + 1), i + 1)
                                for jj in range(max(0, j - w + 1), j + 1)
                                if (ii, jj, w, h) in tile_vars) >= 1)

    # Constraint: No overlapping tiles
    for i in range(n_rows):
        for j in range(n_cols):
            model.Add(sum(tile_vars.get((ii, jj, w, h), 0)
                          for w, h in valid_tiles
                          for ii in range(max(0, i - h + 1), i + 1)
                          for jj in range(max(0, j - w + 1), j + 1)
                          if (ii, jj, w, h) in tile_vars) <= 1)

    # Constraint: Ensure tiles stay within valid occupied regions
    for (i, j, w, h), var in tile_vars.items():
        for di in range(h):
            for dj in range(w):
                if i + di < n_rows and j + dj < n_cols:
                    if grid[i + di, j + dj] == 0:
                        model.Add(var == 0)

    # Objective: Minimize the number of tiles used
    model.Minimize(sum(tile_vars.values()))

    # Solve ILP Model
    solver = cp_model.CpSolver()
    solver.parameters.max_time_in_seconds = 60  # Timeout for large problems
    status = solver.Solve(model)
    if status == cp_model.OPTIMAL:
        pass
    elif status == cp_model.FEASIBLE:
        print("⚠️ Found a feasible (but not optimal) solution due to timeout!")
    elif status == cp_model.INFEASIBLE:
        return []
    else:
        return []


    # Extract solution
    solution = []
    for (i, j, w, h), var in tile_vars.items():
        if solver.Value(var) == 1:
            solution.append((i, j, w, h))

    return solution
     

def optimized_tile(edges, text_area, min_grid_ratio, tile_shape="rectangle", max_tile="max"):
    min_x = int(min_grid_ratio * edges.shape[1])
    min_y = int(min_grid_ratio * edges.shape[0])

    # Calculate how many cells fit (rows x columns)
    n_rows = (edges.shape[0] + min_y - 1) // min_y
    n_cols = (edges.shape[1] + min_x - 1) // min_x
    
    text_grid = np.zeros((n_rows, n_cols), dtype=np.int32)
    text_rectangles = []
    for i, (y_text, x_text, h_text, w_text) in enumerate(text_area):
        subrects = []
        for row in range(n_rows):
            for col in range(n_cols):
                if text_grid[row, col] > 0: continue
                y_grid = row * min_y
                x_grid = col * min_x
                h_grid = min_y
                w_grid = min_x
                # if two box overlap
                if (x_text < x_grid + w_grid and x_text + w_text > x_grid and
                    y_text < y_grid + h_grid and y_text + h_text > y_grid):
                    text_grid[row, col] = i+1
                    subrects.append((row, col))
                    
        text_rectangles.append(subrects)
    
    has_text_grid = text_grid > 0
    has_edge_grid = np.zeros((n_rows, n_cols), dtype=bool)
    

    # For each cell in the grid, check if it has any edge pixels
    for row in range(n_rows):
        for col in range(n_cols):
            if has_text_grid[row, col]:
                has_edge_grid[row, col] = False
                continue
            y_start = row * min_y
            x_start = col * min_x
            h = min_y
            w = min_x
            # Crop cell
            cell = edges[y_start:y_start+h, x_start:x_start+w]
            has_edge_grid[row, col] = cell.any()
    
    
    # assert no overlap btw text and edge
    assert np.all(~(has_text_grid & has_edge_grid))

    # Find connected components in the "no-edge" cells
    noedge_num_labels, noedge_labels = cv2.connectedComponents((~(has_edge_grid | has_text_grid)).astype(np.uint8))

    # merged_rectangles will store *groups* of grid rectangles.
    # Each element in merged_rectangles is a list of (x, y, w, h) sub-rectangles
    # that belong to the same component.
    background_rectangles_groups = []
    for label_id in range(1, noedge_num_labels):
        points = np.argwhere(noedge_labels == label_id)
        if len(points) == 0:
            continue
        group_subrects = []
        for r, c in points:
            group_subrects.append((r, c))
        background_rectangles_groups.append(group_subrects)

    edge_num_labels, edge_labels = cv2.connectedComponents(has_edge_grid.astype(np.uint8))
    edge_rectangles_groups = []
    for label_id in range(1, edge_num_labels):
        points = np.argwhere(edge_labels == label_id)
        if len(points) == 0:
            continue
        group_subrects = []
        for r, c in points:
            group_subrects.append((r, c))
        edge_rectangles_groups.append(group_subrects)

    # STEP 1: Merge background rectangles among themselves
    # (each group is “connected”, but not necessarily face-merged into bigger rectangles).
    # So we do a greedy face merge on background_rectangles.
    background_solutions = []
    for rectangles in background_rectangles_groups:
        if len(rectangles) == 0:continue
        solution = optimal_tiling(rectangles, n_rows, n_cols, tile_shape, max_tile)
        
        background_solutions.extend([(r*min_y, c*min_x, w*min_x, h*min_y) for r, c, w, h in solution])
    edge_solutions = []
    for rectangles in edge_rectangles_groups:
        if len(rectangles) == 0:continue
        solution = optimal_tiling(rectangles, n_rows, n_cols, tile_shape, max_tile)
        
        edge_solutions.extend([(r*min_y, c*min_x, w*min_x, h*min_y) for r, c, w, h in solution])
    text_solutions = []
    for rectangles in text_rectangles:
        if len(rectangles) == 0:continue
        solution = optimal_tiling(rectangles, n_rows, n_cols, tile_shape, max_tile)
        
        text_solutions.extend([(r*min_y, c*min_x, w*min_x, h*min_y) for r, c, w, h in solution])
    return background_solutions, edge_solutions, text_solutions

In [9]:
def optimized_tiles(img_path, total_tile_target=50):
    img_path = img_path.replace("https://raw.githubusercontent.com/jangsus1/ChartDataset/main/", "/Users/minsukchang/Research/ChartDataset/")
    ocr_removed, text_area = ocr(img_path)
    edges = canny(ocr_removed)
    
    
    for min_grid_ratio in np.linspace(0.05, 0.1, 10):
        background_solutions, edge_solutions, text_solutions = optimized_tile(edges, text_area, min_grid_ratio, tile_shape="rect", max_tile = 10)
        rectangles = background_solutions + edge_solutions + text_solutions
        tile_count = len(rectangles)
        if tile_count <= total_tile_target:
            break
    return [{"x": r[1], "y": r[0], "width": r[2], "height": r[3]} for r in rectangles]

def static_grids(img_path, grid_count=8):
    img_path = img_path.replace("https://raw.githubusercontent.com/jangsus1/ChartDataset/main/", "/Users/minsukchang/Research/ChartDataset/")
    img = cv2.imread(img_path)
    image_height, image_width = img.shape[:2]
    interval_y = int(np.ceil(image_height / (grid_count+1)))
    interval_x = int(np.ceil(image_width / (grid_count+1)))
    x_grids = []
    y_grids = []
    
    for i in range(interval_x, image_width, interval_x):
        x_grids.append(i)
    for i in range(interval_y, image_height, interval_y):
        y_grids.append(i)
    
    return x_grids, y_grids

In [10]:
def plot_figure(chart_type, image_id, x_grids, y_grids, image_link):
    plt.figure()
    plt.imshow(cv2.imread(image_link.replace("https://raw.githubusercontent.com/jangsus1/ChartDataset/main/", "/Users/minsukchang/Research/ChartDataset/")))
    plt.title(chart_type + image_id)
    for x in x_grids:
        plt.axvline(x=x, color='black', linewidth=0.8)
    for y in y_grids:
        plt.axhline(y=y, color='black', linewidth=0.8)
    plt.tight_layout()
    plt.axis('off')
    plt.show()

In [11]:
def create_base_components():
    return {
        "grid": {
            "type": "react-component",
            "path": "tools_extend/assets/grid.jsx",
            "response": [
                {
                    "id": "answer",
                    "prompt": "Answer for the question:",
                    "required": True,
                    "location": "sidebar",
                    "type": "shortText"
                },
                {
                    "id": "patches",
                    "prompt": "Selected Patches:",
                    "required": True,
                    "location": "sidebar",
                    "type": "reactive"
                }
            ],
            "instructionLocation": "aboveStimulus",
            "nextButtonLocation": "sidebar"
        },
        # "bubble": {
        #     "type": "react-component",
        #     "path": "tools_extend/assets/bubble.jsx",
        #     "response": [
        #         {
        #             "id": "answer",
        #             "prompt": "Answer for the question:",
        #             "required": True,
        #             "location": "sidebar",
        #             "type": "shortText"
        #         },
        #         {
        #             "id": "circles",
        #             "prompt": "Selected circles:",
        #             "required": True,
        #             "location": "sidebar",
        #             "type": "reactive"
        #         }
        #     ],
        #     "parameters": {
        #         "radius_count": 10
        #     },
        #     "instructionLocation": "aboveStimulus",
        #     "nextButtonLocation": "sidebar"
        # },
        # "important_annot": {
        #     "type": "react-component",
        #     "path": "tools_extend/assets/importantannot.jsx",
        #     "response": [
        #         {
        #             "id": "answer",
        #             "prompt": "Answer for the question:",
        #             "required": True,
        #             "location": "sidebar",
        #             "type": "shortText"
        #         },
        #         {
        #             "id": "annotations",
        #             "prompt": "Annotated Area:",
        #             "required": True,
        #             "location": "sidebar",
        #             "type": "reactive"
        #         }
        #     ],
        #     "instructionLocation": "aboveStimulus",
        #     "nextButtonLocation": "sidebar"
        # },
        "adaptive_tile": {
            "type": "react-component",
            "path": "tools_extend/assets/tilegrid.jsx",
            "response": [
                {
                    "id": "answer",
                    "prompt": "Answer for the question:",
                    "required": True,
                    "location": "sidebar",
                    "type": "shortText"
                },
                {
                    "id": "patches",
                    "prompt": "Selected Patches:",
                    "required": True,
                    "location": "sidebar",
                    "type": "reactive"
                }
            ],
            "instructionLocation": "aboveStimulus",
            "nextButtonLocation": "sidebar"
        }
    }

In [12]:

def create_default_components():
    return {
        "welcome": {
            "type": "markdown",
            "path": "tools_extend/assets/welcome.md",
            "response": []
        },
        "instructions": {
            "type": "markdown",
            "path": "tools_extend/assets/instructions.md",
            "response": []
        },  
        "consent": {
            "type": "markdown",
            "path": "tools_extend/assets/consent.md",
            "nextButtonText": "I agree",
            "response": []
        },
        "intro_static": {
            "type": "markdown",
            "path": "tools_extend/assets/intro_static.md",
            "response": []
        },
        "intro_adaptive": {
            "type": "markdown",
            "path": "tools_extend/assets/intro_adaptive.md",
            "response": []
        },
        # "intro_bubble": {
        #     "type": "markdown",
        #     "path": "tools_extend/assets/intro_bubble.md",
        #     "response": []
        # },
        # "intro_importAnnot": {
        #     "type": "markdown",
        #     "path": "tools_extend/assets/intro_importAnnot.md",
        #     "response": []
        # },
        "vlat_intro": {
            "type": "markdown",
            "path": "tools_extend/assets/vlat_intro.md",
            "response": []
        },
        "demographics": {
            "type": "markdown",
            "path": "tools_extend/assets/blank.md",
            "response": [
                {
                    "id": "gender",
                    "prompt": "What is your **gender**?",
                    "required": True,
                    "location": "aboveStimulus",
                    "type": "radio",
                    "withOther": True,
                    "options": [
                        "Woman",
                        "Man",
                        "Prefer not to say"
                    ],
                    "withDivider": True

                },
                {
                    "id": "age",
                    "prompt": "What is your **age**?",
                    "required": True,
                    "location": "aboveStimulus",
                    "type": "radio",
                    "options": [
                        "Under 18 years",
                        "18-24 years",
                        "25-34 years",
                        "35-44 years",
                        "45-54 years",
                        "55-64 years",
                        "65 years or older",
                        "Prefer not to say"
                    ],
                    "withDivider": True
                },
                {
                    "id": "education",
                    "prompt": "What is the **highest degree or level of education** you have completed?",
                    "required": True,
                    "location": "aboveStimulus",
                    "type": "radio",
                    "withOther": True,
                    "options": [
                        "Less than high school",
                        "High school diploma or equivalent",
                        "Bachelor's degree or equivalent",
                        "Master's degree or equivalent",
                        "Doctoral degree or equivalent"
                    ],
                    "withDivider": True
                }
            ]
        },
        "NASA": {
            "type": "markdown",
            "path": "tools_extend/assets/blank.md",
            "response": [
                {
                    "id": "mental-demand",
                    "prompt": "How **mentally demanding** was the task?",
                    "required": True,
                    "location": "aboveStimulus",
                    "type": "likert",
                    "numItems": 7,
                    "rightLabel": "Very High",
                    "leftLabel": "Very Low",
                    "withDivider": True
                },
                {
                    "id": "physical-demand",
                    "prompt": "How **physically demanding** was the task?",
                    "required": True,
                    "location": "aboveStimulus",
                    "type": "likert",
                    "numItems": 7,
                    "rightLabel": "Very High",
                    "leftLabel": "Very Low",
                    "withDivider": True
                },
                {
                    "id": "temporal-demand",
                    "prompt": "How **hurried or rushed** was the pace of the task?",
                    "required": True,
                    "location": "aboveStimulus",
                    "type": "likert",
                    "numItems": 7,
                    "rightLabel": "Very High",
                    "leftLabel": "Very Low",
                    "withDivider": True
                },
                {
                    "id": "performance",
                    "prompt": "How **successful** were you in accomplishing what you were asked to do?",
                    "required": True,
                    "location": "aboveStimulus",
                    "type": "likert",
                    "numItems": 7,
                    "rightLabel": "Perfect",
                    "leftLabel": "Failure",
                    "withDivider": True
                },
                {
                    "id": "effort",
                    "prompt": "How **hard** did you have to work to accomplish your level of performance?",
                    "required": True,
                    "location": "aboveStimulus",
                    "type": "likert",
                    "numItems": 7,
                    "rightLabel": "Very High",
                    "leftLabel": "Very Low",
                    "withDivider": True
                },
                {
                    "id": "frustration",
                    "prompt": "How **insecure, discouraged, irritated, stressed, and annoyed** were you?",
                    "required": True,
                    "location": "aboveStimulus",
                    "type": "likert",
                    "numItems": 7,
                    "rightLabel": "Very High",
                    "leftLabel": "Very Low",
                    "withDivider": True
                }
            ]
        }
    }


def create_example_components():
    test_charts = {
        "line": "https://raw.githubusercontent.com/jangsus1/ChartDataset/main/chartqa/127.png",
        "bar": "https://raw.githubusercontent.com/jangsus1/ChartDataset/main/chartqa/166.png",
    }
    test_instructions = {
        "line": [
            "Please annotate important areas in the line chart and briefly describe it in the textbox.",
            "Please annotate the minimum area required to identify the highest value of 'Democrats' in the line chart, then provide your answer in the textbox.",
            "Please annotate the minimum area required to identify the difference between 'Democrats' and 'Republicans' in the year 2017, then provide your answer in the textbox."
        ],
        "bar": [
            "Please annotate important areas in the bar chart and briefly describe it in the textbox.",
            "Please annotate the minimum area required to identify the lowest value in the bar chart, then provide your answer in the textbox.",
            "Please annotate the minimum area required to identify the difference between the values of 'Charismatic' and 'Dangerous', then provide your answer in the textbox."
        ]
    }
    groups = defaultdict(dict)
    for chart_type, instructions in test_instructions.items():
        for i, instruction in enumerate(instructions):

            x_grids, y_grids = static_grids(
                test_charts[chart_type], grid_count=8)
            groups["static"][f"ex_static_{chart_type}_{i}"] = {
                "baseComponent": "grid",
                "parameters": {
                    "image": test_charts[chart_type],
                    "question": instruction,
                    "x_grids": x_grids,
                    "y_grids": y_grids,
                    "example": True
                }
            }
            # groups["bubble"][f"ex_bubble_{chart_type}_{i}"] = {
            #     "baseComponent": "bubble",
            #     "parameters": {
            #         "image": test_charts[chart_type],
            #         "question": instruction,
            #         "example": True
            #     }
            # }
            # groups["importAnnot"][f"ex_importAnnot_{chart_type}_{i}"] = {
            #     "baseComponent": "important_annot",
            #     "parameters": {
            #         "image": test_charts[chart_type],
            #         "question": instruction,
            #         "example": True
            #     }
            # }

            tiles = optimized_tiles(test_charts[chart_type], (8+1)**2)
            groups["adaptive"][f"ex_adaptive_{chart_type}_{i}"] = {
                "baseComponent": "adaptive_tile",
                "parameters": {
                    "image": test_charts[chart_type],
                    "question": instruction,
                    "example": True,
                    "tiles": tiles,
                }
            }
    return groups


def create_initial_components():
    group = defaultdict(dict)
    merged = chartqa | charxiv
    for grid_count in [8]:
        for chart_type, image_lists in tqdm(merged.items(), desc=f"{grid_count}", total=len(merged.keys())):
            for image_id in image_lists:
                if chart_type in chartqa:
                    image_link = f"https://raw.githubusercontent.com/jangsus1/ChartDataset/main/chartqa/{image_id}"
                    question = image_questions[image_id]["Q0"]
                    dataset = "chartqa"
                else:
                    image_link = f"https://raw.githubusercontent.com/jangsus1/ChartDataset/main/charxiv/{image_id}"
                    question = df[df['figure_path'] ==
                                  image_id]['reasoning_q'].values[0]
                    dataset = "charxiv"

                tiles = optimized_tiles(image_link, (grid_count+1)**2)
                group["adaptive"][f"adaptive_{dataset}_{chart_type}_{image_id}"] = {
                    "baseComponent": "adaptive_tile",
                    "parameters": {
                        "image": image_link,
                        "question": question,
                        "tiles": tiles,
                        "chart_type": chart_type,
                        "grid": f"adaptive-{grid_count}"
                    }
                }

                x_grids, y_grids = static_grids(
                    image_link, grid_count=grid_count)

                group["static"][f"static_{dataset}_{chart_type}_{image_id}"] = {
                    "baseComponent": "grid",
                    "parameters": {
                        "image": image_link,
                        "question": question,
                        "x_grids": x_grids,
                        "y_grids": y_grids,
                        "chart_type": chart_type,
                        "grid": f"static-{grid_count}"
                    }
                }

                # group["bubble"][f"bubble_{dataset}_{chart_type}_{image_id}"] = {
                #     "baseComponent": "bubble",
                #     "parameters": {
                #         "image": image_link,
                #         "question": question,
                #         "chart_type": chart_type,
                #     }
                # }

                # group["importAnnot"][f"importAnnot_{dataset}_{chart_type}_{image_id}"] = {
                #     "baseComponent": "important_annot",
                #     "parameters": {
                #         "image": image_link,
                #         "question": question,
                #         "chart_type": chart_type,
                #     }
                # }

    return group

In [13]:
def sequence_generator(example_groups, group):
    # example_blocks = {}
    # for tool in example_groups.keys():
    #     example_blocks[tool] = {
    #         "id": f"exp_{tool}",
    #         "order": "fixed",
    #         "components": [
    #             {
    #                 "id": "bar",
    #                 "order": "latinSquare",
    #                 "numSamples": 1,
    #                 "components": [c for c in example_groups[tool].keys() if "bar" in c]
    #             },
    #             {
    #                 "id": "line",
    #                 "order": "latinSquare",
    #                 "numSamples": 1,
    #                 "components": [c for c in example_groups[tool].keys() if "line" in c]
    #             }
    #         ]
    #     }

    sequence = {
        "order": "fixed",
        "components": [
            "welcome",
            "consent",
            "instructions",
            {
                "id": "tools",
                "order": "latinSquare",
                "numSamples": 1,
                "components": [
                    {
                        "id": tool,
                        "order": "fixed",
                        "components": [
                            f"intro_{tool}",
                            # example_blocks[tool],
                            {
                                "id": "main",
                                "order": "latinSquare",
                                "components": list(tasks.keys())
                            }
                        ]
                    }
                    for tool, tasks in group.items()]
            },
            "NASA",
            "vlat_intro",
            "$mini-vlat.se.full",
            "demographics"
        ]
    }
    return sequence

In [14]:
prolificRedirection = "https://app.prolific.com/submissions/complete?cc=C1DEBJ8K"

In [15]:
main_components_group = create_initial_components()
# example_component_groups = create_example_components()

8:   0%|          | 0/13 [00:00<?, ?it/s]

8: 100%|██████████| 13/13 [01:30<00:00,  6.93s/it]


In [16]:
default_components = create_default_components()
baseComponents = create_base_components()
# vlat = create_VLAT()
components = default_components
for group in main_components_group.values():
    components = components | group
# for group in example_component_groups.values():
#     components = components | group
sequence = sequence_generator(None, main_components_group)
print(f"Total number of components: {len(components)}")

Total number of components: 44


In [17]:
with open("config.json", "r") as f:
    config = json.load(f)
config['uiConfig']['studyEndMsg'] = f"**Thank you for completing the study. You may click this link and return to Prolific**: [{prolificRedirection}]({prolificRedirection})"
config['components'] = components
config['sequence'] = sequence
config['baseComponents'] = baseComponents
with open("config.json", "w") as f:
    json.dump(config, f, indent=4)