Skip to content

Commit

Permalink
feature(pu): add rap_gsm8k
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Apr 25, 2024
1 parent 4640ee3 commit 1f97c7e
Show file tree
Hide file tree
Showing 13 changed files with 883 additions and 0 deletions.
7 changes: 7 additions & 0 deletions zoo/reasoning/rap_gsm8k/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Our code is modified from https://github.com/maitrix-org/llm-reasoners/examples/rap_gsm8k

An example script:

```bash
python examples/rap_gsm8k/inference.py --base_lm exllama --exllama_model_dir $LLAMA2_CKPTS --exllama_lora_dir None --exllama_mem_map '[16,22]' --n_action 1 --n_confidence 1 --n_iters 1 --temperature 0.0
```
2 changes: 2 additions & 0 deletions zoo/reasoning/rap_gsm8k/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from world_model import GSM8kState, GSM8kAction, GSM8kWorldModel
from search_config import GSM8kConfig
37 changes: 37 additions & 0 deletions zoo/reasoning/rap_gsm8k/aggregate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pickle
from typing import Optional
import glob
import os

from tqdm import tqdm
from datasets import load_dataset

from reasoners.algorithm import MCTSAggregation, MCTSResult

import utils


def aggregate_rap_gsm8k(log_dir: str,
start: int = 0):
aggregator = MCTSAggregation(utils.retrieve_answer, weight_policy='edge')
files = glob.glob(f'{log_dir}/algo_output/*.pkl')
indices = sorted(filter(lambda index: index >= start, (int(os.path.basename(f)[:-4]) for f in files)))
dataset = load_dataset("gsm8k", "main", split=f'test')
correct_count = 0
for i, index in enumerate(tqdm(indices)):
with open(f'{log_dir}/algo_output/{index}.pkl', 'rb') as f:
result: MCTSResult = pickle.load(f)
output = aggregator(result.tree_state)
# output = utils.retrieve_answer(result.terminal_state)
answer = utils.retrieve_answer_from_dataset(dataset[index - 1]['answer'])
correct = utils.judge_answer(output, answer)

correct_count += correct
accuracy = correct_count / (i + 1)
log_str = f'Case #{i + 1}({index}): {correct=}, {output=}, {answer=} ; {accuracy=:.3f} ({correct_count}/{i+1})'
tqdm.write(log_str)


if __name__ == '__main__':
import fire
fire.Fire(aggregate_rap_gsm8k)
104 changes: 104 additions & 0 deletions zoo/reasoning/rap_gsm8k/get_traces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import pickle
# add path
import sys
sys.path.append('..')
import os
# print(os.cwd())
from reasoners.visualization import visualize
from reasoners.visualization.tree_snapshot import NodeData
from reasoners.algorithm.mcts import MCTSNode
from datasets import load_dataset
import pandas as pd
import fire
df = pd.DataFrame(columns=['question', 'cot'])
# data = load_dataset('gsm8k','main','test')
from datasets import Dataset
def data_reader(dataset, dataset_path, split=None, sample_size=None):
questions = []
answers = []
options = []
filename = os.path.join(dataset_path, f'{dataset}.json')
lines = Dataset.from_json(filename)
if split is not None:
start, end = split
lines = lines[start:end]
for i in range(len(lines)):
data = lines[i]
if isinstance(data, dict):
options_list = data['options']
question_with_options = data['question'] + " Options: " + (" ".join(data['options'])).replace('A)','A) ').replace('B)','B) ').replace('C)','C) ').replace('D)','D) ').replace('E)','E) ') + "."
questions.append(question_with_options)
answers.append(data['correct'])
options.append(options_list)
else:
raise ValueError("Unexpected data format")
return Dataset.from_dict({"question": questions, "answer": answers, "options":options})


def get_trace_gsm8k():
data = load_dataset('gsm8k','main','test')
for i in range(1,len(data['test'])+1):
mcts_result = pickle.load(open(f'/data/haotian/RAP_tune/llm-reasoners/logs/gsm8k_unknown/02292024-025642/algo_output/{i}.pkl', 'rb'))
question = data['test'][i-1]['question']
cot = mcts_result[0]
cot = cot.split('Q:')[0]
# cot = cot.split('\n')[0]#for weak model
cot_steps = cot.split('. ')
print(cot)
cot_final = ""
# cot_final = cot
for j in range(len(cot_steps)):
cot_final += f'Step {j+1}: ' + cot_steps[j] + ".\n"
cot_final = cot_final.rstrip('\n')
df.loc[i-1] = [question, cot_final]

