In [None]:
from _init import *

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [None]:
import numpy as np
import random, json
from typing import List

from transformers import PreTrainedTokenizerFast
from ranger.utils import common_utils, json_utils, tokenizer_utils, file_utils, container_utils

from ranger.corag.corag_result import QueryResult, ChainResult
from ranger.chain_generate.chain_generate_client import request_chain_generate
from ranger.reward.reward_calculator import RewardCalculator

In [None]:
seed = 42
common_utils.set_seed(seed)


def datas_shuffle(datas: list, seed: int):
    rng = random.Random(seed)
    rng.shuffle(datas)


def load_datas(train_data_path: str, test_data_path: str, seed=-1):
    train_datas = json_utils.load_file(train_data_path)
    test_datas = json_utils.load_file(test_data_path)

    if seed != -1:
        datas_shuffle(train_datas, seed)
        datas_shuffle(test_datas, seed)
    
    return train_datas, test_datas


work_dir = f'/home/nlpshlee/dev_env/git/repos/ranger'
data_dir = f'{work_dir}/data'
out_dir = f'{work_dir}/data/sft'

train_data_path = f'{data_dir}/custom_musique_train_5000_final.jsonl'
test_data_path = f'{data_dir}/custom_multihopqa_eval_1000.jsonl'
train_datas, test_datas = load_datas(train_data_path, test_data_path)

tokenizer = tokenizer_utils.load_tokenizer(VLLM_CONFIG['model_name'])
bos_token_id = tokenizer.bos_token_id

reward_calculator = RewardCalculator(REWARD_CONFIG['reward_option'])

In [None]:
# def calculate_max_new_tokens(answer: str, tokenizer: PreTrainedTokenizerFast):
#     token_ids = tokenizer(answer, add_special_tokens=False)['input_ids']
#     # print(f'token_ids len : {len(token_ids)}')

#     start_idx = 1 if token_ids[0] == bos_token_id else 0
#     cur_len = -1

#     for i in range(start_idx, len(token_ids)):
#         decoded = tokenizer.decode(token_ids[start_idx:i+1])

#         if answer == decoded:
#             cur_len = (i+1)-start_idx
#             # print(f'{start_idx} - {i+1}')
#             # print(decoded)
#             # print(f'cur_len : {cur_len}')
#             break
    
#     return cur_len


# datas = train_datas + test_datas
# print(f'datas size : {len(datas)}')

# max_len = -1

# for data in datas:
#     answers = data['answers']

#     for answer in answers:
#         # print(f'answer : {answer}')
#         cur_len = calculate_max_new_tokens(answer, tokenizer)

#         if max_len < cur_len:
#             max_len = cur_len

# print(f'max_len : {max_len}')

In [None]:
# all_responses_dict = []

# for i, data in enumerate(train_datas):
#     print(f'[{i}] query : {data["query"]}')

#     responses: List[QueryResult] = request_chain_generate([data], 1, 5, 5)
#     responses_dict = [query_result.to_dict() for query_result in responses]
#     all_responses_dict.extend(responses_dict)

#     if i == 4:
#         break

# responses: List[QueryResult] = request_chain_generate(train_datas[:5], 5, 5, 5)
# responses_dict = [query_result.to_dict() for query_result in responses]
# all_responses_dict.extend(responses_dict)

# print(f'{json_utils.to_str(all_responses_dict)}')

In [None]:
def chain_generate_for_sft(datas, batch_size, n_chains, chain_depth, out_dir):
    for datas_batch in container_utils.chunks(datas, batch_size):
        query_results: List[QueryResult] = request_chain_generate(datas_batch, batch_size, n_chains, chain_depth)

        reward_calculator.calculate_reward_and_advantage(query_results)

        for data, query_result in zip(datas_batch, query_results):
            query_id = data['query_id']

            if query_id == query_result._query_id:
                data['query_result'] = query_result.to_dict()

                out_file_path = f'{out_dir}/{query_id}.json'
                file_utils.make_parent(out_file_path)

                with open(out_file_path, 'w', encoding='utf-8') as out_file:
                    json.dump(data, out_file, ensure_ascii=False, indent=4)
            else:
                print(f'# [error] chain_generate_for_sft() query_id is diff : {query_id} - {query_result._query_id}')


def merge_to_jsonl(in_dir, out_file_path):
    in_file_paths = file_utils.get_file_paths(in_dir)
    print(f'{in_dir} : {len(in_file_paths)} files merge')

    file_utils.make_parent(out_file_path)
    with open(out_file_path, 'w', encoding='utf-8') as out_file:
        for in_file_path in in_file_paths:
            with open(in_file_path, 'r', encoding='utf-8') as in_file:
                data = json.load(in_file)
                json_line = json.dumps(data, ensure_ascii=False)
                out_file.write(json_line + '\n')

In [None]:
n_chains, chain_depth = 32, 10

In [None]:
out_path = f'{out_dir}/train_5000_n_chains-{n_chains}_chain_depth-{chain_depth}'

chain_generate_for_sft(train_datas, 100, n_chains, chain_depth, out_path)
merge_to_jsonl(out_path, f'{out_path}_merged.jsonl')

In [None]:
out_path = f'{out_dir}/test_1000_n_chains-{n_chains}_chain_depth-{chain_depth}'

chain_generate_for_sft(test_datas, 100, n_chains, chain_depth, out_path)
merge_to_jsonl(out_path, f'{out_path}_merged.jsonl')