In [None]:
import os
import glob

import wandb
import numpy as np
from tqdm.auto import tqdm
from sklearn.metrics import classification_report

from valerie.datasets import Phase2Dataset
from valerie.modeling import SequenceClassificationModel, SequenceClassificationExample

# taken from train script
def generate_sequence_classification_examples(claims):
    examples = []
    for claim in tqdm(claims, desc="generating examples"):
        examples.append(
            SequenceClassificationExample(
                guid=claim.id, text_a=claim.claim, text_b=None, label=claim.label,
            )
        )
    return examples


def thing1():
    api = wandb.Api()

    # run = api.run("jaymody/valerie/4k5mz571")
    trial_dataset = Phase2Dataset.from_raw("data/phase2-trial/metadata.json")
    examples = generate_sequence_classification_examples(trial_dataset.claims)

    run_dir = "models/fnc/combined_dataset_first_probe/combined_dataset_first_probe-0"
    checkpoint_dirs = [
        os.path.basename(path)
        for path in glob.glob(os.path.join(run_dir, "checkpoint*"))
    ] + [""]

    for checkpoint_dir in checkpoint_dirs:
        print("-" * 50)
        print("checkpoint: {}".format(checkpoint_dir).center(50, "-"))
        print("-" * 50)
        model = SequenceClassificationModel.from_pretrained(
            pretrained_model_name_or_path=run_dir, checkpoint_dir=checkpoint_dir
        )
        predict_dataset = model.create_dataset(examples)
        predict_output = model.predict(predict_dataset, predict_batch_size=8)
        labels = predict_output.label_ids
        predictions = [np.argmax(proba) for proba in predict_output.predictions]
        report = classification_report(labels, predictions)
        print(report)
        print()

    # run.save()
    # run.summary["trial_recal"] = 0.9
    # run.summary.update()


def thing2():
    # api = wandb.Api()

    # run = api.run("jaymody/valerie/4k5mz571")
    trial_dataset = Phase2Dataset.from_raw("data/phase2-trial/metadata.json")
    examples = generate_sequence_classification_examples(trial_dataset.claims)

    group_dir = "models/fnc/initial_test_run"
    run_dirs = glob.glob(os.path.join(group_dir, "initial_test*"))

    for run_dir in run_dirs:
        print()
        print()
        print("-" * 50)
        print("checkpoint: {}".format(run_dir).center(50, "-"))
        print("-" * 50)
        model = SequenceClassificationModel.from_pretrained(
            pretrained_model_name_or_path=run_dir
        )
        predict_dataset = model.create_dataset(examples)
        predict_output = model.predict(predict_dataset, predict_batch_size=8)
        labels = predict_output.label_ids
        predictions = [np.argmax(proba) for proba in predict_output.predictions]
        report = classification_report(labels, predictions)
        print(report)
        print()

    # run.save()
    # run.summary["trial_recal"] = 0.9
    # run.summary.update()


if __name__ == "__main__":
    thing2()
