In [1]:
import sys

import numpy as np
import pandas as pd
import torch

sys.path.append("..")
from kuramoto.model import KuramotoModel
from NGS.data import NGSDataset, preprocess
from NGS.ema import EMA
from NGS.experiment import rollout
from NGS.hyperparameter import HyperParameter
from path import DATA_DIR, RESULT_DIR

device = torch.device("cuda")

In [4]:
def evaluate(threshold: int) -> None:
    exp_id = f"kuramoto_th{threshold}"
    result_dir = RESULT_DIR / exp_id

    # Check validity of result directory
    hp = HyperParameter.from_yaml(result_dir / "hyperparameter.yaml")
    assert hp.threshold is not None
    assert hp.threshold is not None
    if threshold == 2:
        assert hp.threshold == 2.0
    else:
        assert np.isclose(hp.threshold, np.pi / threshold)

    # Load model
    checkpoint = torch.load(result_dir / "checkpoint.pth", map_location="cpu")
    model = KuramotoModel(hp.emb_dim, hp.depth, hp.dropout).to(device)
    model.threshold = np.pi / threshold if threshold != 2 else 2.0
    ema = EMA(model)
    ema.load_state_dict(checkpoint["ema"])

    # Load data
    data_df = pd.read_pickle(DATA_DIR / "kuramoto_train.pkl")
    _, test = preprocess(data_df)
    dataset = NGSDataset(**test, window=-1)

    # Rollout
    with ema():
        pred_trajectories, nfevs, runtimes = rollout(model, dataset, device)

    # Save results
    pd.DataFrame(
        {"trajectories": pred_trajectories, "runtime": runtimes, "nfev": nfevs}
    ).to_pickle(result_dir / "kuramoto_train2.pkl")

In [5]:
evaluate(6)