In [19]:
import os
import openai
import dotenv
 

openai.api_key = os.getenv('OPENAI_API_KEY')
model = "text-davinci-003"


def llm(prompt, stop=["\n"]):
  if model == "gpt-4-0314":
    response = openai.ChatCompletion.create(
      model=model,
      messages=prompt,
      temperature=0,
      max_tokens=100,
      top_p=1,
      frequency_penalty=0.0,
      presence_penalty=0.0,
      stop=stop,
    )
    return response['choices'][0]['message']['content']
    
  response = openai.Completion.create(
    model=model,
    prompt=prompt,
    temperature=0,
    max_tokens=100,
    top_p=1,
    frequency_penalty=0.0,
    presence_penalty=0.0,
    stop=stop
  )
  return response["choices"][0]["text"]

In [20]:
from consts import * 
from consts_actonly import *
from consts_thoughtonly import *
from consts_standard import *
import sqlenv, wrappers
import requests

env = sqlenv.SQLEnv()
env = wrappers.WikiSQLWrapper(env)
env = wrappers.LoggingWrapper(env)

def step(env, action):
    attempts = 0
    while attempts < 5:
        try:
            return env.step(action)
        except requests.exceptions.Timeout:
            attempts += 1

Found cached dataset wikisql (/Users/byronzhang/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)
100%|██████████| 3/3 [00:00<00:00, 97.87it/s]


In [22]:
instruction = """You are given the header of a SQL table named df and a question you need to find an answer to. Use logical step-by-step Thought, Action, and Observation to get to the correct answer. In your actions, you can issue a SQL query on the table to fetch the next observation. Return your answer in finish[<answer>].
"""
instruction_cot = """You are given the header of a SQL table named df and a question you need to find an answer to. Use logical step-by-step Thought to get to the correct answer. Return your answer in finish[<answer>].
"""

# change the index to the mode you want to be in
modes = ["react", "act_only", "thought_only", "standard"][0]

examples = [EXAMPLE_REGULAR, EXAMPLE_AND, EXAMPLE_COUNT, EXAMPLE_AVG, EXAMPLE_EMPTY]
if modes == "react":
    examples = [EXAMPLE_AND, EXAMPLE_COUNT, EXAMPLE_EMPTY, EXAMPLE_AVG]
    ACT_ONLY = False
elif modes == "act_only":
    examples = [EXAMPLE_REGULAR_A, EXAMPLE_AND_A, EXAMPLE_COUNT_A, EXAMPLE_EMPTY]
    ACT_ONLY = True
elif modes == "thought_only":
    instruction = """You are given the header of a SQL table named df and a question you need to find an answer to. Use logical step-by-step Thought to get to the correct answer. Return your answer in finish[<answer>].
    """
    examples = [EXAMPLE_REGULAR_T, EXAMPLE_AND_T, EXAMPLE_COUNT_T, EXAMPLE_EMPTY]
    ACT_ONLY = False
else:
    instruction = """You are given the header of a SQL table named df and a question you need to find an answer to. Return your answer in finish[<answer>].
    """
    examples = [EXAMPLE_REGULAR_S, EXAMPLE_AND_S, EXAMPLE_COUNT_S, EXAMPLE_EMPTY]
    ACT_ONLY = True
    


wikisql_prompt = instruction
for j, example in enumerate(examples):
    wikisql_prompt += example.format(i=j+1)

def wikisql(idx=1, prompt=wikisql_prompt, to_print=True, act_only=ACT_ONLY, cot=False):
    question = env.reset(idx=idx)
    if to_print:
        print(idx, question)
    prompt += question
    n_calls, n_badcalls = 0, 0
    for i in range(1, 5):
        n_calls += 1
        if act_only:
            thought_action = llm(prompt, stop=[f"\nObservation {i}:"])
        else:
            thought_action = llm(prompt + f"Thought {i}:", stop=[f"\nObservation {i}:"])
        try:
            if act_only:
                thought, action = thought_action.strip().split(f"Action {i}: ")
            else:
                thought, action = thought_action.strip().split(f"\nAction {i}: ")
        except:
            print('ohh...', thought_action)
            n_badcalls += 1
            n_calls += 1
            thought = thought_action.strip().split('\n')[0]
            action = llm(prompt + f"Thought {i}: {thought}\nAction {i}:", stop=[f"\n"]).strip()
        obs, r, done, info = step(env, action[0].lower() + action[1:])
        obs = obs.replace('\\n', '')
        if act_only:
            step_str = f"Action {i}: {action}\nObservation {i}: \n{obs}\n"
        else:
            step_str = f"Thought {i}: {thought}\nAction {i}: {action}\nObservation {i}: \n{obs}\n"
        prompt += step_str
        if to_print:
            print(step_str)
        if done:
            break
    if not done:
        obs, r, done, info = step(env, "finish[]")
    if to_print:
        print(info, '\n')
    info.update({'n_calls': n_calls, 'n_badcalls': n_badcalls, 'traj': prompt})
    return info

In [23]:
infos = []
for i, idx in zip(range(100), RANDOM_QUESTION_INDICES):
    info = wikisql(idx=i, to_print=False, act_only=ACT_ONLY)
    info["question_idx"] = idx
    infos.append(info)

In [28]:
result_name = "./results/results_react_no_simple_003"
import json

with open(f"{result_name}.json", 'w') as f:
    json.dump(infos, f)

In [25]:
import pickle
with open(f"{result_name}.pkl", 'wb') as f:
    pickle.dump(infos, f)

In [29]:
with open(f"{result_name}.pkl", 'rb') as f:
    infos_loaded = pickle.load(f)

In [30]:
# Observe what a single result dictionary looks like
infos_loaded[0]

{'steps': 3,
 'answer': '<name>',
 'n_calls': 3,
 'n_badcalls': 0,
 'traj': 'You are given the header of a SQL table named df and a question you need to find an answer to. Use logical step-by-step Thought, Action, and Observation to get to the correct answer. In your actions, you can issue a SQL query on the table to fetch the next observation. Return your answer in finish[<answer>].\n\nEXAMPLE 1:\nHeader: [\'Scheme\', \'Tariff code\', \'BTs retail price (regulated)\', \'Approx premium\', \'Prefixes\']\nQuestion: [What prefixes are priced at pence per minute, fixed at all times with a premium of 3p/min?]\nThought 1: I need to find the rows in the table where the "Scheme" is "Pence per minute, fixed at all times," and the "Approx premium" is "3p/min".\nAction 1: sql[SELECT * FROM df WHERE UPPER([Scheme]) = UPPER("Pence per minute, fixed at all times") AND UPPER([Approx premium]) = UPPER("3p/min")]\nObservation 1: \n| Scheme                               | Tariff code   | BTs retail pric