# Setup

In [2]:
from dotenv import load_dotenv

load_dotenv()

True

In [3]:
import requests
import wikienv
import wrappers

env = wikienv.WikiEnv()
env = wrappers.FeverWrapper(env, split="dev")
env = wrappers.LoggingWrapper(env)

def step(env, action):
    attempts = 0
    while attempts < 10:
        try:
            return env.step(action)
        except requests.exceptions.Timeout:
            attempts += 1

# ReAct

In [4]:
import json

folder = './prompts/'
prompt_file = 'fever.json'
with open(folder + prompt_file, 'r') as f:
    prompt_dict = json.load(f)

webthink_prompt = prompt_dict['webthink_simple3']

def webthink(llm_func, idx=None, prompt=webthink_prompt, to_print=True):
    question = env.reset(idx=idx)
    if to_print:
        print(idx, question)
    prompt += question + "\n"
    n_calls, n_badcalls = 0, 0
    for i in range(1, 8):
        n_calls += 1
        thought_action = llm_func(prompt + f"Thought {i}:", stop=[f"\nObservation {i}:"])
        try:
            thought, action = thought_action.strip().split(f"\nAction {i}: ")
        except:
            print('ohh...', thought_action)
            n_badcalls += 1
            n_calls += 1
            thought = thought_action.strip().split('\n')[0]
            new_stop = f"\n"
            action = llm_func(prompt + f"Thought {i}: {thought}\nAction {i}:", stop=[new_stop]).strip()
            action.replace(new_stop, "\n")
        obs, r, done, info = step(env, action[0].lower() + action[1:])
        obs = obs.replace('\\n', '')
        step_str = f"Thought {i}: {thought}\nAction {i}: {action}\nObservation {i}: {obs}\n"
        prompt += step_str
        if to_print:
            print(step_str)
        if done:
            break
    if not done:
        obs, r, done, info = step(env, "finish[]")
    if to_print:
        print(info, '\n')
    info.update({'n_calls': n_calls, 'n_badcalls': n_badcalls, 'traj': prompt})
    return r, info

In [33]:
import random
from tqdm import tqdm
from llms import (GPT3, GPT4, Claude3Opus)

DEBUG = False
num_claims = 100
llms = {"GPT-4": GPT4}

idxs = list(range(7405))
random.Random(0).shuffle(idxs)
for llm_name, llm_func in llms.items():
    rewards = []
    infos = []
    pbar = tqdm(idxs[:num_claims])
    for i in pbar:
        r, info = webthink(llm_func, i, to_print=DEBUG)
        rewards.append(info['em'])
        infos.append(info)
        running_acc = sum(rewards)/len(rewards)
        desc_str = f"LLM: {llm_name}, Accuracy: {running_acc:.3f}"
        pbar.set_description(desc_str)

  logger.warn(
  logger.warn(
  0%|          | 0/100 [00:07<?, ?it/s]


KeyboardInterrupt: 

In [36]:
def webthink_cot(llm_func, idx=None, prompt=webthink_prompt, to_print=True):
    question = env.reset(idx=idx)
    if to_print:
        print(idx, question)

    prompt += "\n" + question + "\n"
    response = llm_func(prompt, stop="Action:")
    bad_call = False
    
    try:
        thought, ans = response.split("Answer:")
    except:
        print(f"Failed to produce answer on {idx}\n{prompt}{response}")
        thought, ans = "", "NOT ENOUGH INFO"
        bad_call = True

    obs, r , done, info = step(env, f"finish[{ans.strip()}]")
    info["bad_call"] = bad_call
    
    if to_print:
        print(info, '\n')

    return r, info

In [37]:
# Chain of Thought

import random
from tqdm import tqdm
from llms import (GPT3, GPT4, Claude3Opus)

DEBUG = False
num_claims = 100
llms = {"GPT-4": GPT4}

idxs = list(range(7405))
random.Random(0).shuffle(idxs)
for llm_name, llm_func in llms.items():
    rewards = []
    infos = []
    bad_calls = 0
    pbar = tqdm(idxs[:num_claims])
    for i in pbar:
        r, info = webthink_cot(llm_func, i, to_print=DEBUG, prompt=prompt_dict["cotqa_simple3"])
        rewards.append(info['em'])
        infos.append(info)
        bad_calls += info["bad_call"]
        running_acc = sum(rewards)/len(rewards)
        desc_str = f"LLM: {llm_name}, Accuracy: {running_acc:.3f}"
        pbar.set_description(desc_str)

    print(f"Bad calls: {bad_calls}")

  0%|          | 0/3 [00:00<?, ?it/s]

  logger.warn(
  logger.warn(
LLM: GPT-4, Accuracy: 1.000: 100%|██████████| 3/3 [00:05<00:00,  1.78s/it]
