In [None]:
from torch.utils.data import DataLoader

from common.dataset import TimeSeriesDataset
from models.lstm_seq2seq import VDEncoderDecoder

import numpy as np
import matplotlib.pyplot as plt

import seaborn as sns
sns.set_style('whitegrid')

import matplotlib
matplotlib.rcParams.update({'font.size': 20})

In [None]:
x_train = np.arange(0, 30, 0.05)
x_test = np.arange(30, 60, 0.05)
y_train = np.sin(x_train) + 2*np.random.randn(len(x_train)) * 0.1
y_test = np.sin(x_test) + 2*np.random.randn(len(x_test)) * 0.1
y_train_scaled = y_train / np.max(y_train)
y_test_scaled = y_test / np.max(y_train)
plt.figure()
plt.plot(x_train, y_train_scaled)
plt.plot(x_test, y_test_scaled);

In [None]:
train_dataset = TimeSeriesDataset(y_train_scaled.reshape(-1, 1), 100, 100)
train_loader = DataLoader(train_dataset, 16)

In [None]:
model = VDEncoderDecoder(1, 32, 100, 0.1, 0.001)

In [None]:
model.learn(train_loader, 100)

In [None]:
test_dataset = TimeSeriesDataset(y_test_scaled.reshape(-1, 1), 100, 100)
test_loader = DataLoader(test_dataset, 1)

In [None]:
x, y = next(iter(test_loader))

In [None]:
# Model uncertainty
preds = np.squeeze(model.predict(x))
mean_preds = np.mean([np.squeeze(model.predict(x)) for _ in range(100)], axis=0)
squared_diff = (preds - mean_preds)**2
model_unc = np.squeeze(np.mean(squared_diff, axis=0))

In [None]:
# Aleatoric uncertainty
train_dataloader = DataLoader(train_dataset, batch_size=len(train_dataset))
x_train, y_train = next(iter(train_dataloader))
y_hat_train = model.predict(x_train) 
squared_errors = (y_hat_train - y_train.numpy())**2
aleatoric_unc = np.squeeze(np.mean(squared_errors, axis=0))

In [None]:
total_unc_std = np.sqrt(model_unc + aleatoric_unc)

In [None]:
x = np.squeeze(x)
plt.figure(figsize=(15, 5))
plt.plot(np.arange(-len(x), 0), x, label="Input")
plt.plot(np.arange(0, len(mean_preds)), mean_preds, label='Forecast')
plt.plot(np.arange(0, len(np.squeeze(y))), np.squeeze(y), label='Ground-truth', linestyle='--')
plt.fill_between(np.arange(0, len(mean_preds)), mean_preds - 1.96*total_unc_std, mean_preds + 1.96*total_unc_std, alpha=0.3, color='tab:orange')
plt.legend()
plt.xticks([]);