Skip to content

Commit

Permalink
Merge pull request #122 from etna-team/issue-115
Browse files Browse the repository at this point in the history
added params_to_tune and test_params_to_tune
  • Loading branch information
Ama16 committed Nov 1, 2023
2 parents 2969754 + 8db0ab9 commit 5ea18ee
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 7 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased
### Added
-
- Add params_to_tune for DeepStateModel ([#115](https://github.com/etna-team/etna/issues/115))
-
-
-
Expand Down
38 changes: 32 additions & 6 deletions etna/models/nn/deepstate/deepstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from torch import Tensor
from typing_extensions import TypedDict

from etna.distributions import BaseDistribution
from etna.distributions import FloatDistribution
from etna.distributions import IntDistribution
from etna.models.base import DeepBaseModel
from etna.models.base import DeepBaseNet
from etna.models.nn.deepstate import LDS
Expand Down Expand Up @@ -304,14 +307,20 @@ def __init__(
* **generator**: (*Optional[torch.Generator]*) - generator for reproducibile train-test splitting
* **torch_dataset_size**: (*Optional[int]*) - number of samples in dataset, in case of dataset not implementing ``__len__``
"""
self.ssm = ssm
self.input_size = input_size
self.num_layers = num_layers
self.n_samples = n_samples
self.lr = lr
self.optimizer_params = optimizer_params
super().__init__(
net=DeepStateNet(
ssm=ssm,
input_size=input_size,
num_layers=num_layers,
n_samples=n_samples,
lr=lr,
optimizer_params=optimizer_params,
ssm=self.ssm,
input_size=self.input_size,
num_layers=self.num_layers,
n_samples=self.n_samples,
lr=self.lr,
optimizer_params=self.optimizer_params,
),
encoder_length=encoder_length,
decoder_length=decoder_length,
Expand All @@ -323,3 +332,20 @@ def __init__(
trainer_params=trainer_params,
split_params=split_params,
)

def params_to_tune(self) -> Dict[str, BaseDistribution]:
"""Get default grid for tuning hyperparameters.
This grid tunes parameters: ``lr``, ``num_layers``, ``encoder_length``.
Other parameters are expected to be set by the user.
Returns
-------
:
Grid to tune.
"""
return {
"num_layers": IntDistribution(low=1, high=3),
"lr": FloatDistribution(low=1e-5, high=1e-2, log=True),
"encoder_length": IntDistribution(low=1, high=20),
}
14 changes: 14 additions & 0 deletions tests/test_models/test_nn/test_deepstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from etna.models.nn.deepstate import WeeklySeasonalitySSM
from etna.transforms import StandardScalerTransform
from tests.test_models.utils import assert_model_equals_loaded_original
from tests.test_models.utils import assert_sampling_is_valid


@pytest.mark.parametrize(
Expand Down Expand Up @@ -55,3 +56,16 @@ def test_save_load(example_tsds):
trainer_params=dict(max_epochs=1),
)
assert_model_equals_loaded_original(model=model, ts=example_tsds, transforms=[], horizon=3)


def test_params_to_tune(example_tsds):
ts = example_tsds
model = DeepStateModel(
ssm=CompositeSSM(seasonal_ssms=[WeeklySeasonalitySSM()], nonseasonal_ssm=None),
input_size=0,
encoder_length=14,
decoder_length=14,
trainer_params=dict(max_epochs=1),
)
assert len(model.params_to_tune()) > 0
assert_sampling_is_valid(model=model, ts=ts)

0 comments on commit 5ea18ee

Please sign in to comment.