In [None]:
from _init import *

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
import random, torch

from ranger.utils import json_utils
from ranger.chain_generate.chain_generator import ChainGenerator

In [None]:
def set_seed(seed: int):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)
    print(f'set_seed() seed : {seed}')

seed = 42
set_seed(seed)

In [None]:
model_name = 'meta-llama/Llama-3.2-3B-Instruct'
device = 0
dtype = 'float16'
max_seq_length = 4096
max_new_tokens = 128
temperature = 0.0
gpu_memory_utilization = 0.8

vllm_config = {
    "model_name": model_name,
    'device': f'cuda:{device}',
    'dtype': dtype,
    'max_seq_length': max_seq_length,
    'max_new_tokens': max_new_tokens,
    'temperature': temperature,
    'gpu_memory_utilization': gpu_memory_utilization,
    'n_log_prob': 20
}

corag_config = {
    'top_k_query': 20,
    'top_k_sub_query': 5,
    "task_desc": "answer multi-hop questions"
}

In [None]:
chain_generator = ChainGenerator(vllm_config, corag_config)

In [None]:
def datas_shuffle(datas: list, seed: int):
    rng = random.Random(seed)
    rng.shuffle(datas)


def load_datas(train_data_path: str, test_data_path: str, seed: int, do_print=False):
    train_datas = json_utils.load_file(train_data_path)
    test_datas = json_utils.load_file(test_data_path)
    datas_shuffle(train_datas, seed)
    datas_shuffle(test_datas, seed)
    
    return train_datas, test_datas

In [None]:
work_dir = f'/home/nlpshlee/dev_env/git/repos/ranger'
data_dir = f'{work_dir}/data'
out_dir = f'{work_dir}/output'

train_data_path = f'{data_dir}/custom_musique_train_5000_final.jsonl'
test_data_path = f'{data_dir}/custom_multihopqa_eval_1000.jsonl'
train_datas, test_datas = load_datas(train_data_path, test_data_path, seed, do_print=False)

In [None]:
# [datas, batch_size, n_chains, chain_depth]

results = chain_generator.generate(train_datas[:5], 2, 2, 3)
chain_generator.reset()

results = chain_generator.generate(train_datas[:10], 2, 2, 3)
chain_generator.reset()

results = chain_generator.generate(train_datas[:100], 16, 2, 3)
chain_generator.reset()

results = chain_generator.generate(train_datas[:100], 32, 2, 3)
chain_generator.reset()

results = chain_generator.generate(train_datas[:100], 16, 4, 3)
chain_generator.reset()

results = chain_generator.generate(train_datas[:100], 16, 2, 6)
chain_generator.reset()

results = chain_generator.generate(train_datas[:100], 2, 2, 3)
chain_generator.reset()