In [1]:

import torch
import argparse
from pathlib import Path
import numpy as np
from src.model_utils import build_model
import src.datasets as datasets
import src.evaluation as evaluation
from src.torch_utils import torch2numpy
from reevaluate import get_test_dataset
from tqdm import tqdm


  "Using `json`-module for json-handling. "


In [2]:
def get_quantile_forecast(pred, quantile):
    # shape is [num_samples, num_series, prediction_length]
    # return the quantile of the samples for each series
    return np.quantile(pred, quantile, axis=0)


In [26]:
ckpt_file = "./Checkpoints/bouncing_ball.pt"
ckpt = torch.load(ckpt_file, map_location="cpu")

config = ckpt["config"]
model = build_model(config=config)
model.load_state_dict(ckpt["model"])

test_dataset = get_test_dataset(config)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=80)
extra_args = dict()
extra_args = {"dur_temperature": 1.0}
device = 'cpu'

In [47]:
def inference(ckpt_file, dataset_path):

    ckpt = torch.load(ckpt_file, map_location="cpu")

    config = ckpt["config"]
    model = build_model(config=config)
    model.load_state_dict(ckpt["model"])

    #test_dataset = get_test_dataset(config)
    test_dataset = datasets.BouncingBallDataset(path=dataset_path)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=20)

    #do predictions autoregressively. 
    context_length = config['context_length']
    prediction_length = config['prediction_length']
    ground_truth = []
    all_mean = []
    # calculate coverage and width of prediction intervals
    for test_batch, test_label in tqdm(test_loader):
        true = test_batch[:, -prediction_length:]
        autoregressive_mean = []
        for t in range(prediction_length):
            pred = model.predict(test_batch[:,t:t+context_length], num_samples=1, pred_one_step=True)
            pred_y = pred['rec_n_forecast'][:,:,-1]
            autoregressive_mean.append(pred_y)
        ground_truth.append(true)
        all_mean.append(np.concatenate(autoregressive_mean, -1).squeeze())

    ground_truth = np.concatenate(ground_truth, 0)
    all_mean = np.concatenate(all_mean, 0)

    return ground_truth, all_mean


In [48]:
ckpt_bb = "./Checkpoints/bouncing_ball.pt"
ckpt_bbnoisy = "./Checkpoints/bb_noisy.pt"


In [49]:
ground_truth, all_mean = inference(ckpt_bb, "./data/bouncing_ball_calibration.npz")
np.savez("./results/bouncing_ball_calibration.npz", ground_truth=ground_truth, mean=all_mean)


  indices = indices // dim
100%|██████████| 150/150 [08:30<00:00,  3.40s/it]


In [72]:
ground_truth, all_mean = inference(ckpt_bbnoisy, "./data/bouncing_ball_noisy_calibration.npz")
np.savez("./results/bouncing_ball_noisy_calibration.npz", ground_truth=ground_truth, mean=all_mean)


  indices = indices // dim
100%|██████████| 150/150 [08:17<00:00,  3.32s/it]


In [50]:
#calculate errors on the calibration dataset

preds = np.load("./results/bouncing_ball_calibration.npz", allow_pickle=True)
ground_truth = preds["ground_truth"]
mean = preds["mean"]


In [52]:
mses = (ground_truth.squeeze() - mean) ** 2
mses.shape

(3000, 50)

In [56]:
# vanilla cp
nonconformity = mses.flatten()

# get quantiles of nonconformity scores
alpha = 0.1
thresh = np.quantile(nonconformity, 1-alpha)


In [57]:
# inference on test set
ground_truth, all_mean = inference(ckpt_bb, "./data/bouncing_ball_test.npz")
np.savez("./results/bouncing_ball_test.npz", ground_truth=ground_truth, mean=all_mean)


  indices = indices // dim
100%|██████████| 50/50 [02:39<00:00,  3.18s/it]


In [66]:
test_mses = (ground_truth.squeeze() - all_mean) ** 2
test_nonconformity = test_mses.flatten()
coverage = np.mean(test_nonconformity < thresh)


In [67]:
coverage, thresh

(0.90404, 0.23606456667184836)

## Noisey

In [73]:
#calculate errors on the calibration dataset

preds = np.load("./results/bouncing_ball_noisy_calibration.npz", allow_pickle=True)
ground_truth = preds["ground_truth"]
mean = preds["mean"]


In [74]:
mean[0].shape

(50,)

In [75]:
mses = (ground_truth.squeeze() - mean) ** 2
mses.shape

(3000, 50)

In [76]:
# vanilla cp
nonconformity = mses.flatten()

# get quantiles of nonconformity scores
alpha = 0.1
thresh = np.quantile(nonconformity, 1-alpha)


In [77]:
# inference on test set
ground_truth, all_mean = inference(ckpt_bbnoisy, "./data/bouncing_ball_noisy_test.npz")
np.savez("./results/bouncing_ball_noisy_test.npz", ground_truth=ground_truth, mean=all_mean)


  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [02:40<00:00,  3.20s/it]


In [78]:
test_mses = (ground_truth.squeeze() - all_mean) ** 2
test_nonconformity = test_mses.flatten()
coverage = np.mean(test_nonconformity < thresh)


In [79]:
coverage, thresh

(0.90096, 0.18580201715230943)

## Forcasting

In [None]:

pred_segs = []
true_segs = []
for test_batch, test_label in tqdm(test_loader):
    test_batch = test_batch.to(device)
    pred = model.predict(test_batch, num_samples=100)
    true = test_batch[:, config['context_length']:]
    pred_segs.append(torch2numpy(torch.argmax(pred['z_emp_probs'], dim=-1)))
    true_segs.append(torch2numpy(test_label[:, config['context_length']:]))


In [106]:
seg_error = evaluation.evaluate_segmentation(np.concatenate(true_segs, 0), np.concatenate(pred_segs, 0), K=config["num_categories"])

seg_error

{'nmi_score': 0.10380844559277072,
 'ari_score': 0.140397887557092,
 'accuracy': 0.68736,
 'f1_score': 0.6873676161666078}

In [84]:
np.mean(mses)

0.08520873

In [97]:
pred_segs

tensor([[0, 0, 0,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]])

## Calibration and conformal prediction