In [40]:
import os
import openai
import dotenv
 

openai.api_key = os.getenv('OPENAI_API_KEY')

def llm(prompt, stop=["\n"]):
    response = openai.Completion.create(
      model="text-davinci-002",
      prompt=prompt,
      temperature=0,
      max_tokens=100,
      top_p=1,
      frequency_penalty=0.0,
      presence_penalty=0.0,
      stop=stop
    )
    return response["choices"][0]["text"]

In [2]:
from consts import * 
import sqlenv, wrappers
import requests

env = sqlenv.SQLEnv()
env = wrappers.WikiSQLWrapper(env)
env = wrappers.LoggingWrapper(env)

def step(env, action):
    attempts = 0
    while attempts < 5:
        try:
            return env.step(action)
        except requests.exceptions.Timeout:
            attempts += 1

  from .autonotebook import tqdm as notebook_tqdm
Found cached dataset wikisql (/Users/byronzhang/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)
100%|██████████| 3/3 [00:00<00:00, 209.47it/s]


In [3]:
instruction = """You are given the header of a SQL table named df and a query you need to find an answer to. Use logical step-by-step Thought, Action, and Observation to get to the correct answer. In your actions, you can issue a SQL query on the table to fetch the next observation. Return your answer in finish[<answer>].
"""
examples = [EXAMPLE_REGULAR, EXAMPLE_AND, EXAMPLE_MAX, EXAMPLE_EMPTY]

wikisql_prompt = instruction
for j, example in enumerate(examples):
    wikisql_prompt += example.format(i=j+1)

def wikisql(idx=1, prompt=wikisql_prompt, to_print=True):
    question = env.reset(idx=idx)
    if to_print:
        print(idx, question)
    prompt += question
    n_calls, n_badcalls = 0, 0
    for i in range(1, 5):
        n_calls += 1
        thought_action = llm(prompt + f"Thought {i}:", stop=[f"\nObservation {i}:"])
        try:
            thought, action = thought_action.strip().split(f"\nAction {i}: ")
        except:
            print('ohh...', thought_action)
            n_badcalls += 1
            n_calls += 1
            thought = thought_action.strip().split('\n')[0]
            action = llm(prompt + f"Thought {i}: {thought}\nAction {i}:", stop=[f"\n"]).strip()
        obs, r, done, info = step(env, action[0].lower() + action[1:])
        obs = obs.replace('\\n', '')
        step_str = f"Thought {i}: {thought}\nAction {i}: {action}\nObservation {i}: {obs}\n"
        prompt += step_str
        if to_print:
            print(step_str)
        if done:
            break
    if not done:
        obs, r, done, info = step(env, "finish[]")
    if to_print:
        print(info, '\n')
    info.update({'n_calls': n_calls, 'n_badcalls': n_badcalls, 'traj': prompt})
    return info

In [15]:
question = env.reset(idx=2)
env.step("""sql[SELECT * FROM df WHERE UPPER([Location]) = UPPER("Aspen, USA")]""")

    Season         Date                   Location    Discipline Place
0     2009  20 Nov 2008                Aspen , USA  Giant Slalom   1st
1     2010  12 Dec 2009               Åre , Sweden  Giant Slalom   1st
2     2011  27 Nov 2010                 Aspen, USA  Giant Slalom   1st
3     2011  12 Dec 2010    St.Moritz , Switzerland  Giant Slalom   1st
4     2011  28 Dec 2010        Semmering , Austria  Giant Slalom   1st
5     2012  28 Dec 2011            Lienz , Austria  Giant Slalom   3rd
6     2012  21 Jan 2012   Kranjska Gora , Slovenia  Giant Slalom   1st
7     2012  12 Feb 2012           Soldeu , Andorra  Giant Slalom   1st
8     2013   9 Dec 2012    St. Moritz, Switzerland  Giant Slalom   3rd
9     2013  16 Dec 2012        Courchevel , France  Giant Slalom   3rd
10    2013  28 Dec 2012         Semmering, Austria  Giant Slalom   3rd
11    2013  17 Mar 2013  Lenzerheide , Switzerland  Giant Slalom   2nd


('|   Season | Date        | Location   | Discipline   | Place   |\n|---------:|:------------|:-----------|:-------------|:--------|\n|     2011 | 27 Nov 2010 | Aspen, USA | Giant Slalom | 1st     |',
 0,
 False,
 {'steps': 1, 'answer': None})

In [5]:
infos = []
for i, idx in zip(range(100), RANDOM_QUESTION_INDICES):
    info = wikisql(idx=i, to_print=False)
    info["question_idx"] = idx
    infos.append(info)

In [9]:
import json

with open("results_react.json", 'w') as f:
    json.dump(infos, f)

In [11]:
import pickle
with open("results_react.pkl", 'wb') as f:
    pickle.dump(infos, f)

In [37]:
infos[25]

{'steps': 3,
 'answer': '6',
 'n_calls': 3,
 'n_badcalls': 0,
 'traj': 'You are given the header of a SQL table named df and a query you need to find an answer to. Use logical step-by-step Thought, Action, and Observation to get to the correct answer. In your actions, you can issue a SQL query on the table to fetch the next observation. Return your answer in finish[<answer>].\n\nEXAMPLE 1:\nHeader: [\'Player\', \'No.\', \'Nationality\', \'Position\', \'Years in Toronto\', \'School/Club Team\']\nQuestion: [What is terrence ross\' nationality]\nThought 1: I need to find the information related to terrence ross.\nAction 1: sql[SELECT * FROM df WHERE UPPER([Player]) = UPPER("Terrence Ross")]\nObservation 1: \n| Player        |   No. | Nationality   | Position   | Years in Toronto   | School/Club Team   |\n|:--------------|------:|:--------------|:-----------|:-------------------|:-------------------|\n| Terrence Ross |    31 | United States | Guard      | 2012-present       | Washington   