In [22]:
import os
from glob import glob
import json
import pandas as pd
from tqdm import tqdm
from matplotlib import pyplot as plt
from IPython.display import Image, display
import time
import numpy as np
import cv2
from collections import defaultdict

In [23]:
chartqa = {
    "text_bar": ["05411753006467.png"],
    "text_stacked_bar": ["9280.png"],
    "text_pie": ["43.png"],
    "text_line": ["19371621021871.png"],
    
    "notext_line": ["two_col_4524.png"],
    "notext_pie": ["two_col_61107.png"],
    "notext_bar": ["two_col_40186.png"],
    "notext_stacked_bar": ["multi_col_60949.png"]
}

with open("/Users/minsukchang/Research/ChartQA/image_questions.json") as f:
    image_questions = json.load(f)

In [24]:
splits = {'validation': 'val.parquet', 'test': 'test.parquet'}
df = pd.read_parquet("hf://datasets/princeton-nlp/CharXiv/" + splits["validation"])
df['figure_path'] = df['figure_path'].apply(lambda x: x.split("/")[-1])

In [25]:
df['reasoning_q'] = df['reasoning_q'].apply(lambda x: x.replace("\\", ""))

In [26]:
charxiv = {
    "scatterplot": ["17.jpg", "617.jpg"],
    "hist": ["20.jpg", '81.jpg'],
    "contour": ["954.jpg", "1248.jpg"],
    "3d": ["1314.jpg", "1186.jpg"],
    "heatmap": ["568.jpg", "446.jpg"],
    "geo": ['433.jpg', '550.jpg'], 
}

In [27]:
ocr_dict = {}
with open("/Users/minsukchang/Research/ChartQA/OCR/chartqa.json") as f:
    ocr_dict = ocr_dict | json.load(f)
with open("/Users/minsukchang/Research/ChartQA/OCR/charxiv.json") as f:
    ocr_dict = ocr_dict | json.load(f)

def ocr(path):
    basename = os.path.basename(path)
    img = cv2.imread(path)
    for bbox in ocr_dict[basename]:
        left, top = bbox[0]
        right, bottom = bbox[2]
        left, top, right, bottom = int(left), int(top), int(right), int(bottom)

        left = max(0, min(left, img.shape[1] - 1))
        right = max(0, min(right, img.shape[1] - 1))
        top = max(0, min(top, img.shape[0] - 1))
        bottom = max(0, min(bottom, img.shape[0] - 1))

        top_line = img[top, left:right, :].reshape(-1, 3)
        bottom_line = img[bottom, left:right, :].reshape(-1, 3)
        left_line = img[top:bottom, left, :].reshape(-1, 3)
        right_line = img[top:bottom, right, :].reshape(-1, 3)
        lines = np.vstack([top_line, bottom_line, left_line, right_line])
        R_color = np.argmax(np.bincount(lines[:, 0]))
        G_color = np.argmax(np.bincount(lines[:, 1]))
        B_color = np.argmax(np.bincount(lines[:, 2]))
        R_color, G_color, B_color = int(R_color), int(G_color), int(B_color)
        # filled rectangle
        cv2.rectangle(img, (left, top), (right, bottom),
                      (R_color, G_color, B_color), -1)
    return img


