In [1]:
import nnts
import nnts.data
import nnts.plotting
import nnts.torch.preprocessing
import nnts.torch.models
import nnts.metrics
import nnts.torch.datasets
import nnts.loggers
import nnts.datasets
from nnts import utils
import nnts.torch.utils
import nnts.torch.trainers
import nnts.torch
import nnts.torch.models.dlinear
import nnts.lags
import torch
torch.set_printoptions(precision=8, sci_mode=False)
%load_ext autoreload
%autoreload 2

In [10]:
import nnts.torch.models.nlinear

### Load the dataset 
You can easily load any of the Monash datasets using the `load_dataset` function.

In [11]:
df, metadata = nnts.datasets.load_dataset("tourism_monthly")

In [12]:
lag_seq = nnts.lags.get_lags_for_frequency(metadata.freq)
lag_seq = [lag - 1 for lag in lag_seq if lag > 1]


In [13]:
metadata.context_length += max(lag_seq)
metadata.context_length

51

### Set the Hyperparamters 
You will need to provide an optimizer and a loss function. Other hyperparameters are optional.

In [14]:
params = utils.Hyperparams(
    optimizer=torch.optim.Adam,
    loss_fn=torch.nn.L1Loss(),
    batch_size=32,
    batches_per_epoch=50,
    training_method=utils.TrainingMethod.TEACHER_FORCING,
    model_file_path="logs"
)

### Data Processing
We specify how we want to split the data, what transformations we want to apply and how we want to batch and sample the data during training.  

### Create, Train and Evaluate the model
Create a Pytorch model.

In [15]:

nnts.torch.utils.seed_everything(42)
dataset_options = {
    "context_length": metadata.context_length,
    "prediction_length": metadata.prediction_length,
    "conts": [],
}

trn_dl, test_dl = nnts.torch.utils.create_dataloaders(
    df,
    nnts.datasets.split_test_train_last_horizon,
    metadata.context_length,
    metadata.prediction_length,
    Dataset=nnts.torch.datasets.TimeseriesDataset,
    dataset_options=dataset_options,
    Sampler=nnts.torch.datasets.TimeSeriesSampler,
)
net = nnts.torch.models.nlinear.NLinear(metadata)

trner = nnts.torch.trainers.TorchEpochTrainer(net, params, metadata)
evaluator = trner.train(trn_dl)
y_hat, y = evaluator.evaluate(
    test_dl, metadata.prediction_length, metadata.context_length
)

test_metrics = nnts.metrics.calc_metrics(
    y_hat, y, nnts.metrics.calculate_seasonal_error(trn_dl, metadata.seasonality)
)
test_metrics

NLinear(
  (Linear): ModuleList(
    (0): Linear(in_features=51, out_features=24, bias=True)
  )
)
saving model
Epoch 1 Train Loss: 4612.0126953125
saving model
Epoch 2 Train Loss: 4431.8623046875
saving model
Epoch 3 Train Loss: 3911.188232421875
saving model
Epoch 4 Train Loss: 3755.31787109375
saving model
Epoch 5 Train Loss: 3266.650390625
saving model
Epoch 6 Train Loss: 2979.943115234375
saving model
Epoch 7 Train Loss: 2777.298828125
saving model
Epoch 8 Train Loss: 2751.07958984375
saving model
Epoch 9 Train Loss: 2701.814697265625
saving model
Epoch 10 Train Loss: 2482.814697265625
saving model
Epoch 11 Train Loss: 2432.57177734375
saving model
Epoch 12 Train Loss: 2223.023681640625
saving model
Epoch 13 Train Loss: 2212.460693359375
saving model
Epoch 14 Train Loss: 2083.672607421875
saving model
Epoch 15 Train Loss: 2002.1507568359375
Epoch 16 Train Loss: 2060.4013671875
Epoch 17 Train Loss: 2191.999267578125
Epoch 18 Train Loss: 2088.48583984375
saving model
Epoch 19 Train 

{'mse': 63871748.0,
 'abs_error': 16272740.0,
 'abs_target_sum': 166958480.0,
 'abs_target_mean': 19007.11328125,
 'mase': 1.4483788013458252,
 'mape': 0.229703888297081,
 'smape': 0.19521769881248474,
 'msmape': 0.19519780576229095,
 'nd': 0.17888924479484558,
 'mae': 1852.543212890625,
 'rmse': 2364.0576171875,
 'seasonal_error': 1543.2191162109375}

### Evaluate

### Evaluate

In [5]:
nnts.plotting.plotly_forecasts_vs_actuals(y, y_hat)

### Forecast

In [28]:
net = evaluator.net

forecaster = nnts.torch.trainers.TorchForecaster(net)

In [29]:
forecast = forecaster.forecast(test_dl, metadata.prediction_length, metadata.context_length)

In [30]:
forecast

tensor([[[  6641.45605469],
         [  4133.40673828],
         [  3067.53247070],
         ...,
         [  3548.06396484],
         [  4189.00341797],
         [  6918.93017578]],

        [[194433.23437500],
         [159917.76562500],
         [152778.92187500],
         ...,
         [134823.09375000],
         [146866.96875000],
         [201028.07812500]],

        [[ 95321.10937500],
         [ 86995.16406250],
         [103553.07812500],
         ...,
         [133744.62500000],
         [125787.48437500],
         [103492.85937500]],

        ...,

        [[  7711.12939453],
         [ 14208.24902344],
         [ 10858.39062500],
         ...,
         [  7101.75195312],
         [  6812.04296875],
         [  8448.37890625]],

        [[  7796.66650391],
         [  7452.81250000],
         [ 11486.20019531],
         ...,
         [ 14187.00683594],
         [ 10841.85546875],
         [ 10177.20507812]],

        [[  2008.92102051],
         [  1647.43176270],
         [