In [1]:
import json

def load_json(file_path):
    """
    Load a JSON file and return its content as a Python dictionary.

    Parameters:
        file_path (str): The path to the JSON file.

    Returns:
        dict: The content of the JSON file as a dictionary.
    """
    with open(file_path, 'r', encoding='utf-8') as file:
        data = json.load(file)
    return data
def save_json(data, file_path):
    """
    Save a Python dictionary to a JSON file.

    Parameters:
        data (dict): The data to save.
        file_path (str): The path where the JSON file will be saved.
    """
    with open(file_path, 'w', encoding='utf-8') as file:
        json.dump(data, file, ensure_ascii=False, indent=4)

In [2]:
import os
import glob

json_root_dir = "/data_ssd/LLaVA-OneVision-Data"
path_list = glob.glob(os.path.join(json_root_dir, "*", "*_checked_image_tag.json"), recursive=True)
print(len(path_list))

91


In [3]:
dataset_name_list = [os.path.basename(path).split("_checked_image_tag")[0] for path in path_list]

In [4]:
path_list.append(("/data_ssd/M4-Instruct-Data/m4_instruct_annotations_fixed.json"))
dataset_name_list.append("M4-Instruct")

In [5]:
print(dataset_name_list)
from tqdm import tqdm
dataset_dict = {}
total_data_num = 0
for dataset_name, path in tqdm(zip(dataset_name_list, path_list)):
    data = load_json(path)
    dataset_dict[dataset_name] = {"data":data, "data_num": len(data)}
    total_data_num += len(data)
    if len(data) == 0:
        print(f"Warning: {dataset_name} has no data, please check the path: {path}")

for dataset_name, dataset_info in dataset_dict.items():
    dataset_dict[dataset_name]["weight"] = dataset_info["data_num"] / total_data_num

