In [1]:
import joblib
from react_cls import ReactAgent
from mocks import DocStoreExplorerMock, LLMMock

In [3]:
def summarize_trial(agents):
    correct = [a for a in agents if a.is_correct()]
    halted = [a for a in agents if a.is_halted()]
    incorrect = [a for a in agents if a.is_finished() and not a.is_correct()]
    return correct, incorrect, halted

def log_trial(agents, trial_n):
    correct, incorrect, halted = summarize_trial(agents)

    log = f"""
########################################
BEGIN TRIAL {trial_n}
Trial summary: Correct: {len(correct)}, Incorrect: {len(incorrect)}, Halted: {len(halted)}
#######################################
"""

    log += '------------- BEGIN CORRECT AGENTS -------------\n\n'
    for agent in correct:
        log += f'Question: {agent.question}{agent.scratchpad}\nCorrect answer: {agent.key}\n\n'

    log += '------------- BEGIN INCORRECT AGENTS -----------\n\n'
    for agent in incorrect:
        log += f'Question: {agent.question}{agent.scratchpad}\nCorrect answer: {agent.key}\n\n'

    log += '------------- BEGIN HALTED AGENTS --------------\n\n'
    for agent in halted:
        log += f'Question: {agent.question}{agent.scratchpad}\nCorrect answer: {agent.key}\n\n'

    return log

In [4]:
hotpot = joblib.load('data/hotpot-qa-distractor-sample.joblib').reset_index(drop = True)

In [5]:
agents = [ReactAgent(row['question'], row['answer']) for _, row in hotpot.iterrows()]

In [6]:
trial = 0
log = ''

In [21]:
q = 0

In [22]:
agents_to_run = [a for a in agents if not a.is_correct()]

while q < len(agents_to_run):
    print(f'Trial: {trial} ({q}/{len(agents_to_run)})')
    agents_to_run[q].run()
    q += 1

trial += 1

log += log_trial(agents, trial)
correct, incorrect, halted = summarize_trial(agents)
print(f'Finished Trial {trial}, Correct: {len(correct)}, Incorrect: {len(incorrect)}, Halted: {len(halted)}')

Trial: 4 (0/66)
Trial: 4 (1/66)
Trial: 4 (2/66)
Trial: 4 (3/66)
Trial: 4 (4/66)
Trial: 4 (5/66)
Trial: 4 (6/66)
Trial: 4 (7/66)
Trial: 4 (8/66)
Trial: 4 (9/66)
Trial: 4 (10/66)
Trial: 4 (11/66)
Trial: 4 (12/66)
Trial: 4 (13/66)
Trial: 4 (14/66)
Trial: 4 (15/66)
Trial: 4 (16/66)
Trial: 4 (17/66)
Trial: 4 (18/66)
Trial: 4 (19/66)
Trial: 4 (20/66)
Trial: 4 (21/66)
Trial: 4 (22/66)
Trial: 4 (23/66)
Trial: 4 (24/66)
Trial: 4 (25/66)
Trial: 4 (26/66)
Trial: 4 (27/66)
Trial: 4 (28/66)
Trial: 4 (29/66)
Trial: 4 (30/66)
Trial: 4 (31/66)
Trial: 4 (32/66)
Trial: 4 (33/66)
Trial: 4 (34/66)
Trial: 4 (35/66)
Trial: 4 (36/66)
Trial: 4 (37/66)
Trial: 4 (38/66)
Trial: 4 (39/66)
Trial: 4 (40/66)
Trial: 4 (41/66)
Trial: 4 (42/66)
Trial: 4 (43/66)
Trial: 4 (44/66)
Trial: 4 (45/66)
Trial: 4 (46/66)
Trial: 4 (47/66)
Trial: 4 (48/66)
Trial: 4 (49/66)
Trial: 4 (50/66)
Trial: 4 (51/66)
Trial: 4 (52/66)
Trial: 4 (53/66)
Trial: 4 (54/66)
Trial: 4 (55/66)
Trial: 4 (56/66)
Trial: 4 (57/66)
Trial: 4 (58/66)
Trial: 

In [23]:
with open('output/base_react/100_questions_5_trials.txt', 'w') as f:
    f.write(log)

In [26]:
dicts = [dict(a.__dict__) for a in agents]
for d in dicts:
    for k, v in d.items():
        d[k] = str(v)

joblib.dump(dicts, 'output/base_react_dicts.joblib')

['output/base_react_dicts.joblib']