In [14]:
import gluonts
import nnts
import nnts.data
import nnts.experiments.plotting
import nnts.torch.preprocessing
import nnts.torch.models
import nnts.metrics
import nnts.torch.datasets
import nnts.loggers
import nnts.datasets
from nnts import utils
import gluonts.time_feature
import nnts.torch.utils
import nnts.torch.trainers
import nnts.metrics

import torch
torch.set_printoptions(precision=8, sci_mode=False)
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Load the dataset 
You can easily load any of the Monash datasets using the `load_dataset` function.

In [15]:
df, metadata = nnts.datasets.load_dataset("tourism_monthly")

### Set the Hyperparamters 
You will need to provide an optimizer and a loss function. Other hyperparameters are optional.

In [16]:
params = utils.Hyperparams(
    optimizer=torch.optim.Adam,
    loss_fn=nnts.torch.models.deepar.distr_nll,
    batch_size=32,
    batches_per_epoch=50,
    training_method=utils.TrainingMethod.TEACHER_FORCING,
    model_file_path="logs"
)

### Create covariate features
When using a DeepAR model we recommend using lag features as covariates.

In [17]:
lag_seq = gluonts.time_feature.lag.get_lags_for_frequency(metadata.freq)
lag_seq = [lag - 1 for lag in lag_seq if lag > 1]
lag_processor = nnts.torch.preprocessing.LagProcessor(lag_seq)

### Data Processing
We specify how we want to split the data, what transformations we want to apply and how we want to batch and sample the data during training.  

In [18]:
context_length = metadata.context_length + max(lag_seq)
dataset_options = {
    "context_length": metadata.context_length,
    "prediction_length": metadata.prediction_length,
    "conts": [],
    "lag_seq": lag_seq,
}

trn_dl, test_dl = nnts.torch.utils.create_dataloaders(
    df,
    nnts.datasets.split_test_train_last_horizon,
    context_length,
    metadata.prediction_length,
    Dataset=nnts.torch.datasets.TimeseriesLagsDataset,
    dataset_options=dataset_options,
    Sampler=nnts.torch.datasets.TimeSeriesSampler,
)

### Create the model
Create a Pytorch model.

In [19]:

net = nnts.torch.models.DistrDeepAR(
    nnts.torch.models.deepar.StudentTHead,
    params,
    nnts.torch.preprocessing.masked_mean_abs_scaling,
    1,
    lag_processor=lag_processor,
    scaled_features=[],
    context_length=metadata.context_length,
    cat_idx=None,
    seq_cat_idx=None,
)


### Train the model

In [20]:

trner = nnts.torch.trainers.TorchEpochTrainer(
    net,
    params,
    metadata
)
evaluator = trner.train(trn_dl)

DistrDeepAR(
  (decoder): UnrolledLSTMDecoder(
    (rnn): LSTM(16, 40, num_layers=2, batch_first=True, dropout=0.1)
  )
  (distribution): StudentTHead(
    (main): ModuleList(
      (0-2): 3 x Linear(in_features=40, out_features=1, bias=True)
    )
  )
)
Epoch 1 Train Loss: 9.171257019042969
Epoch 2 Train Loss: 8.696732521057129
Epoch 3 Train Loss: 8.45849323272705
Epoch 4 Train Loss: 8.411606788635254
Epoch 5 Train Loss: 8.359128952026367
Epoch 6 Train Loss: 8.29736614227295
Epoch 7 Train Loss: 8.151010513305664
Epoch 8 Train Loss: 8.030811309814453
Epoch 9 Train Loss: 7.937575817108154
Epoch 10 Train Loss: 7.896162509918213
Epoch 11 Train Loss: 7.840521812438965
Epoch 12 Train Loss: 7.8140106201171875
Epoch 13 Train Loss: 7.816997051239014
Epoch 14 Train Loss: 7.800825119018555
Epoch 15 Train Loss: 7.800720691680908
Epoch 16 Train Loss: 7.823911190032959
Epoch 17 Train Loss: 7.763418197631836
Epoch 18 Train Loss: 7.770650863647461
Epoch 19 Train Loss: 7.765244007110596
Epoch 20 Train

### Evaluate

In [11]:
y_hat, y = evaluator.evaluate(
    test_dl, metadata.prediction_length, metadata.context_length
)

test_metrics = nnts.metrics.calc_metrics(
    y_hat, y, nnts.metrics.calculate_seasonal_error(trn_dl, metadata.seasonality)
)
test_metrics

{'mse': 53367568.0,
 'abs_error': 16483498.0,
 'abs_target_sum': 166958480.0,
 'abs_target_mean': 19007.11328125,
 'mase': 1.4142117500305176,
 'mape': 0.2124391347169876,
 'smape': 0.18161988258361816,
 'nd': 0.17456156015396118,
 'mae': 1876.5364990234375,
 'rmse': 2358.773193359375,
 'seasonal_error': 1543.2191162109375}

In [22]:
nnts.experiments.plotting.plotly_forecasts_vs_actuals(y, y_hat)

### Forecast