# Environment initialization

In [2]:
import os

import wandb
from solvers import ZeroShotGPT
from structs import DataSet

In [3]:
solver = ZeroShotGPT()

In [4]:
os.environ["WANDB_SILENT"] = "true"

wandb.init(
    project="brainteasers",
    config={
        "solver": "ZeroShotGPT"
    }
)

# Evaluation on Sentence Puzzle

In [5]:
sp = DataSet.from_file("../data/SP-train.pkl")

sp_answers = solver.solve(sp)
sp_are_answers_correct = [instance.is_choice_correct(answer) for instance, answer in zip(sp, sp_answers)]
sp_accuracy = sum(sp_are_answers_correct) / len(sp_are_answers_correct)

print(f"Accuracy on the Sentence Puzzle dataset: {sp_accuracy: .4f}")

  0%|          | 0/507 [00:00<?, ?it/s]

Accuracy on the Sentence Puzzle dataset:  0.2465


# Evaluation on Word Puzzle

In [7]:
wp = DataSet.from_file("../data/WP-train.pkl")

wp_answers = solver.solve(wp)
wp_are_answers_correct = [instance.is_choice_correct(answer) for instance, answer in zip(wp, wp_answers)]
wp_accuracy = sum(wp_are_answers_correct) / len(wp_are_answers_correct)

print(f"Accuracy on the Word Puzzle dataset: {wp_accuracy: .4f}")

Accuracy on the Word Puzzle dataset:  0.1995


# Logging metrics

In [8]:
total_cardinality = (len(sp_answers) + len(wp_answers))
total_accuracy = (sp_accuracy * (len(sp_answers) / total_cardinality) +
                  wp_accuracy * (len(wp_answers) / total_cardinality))

In [9]:
wandb.log(
    {
        "accuracy/overall": total_accuracy,
        "accuracy/sp": sp_accuracy,
        "accuracy/wp": wp_accuracy,
    }
)

wandb.finish(quiet=True)