df.to_json('/data/haotian/RAP_tune/llm-reasoners/logs/gsm8k_unknown/02292024-025642/cot1.json')

def get_trace_sq():
# data = data_reader('AQuA','/data/haotian/RAP_tune/llm-reasoners/dataset/AQuA')
import json
with open('/data/haotian/RAP_tune/llm-reasoners/examples/rap_strategyQA/data/strategyqa_test.json', 'r') as f:
data = json.load(f)
# data = load_dataset('gsm8k','main','test')
for i in range(1,len(data)+1):
mcts_result = pickle.load(open(f'/data/haotian/RAP_tune/llm-reasoners/logs/strategyqa_cot/03062024-052230_anthropic/algo_output/{i}.pkl', 'rb'))
question = data[i-1]['question']
cot = mcts_result
cot = cot.split('Q:')[0]
cot = cot.split('\n')[0]
cot_steps = cot.split('. ')

print(cot)
cot_final = ""
# cot_final = cot
for j in range(len(cot_steps)):
cot_final += f'Step {j+1}: ' + cot_steps[j] + ".\n"
cot_final = cot_final.rstrip('\n')
df.loc[i-1] = [question, cot_final]

df.to_json('/data/haotian/RAP_tune/llm-reasoners/logs/strategyqa_cot/03062024-052230_anthropic/cot.json')


def get_trace_aqua():
data = data_reader('AQuA','/data/haotian/RAP_tune/llm-reasoners/dataset/AQuA')
for i in range(1,len(data)+1):
mcts_result = pickle.load(open(f'/data/haotian/RAP_tune/llm-reasoners/logs/AQuAcot/03052024-051243_anthropic/algo_output/{i}.pkl', 'rb'))
question = data[i-1]['question']
cot = mcts_result[0]
print('------------',cot)
cot = cot.split('Q:')[0]
# cot = cot.split('\n')[0]
# cot_steps = cot.split('. ')
print(cot)
cot_final = cot
# for j in range(len(cot_steps)):
# cot_final += f'Step {j+1}: ' + cot_steps[j] + ".\n"
cot_final = cot_final.rstrip('\n')
df.loc[i-1] = [question, cot_final]

df.to_json('/data/haotian/RAP_tune/llm-reasoners/logs/AQuAcot/03052024-051243_anthropic/cot.json')


# fire.Fire(get_trace_sq)
fire.Fire(get_trace_gsm8k)
# fire.Fire(get_trace_aqua)
158 changes: 158 additions & 0 deletions zoo/reasoning/rap_gsm8k/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from typing import Type, Callable, Optional, Literal

import numpy as np

from reasoners.benchmark import GSM8KEvaluator

from reasoners import LanguageModel, Reasoner, SearchAlgorithm
from reasoners.algorithm import MCTS, MCTSNode, MCTSAggregation

from world_model import GSM8kWorldModel, GSM8kState, GSM8kAction, GSM8kPromptDict
from search_config import GSM8kConfig, GSM8kUsefulPrompt
import utils


def node_visualizer(x: MCTSNode):
if not x.state:
return {}
return {"question": x.state[-1].sub_question, "answer": x.state[-1].sub_answer}

def rap_gsm8k(base_model: LanguageModel,
prompt: GSM8kPromptDict,
useful_prompt: GSM8kUsefulPrompt,
search_algo: Type[SearchAlgorithm] = MCTS,
resume: int = 0,
# n_action: int = 4, # TODO
# n_confidence: int = 8,
n_action: int = 1,
n_confidence: int = 1,
depth_limit: int = 5,
force_terminating_on_depth_limit: bool = True,
batch_size: int = 2,
# temperature: float = 0.8,# TODO
temperature: float = 0.,
early_stop_base: int = 2,
early_stop_threshold: float = 0.5,
reward_alpha: float = 0.5,
reward_confidence_default: float = 0.8,
cum_reward: Callable[[list[float]], float] = np.mean,
calc_q: Callable[[list[float]], float] = max,
log_dir: Optional[str] = None,
disable_log: bool = False,
disable_tqdm: bool = False,
output_trace_in_each_iter: bool = True,
aggregate: bool = True,
**search_algo_params):

if aggregate:
aggregator = MCTSAggregation(utils.retrieve_answer, weight_policy='edge')
else:
aggregator = None