def canny(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    edges = cv2.Canny(gray, 50, 150)
    return edges


def generate_line_kernels(size):
    assert size % 2 == 0, "size must be even"
    kernels = []
    slopes = []
    for y in range(size):  # tracing y-axis
        slope = -(size-1-y-y) / (size-1)
        kernel = np.zeros((size, size), dtype=np.uint8)
        cv2.line(kernel, (0, y), (size-1, size-1-y), 1, 1)
        kernels.append(kernel)
        slopes.append(slope)
    for x in range(size+1//2):
        kernel = np.zeros((size, size), dtype=np.uint8)
        cv2.line(kernel, (x, 0), (size-1-x, size-1), 1, 1)
        kernels.append(kernel)
        slope = -(size-1) / (size-1-x-x)
        slopes.append(slope)
    # sort
    kernels = np.array(kernels)
    slopes = np.array(slopes)
    idx = np.argsort(slopes)
    kernels = kernels[idx]
    slopes = slopes[idx]
    return kernels, slopes


def merge_grids(grid, min_grid_size):
    points = np.where(grid > 0)[0]

    while True:
        intervals = np.diff(points)
        min_index = np.argmin(intervals)
        if intervals[min_index] > min_grid_size:
            break
        if len(intervals) == 3:
            break

        # Merge the two indices with the smallest gap
        # merge right
        if min_index == 0:
            points = np.delete(points, min_index + 1)
        # merge left
        elif min_index == len(intervals) - 1:
            points = np.delete(points, min_index)
        elif intervals[min_index-1] >= intervals[min_index+1]:
            points = np.delete(points, min_index + 1)
        elif intervals[min_index-1] < intervals[min_index+1]:
            points = np.delete(points, min_index)
        else:  # no end case
            print("error")
    return points

In [28]:
def adaptive_grids(img_path, min_grid_ratio=0.05, grid_count=10, probabilistic=False):
    img_path = img_path.replace("https://raw.githubusercontent.com/jangsus1/ChartQA/main/", "/Users/minsukchang/Research/ChartQA/")
    img = cv2.imread(img_path)
    ocr_removed = ocr(img_path)
    edges = canny(ocr_removed)
    
    image_height = edges.shape[0]
    min_x_grid = int(min_grid_ratio * image_height)
    x_weights = np.sum(edges, axis=0)
    reduced_x_weights = []
    for i in range(0, len(x_weights), min_x_grid):
        reduced_x_weights.append(np.sum(x_weights[i:i+min_x_grid]))
    if probabilistic:
        reduced_x_probs = np.array(reduced_x_weights) / np.sum(reduced_x_weights)
        grid_points_x = np.random.choice(np.arange(len(reduced_x_probs)), grid_count, p=reduced_x_probs, replace=False)
    else:
        grid_points_x = np.argsort(reduced_x_weights)[-grid_count:]
    grid_points_x *= min_x_grid


    y_weights = np.sum(edges, axis=1)
    image_width = edges.shape[1]
    min_y_grid = int(min_grid_ratio * image_width)
    reduced_y_weights = []
    for i in range(0, len(y_weights), min_y_grid):
        reduced_y_weights.append(np.sum(y_weights[i:i+min_y_grid]))
    if probabilistic:
        reduced_y_probs = np.array(reduced_y_weights) / np.sum(reduced_y_weights)
        grid_points_y = np.random.choice(np.arange(len(reduced_y_weights)), grid_count, p=reduced_y_probs, replace=False)
    else:
        grid_points_y = np.argsort(reduced_y_weights)[-grid_count:]
    grid_points_y *= min_y_grid
    
    return grid_points_x.tolist(), grid_points_y.tolist()

def static_grids(img_path, grid_count=10):
    img_path = img_path.replace("https://raw.githubusercontent.com/jangsus1/ChartQA/main/", "/Users/minsukchang/Research/ChartQA/")
    img = cv2.imread(img_path)
    image_height, image_width = img.shape[:2]
    interval_y = int(np.ceil(image_height / (grid_count+1)))
    interval_x = int(np.ceil(image_width / (grid_count+1)))
    x_grids = []
    y_grids = []
    
    for i in range(interval_x, image_width, interval_x):
        x_grids.append(i)
    for i in range(interval_y, image_height, interval_y):
        y_grids.append(i)
    
    return x_grids, y_grids

In [29]:
def plot_figure(chart_type, image_id, x_grids, y_grids, image_link):
    plt.figure()
    plt.imshow(cv2.imread(image_link.replace("https://raw.githubusercontent.com/jangsus1/ChartQA/main/", "/Users/minsukchang/Research/ChartQA/")))
    plt.title(chart_type + image_id)
    for x in x_grids:
        plt.axvline(x=x, color='black', linewidth=0.8)
    for y in y_grids:
        plt.axhline(y=y, color='black', linewidth=0.8)
    plt.tight_layout()
    plt.axis('off')
    plt.show()

In [30]:
def create_default_components():
    return {
        "welcome": {
            "type": "markdown",
            "path": "importance/assets/welcome.md",
            "response": []
        },
        "consent": {
            "type": "markdown",
            "path": "importance/assets/consent.md",
            "nextButtonText": "I agree",
            "response": []
        },
        "example_start": {
            "type": "markdown",
            "path": "importance/assets/example_start.md",
            "response": []
        },
        "main_start": {
            "type": "markdown",
            "path": "importance/assets/main_start.md",
            "response": []
        },
    }


def create_example_components():
    image_link = "https://raw.githubusercontent.com/jangsus1/ChartQA/main/chartqa/34.png"
    x_grids, y_grids = static_grids(image_link, grid_count=10)
    return {
        "grid_example": {
            "baseComponent": "grid",
            "parameters": {
                "image": image_link,
                "question": "What is the highest number for the brown line?",
                "x_grids": x_grids,
                "y_grids": y_grids,
                "example": True
            }
        },
        "grid_example2": {
            "baseComponent": "grid",
            "parameters": {
                "image": image_link,
                "question": "How many color of lines are there?",
                "x_grids": x_grids,
                "y_grids": y_grids,
                "example": True,
                "ourDefinition": True
            }
        },
        "bubble_example": {
            "baseComponent": "bubble",
            "parameters": {
                "image": image_link,
                "question": "Do the two lines meet?",
                "example": True
            }
        },
        "importAnnot_example": {
            "baseComponent": "important_annot",
            "parameters": {
                "image": image_link,
                "question": "How many years are the graph?",
                "example": True
            }
        }
    }


grid_methods = {
    "static": static_grids,
    # "adaptive": adaptive_grids
}


def create_initial_components():
    group = defaultdict(dict)
    merged = chartqa | charxiv
    for grid_method, func in grid_methods.items():
        for grid_count in [10]:
            for chart_type, image_lists in tqdm(merged.items(), desc=f"{grid_count}", total=len(merged.keys())):
                image_id = image_lists[0]
                if chart_type in chartqa:
                    image_link = f"https://raw.githubusercontent.com/jangsus1/ChartQA/main/chartqa/{image_id}"
                    question = image_questions[image_id]["Q0"]
                    dataset = "chartqa"
                else:
                    image_link = f"https://raw.githubusercontent.com/jangsus1/ChartQA/main/charxiv/{image_id}"
                    question = df[df['figure_path'] ==
                                  image_id]['reasoning_q'].values[0]
                    dataset = "charxiv"

                x_grids, y_grids = func(image_link, grid_count=grid_count)

                group[image_id][f"{grid_method}_{dataset}_{chart_type}_{image_id}_{grid_count}_false"] = {
                    "baseComponent": "grid",
                    "parameters": {
                        "image": image_link,
                        "question": question,
                        "x_grids": x_grids,
                        "y_grids": y_grids,
                        "chart_type": chart_type,
                        "grid": f"{grid_method}-{grid_count}"
                    }
                }
                group[image_id][f"{grid_method}_{dataset}_{chart_type}_{image_id}_{grid_count}_true"] = {
                    "baseComponent": "grid",
                    "parameters": {
                        "image": image_link,
                        "question": question,
                        "x_grids": x_grids,
                        "y_grids": y_grids,
                        "chart_type": chart_type,
                        "grid": f"{grid_method}-{grid_count}",
                        "ourDefinition": True
                    }
                }
                group[image_id][f"bubble_{dataset}_{chart_type}_{image_id}"] = {
                    "baseComponent": "bubble",
                    "parameters": {
                        "image": image_link,
                        "question": question,
                        "chart_type": chart_type,
                    }
                }

                group[image_id][f"importAnnot_{dataset}_{chart_type}_{image_id}"] = {
                    "baseComponent": "important_annot",
                    "parameters": {
                        "image": image_link,
                        "question": question,
                        "chart_type": chart_type,
                    }
                }
                # plot_figure(chart_type, image_id, x_grids, y_grids, image_link)

    return group


def sequence_generator(group, example_components):
    methods = [{
        "id": key,
        "order": "latinSquare",
        "numSamples": 1,
        "components": list(group[key].keys())
    } for key in group.keys()]

    sequence = {
        "order": "fixed",
        "components": [
            "welcome",
            "consent",
            "example_start",
            *example_components,
            "main_start",
            {
                "id": "images",
                "order": "random",
                "components": methods
            }
        ]
    }
    return sequence

In [31]:
default_components = create_default_components()
main_component_groups = create_initial_components()
example_components = create_example_components()
main_components = {}
for image_id, components in main_component_groups.items():
    main_components = main_components | components
components = default_components | main_components | example_components
sequence = sequence_generator(main_component_groups, example_components)
print(f"Total number of components: {len(components)}")

10: 100%|██████████| 14/14 [00:00<00:00, 137.54it/s]

Total number of components: 64





In [32]:
with open("config.json", "r") as f:
    config = json.load(f)
config['components'] = components
config['sequence'] = sequence
with open("config.json", "w") as f:
    json.dump(config, f, indent=4)