In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../src")

In [None]:
import logging
import random
import time

import matplotlib.pyplot as plt

from benchmark import load_dataset, load_model, load_prompt_method
from utils_ext.tools import setup_logging

plt.ioff()
setup_logging()

logger = logging.getLogger(__name__)

PATH_OUTPUT = "../results"

In [None]:
DATASET_CACHE = {}

## Load dataset, model, prompt method

In [None]:
dataset = load_dataset("arc-c", dataset_cache=DATASET_CACHE)

In [None]:
model = load_model("gemma1.1-2b-it")

In [None]:
prompt_method = load_prompt_method("basic_1s")

## Run single prompt inference

In [None]:
# random question
prompt = random.choice(dataset)

# specific question
# prompt = dataset[dataset.id2index["QUESTION_ID"]]

start = time.perf_counter()
responses, statistics = prompt_method.generate_response(model, prompt, verbose=True)
answer, confidence = prompt_method.extract_answer(responses)
is_correct = dataset.evaluate_answer(answer, prompt["correct_answer"])
end = time.perf_counter()

print()
print("id:             ", prompt["id"])
print("correct_answer: ", prompt["correct_answer"])
print("answer:         ", answer)
print("confidence:     ", confidence)
print("is_correct:     ", is_correct)
print("time:           ", f"{end-start:.2f}s")
print("statistics:     ", statistics)

## Run batch API inference

In [None]:
model.create_batch(
    path_requests=f"{PATH_OUTPUT}/batch/batch_requests.jsonl",
    path_responses=f"{PATH_OUTPUT}/batch/batch_responses.jsonl",
)
for i in range(5):
    prompt = random.choice(dataset)
    prompt_method.generate_response(model, prompt)
results = model.submit_batch(verbose=True)

print("Total cost: Â¢ ", sum(statistics["compl_costs"] * 100 for _, statistics in results.values()))
for id, (response, statistics) in results.items():
    prompt = dataset[dataset.id2index[id]]
    answer, confidence = prompt_method.extract_answer([response])
    is_correct = dataset.evaluate_answer(answer, prompt["correct_answer"])
    print()
    print("id:             ", prompt["id"])
    print("correct_answer: ", prompt["correct_answer"])
    print("answer:         ", answer)
    print("confidence:     ", confidence)
    print("is_correct:     ", is_correct)
    print("time:           ", f"{end-start:.2f}s")
    print("statistics:     ", statistics)