search_algo_params |= {'cum_reward': cum_reward, 'calc_q': calc_q, 'disable_tqdm': disable_tqdm,
'output_trace_in_each_iter': output_trace_in_each_iter,
'node_visualizer': node_visualizer, 'aggregator': aggregator}
world_model = GSM8kWorldModel(base_model=base_model,
n_confidence=n_confidence, batch_size=batch_size, temperature=temperature,
early_stop_base=early_stop_base, early_stop_threshold=early_stop_threshold)
config = GSM8kConfig(base_model=base_model, useful_prompt=useful_prompt,
n_actions=n_action, batch_size=batch_size, temperature=temperature,
reward_alpha=reward_alpha, reward_confidence_default=reward_confidence_default,
force_terminating_on_depth_limit=force_terminating_on_depth_limit, depth_limit=depth_limit)
search_algo = search_algo(**search_algo_params)
reasoner = Reasoner(world_model=world_model, search_config=config, search_algo=search_algo)

evaluator = GSM8KEvaluator(output_extractor=utils.retrieve_answer,
answer_extractor=utils.retrieve_answer_from_dataset,
init_prompt=prompt,
sample_prompt_type="rap",
disable_log=disable_log,
disable_tqdm=disable_tqdm)

accuracy = evaluator.evaluate(reasoner, num_shot=4, resume=resume, log_dir=log_dir)
print(accuracy)


if __name__ == '__main__':
import os
import sys
import json
import warnings
import fire
import random

llama_ckpts = os.environ.get("LLAMA_CKPTS", None)
llama_2_ckpts = os.environ.get("LLAMA_2_CKPTS", None)
# llama_2_ckpts = "/mnt/afs/niuyazhe/data/llama-2-7b-hf"

local_rank = int(os.environ.get("LOCAL_RANK", 0))
if local_rank != 0:
sys.stdout = open(os.devnull, 'w')
warnings.filterwarnings('ignore')



def main(base_lm: Literal['llama', 'llama.cpp', 'llama-2', 'hf', 'exllama'] = 'hf',
llama_ckpts: str = llama_ckpts,
llama_2_ckpts: str = llama_2_ckpts,
# llama_size: str = '13B',
llama_size: str = '7B',
llama_cpp_path: str = None,
llama_cpp_n_batch: int = 512,
hf_path: str = '/mnt/afs/niuyazhe/data/llama-2-7b-hf', # TODO
# hf_path: str = 'meta-llama/Llama-2-13b-hf',
hf_peft_path: Optional[str] = None,
hf_quantized: Optional[Literal['awq', 'int8', 'fp4', 'nf4']] = None,
hf_load_awq_path: Optional[str] = None,
exllama_model_dir: str = 'WizardMath-13B-V1.0-GPTQ',
exllama_lora_dir: Optional[str] = None,
exllama_mem_map: Optional[str] = None,
batch_size: int = 1,
useful_prompt: str = '/mnt/afs/niuyazhe/code/llm-reasoners/examples/rap_gsm8k/prompts/useful_examples.json',
prompt: str = '/mnt/afs/niuyazhe/code/llm-reasoners/examples/rap_gsm8k/prompts/prompt_pool.json',
disable_log: bool = False,
disable_tqdm: bool = False,
**kwargs):

with open(useful_prompt) as f:
useful_prompt = json.load(f)
with open(prompt) as f:
prompt = json.load(f)
if base_lm in ['llama', 'llama2']:
import torch
import torch.backends.cudnn
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic = True

if base_lm == 'llama':
from reasoners.lm import LlamaModel
base_model = LlamaModel(llama_ckpts, llama_size, max_batch_size=batch_size)
elif base_lm == 'llama.cpp':
from reasoners.lm import LlamaCppModel
base_model = LlamaCppModel(llama_cpp_path, n_batch=llama_cpp_n_batch)
elif base_lm == 'llama-2':
from reasoners.lm import Llama2Model
base_model = Llama2Model(llama_2_ckpts, llama_size, max_batch_size=batch_size)
elif base_lm == 'hf':
from reasoners.lm import HFModel
base_model = HFModel(hf_path, hf_path, max_batch_size=batch_size, max_new_tokens=512,
peft_pth=hf_peft_path, quantized=hf_quantized, load_awq_pth=hf_load_awq_path)
elif base_lm == 'exllama':
from reasoners.lm import ExLlamaModel
base_model = ExLlamaModel(exllama_model_dir, exllama_lora_dir, mem_map=exllama_mem_map,
max_batch_size=batch_size, max_new_tokens=200, max_seq_length=3072)
else:
assert False, f'cannot resolve {base_lm=}'
rap_gsm8k(base_model=base_model,
useful_prompt=useful_prompt,
prompt=prompt,
batch_size=batch_size,
disable_log=disable_log or local_rank != 0,
disable_tqdm=disable_tqdm or local_rank != 0,
**kwargs)


