In [None]:
from _init import *

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [None]:
import numpy as np
import random, json
from typing import List
import matplotlib.pyplot as plt

from ranger.utils import common_utils, json_utils, tokenizer_utils, file_utils, container_utils
from ranger.corag.corag_result import ChainResult, QueryResult
from ranger.corag import corag_prompts

In [None]:
seed = 42
common_utils.set_seed(seed)

work_dir = f'/home/nlpshlee/dev_env/git/repos/ranger'
data_dir = f'{work_dir}/data'
sft_dir = f'{data_dir}/sft'
selected_dir = f'{sft_dir}/selected'
train_dir = f'{sft_dir}/selected_train'

In [None]:
# def plot_graph(data, xlabel, ylabel, title):
#     sorted_items = sorted(data.items(), key=lambda x: int(x[0]))

#     keys = [item[0] for item in sorted_items]
#     values = [item[1] for item in sorted_items]

#     plt.figure(figsize=(8, 6))
#     bars = plt.bar(keys, values, color='skyblue')

#     for bar in bars:
#         yval = bar.get_height()
#         # x: 막대 중심, y: 막대 높이 + 5(여백), s: 값 텍스트
#         plt.text(bar.get_x() + bar.get_width()/2, yval + 5, int(yval), ha='center', va='bottom')

#     plt.xlabel(xlabel)
#     plt.ylabel(ylabel)
#     plt.title(title)
#     # plt.savefig('dict_bar_chart_with_values.png')


def plot_graph(data, xlabel, ylabel, title):
    # 1. 키를 정수로 변환하여 정렬
    sorted_items = sorted(data.items(), key=lambda x: float(x[0]))

    keys = [item[0] for item in sorted_items]
    values = [item[1] for item in sorted_items]

    # 2. 전체 합계 계산 (비율 계산용)
    total_sum = sum(values)

    plt.figure(figsize=(12, 6)) # 텍스트가 길어지므로 가로를 좀 더 넓게 잡는 것을 추천합니다.
    bars = plt.bar(keys, values, color='skyblue')

    # 3. 누적 합계를 계산하며 텍스트 표기
    cum_value = 0
    
    # zip을 사용하여 막대 객체와 값을 동시에 순회
    for bar, value in zip(bars, values):
        cum_value += value # 누적 합계 업데이트
        
        # 비율 계산
        percent = (value / total_sum) * 100
        cum_percent = (cum_value / total_sum) * 100
        
        # 텍스트 포맷: 값 (개별% / 누적%)
        # 예: 232 (13.10% / 90.00%)
        label_text = f"{int(value)}({percent:.2f}%) / {int(cum_value)}({cum_percent:.2f}%)"
        
        yval = bar.get_height()
        
        # 텍스트 출력 (글자가 겹치지 않게 fontsize를 조정하거나 rotation을 줄 수 있음)
        plt.text(bar.get_x() + bar.get_width()/2, 
                 yval + (max(values) * 0.01), # 막대 높이에 따라 유동적으로 위쪽 여백 조정
                 label_text, 
                 ha='center', 
                 va='bottom',
                 fontsize=9) # 텍스트가 길어서 폰트 사이즈를 조금 줄임

    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    
    # 그래프 여백 자동 조정 (텍스트 잘림 방지)
    plt.tight_layout()
    plt.show()
    # plt.savefig('custom_bar_chart.png')

