-
Notifications
You must be signed in to change notification settings - Fork 84
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4640ee3
commit 1f97c7e
Showing
13 changed files
with
883 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Our code is modified from https://github.com/maitrix-org/llm-reasoners/examples/rap_gsm8k | ||
|
||
An example script: | ||
|
||
```bash | ||
python examples/rap_gsm8k/inference.py --base_lm exllama --exllama_model_dir $LLAMA2_CKPTS --exllama_lora_dir None --exllama_mem_map '[16,22]' --n_action 1 --n_confidence 1 --n_iters 1 --temperature 0.0 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from world_model import GSM8kState, GSM8kAction, GSM8kWorldModel | ||
from search_config import GSM8kConfig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import pickle | ||
from typing import Optional | ||
import glob | ||
import os | ||
|
||
from tqdm import tqdm | ||
from datasets import load_dataset | ||
|
||
from reasoners.algorithm import MCTSAggregation, MCTSResult | ||
|
||
import utils | ||
|
||
|
||
def aggregate_rap_gsm8k(log_dir: str, | ||
start: int = 0): | ||
aggregator = MCTSAggregation(utils.retrieve_answer, weight_policy='edge') | ||
files = glob.glob(f'{log_dir}/algo_output/*.pkl') | ||
indices = sorted(filter(lambda index: index >= start, (int(os.path.basename(f)[:-4]) for f in files))) | ||
dataset = load_dataset("gsm8k", "main", split=f'test') | ||
correct_count = 0 | ||
for i, index in enumerate(tqdm(indices)): | ||
with open(f'{log_dir}/algo_output/{index}.pkl', 'rb') as f: | ||
result: MCTSResult = pickle.load(f) | ||
output = aggregator(result.tree_state) | ||
# output = utils.retrieve_answer(result.terminal_state) | ||
answer = utils.retrieve_answer_from_dataset(dataset[index - 1]['answer']) | ||
correct = utils.judge_answer(output, answer) | ||
|
||
correct_count += correct | ||
accuracy = correct_count / (i + 1) | ||
log_str = f'Case #{i + 1}({index}): {correct=}, {output=}, {answer=} ; {accuracy=:.3f} ({correct_count}/{i+1})' | ||
tqdm.write(log_str) | ||
|
||
|
||
if __name__ == '__main__': | ||
import fire | ||
fire.Fire(aggregate_rap_gsm8k) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import pickle | ||
# add path | ||
import sys | ||
sys.path.append('..') | ||
import os | ||
# print(os.cwd()) | ||
from reasoners.visualization import visualize | ||
from reasoners.visualization.tree_snapshot import NodeData | ||
from reasoners.algorithm.mcts import MCTSNode | ||
from datasets import load_dataset | ||
import pandas as pd | ||
import fire | ||
df = pd.DataFrame(columns=['question', 'cot']) | ||
# data = load_dataset('gsm8k','main','test') | ||
from datasets import Dataset | ||
def data_reader(dataset, dataset_path, split=None, sample_size=None): | ||
questions = [] | ||
answers = [] | ||
options = [] | ||
filename = os.path.join(dataset_path, f'{dataset}.json') | ||
lines = Dataset.from_json(filename) | ||
if split is not None: | ||
start, end = split | ||
lines = lines[start:end] | ||
for i in range(len(lines)): | ||
data = lines[i] | ||
if isinstance(data, dict): | ||
options_list = data['options'] | ||
question_with_options = data['question'] + " Options: " + (" ".join(data['options'])).replace('A)','A) ').replace('B)','B) ').replace('C)','C) ').replace('D)','D) ').replace('E)','E) ') + "." | ||
questions.append(question_with_options) | ||
answers.append(data['correct']) | ||
options.append(options_list) | ||
else: | ||
raise ValueError("Unexpected data format") | ||
return Dataset.from_dict({"question": questions, "answer": answers, "options":options}) | ||
|
||
|
||
def get_trace_gsm8k(): | ||
data = load_dataset('gsm8k','main','test') | ||
for i in range(1,len(data['test'])+1): | ||
mcts_result = pickle.load(open(f'/data/haotian/RAP_tune/llm-reasoners/logs/gsm8k_unknown/02292024-025642/algo_output/{i}.pkl', 'rb')) | ||
question = data['test'][i-1]['question'] | ||
cot = mcts_result[0] | ||
cot = cot.split('Q:')[0] | ||
# cot = cot.split('\n')[0]#for weak model | ||
cot_steps = cot.split('. ') | ||
print(cot) | ||
cot_final = "" | ||
# cot_final = cot | ||
for j in range(len(cot_steps)): | ||
cot_final += f'Step {j+1}: ' + cot_steps[j] + ".\n" | ||
cot_final = cot_final.rstrip('\n') | ||
df.loc[i-1] = [question, cot_final] | ||
|
||
df.to_json('/data/haotian/RAP_tune/llm-reasoners/logs/gsm8k_unknown/02292024-025642/cot1.json') | ||
|
||
def get_trace_sq(): | ||
# data = data_reader('AQuA','/data/haotian/RAP_tune/llm-reasoners/dataset/AQuA') | ||
import json | ||
with open('/data/haotian/RAP_tune/llm-reasoners/examples/rap_strategyQA/data/strategyqa_test.json', 'r') as f: | ||
data = json.load(f) | ||
# data = load_dataset('gsm8k','main','test') | ||
for i in range(1,len(data)+1): | ||
mcts_result = pickle.load(open(f'/data/haotian/RAP_tune/llm-reasoners/logs/strategyqa_cot/03062024-052230_anthropic/algo_output/{i}.pkl', 'rb')) | ||
question = data[i-1]['question'] | ||
cot = mcts_result | ||
cot = cot.split('Q:')[0] | ||
cot = cot.split('\n')[0] | ||
cot_steps = cot.split('. ') | ||
|
||
print(cot) | ||
cot_final = "" | ||
# cot_final = cot | ||
for j in range(len(cot_steps)): | ||
cot_final += f'Step {j+1}: ' + cot_steps[j] + ".\n" | ||
cot_final = cot_final.rstrip('\n') | ||
df.loc[i-1] = [question, cot_final] | ||
|
||
df.to_json('/data/haotian/RAP_tune/llm-reasoners/logs/strategyqa_cot/03062024-052230_anthropic/cot.json') | ||
|
||
|
||
def get_trace_aqua(): | ||
data = data_reader('AQuA','/data/haotian/RAP_tune/llm-reasoners/dataset/AQuA') | ||
for i in range(1,len(data)+1): | ||
mcts_result = pickle.load(open(f'/data/haotian/RAP_tune/llm-reasoners/logs/AQuAcot/03052024-051243_anthropic/algo_output/{i}.pkl', 'rb')) | ||
question = data[i-1]['question'] | ||
cot = mcts_result[0] | ||
print('------------',cot) | ||
cot = cot.split('Q:')[0] | ||
# cot = cot.split('\n')[0] | ||
# cot_steps = cot.split('. ') | ||
print(cot) | ||
cot_final = cot | ||
# for j in range(len(cot_steps)): | ||
# cot_final += f'Step {j+1}: ' + cot_steps[j] + ".\n" | ||
cot_final = cot_final.rstrip('\n') | ||
df.loc[i-1] = [question, cot_final] | ||
|
||
df.to_json('/data/haotian/RAP_tune/llm-reasoners/logs/AQuAcot/03052024-051243_anthropic/cot.json') | ||
|
||
|
||
# fire.Fire(get_trace_sq) | ||
fire.Fire(get_trace_gsm8k) | ||
# fire.Fire(get_trace_aqua) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
from typing import Type, Callable, Optional, Literal | ||
|
||
import numpy as np | ||
|
||
from reasoners.benchmark import GSM8KEvaluator | ||
|
||
from reasoners import LanguageModel, Reasoner, SearchAlgorithm | ||
from reasoners.algorithm import MCTS, MCTSNode, MCTSAggregation | ||
|
||
from world_model import GSM8kWorldModel, GSM8kState, GSM8kAction, GSM8kPromptDict | ||
from search_config import GSM8kConfig, GSM8kUsefulPrompt | ||
import utils | ||
|
||
|
||
def node_visualizer(x: MCTSNode): | ||
if not x.state: | ||
return {} | ||
return {"question": x.state[-1].sub_question, "answer": x.state[-1].sub_answer} | ||
|
||
def rap_gsm8k(base_model: LanguageModel, | ||
prompt: GSM8kPromptDict, | ||
useful_prompt: GSM8kUsefulPrompt, | ||
search_algo: Type[SearchAlgorithm] = MCTS, | ||
resume: int = 0, | ||
# n_action: int = 4, # TODO | ||
# n_confidence: int = 8, | ||
n_action: int = 1, | ||
n_confidence: int = 1, | ||
depth_limit: int = 5, | ||
force_terminating_on_depth_limit: bool = True, | ||
batch_size: int = 2, | ||
# temperature: float = 0.8,# TODO | ||
temperature: float = 0., | ||
early_stop_base: int = 2, | ||
early_stop_threshold: float = 0.5, | ||
reward_alpha: float = 0.5, | ||
reward_confidence_default: float = 0.8, | ||
cum_reward: Callable[[list[float]], float] = np.mean, | ||
calc_q: Callable[[list[float]], float] = max, | ||
log_dir: Optional[str] = None, | ||
disable_log: bool = False, | ||
disable_tqdm: bool = False, | ||
output_trace_in_each_iter: bool = True, | ||
aggregate: bool = True, | ||
**search_algo_params): | ||
|
||
if aggregate: | ||
aggregator = MCTSAggregation(utils.retrieve_answer, weight_policy='edge') | ||
else: | ||
aggregator = None | ||
|
||
search_algo_params |= {'cum_reward': cum_reward, 'calc_q': calc_q, 'disable_tqdm': disable_tqdm, | ||
'output_trace_in_each_iter': output_trace_in_each_iter, | ||
'node_visualizer': node_visualizer, 'aggregator': aggregator} | ||
world_model = GSM8kWorldModel(base_model=base_model, | ||
n_confidence=n_confidence, batch_size=batch_size, temperature=temperature, | ||
early_stop_base=early_stop_base, early_stop_threshold=early_stop_threshold) | ||
config = GSM8kConfig(base_model=base_model, useful_prompt=useful_prompt, | ||
n_actions=n_action, batch_size=batch_size, temperature=temperature, | ||
reward_alpha=reward_alpha, reward_confidence_default=reward_confidence_default, | ||
force_terminating_on_depth_limit=force_terminating_on_depth_limit, depth_limit=depth_limit) | ||
search_algo = search_algo(**search_algo_params) | ||
reasoner = Reasoner(world_model=world_model, search_config=config, search_algo=search_algo) | ||
|
||
evaluator = GSM8KEvaluator(output_extractor=utils.retrieve_answer, | ||
answer_extractor=utils.retrieve_answer_from_dataset, | ||
init_prompt=prompt, | ||
sample_prompt_type="rap", | ||
disable_log=disable_log, | ||
disable_tqdm=disable_tqdm) | ||
|
||
accuracy = evaluator.evaluate(reasoner, num_shot=4, resume=resume, log_dir=log_dir) | ||
print(accuracy) | ||
|
||
|
||
if __name__ == '__main__': | ||
import os | ||
import sys | ||
import json | ||
import warnings | ||
import fire | ||
import random | ||
|
||
llama_ckpts = os.environ.get("LLAMA_CKPTS", None) | ||
llama_2_ckpts = os.environ.get("LLAMA_2_CKPTS", None) | ||
# llama_2_ckpts = "/mnt/afs/niuyazhe/data/llama-2-7b-hf" | ||
|
||
local_rank = int(os.environ.get("LOCAL_RANK", 0)) | ||
if local_rank != 0: | ||
sys.stdout = open(os.devnull, 'w') | ||
warnings.filterwarnings('ignore') | ||
|
||
|
||
|
||
def main(base_lm: Literal['llama', 'llama.cpp', 'llama-2', 'hf', 'exllama'] = 'hf', | ||
llama_ckpts: str = llama_ckpts, | ||
llama_2_ckpts: str = llama_2_ckpts, | ||
# llama_size: str = '13B', | ||
llama_size: str = '7B', | ||
llama_cpp_path: str = None, | ||
llama_cpp_n_batch: int = 512, | ||
hf_path: str = '/mnt/afs/niuyazhe/data/llama-2-7b-hf', # TODO | ||
# hf_path: str = 'meta-llama/Llama-2-13b-hf', | ||
hf_peft_path: Optional[str] = None, | ||
hf_quantized: Optional[Literal['awq', 'int8', 'fp4', 'nf4']] = None, | ||
hf_load_awq_path: Optional[str] = None, | ||
exllama_model_dir: str = 'WizardMath-13B-V1.0-GPTQ', | ||
exllama_lora_dir: Optional[str] = None, | ||
exllama_mem_map: Optional[str] = None, | ||
batch_size: int = 1, | ||
useful_prompt: str = '/mnt/afs/niuyazhe/code/llm-reasoners/examples/rap_gsm8k/prompts/useful_examples.json', | ||
prompt: str = '/mnt/afs/niuyazhe/code/llm-reasoners/examples/rap_gsm8k/prompts/prompt_pool.json', | ||
disable_log: bool = False, | ||
disable_tqdm: bool = False, | ||
**kwargs): | ||
|
||
with open(useful_prompt) as f: | ||
useful_prompt = json.load(f) | ||
with open(prompt) as f: | ||
prompt = json.load(f) | ||
if base_lm in ['llama', 'llama2']: | ||
import torch | ||
import torch.backends.cudnn | ||
np.random.seed(0) | ||
random.seed(0) | ||
torch.manual_seed(0) | ||
torch.cuda.manual_seed(0) | ||
torch.backends.cudnn.deterministic = True | ||
|
||
if base_lm == 'llama': | ||
from reasoners.lm import LlamaModel | ||
base_model = LlamaModel(llama_ckpts, llama_size, max_batch_size=batch_size) | ||
elif base_lm == 'llama.cpp': | ||
from reasoners.lm import LlamaCppModel | ||
base_model = LlamaCppModel(llama_cpp_path, n_batch=llama_cpp_n_batch) | ||
elif base_lm == 'llama-2': | ||
from reasoners.lm import Llama2Model | ||
base_model = Llama2Model(llama_2_ckpts, llama_size, max_batch_size=batch_size) | ||
elif base_lm == 'hf': | ||
from reasoners.lm import HFModel | ||
base_model = HFModel(hf_path, hf_path, max_batch_size=batch_size, max_new_tokens=512, | ||
peft_pth=hf_peft_path, quantized=hf_quantized, load_awq_pth=hf_load_awq_path) | ||
elif base_lm == 'exllama': | ||
from reasoners.lm import ExLlamaModel | ||
base_model = ExLlamaModel(exllama_model_dir, exllama_lora_dir, mem_map=exllama_mem_map, | ||
max_batch_size=batch_size, max_new_tokens=200, max_seq_length=3072) | ||
else: | ||
assert False, f'cannot resolve {base_lm=}' | ||
rap_gsm8k(base_model=base_model, | ||
useful_prompt=useful_prompt, | ||
prompt=prompt, | ||
batch_size=batch_size, | ||
disable_log=disable_log or local_rank != 0, | ||
disable_tqdm=disable_tqdm or local_rank != 0, | ||
**kwargs) | ||
|
||
|
||
fire.Fire(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
{ | ||
"input": "Given a question, please decompose it into sub-questions. For each sub-question, please answer it in a complete sentence, ending with \"The answer is\". When the original question is answerable, please start the subquestion with \"Now we can answer the question: \".\n\nQuestion 1: Four years ago, Kody was only half as old as Mohamed. If Mohamed is currently twice as 30 years old, how old is Kody?\nQuestion 1.1: How old is Mohamed?\nAnswer 1.1: He is currently 30 * 2 = 60 years old. The answer is 60.\nQuestion 1.2: How old was Mohamed four years ago?\nAnswer 1.2: Four years ago, he must have been 60 - 4 = 56 years old. The answer is 56.\nQuestion 1.3: How old was Kody four years ago?\nAnswer 1.3: Kody was half as old as Mohamed four years ago. Thus, Kody was 56 / 2 = 28 years old. The answer is 28.\nQuestion 1.4: Now we can answer the question: How old is Kody?\nAnswer 1.4: She is currently 28 + 4 = 32 years old. The answer is 32.\n\nQuestion 2: On a moonless night, three fireflies danced in the evening breeze. They were joined by four less than a dozen more fireflies before two of the fireflies flew away. How many fireflies remained?\nQuestion 2.1: How many fireflies joined?\nAnswer 2.1: The fireflies were joined by four less than a dozen more fireflies, which are 12 - 4 = 8 fireflies. The answer is 8.\nQuestion 2.2: Now we can answer the question: How many fireflies remained?\nAnswer 2.2: Three fireflies were dancing originally. They were joined by 8 fireflies before two of them flew away. So there were 3 + 8 - 2 = 9 remaining. The answer is 9.\n\nQuestion 3: Ali has four $10 bills and six $20 bills that he saved after working for Mr. James on his farm. Ali gives her sister half of the total money he has and uses 3/5 of the remaining amount of money to buy dinner. Calculate the amount of money he has after buying the dinner.\nQuestion 3.1: How much money does Ali have in total?\nAnswer 3.1: Ali has four $10 bills and six $20 bills. So he has 4 * 10 + 6 * 20 = 160 dollars. The answer is 160.\nQuestion 3.2: How much money does Ali give to his sister?\nAnswer 3.2: Ali gives half of the total money he has to his sister. So he gives 160 / 2 = 80 dollars to his sister. The answer is 80.\nQuestion 3.3: How much money does Ali have after giving his sister the money?\nAnswer 3.3: After giving his sister the money, Ali has 160 - 80 = 80 dollars left. The answer is 80.\nQuestion 3.4: How much money does Ali use to buy dinner?\nAnswer 3.4: Ali uses 3/5 of the remaining amount of money to buy dinner. So he uses 80 * 3/5 = 48 dollars to buy dinner. The answer is 48.\nQuestion 3.5: Now we can answer the question: How much money does Ali have after buying the dinner?\nAnswer 3.5: After buying the dinner, Ali has 80 - 48 = 32 dollars left. The answer is 32.\n\nQuestion 4: A car is driving through a tunnel with many turns. After a while, the car must travel through a ring that requires a total of 4 right-hand turns. After the 1st turn, it travels 5 meters. After the 2nd turn, it travels 8 meters. After the 3rd turn, it travels a little further and at the 4th turn, it immediately exits the tunnel. If the car has driven a total of 23 meters around the ring, how far did it have to travel after the 3rd turn?\nQuestion 4.1: How far did the car travel except for the 3rd turn?\nAnswer 4.1: It travels 5 meters after the 1st, 8 meters after the 2nd, and 0 meters after the 4th turn. It's a total of 5 + 8 + 0 = 13 meters. The answer is 13.\nQuestion 4.2: Now we can answer the question: How far did the car have to travel after the 3rd turn?\nAnswer 4.2: The car has driven a total of 23 meters around the ring. It travels 13 meters except for the 3rd turn. So it has to travel 23 - 13 = 10 meters after the 3rd turn. The answer is 10.\n\n", | ||
"question_prefix": "Question 5: ", | ||
"subquestion_prefix": "Question 5.{}:", | ||
"overall_question_prefix": "Now we can answer the question:", | ||
"answer_prefix": "Answer 5.{}:", | ||
"index": 5 | ||
} |
Oops, something went wrong.