fire.Fire(main)
8 changes: 8 additions & 0 deletions zoo/reasoning/rap_gsm8k/prompts/interactive_examples.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"input": "Given a question, please decompose it into sub-questions. For each sub-question, please answer it in a complete sentence, ending with \"The answer is\". When the original question is answerable, please start the subquestion with \"Now we can answer the question: \".\n\nQuestion 1: Four years ago, Kody was only half as old as Mohamed. If Mohamed is currently twice as 30 years old, how old is Kody?\nQuestion 1.1: How old is Mohamed?\nAnswer 1.1: He is currently 30 * 2 = 60 years old. The answer is 60.\nQuestion 1.2: How old was Mohamed four years ago?\nAnswer 1.2: Four years ago, he must have been 60 - 4 = 56 years old. The answer is 56.\nQuestion 1.3: How old was Kody four years ago?\nAnswer 1.3: Kody was half as old as Mohamed four years ago. Thus, Kody was 56 / 2 = 28 years old. The answer is 28.\nQuestion 1.4: Now we can answer the question: How old is Kody?\nAnswer 1.4: She is currently 28 + 4 = 32 years old. The answer is 32.\n\nQuestion 2: On a moonless night, three fireflies danced in the evening breeze. They were joined by four less than a dozen more fireflies before two of the fireflies flew away. How many fireflies remained?\nQuestion 2.1: How many fireflies joined?\nAnswer 2.1: The fireflies were joined by four less than a dozen more fireflies, which are 12 - 4 = 8 fireflies. The answer is 8.\nQuestion 2.2: Now we can answer the question: How many fireflies remained?\nAnswer 2.2: Three fireflies were dancing originally. They were joined by 8 fireflies before two of them flew away. So there were 3 + 8 - 2 = 9 remaining. The answer is 9.\n\nQuestion 3: Ali has four $10 bills and six $20 bills that he saved after working for Mr. James on his farm. Ali gives her sister half of the total money he has and uses 3/5 of the remaining amount of money to buy dinner. Calculate the amount of money he has after buying the dinner.\nQuestion 3.1: How much money does Ali have in total?\nAnswer 3.1: Ali has four $10 bills and six $20 bills. So he has 4 * 10 + 6 * 20 = 160 dollars. The answer is 160.\nQuestion 3.2: How much money does Ali give to his sister?\nAnswer 3.2: Ali gives half of the total money he has to his sister. So he gives 160 / 2 = 80 dollars to his sister. The answer is 80.\nQuestion 3.3: How much money does Ali have after giving his sister the money?\nAnswer 3.3: After giving his sister the money, Ali has 160 - 80 = 80 dollars left. The answer is 80.\nQuestion 3.4: How much money does Ali use to buy dinner?\nAnswer 3.4: Ali uses 3/5 of the remaining amount of money to buy dinner. So he uses 80 * 3/5 = 48 dollars to buy dinner. The answer is 48.\nQuestion 3.5: Now we can answer the question: How much money does Ali have after buying the dinner?\nAnswer 3.5: After buying the dinner, Ali has 80 - 48 = 32 dollars left. The answer is 32.\n\nQuestion 4: A car is driving through a tunnel with many turns. After a while, the car must travel through a ring that requires a total of 4 right-hand turns. After the 1st turn, it travels 5 meters. After the 2nd turn, it travels 8 meters. After the 3rd turn, it travels a little further and at the 4th turn, it immediately exits the tunnel. If the car has driven a total of 23 meters around the ring, how far did it have to travel after the 3rd turn?\nQuestion 4.1: How far did the car travel except for the 3rd turn?\nAnswer 4.1: It travels 5 meters after the 1st, 8 meters after the 2nd, and 0 meters after the 4th turn. It's a total of 5 + 8 + 0 = 13 meters. The answer is 13.\nQuestion 4.2: Now we can answer the question: How far did the car have to travel after the 3rd turn?\nAnswer 4.2: The car has driven a total of 23 meters around the ring. It travels 13 meters except for the 3rd turn. So it has to travel 23 - 13 = 10 meters after the 3rd turn. The answer is 10.\n\n",
"question_prefix": "Question 5: ",
"subquestion_prefix": "Question 5.{}:",
"overall_question_prefix": "Now we can answer the question:",
"answer_prefix": "Answer 5.{}:",
"index": 5
}

0 comments on commit 1f97c7e

Please sign in to comment.