Skip to content

Commit

Permalink
DeepAR native implementation (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
egoriyaa committed Nov 21, 2023
1 parent 51e6d92 commit e68ad36
Show file tree
Hide file tree
Showing 15 changed files with 1,574 additions and 339 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
### Added
- Add params_to_tune for DeepStateModel ([#115](https://github.com/etna-team/etna/issues/115))
-
-
- Handle new functionality for prediction intervals in the `plot_forecast` ([#130](https://github.com/etna-team/etna/pull/130))
- Add `get_historical_forecasts` to pipelines for forecast estimation at each fold on the historical dataset ([#143](https://github.com/etna-team/etna/pull/143))
-
- Add DeepARNativeModel ([#114](https://github.com/etna-team/etna/pull/114))
-
-
-
-
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_reference/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ Native neural network models:
nn.NBeatsGenericModel
nn.NBeatsInterpretableModel
nn.PatchTSModel
nn.DeepARNativeModel

Utilities for :py:class:`~etna.models.nn.deepstate.deepstate.DeepStateModel`

Expand Down
1 change: 1 addition & 0 deletions etna/models/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

if SETTINGS.torch_required:
from etna.models.nn.deepar import DeepARModel
from etna.models.nn.deepar_native import DeepARNativeModel
from etna.models.nn.deepstate.deepstate import DeepStateModel
from etna.models.nn.mlp import MLPModel
from etna.models.nn.nbeats import NBeatsGenericModel
Expand Down
2 changes: 2 additions & 0 deletions etna/models/nn/deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Sequence

import pandas as pd
from deprecated import deprecated

from etna import SETTINGS
from etna.datasets.tsdataset import TSDataset
Expand All @@ -27,6 +28,7 @@
from pytorch_lightning import Trainer


@deprecated(reason="DeepARModel is deprecated. Use DeepARNativeModel instead.", version="3.0")
class DeepARModel(
_DeepCopyMixin, PytorchForecastingMixin, SavePytorchForecastingMixin, PredictionIntervalContextRequiredAbstractModel
):
Expand Down
4 changes: 4 additions & 0 deletions etna/models/nn/deepar_native/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from etna import SETTINGS

if SETTINGS.torch_required:
from etna.models.nn.deepar_native.deepar import DeepARNativeModel

0 comments on commit e68ad36

Please sign in to comment.