['CLEVR-Math(MathV360K)', 'FigureQA(MathV360K)', 'GEOS(MathV360K)', 'GeoQA+(MathV360K)', 'Geometry3K(MathV360K)', 'IconQA(MathV360K)', 'MapQA(MathV360K)', 'PMC-VQA(MathV360K)', 'Super-CLEVR(MathV360K)', 'TabMWP(MathV360K)', 'UniGeo(MathV360K)', 'VisualWebInstruct(filtered)', 'VizWiz(MathV360K)', 'ai2d(cauldron,llava_format)', 'ai2d(gpt4v)', 'ai2d(internvl)', 'allava_instruct_laion4v', 'allava_instruct_vflan4v', 'aokvqa(cauldron,llava_format)', 'chart2text(cauldron)', 'chartqa(cauldron,llava_format)', 'chrome_writting', 'clevr(cauldron,llava_format)', 'diagram_image_to_text(cauldron)', 'dvqa(cauldron,llava_format)', 'figureqa(cauldron,llava_format)', 'geo170k(align)', 'geo170k(qa)', 'geo3k', 'geomverse(cauldron)', 'hateful_memes(cauldron,llava_format)', 'hitab(cauldron,llava_format)', 'hme100k', 'iam(cauldron)', 'iconqa(cauldron,llava_format)', 'iiit5k', 'image_textualization(filtered)', 'infographic(gpt4v)', 'infographic_vqa', 'infographic_vqa_llava_format', 'intergps(cauldron,llava_fo

0it [00:00, ?it/s]

92it [01:50,  1.20s/it]


In [6]:
for dataset_name, dataset_info in dataset_dict.items():
    print(f"{dataset_name}: {dataset_info['data_num']} samples, weight: {dataset_info['weight']:.4f}")

CLEVR-Math(MathV360K): 5280 samples, weight: 0.0011
FigureQA(MathV360K): 17587 samples, weight: 0.0036
GEOS(MathV360K): 498 samples, weight: 0.0001
GeoQA+(MathV360K): 17162 samples, weight: 0.0035
Geometry3K(MathV360K): 9724 samples, weight: 0.0020
IconQA(MathV360K): 22589 samples, weight: 0.0047
MapQA(MathV360K): 5225 samples, weight: 0.0011
PMC-VQA(MathV360K): 35948 samples, weight: 0.0074
Super-CLEVR(MathV360K): 8642 samples, weight: 0.0018
TabMWP(MathV360K): 22452 samples, weight: 0.0046
UniGeo(MathV360K): 11949 samples, weight: 0.0025
VisualWebInstruct(filtered): 263584 samples, weight: 0.0543
VizWiz(MathV360K): 6604 samples, weight: 0.0014
ai2d(cauldron,llava_format): 2429 samples, weight: 0.0005
ai2d(gpt4v): 4864 samples, weight: 0.0010
ai2d(internvl): 12403 samples, weight: 0.0026
allava_instruct_laion4v: 49990 samples, weight: 0.0103
allava_instruct_vflan4v: 19990 samples, weight: 0.0041
aokvqa(cauldron,llava_format): 16534 samples, weight: 0.0034
chart2text(cauldron): 26956 s

In [7]:
print(total_data_num)
print(f"min data num: {min([dataset_info['data_num'] for dataset_info in dataset_dict.values()])}")

4850679
min data num: 295


# 最低保証サンプル

In [8]:
import random
random.seed(42)
from collections import Counter

sample_data_num = 20000 #50000 #20000
dataset_name_list = []
weights = []
import numpy as np
sample_num_per_dataset = []
for dataset_name, dataset_info in dataset_dict.items():
    dataset_name_list.append(dataset_name)
    sample_num_per_dataset.append(round(dataset_info["weight"] * sample_data_num) if dataset_info["weight"] > 0 else 1 )
    
min_sample_num = 10 #min(sample_num_per_dataset)

for dataset_info in dataset_dict.values():
    #weights.append((dataset_info["data_num"]-min_sample_num) / (total_data_num - len(dataset_dict) * min_sample_num))
    weights.append(round(dataset_info["weight"] * sample_data_num) - min_sample_num)

In [9]:
print(np.sum(sample_num_per_dataset))
print(f"Minimum sample number per dataset: {min_sample_num}")

19997
Minimum sample number per dataset: 10


In [10]:
sample_num_per_dataset = {k:min_sample_num for k in dataset_dict.keys()}
print(f"Sampled {sum(sample_num_per_dataset.values())} samples from the datasets.")

Sampled 920 samples from the datasets.


In [11]:
sample_dataset_iter = random.choices(dataset_name_list, weights=weights, k=(sample_data_num - len(dataset_name_list) * min_sample_num))
sample_num_counter = Counter(sample_dataset_iter)
for k, v in sample_num_counter.items():
    sample_num_per_dataset[k] += v

In [12]:
for k, v in sample_num_per_dataset.items():
    print(f"{k}: {v} samples, target: {round(dataset_dict[k]['weight'] * sample_data_num)}")
    
print(f"Total sampled data number: {sum(sample_num_per_dataset.values())}")

CLEVR-Math(MathV360K): 29 samples, target: 22
FigureQA(MathV360K): 58 samples, target: 73
GEOS(MathV360K): 10 samples, target: 2
GeoQA+(MathV360K): 66 samples, target: 71
Geometry3K(MathV360K): 37 samples, target: 40
IconQA(MathV360K): 89 samples, target: 93
MapQA(MathV360K): 18 samples, target: 22
PMC-VQA(MathV360K): 131 samples, target: 148
Super-CLEVR(MathV360K): 37 samples, target: 36
TabMWP(MathV360K): 86 samples, target: 93
UniGeo(MathV360K): 49 samples, target: 49
VisualWebInstruct(filtered): 1069 samples, target: 1087
VizWiz(MathV360K): 32 samples, target: 27
ai2d(cauldron,llava_format): 10 samples, target: 10
ai2d(gpt4v): 21 samples, target: 20
ai2d(internvl): 39 samples, target: 51
allava_instruct_laion4v: 231 samples, target: 206
allava_instruct_vflan4v: 102 samples, target: 82
aokvqa(cauldron,llava_format): 61 samples, target: 68
chart2text(cauldron): 116 samples, target: 111
chartqa(cauldron,llava_format): 82 samples, target: 75
chrome_writting: 36 samples, target: 36
clev

# 実際にサンプル

In [13]:
save_json_data = []

for dataset_name, sample_num in sample_num_per_dataset.items():
    data = dataset_dict[dataset_name]["data"]
    if sample_num > len(data):
        print(f"Warning: {dataset_name} has only {len(data)} samples, but requested {sample_num} samples.")
        sample_num = len(data)
    sampled_data = random.sample(data, sample_num)
    save_json_data.extend(sampled_data)
    
print(f"Total sampled data number: {len(save_json_data)}")

Total sampled data number: 20000


# image数チェック

In [14]:
image_folder_root = "/data_ssd/llava-onevision-data-symbolic-link"
for item in tqdm(save_json_data):
    if "image" in item:
        image_list = item["image"] if isinstance(item["image"], list) else [item["image"]]
        iamge_list = [os.path.join(image_folder_root, img) for img in image_list]
        
        for image_path in iamge_list:
            if not os.path.exists(image_path):
                print(f"Warning: Image path {image_path} does not exist, removing item from data.")
            
        image_count = 0
        for conversation in item["conversations"]:
            image_count += conversation["value"].count("<image>")
            
        if image_count != len(image_list):
            print(image_list[0])
            break
            
            # print(f"{item["image"]} has more than one <image> tag {image_count}, removing item from data.")
        
        

100%|██████████| 20000/20000 [00:07<00:00, 2747.68it/s] 


In [16]:
print(image_count)

8


In [17]:
for k, v in item.items():
    print(f"{k}: {v}")

datasource: nextqa
id: 3104055504
image: ['nextqa/0.0.0/c55fcb268ead378049e4743c77ca2db3142e12a0f7dfc42eb8267e08efa85f58/train_images/3104055504_0.jpg', 'nextqa/0.0.0/c55fcb268ead378049e4743c77ca2db3142e12a0f7dfc42eb8267e08efa85f58/train_images/3104055504_1.jpg', 'nextqa/0.0.0/c55fcb268ead378049e4743c77ca2db3142e12a0f7dfc42eb8267e08efa85f58/train_images/3104055504_2.jpg', 'nextqa/0.0.0/c55fcb268ead378049e4743c77ca2db3142e12a0f7dfc42eb8267e08efa85f58/train_images/3104055504_3.jpg', 'nextqa/0.0.0/c55fcb268ead378049e4743c77ca2db3142e12a0f7dfc42eb8267e08efa85f58/train_images/3104055504_4.jpg', 'nextqa/0.0.0/c55fcb268ead378049e4743c77ca2db3142e12a0f7dfc42eb8267e08efa85f58/train_images/3104055504_5.jpg', 'nextqa/0.0.0/c55fcb268ead378049e4743c77ca2db3142e12a0f7dfc42eb8267e08efa85f58/train_images/3104055504_6.jpg', 'nextqa/0.0.0/c55fcb268ead378049e4743c77ca2db3142e12a0f7dfc42eb8267e08efa85f58/train_images/3104055504_7.jpg']
conversations: [{'from': 'human', 'value': '<image><image><image><imag

In [18]:
save_json_path = os.path.join("/data_ssd/LLaVA-OneVision-Data-M4-Instrct-Json", f"llava-onevision-m4-instruct_{sample_data_num}.json")

In [20]:
save_json(save_json_data, save_json_path)

In [21]:
loaded_data = load_json(save_json_path)
print(f"Loaded {len(loaded_data)} samples from {save_json_path}.")

Loaded 20000 samples from /data_ssd/LLaVA-OneVision-Data-M4-Instrct-Json/llava-onevision-m4-instruct_20000.json.


In [22]:
print(loaded_data[0])  # Print the first item to verify the content

{'id': 'identity_168600', 'conversations': [{'from': 'human', 'value': '<image>\nHint: Please answer the question and provide the final answer at the end.\nQuestion: How many objects are there in total?'}, {'from': 'gpt', 'value': 'The answer is 9'}], 'data_source': 'CLEVR-Math(MathV360K)', 'image': 'LLaVA-OneVision-Data/CLEVR-Math(MathV360K)/train/identity_168600.png'}