In [None]:
def check_is_stop(in_dir):
    in_file_paths = file_utils.get_file_paths(in_dir)
    print(f'{in_dir} size : {len(in_file_paths)}')

    all_cnt_is_stop = 0
    sum_cnt_is_stop = 0
    chain_depth_cnt_dict = {}
    f1_chain_depth_cnt_dict = {}
    chain_num_cnt_dict = {}
    min_chain_depth_cnt_dict = {}

    all_cnt_f1 = 0
    sum_cnt_f1 = 0
    f1_cnt_dict = {}

    for in_file_path in in_file_paths:
        with open(in_file_path, 'r', encoding='utf-8') as in_file:
            data = json.load(in_file)
            query_result = data['query_result']
            chain_results = query_result['chain_results']

            cnt_is_stop = 0
            min_chain_depth = 10000

            for chain_result in chain_results:
                is_stop = chain_result['is_stop']

                if bool(is_stop):
                    cnt_is_stop += 1

                    chain_depth = len(chain_result['final_answers'])
                    container_utils.add_str_int(chain_depth_cnt_dict, str(chain_depth), 1)

                    if chain_depth < min_chain_depth:
                        min_chain_depth = chain_depth
            
            if cnt_is_stop > 0:
                all_cnt_is_stop += 1
                sum_cnt_is_stop += cnt_is_stop
                container_utils.add_str_int(chain_num_cnt_dict, str(cnt_is_stop), 1)
                container_utils.add_str_int(min_chain_depth_cnt_dict, str(min_chain_depth), 1)
            else:
                cnt_f1 = 0

                for chain_result in chain_results:
                    f1 = float(chain_result['f1'])

                    # if f1 == 1.0:
                    #     print(f'in_file_path : {in_file_path}')
                    #     sys.exit(-1)

                    if f1 > 0.0:
                        cnt_f1 += 1
                        container_utils.add_str_int(f1_cnt_dict, f'{f1:.1f}', 1)

                        chain_depth = len(chain_result['final_answers'])
                        container_utils.add_str_int(f1_chain_depth_cnt_dict, str(chain_depth), 1)
                
                if cnt_f1 > 0:
                    all_cnt_f1 += 1
                    sum_cnt_f1 += cnt_f1

    
    print(f'all_cnt_is_stop : {all_cnt_is_stop}')
    print(f'sum_cnt_is_stop : {sum_cnt_is_stop}')
    print(f'chain_depth_cnt_dict : {sum(chain_depth_cnt_dict.values())} : {container_utils.sorted_dict_key(chain_depth_cnt_dict)}')
    print(f'min_chain_depth_cnt_dict : {sum(min_chain_depth_cnt_dict.values())} : {container_utils.sorted_dict_key(min_chain_depth_cnt_dict)}')
    print(f'chain_num_cnt_dict : {sum(chain_num_cnt_dict.values())} : {container_utils.sorted_dict_key(chain_num_cnt_dict)}\n')

    print(f'all_cnt_f1 : {all_cnt_f1}')
    print(f'sum_cnt_f1 : {sum_cnt_f1}')
    print(f'f1_cnt_dict : {sum(f1_cnt_dict.values())} : {container_utils.sorted_dict_key(f1_cnt_dict)}')
    print(f'f1_chain_depth_cnt_dict : {sum(f1_chain_depth_cnt_dict.values())} : {container_utils.sorted_dict_key(f1_chain_depth_cnt_dict)}\n')

    print(f'all_cnt : {all_cnt_is_stop+all_cnt_f1}')
    print(f'sum_cnt : {sum_cnt_is_stop+sum_cnt_f1}\n')


    plot_graph(
        chain_depth_cnt_dict,
        'chain depth',
        'count',
        ''
    )

    plot_graph(
        min_chain_depth_cnt_dict,
        'min chain depth per query',
        'count',
        ''
    )

    plot_graph(
        chain_num_cnt_dict,
        'number of chain(is_stop) per query',
        'count',
        ''
    )

    plot_graph(
        f1_cnt_dict,
        'f1-score',
        'count',
        ''
    )

    # plot_graph(
    #     f1_chain_depth_cnt_dict,
    #     'chain depth (f1)',
    #     'count',
    #     ''
    # )

In [None]:
# in_dir = f'{sft_dir}/train_5000_n_chains-5_chain_depth-5'
# check_is_stop(in_dir)

In [None]:
# in_dir = f'{sft_dir}/train_5000_n_chains-32_chain_depth-10'
# check_is_stop(in_dir)

In [None]:
def select_candidate(in_dir, max_depth, max_chains, out_dir):
    in_file_paths = file_utils.get_file_paths(in_dir)
    print(f'{in_dir} size : {len(in_file_paths)}')

    all_cnt_is_stop = 0
    sum_cnt_is_stop = 0
    
    all_cnt_selected = 0
    sum_cnt_selected = 0

    for in_file_path in in_file_paths:
        # in_file_path = '/home/nlpshlee/dev_env/git/repos/ranger/data/sft/train_5000_n_chains-32_chain_depth-10/2hop__720223_10690.json'

        with open(in_file_path, 'r', encoding='utf-8') as in_file:
            data = json.load(in_file)
            query_result = data['query_result']
            chain_results = query_result['chain_results']

            cnt_is_stop = 0
            chain_depth_dict = {}

            for _chain_result in chain_results:
                chain_result: ChainResult = ChainResult.from_dict(_chain_result)

                chain_depth = len(chain_result._final_answers)
                # print(f'chain_depth : {chain_depth}')

                if chain_result._is_stop and chain_depth <= max_depth:
                    cnt_is_stop += 1
                    
                    if chain_depth in chain_depth_dict.keys():
                        chain_depth_dict[chain_depth].append(chain_result)
                    else:
                        chain_depth_dict[chain_depth] = [chain_result]

            if 0 < cnt_is_stop:
                all_cnt_is_stop += 1
                sum_cnt_is_stop += cnt_is_stop
            
            concated_sq_set = set()
            selected_crs = []

            for chain_depth in container_utils.sorted_dict_key(chain_depth_dict).keys():
                chain_results = chain_depth_dict[chain_depth]
                # print(f'chain_depth : {chain_depth}, len : {len(chain_results)}')

                for i in range(len(chain_results)):
                    chain_result: ChainResult = chain_results[i]
                    concated_sq = '#%#'.join(chain_result._sub_querys)

                    if not concated_sq in concated_sq_set:
                        # print(f'concated_sq : {concated_sq}')

                        selected_crs.append(chain_result)
                        concated_sq_set.add(concated_sq)

                        if len(selected_crs) == max_chains:
                            break
                
                if len(selected_crs) == max_chains:
                    break
            
            if 0 < len(selected_crs):
                all_cnt_selected += 1
                sum_cnt_selected += len(selected_crs)

                selected_crs_dict = [chain_result.to_dict() for chain_result in selected_crs]
                data['query_result']['chain_results'] = selected_crs_dict

                out_file_path = f'{out_dir}/{file_utils.get_file_name(in_file_path)}'
                file_utils.make_parent(out_file_path)

                with open(out_file_path, 'w', encoding='utf-8') as out_file:
                    json.dump(data, out_file, ensure_ascii=False, indent=4)
        
        # break

    print(f'all_cnt_is_stop : {all_cnt_is_stop}')
    print(f'sum_cnt_is_stop : {sum_cnt_is_stop}')
    print(f'all_cnt_selected : {all_cnt_selected}')
    print(f'sum_cnt_selected : {sum_cnt_selected}\n')

