In [1]:
import sys

import pandas as pd
import torch

sys.path.append("..")
from heat.model import HeatModel
from NGS.data import NGSDataset, preprocess
from NGS.ema import EMA
from NGS.experiment import rollout
from NGS.hyperparameter import HyperParameter
from path import DATA_DIR, RESULT_DIR

device = torch.device("cuda")
file_names = ["heat_train", "heat_test_int", "heat_test_ext"]

In [2]:
def evaluate(missing: float, noise: float) -> None:
    exp_id = f"heat_p{missing}_s{noise}"
    result_dir = RESULT_DIR / exp_id

    # Check validity of result directory
    hp = HyperParameter.from_yaml(result_dir / "hyperparameter.yaml")
    assert hp.missing == missing
    assert hp.noise == noise

    # Load model
    checkpoint = torch.load(result_dir / "checkpoint.pth", map_location="cpu")
    model = HeatModel(hp.emb_dim, hp.depth).to(device)
    ema = EMA(model)
    ema.load_state_dict(checkpoint["ema"])

    for file_name in file_names:
        # Load data
        data_df = pd.read_pickle(DATA_DIR / f"{file_name}.pkl")
        _, test = preprocess(data_df)
        dataset = NGSDataset(**test, window=-1)

        # Rollout
        with ema():
            pred_trajectories, nfevs, runtimes = rollout(model, dataset, device)

        # Save results
        pd.DataFrame(
            {"trajectories": pred_trajectories, "runtime": runtimes, "nfev": nfevs}
        ).to_pickle(result_dir / f"{file_name}.pkl")

In [3]:
evaluate(0.1, 0.001)