In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from torchts.nn.loss import quantile_loss
from torchts.nn.model import TimeSeriesModel
from torchts.nn.models.seq2seq import Encoder, Decoder, Seq2Seq 

# Generate dataset

In [2]:
# generate linear time series data with some noise
x = np.linspace(-10,10,100).reshape(-1,1).astype(np.float32)
y = 2*x+1 + np.random.normal(0, 2, x.shape).reshape(-1,1).astype(np.float32)
plt.plot(x.flatten(), y.flatten())
plt.show()

# Enable uncertainty quantification in LSTM model

In [3]:
class LSTM(TimeSeriesModel):
    def __init__(self, input_size, output_size, optimizer, hidden_size=8, batch_size=10, **kwargs):
        super(LSTM, self).__init__(optimizer, **kwargs)
        self.hidden_size = hidden_size
        self.batch_size = batch_size

        self.lstm = torch.nn.LSTMCell(input_size, hidden_size)
        self.linear = torch.nn.Linear(hidden_size, output_size)

    def init_hidden(self):
        # initialize the hidden state and the cell state to zeros
        return (torch.zeros(self.batch_size, self.hidden_size),
                torch.zeros(self.batch_size, self.hidden_size))

    def forward(self, x, y=None, batches_seen=None):
        hc = self.init_hidden()
        
        hidden, _ = self.lstm(x, hc)
        out = self.linear(hidden)
        return out

In [4]:
inputDim = 1       
outputDim = 1 
optimizer_args = {"lr": 0.01}
quantiles = [0.025, 0.5, 0.975]

batch_size = 10
models = {quantile: LSTM(
    inputDim, 
    outputDim, 
    torch.optim.Adam,
    criterion=quantile_loss, 
    criterion_args={"quantile": quantile}, 
    optimizer_args=optimizer_args
    ) for quantile in quantiles}

In [5]:
for _, model in models.items():
    # train model
    model.fit(
        torch.from_numpy(x),
        torch.from_numpy(y),
        max_epochs=100,
        batch_size=batch_size,
    )

In [6]:
# inference
y_preds = {}
for x_batch in torch.split(torch.from_numpy(x), batch_size):
    for q, model in models.items():
        if q not in y_preds:
            y_preds[q] = [model.predict(x_batch).detach().numpy()]
        else:
            y_preds[q].append(model.predict(x_batch).detach().numpy())
y_preds = {q: np.concatenate(y_pred) for q, y_pred in y_preds.items()}

In [8]:
plt.plot(x.flatten(), y_preds[0.025].flatten(), label="p=0.025")
plt.plot(x.flatten(), y_preds[0.5].flatten(), label="p=0.5")
plt.plot(x.flatten(), y_preds[0.975].flatten(), label="p=0.975")
plt.plot(x.flatten(), y.flatten(), label="y_true")
plt.legend()
plt.show()