# Setup

In [1]:
import os
from openai import OpenAI
from dotenv import load_dotenv

# TODO:
# Write LLM interface and other LLMs
 
load_dotenv()

def llm(prompt, stop=["\n"]):
    # response = openai.Completion.create(
    #   model="text-davinci-002",
    #   prompt=prompt,
    #   temperature=0,
    #   max_tokens=100,
    #   top_p=1,
    #   frequency_penalty=0.0,
    #   presence_penalty=0.0,
    #   stop=stop
    # )

    client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])

    response = client.completions.create(
      model="gpt-3.5-turbo-instruct",
      prompt=prompt,
      temperature=0,
      max_tokens=100,
      top_p=1,
      frequency_penalty=0.0,
      presence_penalty=0.0,
      stop=stop
    )
    return response.choices[0].text

In [2]:
import requests
import wikienv
import wrappers

env = wikienv.WikiEnv()
env = wrappers.FeverWrapper(env, split="dev")
env = wrappers.LoggingWrapper(env)

def step(env, action):
    attempts = 0
    while attempts < 10:
        try:
            return env.step(action)
        except requests.exceptions.Timeout:
            attempts += 1

# ReAct

In [3]:
import json

folder = './prompts/'
prompt_file = 'fever.json'
with open(folder + prompt_file, 'r') as f:
    prompt_dict = json.load(f)

webthink_prompt = prompt_dict['webthink_simple3']

def webthink(idx=None, prompt=webthink_prompt, to_print=True):
    question = env.reset(idx=idx)
    if to_print:
        print(idx, question)
    prompt += question + "\n"
    n_calls, n_badcalls = 0, 0
    for i in range(1, 8):
        n_calls += 1
        thought_action = llm(prompt + f"Thought {i}:", stop=[f"\nObservation {i}:"])
        try:
            thought, action = thought_action.strip().split(f"\nAction {i}: ")
        except:
            print('ohh...', thought_action)
            n_badcalls += 1
            n_calls += 1
            thought = thought_action.strip().split('\n')[0]
            action = llm(prompt + f"Thought {i}: {thought}\nAction {i}:", stop=[f"\n"]).strip()
        obs, r, done, info = step(env, action[0].lower() + action[1:])
        obs = obs.replace('\\n', '')
        step_str = f"Thought {i}: {thought}\nAction {i}: {action}\nObservation {i}: {obs}\n"
        prompt += step_str
        if to_print:
            print(step_str)
        if done:
            break
    if not done:
        obs, r, done, info = step(env, "finish[]")
    if to_print:
        print(info, '\n')
    info.update({'n_calls': n_calls, 'n_badcalls': n_badcalls, 'traj': prompt})
    return r, info

In [5]:
import random
import time
from tqdm import tqdm

DEBUG = False

idxs = list(range(7405))
random.Random(233).shuffle(idxs)

rewards = []
infos = []
pbar = tqdm(idxs[:500])
for i in pbar:
    r, info = webthink(i, to_print=DEBUG)
    rewards.append(info['em'])
    infos.append(info)
    running_acc = sum(rewards)/len(rewards)
    desc_str = f"Accuracy: {running_acc:.3f}"
    pbar.set_description(desc_str)

Accuracy: 0.606:  19%|█▉        | 94/500 [04:06<17:43,  2.62s/it]


KeyboardInterrupt: 