In [1]:
import os
import json
import itertools
import argparse
import numpy as np
from functools import partial
import sys
sys.path.append('../')
sys.path.append('../../')
from models import gpt, gpt_usage
# from model_llama import llama, llama_usage
import model_llama
from tasks import get_task # get_task is a function defined in tasks/__init__.py, where it imports a task class from e.g.: tasks/text.py and calls a constructor for that class to create an object, and returns it. 
from run import get_value, get_values, get_votes, get_proposals, get_samples




  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%%time

# Replaced with llama
global LLM 

# load llama --- it can take a few minutes
LLM = model_llama.LLM(model_name='llama-7B')

Loading checkpoint shards: 100%|██████████| 2/2 [00:06<00:00,  3.07s/it]

CPU times: user 1min 40s, sys: 6.35 s, total: 1min 46s
Wall time: 1min 24s





In [3]:
def get_value(task, x, y, n_evaluate_sample, cache_value=True):
    # breakpoint()
    value_prompt = task.value_prompt_wrap(x, y)

    if cache_value and value_prompt in task.value_cache:
        return task.value_cache[value_prompt]
    # value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)

    # Replaced with llama
    value_outputs = LLM.llama(value_prompt, max_tokens = 100, do_sample = False, beams = n_evaluate_sample, n= n_evaluate_sample)

    value = task.value_outputs_unwrap(x, y, value_outputs)
    if cache_value:
        task.value_cache[value_prompt] = value

    return value

def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
    # breakpoint()
    values = []
    local_value_cache = {}
    for y in ys:  # each partial output
        # breakpoint()
        if y in local_value_cache:  # avoid duplicate candidates
            value = 0
        else:    
            value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
            
            local_value_cache[y] = value
        values.append(value)
    return values

def get_votes(task, x, ys, n_evaluate_sample):
    # breakpoint()
    vote_prompt = task.vote_prompt_wrap(x, ys)
    # vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None)

    # Replaced with llama
    vote_outputs = LLM.llama(vote_prompt, max_tokens = 100, do_sample = False, beams = n_evaluate_sample, n = n_evaluate_sample)


    values = task.vote_outputs_unwrap(vote_outputs, len(ys))

    return values

def get_proposals(task, x, y): 
    # breakpoint()
    propose_prompt = task.propose_prompt_wrap(x, y) 
    # proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')

    # Replaced with llama
    proposals = LLM.llama(propose_prompt, max_tokens = 100, do_sample = False, beams = 1, n= 1)[0].split('\n')

    return [y + _ + '\n' for _ in proposals]

# Use wrapped prompts to generate new samples from LLM
# TODO: add support for other sampling methods
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
    # breakpoint()
    if prompt_sample == 'standard':
        prompt = task.standard_prompt_wrap(x, y)
    elif prompt_sample == 'cot':
        prompt = task.cot_prompt_wrap(x, y)
    else:
        raise ValueError(f'prompt_sample {prompt_sample} not recognized')
    # samples = gpt(prompt, n=n_generate_sample, stop=stop)

    # Replaced with llama
    samples = LLM.llama(prompt, max_tokens = 100, do_sample = False, beams = n_generate_sample, n = n_generate_sample)

    return [y + _ for _ in samples]

def solve(method_generate, n_generate_sample,
          prompt_sample, method_evaluate,
          method_select, n_select_sample,
          task, idx, to_print=True):
    x = task.get_input(idx)  # p: '4 5 6 10' - from 24.csv, read as a pandas frame, extracting 'Puzzles' column, and then indexing into the 900th puzzle
    ys = [''] 
    infos = []

    # Breadth of tree in bfs ToT is set using cli: --n_generate_sample
    # Height of tree in ToT is set using task.steps in their respective tasks/{file}.py files
    for step in range(task.steps): # p: (task.steps = 4 for game24.py) - Set manually in task/{files}.py - e.g., task.steps for game24.py is 4 for 4 operations; text.py is 2.; crossword.py is 10 steps.
        # breakpoint()
        # generation
        if method_generate == 'sample':
            new_ys = [get_samples(task, x, y, n_generate_sample, prompt_sample=prompt_sample, stop=task.stops[step]) for y in ys]
        elif method_generate == 'propose':
            new_ys = [get_proposals(task, x, y) for y in ys]
        new_ys = list(itertools.chain(*new_ys)) # itertools.chain takes iterables and convert to one iterable
        ids = list(range(len(new_ys)))
        # breakpoint()
        # evaluation
        if method_evaluate == 'vote':
            values = get_votes(task, x, new_ys, n_evaluate_sample)
        elif method_evaluate == 'value':
            values = get_values(task, x, new_ys, n_evaluate_sample)

        # breakpoint()
        # selection - bfs/ dfs are greedy - essentially, based on the values in evaluation, 
        # For greedy, we select the top n_select_sample 
        # For sample, we select n_select_sample based on the probability distribution of the values, 
        # where we fix the size of output: because 'size' argument is the output shape of random samples of numpy array.
        if method_select == 'sample':
            ps = np.array(values) / sum(values) # Convert 'values' assigned to each response to probability distribution
            select_ids = np.random.choice(ids, size=n_select_sample, p=ps).tolist() # Randomly select n_select_sample for each ys identified by their ids, based on the probability distribution ps (which corresponds to each id/ ys)
        elif method_select == 'greedy':
            select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:n_select_sample]
        select_new_ys = [new_ys[select_id] for select_id in select_ids] # using the filtered identifier ids (select_ids), select the corresponding new_ys, and assign to select_new_ys

        # breakpoint()
        # log
        if to_print: 
            # Sort the values and new_ys based on the values
            sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
            print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n')

        # Append the information of each step to the json file 
        infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})
        ys = select_new_ys
    
    # breakpoint()
    if to_print: 
        print(ys)

    return ys, {'steps': infos}