In [None]:
# target_dir = 'train_5000_n_chains-32_chain_depth-10'
# max_depth, max_chains = 5, 5

# in_dir = f'{sft_dir}/{target_dir}'
# out_dir = f'{selected_dir}/{target_dir}_max_depth-{max_depth}_max_chains-{max_chains}'
# select_candidate(in_dir, max_depth, max_chains, out_dir)
# check_is_stop(out_dir)

In [None]:
def add_train_data(train_datas: list, query_id, query, prompt, generated_text, is_final_answer, is_last):
    if is_final_answer:
        prompt['content'] = corag_prompts.add_decide_stop_or_continue_prompt(prompt['content'])

        if not is_last:
            generated_text = f'<CONTINUE> {generated_text}'
        else:
            generated_text = f'<STOP> {generated_text}'

    train_data = {
        'query_id': query_id,
        'query': query,
        'source': prompt,
        'target': generated_text
    }

    train_datas.append(train_data)


def make_train(in_dir, out_dir, merge_file_path):
    file_utils.make_parent(merge_file_path)
    with open(merge_file_path, 'w', encoding='utf-8') as merge_file:
        in_file_paths = file_utils.get_file_paths(in_dir)
        print(f'{in_dir} size : {len(in_file_paths)}')

        merged_cnt = 0
        for in_file_path in in_file_paths:
            with open(in_file_path, 'r', encoding='utf-8') as in_file:
                data = json.load(in_file)
                query_id = data['query_id']
                query = data['query']

                query_result = data['query_result']
                chain_results = query_result['chain_results']

                train_datas = []

                for _chain_result in chain_results:
                    chain_result: ChainResult = ChainResult.from_dict(_chain_result)
                    chain_depth = len(chain_result._final_answers)
                    # print(f'chain_depth : {chain_depth}')

                    # depth 별로 모든 prompt 는 list 로 되어 있지만, size 는 '1'
                    for i in range(chain_depth):
                        prompt = chain_result._sub_query_prompts[i][0]
                        generated_text = chain_result._sub_querys[i]
                        add_train_data(train_datas, query_id, query, prompt, generated_text, False, False)

                        prompt = chain_result._sub_answer_prompts[i][0]
                        generated_text = chain_result._sub_answers[i]
                        add_train_data(train_datas, query_id, query, prompt, generated_text, False, False)

                        prompt = chain_result._final_answer_prompts[i][0]
                        generated_text = chain_result._final_answers[i]

                        if i < chain_depth-1:
                            add_train_data(train_datas, query_id, query, prompt, generated_text, True, False)
                        else:
                            add_train_data(train_datas, query_id, query, prompt, generated_text, True, True)

                out_file_path = f'{out_dir}/{query_id}.json'
                file_utils.make_parent(out_file_path)

                with open(out_file_path, 'w', encoding='utf-8') as out_file:
                    json.dump(train_datas, out_file, ensure_ascii=False, indent=4)
                
                for train_data in train_datas:
                    json_line = json.dumps(train_data, ensure_ascii=False)
                    merge_file.write(json_line + '\n')
                
                merged_cnt += len(train_datas)
    
    print(f'{merge_file_path} size : {merged_cnt}')

In [None]:
target_dir = 'train_5000_n_chains-32_chain_depth-10_max_depth-5_max_chains-5'
in_dir = f'{selected_dir}/{target_dir}'
out_dir = f'{train_dir}/{target_dir}'
merge_file_path = f'{train_dir}/{target_dir}_merged.jsonl'
make_train(in_dir, out_dir, merge_file_path)