In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
import json
from time import sleep
import nltk
import numpy as np
import argparse
from langchain.llms import OpenAI
from src.baselines.baseline_utils import load_jsonl
from dotenv import load_dotenv
from types import SimpleNamespace
import asyncio
load_dotenv()

In [None]:
args = {
    'data_dir': '../../data/gsm_data',
    'save_dir': 'models',
    'debug': False,
    'exp_label': 'default',
    'task': 'pot_gsm',
    'model': 'gpt-3.5-turbo',
    'max_tokens': 2048,
    'temperature': 0.0,
}
args['ckpt_path'] = os.path.join(args['save_dir'], args['exp_label'])
args = SimpleNamespace(**args)

In [None]:
logger = baseline_utils.Logger(os.path.join(args.ckpt_path, 'log.txt'))
completed_rounds = 0

In [None]:
async def async_generate_answer(llm, prompt_template, problem):
    inp = prompt_template.format(context = problem['context'])
    # print(inp)
    success = False
    while not success:
      try:
        output = await llm.agenerate([inp])
        print(output)
        success = True
      except Exception as e:
        logger.write(e)
        logger.write(f'API server overloaded. Waiting for 30 seconds...')
        sleep(30)
        continue
    problem['output'] = output.generations[0][0].text
    global completed_rounds
    completed_rounds += 1
    print(f"Completed {completed_rounds} rounds")


In [None]:
async def async_generate_answers(llm, prompt_template, problems):
  '''Generate the answer for the given problem.'''
  outputs = [async_generate_answer(llm, prompt_template, prob) for prob in problems]
  await asyncio.gather(*outputs)


In [None]:
async def gsm_run(prompt_template, llm, data):
    global completed_rounds
    completed_rounds = 0
    problems = [{'context': d['input'], 'target': d['target']} for d in data]
    step = 5
    for i in range(0, len(problems), step):
        await async_generate_answers(llm, prompt_template, problems[i:min(i + step, len(problems))])
        print (f"Completed {i + step} problems")
    return problems

In [None]:
def calc_accuracy(problems):
    return sum([p['correct'] for p in problems]) / len(problems)

In [None]:
def parse_answers(filename, task):
    with open(filename, 'r') as f:
        problems = json.loads(f.read())
    for p in problems:
        p['final_answer'] = baseline_utils.parse_answer(p['output'], task)
        p['correct'] = p['final_answer'] == p['target']
    with open(filename, 'w') as f:
        f.write(json.dumps(problems) + '\n')
    return problems
    

In [None]:
async def gsm_baseline(model, task):
    prompt_template = baseline_utils.create_prompt_template(task)
    llm = OpenAI(
        model_name=model,
        max_tokens=args.max_tokens,
        stop=['\\n\\n', 'A:', 'Q:'],
        temperature=args.temperature,
        openai_api_key = os.getenv('OPEN_AI_API_KEY')
  ) 
    for i in range(3):
        for variant in ['original', 'irc']:
            data = baseline_utils.load_gsm_data(os.path.join(args.data_dir, f'gsmic_mixed_{i}_{variant}.jsonl'))
            if (os.path.exists(os.path.join(args.save_dir, f'gsmic_mixed_{i}_{variant}_output_{model}_{task}.json')) 
                or os.path.exists(os.path.join(args.save_dir, f'hand_gsmic_mixed_{i}_{variant}_output_{model}_{task}.json'))):
                continue
            problems = await gsm_run(prompt_template, llm, data)
            output_file = os.path.join(args.save_dir, f'gsmic_mixed_{i}_{variant}_output_{model}_{task}.json')
            with open(output_file, 'w') as f:
                  f.write(json.dumps(problems) + '\n')
            problems = parse_answers(output_file, task)

            logger.write(f'Accuracy for gsmic_mixed_{i}_{variant} = {calc_accuracy(problems)}')

In [None]:
print(args.model, args.task)
await gsm_baseline(args.model, args.task)

In [None]:
test_question = baseline_utils.load_gsm_data(os.path.join(args.data_dir, f'gsmic_mixed_0_original.jsonl'))[2]

prompt_template = baseline_utils.create_prompt_template('pot_gsm')
llm = OpenAI(
    model_name=args.model, 
    max_tokens=args.max_tokens, 
    stop=['\\n\\n', 'A:', 'Q:'],
    
    temperature=args.temperature,
    openai_api_key = os.getenv('OPEN_AI_API_KEY')
)

In [None]:
print(prompt_template.format(context = test_question['input']))

In [None]:

tp = {**test_question, 'context': test_question['input']}
await async_generate_answers(llm, prompt_template, [tp])

In [None]:
print(baseline_utils.parse_answer(tp['output'], 'pot_gsm'))

In [None]:
def grade(filepath, rewrite = False):
    with open(filepath, 'r') as f:
        problems = json.load(f)
    for p in problems:
        if (p['target'] == p['final_answer']):
            p['correct'] = True
    if (rewrite):
        filepath = os.path.join('/'.join(filepath.split('/')[:-1]), 'hand_' + filepath.split('/')[-1])
    with open(filepath, 'w') as f:
        f.write(json.dumps(problems) + '\n')
    print(calc_accuracy(problems))


In [None]:


for task in ['0cot', '1cot_gsm', 'pot_gsm']:
    for i in range(3):
            for variant in ['original', 'irc']:
                filepath = os.path.join(args.save_dir, f'hand_gsmic_mixed_{i}_{variant}_output_{args.model}_{task}.json')
                print(filepath)
                # parse_answers(filepath, args.task)
                print(grade(filepath))
            
