In [8]:
import joblib
from react_cls import ReactReflectAgent, format_reflections
from mocks import DocStoreExplorerMock, LLMMock

In [15]:
def summarize_trial(agents):
    correct = [a for a in agents if a.is_correct()]
    incorrect = [a for a in agents if a.is_finished() and not a.is_correct()]
    return correct, incorrect

def remove_fewshot(prompt: str) -> str:
    prefix = prompt.split('Here are some examples:')[0]
    suffix = prompt.split('(END OF EXAMPLES)')[1]
    return prefix.strip('\n').strip() +'\n' +  suffix.strip('\n').strip()

def log_trial(agents, trial_n):
    correct, incorrect = summarize_trial(agents)

    log = f"""
########################################
BEGIN TRIAL {trial_n}
Trial summary: Correct: {len(correct)}, Incorrect: {len(incorrect)}
#######################################
"""

    log += '------------- BEGIN CORRECT AGENTS -------------\n\n'
    for agent in correct:
        log += remove_fewshot(agent._build_agent_prompt()) + f'\nCorrect answer: {agent.key}\n\n'

    log += '------------- BEGIN INCORRECT AGENTS -----------\n\n'
    for agent in incorrect:
        log += remove_fewshot(agent._build_agent_prompt()) + f'\nCorrect answer: {agent.key}\n\n'

    return log


In [11]:
hotpot = joblib.load('data/hotpot-qa-distractor-sample.joblib').reset_index(drop = True)

In [12]:
agents = [ReactReflectAgent(row['question'], row['answer']) for _, row in hotpot.iterrows()]

In [13]:
trial = 0
log = ''
last_correct = 0 

In [None]:
for agent in [a for a in agents if not a.is_correct()]:
        agent.run(reflect_strategy='last_attempt')
        print(f'Answer: {agent.key}')
trial += 1
log += log_trial(agents, trial)
correct, incorrect = summarize_trial(agents)
print(f'Finished Trial {trial}, Correct: {len(correct)}, Incorrect: {len(incorrect)}')

In [17]:
dicts = [dict(a.__dict__) for a in incorrect]
for d in dicts:
    for k, v in d.items():
        d[k] = str(v)

joblib.dump(dicts, 'output/last_trial_react/react_incorrect_dicts_trial_0.joblib')

['output/last_trial_react/react_incorrect_dicts_trial_0.joblib']

In [None]:
while last_correct != correct:
    last_correct, _ = summarize_trial(agents)
    for agent in [a for a in agents if not a.is_correct()]:
        agent.run(reflect_strategy='last_attempt')
        print(f'Answer: {agent.key}')
    trial += 1
    log += log_trial(agents, trial)
    correct, incorrect = summarize_trial(agents)
    print(f'Finished Trial {trial}, Correct: {len(correct)}, Incorrect: {len(incorrect)}')

In [None]:
for agent in [a for a in agents if not a.is_correct()]:
        agent.run(reflect_strategy='last_attempt + reflexion')
        print(f'Answer: {agent.key}')
trial += 1
log += log_trial(agents, trial)
correct, incorrect = summarize_trial(agents)
print(f'Finished Trial {trial}, Correct: {len(correct)}, Incorrect: {len(incorrect)}')

In [18]:
with open('output/last_trial_react/100_questions_5_trials.txt', 'w') as f:
    f.write(log)

In [18]:
dicts = [dict(a.__dict__) for a in correct]
for d in dicts:
    for k, v in d.items():
        d[k] = str(v)

joblib.dump(dicts, 'output/reflect/react_reflect_50_correct_dicts.joblib')

['output/reflect/react_reflect_50_correct_dicts.joblib']