In [13]:
import os
import json
import glob
import re
from tqdm import tqdm
import torch
import openai
import itertools
import random
from threading import Thread, BoundedSemaphore, Lock

openai.api_key = input("Enter your OpenAI API key: ")

In [2]:
option_map = {
    1: '(A)',
    2: '(B)',
    3: '(C)',
    4: '(D)',
    5: '(E)',
    6: '(F)',
    7: '(G)',
    8: '(H)',
    9: '(I)',
    10: '(J)',
    'A': '(A)',
    'B': '(B)',
    'C': '(C)',
    'D': '(D)',
    'E': '(E)',
    'F': '(F)',
    'G': '(G)',
    'H': '(H)',
    'I': '(I)',
    'J': '(J)',
}

def get_option(ans):
    if ans in option_map:
        return option_map[ans]
    return ans

In [3]:
def generate(prompt, max_tokens=256, temperature=0.0, model="gpt-3.5-turbo"):
    if model in ["gpt-3.5-turbo", "gpt-4"]:
        params = {
            "model": model,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "messages": [{"role": "user", "content": prompt}]
        }
        for retry in range(3):
            try:
                return openai.ChatCompletion.create(**params)["choices"][0]["message"]["content"]
            except:
                pass
        raise Exception("Failed to generate")
    
    # For older models, use the completion API with max_tokens=1024
    params = {
        "model": model,
        "max_tokens": min(max_tokens, 1024),
        "temperature": temperature,
        "prompt": prompt
    }
    for retry in range(3):
        try:
            return openai.Completion.create(**params)["choices"][0]["text"]
        except:
            pass

In [4]:
def get_task(task):
    with open(f"bbh/{task}.json") as f:
        data = json.load(f)

    # For dyck languages task, we need remove the spaces in the inputs to avoid unnecessary issues with tokenization
    if task == "dyck_languages":
        for example in data["examples"]:
            desc, input = example["input"].split("Input: ")
            input = input.replace(" ", "")
            example["input"] = f"{desc}Input: {input}"
            example["target"] = example["target"].replace(" ", "")
    
    train = []
    val = []
    test = []
    for index in range(len(data['examples'])):
        sample = {
            'question': data['examples'][index]['input'],
            'answer': data['examples'][index]['target'],
        }
        if index < 5:
            train.append(sample)
        elif index < 10:
            val.append(sample)
        else:
            test.append(sample)
    return train, val, test

In [5]:
def get_tool(task):
    # torch.load(f"tools/{task}.pt")
    message = json.load(open(f"tools/{task}.json"))
    wrapper = message[-1]["content"]
    func = re.findall(r"```python\n(.*?)\n```", wrapper, re.DOTALL)[0]
    return wrapper, func

In [6]:
def get_ans(sample, model="gpt-3.5-turbo", print_func_call=False, task=None):
    func_call = generate(wrapper + "\n\nQuestion: " + sample["question"], model=model)
    func_call = re.findall(r"```python\n(.*?)\n```", func_call, re.DOTALL)[0]

    if print_func_call:
        print("=== Function call ===")
        print(func_call)
        print("=====================")

    exec(func+"\n"+func_call, globals())
    answer_var = re.findall(r"(ans.*?) =", func_call, re.DOTALL)[-1]
    ans = globals()[answer_var]

    if "Options:" in sample["question"] and ans not in option_map.keys() and ans not in option_map.values():
        options = re.findall(r"Options:(.*)", sample["question"], re.DOTALL)[0].strip().split("\n")
        for option in options:
            if ans in option:
                ans = option.split(" ")[0]
                break

    if task == "schedule_meeting":
        if ans is None:
            ans = "No time slot works."
        elif isinstance(ans, list) or isinstance(ans, tuple):
            ans = f"{ans[0]} - {ans[1]}"
    return get_option(ans)

### Test *Tool User*

In [7]:
task = "word_sorting"
#"tracking_shuffled_objects_five_objects"
#"tracking_shuffled_objects_seven_objects"
#"logical_deduction_five_objects"
#"logical_deduction_seven_objects"
#"dyck_languages"
#"word_sorting"
#"chinese_remainder_theorem"
train, val, test = get_task(task)
wrapper, func = get_tool(task)

In [None]:
# Test tool using
model = "gpt-3.5-turbo"

tot = 0
correct = 0

pool = BoundedSemaphore(8)
lock = Lock()
pbar = tqdm(test)

def run(sample):
    global tot, correct
    try:
        ans = get_ans(sample, model=model, task=task)
    except:
        ans = "Error"
    lock.acquire()
    tot += 1
    if str(ans) == str(sample["answer"]):
        correct += 1
    acc = correct / tot
    pbar.set_description(f"Accuracy: {acc:.4f}")
    lock.release()
    pool.release()

threads = []
for sample in pbar:
    pool.acquire()
    thread = Thread(target=run, args=(sample,))
    threads.append(thread)
    thread.start()

### Test *Dispatcher*

In [101]:
tasks = ["word_sorting", "dyck_languages", "chinese_remainder_theorem", "logical_deduction_five_objects", "tracking_shuffled_objects_five_objects", "schedule_meeting"]

wrappers = {}
funcs = {}
for task in tasks:
    wrapper, func = get_tool(task)
    wrappers[task] = wrapper
    funcs[task] = func

In [126]:
# Randomly split the tasks into 2 groups
train_tasks = random.sample(tasks, 4)
test_tasks = random.sample(tasks, 4)
# Overlap
print(set(train_tasks) & set(test_tasks))

{'schedule_meeting', 'logical_deduction_five_objects'}


In [127]:
router = "Here are several questions from different tasks:\n\n"

for task in train_tasks:
    wrapper = wrappers[task]
    question = "Question: " + wrapper.split('Question: ')[1].split('Solution:')[0].split('Answer: ')[0].strip()
    router += f"Task: {task}\n\n{question}\n\n" + "="*3 + "\n\n"

In [128]:
dataset = []
for task in test_tasks:
    t, v, te = get_task(task)
    subset = te
    for sample in subset:
        sample["task"] = task
    dataset += subset

In [129]:
random.shuffle(dataset)

In [130]:
router_usage_template = "Classify the following question into one task (classify as unknown if cannot be classified into any existing task):\n\n{question}\n\nReply in the format:\nTask: {{task}}"

In [None]:
model = "gpt-3.5-turbo"

tot = 0
correct = 0
task_correct = 0

pool = BoundedSemaphore(8)
lock = Lock()
pbar = tqdm(dataset[:100])

def run(sample):
    global tot, correct, task_correct
    try:
        response = generate(router+router_usage_template.format(question=sample["question"]))
        task = re.findall(r"Task: (.*)", response)[0]
    except:
        task = "unknown"
    if task not in tasks and task != "unknown":
        task = "unknown"
    lock.acquire()
    tot += 1
    if sample["task"] in train_tasks and task == sample["task"] or sample["task"] not in train_tasks and task == "unknown":
        task_correct += 1
    else:
        print(sample["question"])
        print(task)
        print("="*20)
    task_acc = task_correct / tot
    pbar.set_description(f"Task accuracy: {task_acc:.4f}")
    lock.release()
    pool.release()

threads = []
for sample in pbar:
    pool.acquire()
    thread = Thread(target=run, args=(sample,))
    threads.append(thread)
    thread.start()