In [1]:
from mhr.utils.utils import process_jsonl, load_json_file, write_json_file
import random
import math
from tqdm import tqdm
import argparse

In [2]:
threshold_dict = {'object': 304, 'token': 120, 'co_occurrence': 24, 'what_word': 4895}
reverse_index_file_dict = {
    'object': '/mnt/petrelfs/songmingyang/songmingyang/data/llava_train/LLaVA-Instruct-150K/reverse_index/llava_v1_5_mix665k_dino_stat_reverse_index.jsonl',
    'token': '/mnt/petrelfs/songmingyang/songmingyang/data/llava_train/LLaVA-Instruct-150K/reverse_index/llava_v1_5_mix665k_token_reverse_index.jsonl',
    'co_occurrence': '/mnt/petrelfs/songmingyang/songmingyang/data/llava_train/LLaVA-Instruct-150K/reverse_index/llava_v1_5_mix665k_co_occurrence_reverse_index.jsonl',
    'what_word': '/mnt/petrelfs/songmingyang/songmingyang/data/llava_train/LLaVA-Instruct-150K/reverse_index/llava_v1_5_mix665k_what_word_reverse_index.jsonl'
}

def build_prob_dict(file_dict,threshold_dict):
    """
    Build a dictionary of probabilities for each entry in the data.
    """
    entry_prob={}
    for key in ["token","what_word"]:
        entry_prob[key] = dict()
        data = process_jsonl(file_dict[key]) 
        for item in data:
            # length = len(item['ids']) if len(item['ids']) > threshold_dict[key] else threshold_dict[key]
            length = len(item['ids']) 
            entry_prob[key][item['object']] = threshold_dict[key] / length 

    return entry_prob

def sample_data_compose_alpha(threshold_dict, reverse_index_file_dict, input_dataset_file_path ,output_file_path,compose_list,alpha,pass_num=0):
    """
    Sample n rows from data.
    """
    D_star=[]
    entry_prob = build_prob_dict(reverse_index_file_dict,threshold_dict)
    origin_data = load_json_file(input_dataset_file_path)
    assert isinstance(compose_list,list)
    
    for item in tqdm(origin_data):
        pass_cnt = 0
        for key in compose_list:
            for obj in item['statistics'][key]:
                prob = entry_prob[key].get(obj,0)
                if random.random() < prob:
                    pass_cnt += 1
                    break
        if pass_cnt > pass_num and random.random() < alpha:
            D_star.append(item)
    print(f"length of D_star: {len(D_star)}")
    return D_star
    

In [4]:
compose_list=["token","what_word"]
input_data_file = "/mnt/petrelfs/songmingyang/songmingyang/data/llava_train/LLaVA-Instruct-150K/reformed_data/origin_data/llava_v1_5_mix665k.json"
output_data_file = "/mnt/petrelfs/songmingyang/songmingyang/data/llava_train/LLaVA-Instruct-150K/reformed_data/dr_algo/aug/token_aug/lm_rewrite_tail/llava_ft_to_aug.json"
D_star =  sample_data_compose_alpha(threshold_dict, reverse_index_file_dict, input_data_file, output_data_file, compose_list, 0.4, 0)

100%|██████████| 665298/665298 [00:01<00:00, 378367.63it/s]

length of D_star: 209028





In [5]:
write_json_file(D_star,output_data_file)