In [None]:
import torch

from add_thin.metrics import forecast_wasserstein
from add_thin.evaluate_utils import get_task, get_run_data

In [None]:
# Set run id and paths
RUN_ID = "id"

WANDB_DIR = "path/to/wandb/logging/directory"
PROJECT_ROOT = "path/to/project"  # should include data folder

In [None]:
def sample_model(task, datamodule):
    """
    Sample forecasts from model.
    """
    samples = []
    targets = []
    mapes = []
    with torch.no_grad():
        for batch in datamodule.test_dataloader():
            batch.to(task.device)
            # Sample 50 forecasts
            for _ in range(10):
                # Set history
                future, tmax, tmin = task.set_history(
                    batch.concat(batch, batch, batch, batch)
                )  # Note that we are using the same batch 5 times to get 5 different histories

                # Sample forecasts from model
                sample = task.model.sample(
                    len(future),
                    tmax=future.tmax,
                )

                # Rescale and shift to right forecast window
                sample.time = (sample.time / future.tmax) * (tmax - tmin)[
                    :, None
                ] + tmin[:, None]

                # Calculate Absolute Percentage Error
                mapes.append(
                    (
                        torch.abs(future.mask.sum(-1) - sample.mask.sum(-1))
                        / (future.mask.sum(-1) + 1)
                    )
                    .detach()
                    .cpu()
                )

                samples = samples + sample.to_time_list()
                targets = targets + future.to_time_list()

    return samples, targets, mapes

In [None]:
# Get run data
data_name, seed, run_path = get_run_data(RUN_ID, WANDB_DIR)

# Get task and datamodule
task, datamodule = get_task(run_path, density=False, data_root=PROJECT_ROOT)

# Sample forecasts
samples, targets, mapes = sample_model(task, datamodule)

# Calculate Wasserstein distance and MAPE
wassertstein_distance = forecast_wasserstein(
    samples,
    targets,
    datamodule.tmax.detach().cpu().item(),
)
MAPE = torch.cat(mapes).mean()

# Print rounded results for data and seed
print("ADD and Thin forecast evaluation:")
print("================================")
print(
    f"{data_name} (Seed: {seed}): Wasserstein: {wassertstein_distance:.3f}, MAPE: {MAPE:.3f}"
)