def naive_solve(n_generate_sample, prompt_sample, task, idx, to_print=True):
    # breakpoint()
    x = task.get_input(idx)  # input
    ys = get_samples(task, x, '', n_generate_sample, prompt_sample, stop=None) # Get generated output from LLM 
    return ys, {}

In [4]:
backend = 'llama-7B'
temperature = 0.7
naive_run = False
task = 'game24' # 'game24' | 'text' | 'crosswords'
task_file_path = '24.csv'
task_start_index = 100
task_end_index = 101
method_generate = 'propose'
method_evaluate = 'value'
method_select = 'greedy'
n_evaluate_sample = 3
n_select_sample = 5
n_generate_sample = 100
prompt_sample = 'cot'

In [5]:

# Ensures functions invoked using 'task' - an object of a Task class - are from their respective class where they are defined.
task = get_task(task, task_file_path) # returns a task class object (e.g., Game24Task class in tasks/game24.py) returned by get_task() which is imported as a function from __init__.py, which takes in args.task (identifier of task entered in the cli) and returns an instantiated object of a task class e.g., Game24Task class obj in tasks/game24.py
logs, cnt_avg, cnt_any = [], 0, 0

# breakpoint()
if naive_run: # create new directory and file name to store generated data
    file = f'logs/{task}/{backend}_{temperature}_naive_{prompt_sample}_sample_{n_generate_sample}_start{task_start_index}_end{task_end_index}.json'
else:
    file = f'logs/{task}/{backend}_{temperature}_{method_generate}{n_generate_sample}_{method_evaluate}{n_evaluate_sample}_{method_select}{n_select_sample}_start{task_start_index}_end{task_end_index}.json'
os.makedirs(os.path.dirname(file), exist_ok=True)

for i in range(task_start_index, task_end_index):

    # breakpoint()
    # solve: choosing between standard prompting, CoT, ToT
    if naive_run: # naive run happens to all standard.* and cot prompts,
        ys, info = naive_solve(n_generate_sample, prompt_sample, task, idx=i, to_print=True)
    else:
        ys, info = solve(method_generate, n_generate_sample,
          prompt_sample, method_evaluate,
          method_select, n_select_sample,
          task, idx=i, to_print=True)

    # Appends a dictionary to logs 
    # log 
    infos = [task.test_output(i, y) for y in ys] # test_output() for each task are defined in ./task/* 

   # Replaced with llama_usage
    info.update({'idx': i, 'ys': ys, 'infos': infos, 'usage_so_far': LLM.llama_usage(backend)})

    logs.append(info)
    with open(file, 'w') as f:
        json.dump(logs, f, indent=4)

    # log main metric
    accs = [info['r'] for info in infos]
    cnt_avg += sum(accs) / len(accs)
    cnt_any += any(accs)
    print(i, 'sum(accs)', sum(accs), 'cnt_avg', cnt_avg, 'cnt_any', cnt_any, '\n')

n = task_end_index - task_start_index
print(cnt_avg / n, cnt_any / n)

# print('usage_so_far', gpt_usage(args.backend))

# Replaced with llama_usage
print('usage_so_far', LLM.llama_usage(backend))

file 24.csv
-- new_ys --: ('4 + 5 = 9 (left: 9 11 12)\n', '5 - 4 = 1 (left: 1 5 11)\n', '11 + 1 = 12 (left: 12 1 11)\n', '12 - 1 = 11 (left: 11 12 1)\n', '11 /  2 = 5.5 (left: 5.5\n')
-- sol values --: (0.002, 0.0, 0.0, 0.0, 0.0)
-- choices --: ['4 + 5 = 9 (left: 9 11 12)\n', '5 - 4 = 1 (left: 1 5 11)\n', '11 + 1 = 12 (left: 12 1 11)\n', '12 - 1 = 11 (left: 11 12 1)\n', '11 /  2 = 5.5 (left: 5.5\n']

-- new_ys --: ('5 - 4 = 1 (left: 1 5 11)\n11 + 1 = 12 (left: 12 11)\n', '5 - 4 = 1 (left: 1 5 11)\n11 - 1 = 10 (left: 10 11)\n', '12 - 1 = 11 (left: 11 12 1)\n11 + 12 = 23 (left: 12 23 1)\n', '4 + 5 = 9 (left: 9 11 12)\n11\n', '5 - 4 = 1 (left: 1 5 11)\n11 - 1 =\n', '11 + 1 = 12 (left: 12 1 11)\n12 /  1 = 12 (\n', '12 - 1 = 11 (left: 11 12 1)\n11 / 12 =\n', '11 /  2 = 5.5 (left: 5.5\n5.5 + 5.5 = 11.5 (left: 11.5 5.5)\n', '11 /  2 = 5.5 (left: 5.5\n1 +\n', '4 + 5 = 9 (left: 9 11 12)\n9 + 11 = 20 (left: 12 20)\n', '4 + 5 = 9 (left: 9 11 12)\n11 - 9 = 2 (left: 20 2)\n', '4 + 5 = 9 (left: 9 11

TODO: formatting in jupyter notebook