Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepAR native implementation #114

Merged
merged 30 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
309 changes: 309 additions & 0 deletions diff.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,309 @@
diff --git a/etna/models/nn/rnn.py b/etna/models/nn/deepar_new/deepar.py
index 49bc37f..74a40cc 100644
--- a/etna/models/nn/rnn.py
+++ b/etna/models/nn/deepar_new/deepar.py
@@ -13,24 +13,30 @@ from etna.distributions import FloatDistribution
from etna.distributions import IntDistribution
from etna.models.base import DeepBaseModel
from etna.models.base import DeepBaseNet
+from etna.models.nn.deepar_new import GaussianLoss
+from etna.models.nn.deepar_new import NegativeBinomialLoss
+from etna.models.nn.deepar_new import SamplerWrapper

if SETTINGS.torch_required:
import torch
import torch.nn as nn
+ from torch.utils.data.sampler import RandomSampler
+ from torch.utils.data.sampler import WeightedRandomSampler


-class RNNBatch(TypedDict):
- """Batch specification for RNN."""
+class DeepARBatchNew(TypedDict):
+ """Batch specification for DeepAR."""

encoder_real: "torch.Tensor"
decoder_real: "torch.Tensor"
encoder_target: "torch.Tensor"
decoder_target: "torch.Tensor"
segment: "torch.Tensor"
+ weight: "torch.Tensor"


-class RNNNet(DeepBaseNet):
- """RNN based Lightning module with LSTM cell."""
+class DeepARNetNew(DeepBaseNet):
+ """DeepAR based Lightning module with LSTM cell."""

def __init__(
self,
@@ -38,10 +44,12 @@ class RNNNet(DeepBaseNet):
num_layers: int,
hidden_size: int,
lr: float,
+ scale: bool,
+ n_samples: Optional[int], # TODO
loss: "torch.nn.Module",
optimizer_params: Optional[dict],
) -> None:
- """Init RNN based on LSTM cell.
+ """Init DeepAR.

Parameters
----------
@@ -53,8 +61,6 @@ class RNNNet(DeepBaseNet):
size of the hidden state
lr:
learning rate
- loss:
- loss function
optimizer_params:
parameters for optimizer for Adam optimizer (api reference :py:class:`torch.optim.Adam`)
"""
@@ -63,15 +69,18 @@ class RNNNet(DeepBaseNet):
self.num_layers = num_layers
self.input_size = input_size
self.hidden_size = hidden_size
- self.loss = torch.nn.MSELoss() if loss is None else loss
self.rnn = nn.LSTM(
num_layers=self.num_layers, hidden_size=self.hidden_size, input_size=self.input_size, batch_first=True
)
- self.projection = nn.Linear(in_features=self.hidden_size, out_features=1)
+ self.linear_1 = nn.Linear(in_features=self.hidden_size, out_features=1)
+ self.linear_2 = nn.Linear(in_features=self.hidden_size, out_features=1)
self.lr = lr
+ self.scale = scale
+ self.n_samples = n_samples
self.optimizer_params = {} if optimizer_params is None else optimizer_params
+ self.loss = loss

- def forward(self, x: RNNBatch, *args, **kwargs): # type: ignore
+ def forward(self, x: DeepARBatchNew, *args, **kwargs): # type: ignore
"""Forward pass.

Parameters
@@ -88,22 +97,49 @@ class RNNNet(DeepBaseNet):
decoder_real = x["decoder_real"].float() # (batch_size, decoder_length, input_size)
decoder_target = x["decoder_target"].float() # (batch_size, decoder_length, 1)
decoder_length = decoder_real.shape[1]
- output, (h_n, c_n) = self.rnn(encoder_real)
+ weights = x["weight"]
+ _, (h_n, c_n) = self.rnn(encoder_real)
forecast = torch.zeros_like(decoder_target) # (batch_size, decoder_length, 1)

for i in range(decoder_length - 1):
output, (h_n, c_n) = self.rnn(decoder_real[:, i, None], (h_n, c_n))
- forecast_point = self.projection(output[:, -1]).flatten()
+ distribution_params = self._count_distribution_params(output[:, -1], weights)
+ forecast_point = self.loss.sample(**distribution_params).flatten()
forecast[:, i, 0] = forecast_point
- decoder_real[:, i + 1, 0] = forecast_point
+ decoder_real[:, i + 1, 0] = forecast_point # TODO можно через if

# Last point is computed out of the loop because `decoder_real[:, i + 1, 0]` would cause index error
- output, (h_n, c_n) = self.rnn(decoder_real[:, decoder_length - 1, None], (h_n, c_n))
- forecast_point = self.projection(output[:, -1]).flatten()
+ output, (_, _) = self.rnn(decoder_real[:, decoder_length - 1, None], (h_n, c_n))
+ distribution_params = self._count_distribution_params(output[:, -1], weights)
+ forecast_point = self.loss.sample(**distribution_params).flatten()
forecast[:, decoder_length - 1, 0] = forecast_point
return forecast

- def step(self, batch: RNNBatch, *args, **kwargs): # type: ignore
+ def _count_distribution_params(self, output, weight):
+ if isinstance(self.loss, GaussianLoss):
+ mean = self.linear_1(output)
+ std = nn.Softplus()(self.linear_2(output))
+ if self.scale:
+ reshaped = [-1] + [1] * (output.dim() - 1)
+ weight = weight.reshape(reshaped).expand(mean.shape)
+ mean *= weight
+ std *= weight.abs()
+ params = {"mean": mean, "std": std}
+ elif isinstance(self.loss, NegativeBinomialLoss):
+ mean = nn.Softplus()(self.linear_1(output))
+ alpha = nn.Softplus()(self.linear_2(output))
+ if self.scale:
+ reshaped = [-1] + [1] * (output.dim() - 1)
+ weight = weight.reshape(reshaped).expand(alpha.shape)
+ alpha *= torch.sqrt(torch.tensor(weight))
+ total_count = 1 / alpha
+ probs = 1 / (alpha * mean + 1)
+ params = {"total_count": total_count, "probs": probs}
+ else:
+ raise NotImplementedError()
+ return params
+
+ def step(self, batch: DeepARBatchNew, *args, **kwargs): # type: ignore
"""Step for loss computation for training or validation.

Parameters
@@ -121,19 +157,20 @@ class RNNNet(DeepBaseNet):

encoder_target = batch["encoder_target"].float() # (batch_size, encoder_length-1, 1)
decoder_target = batch["decoder_target"].float() # (batch_size, decoder_length, 1)
-
- decoder_length = decoder_real.shape[1]
+ weights = batch["weight"]
+ target = torch.cat((encoder_target, decoder_target), dim=1)

output, (_, _) = self.rnn(torch.cat((encoder_real, decoder_real), dim=1))
-
- target_prediction = output[:, -decoder_length:]
- target_prediction = self.projection(target_prediction) # (batch_size, decoder_length, 1)
-
- loss = self.loss(target_prediction, decoder_target)
- return loss, decoder_target, target_prediction
+ distribution_params = self._count_distribution_params(output, weights)
+ target_prediction = self.loss.sample(**distribution_params)
+ distribution_params.update({"inputs": target})
+ loss = self.loss(**distribution_params)
+ return loss, target, target_prediction

def make_samples(self, df: pd.DataFrame, encoder_length: int, decoder_length: int) -> Iterator[dict]:
"""Make samples from segment DataFrame."""
+ segment = df["segment"].values[0]
+ values_target = df["target"].values
values_real = (
df.select_dtypes(include=[np.number])
.assign(target_shifted=df["target"].shift(1))
@@ -141,8 +178,6 @@ class RNNNet(DeepBaseNet):
.pipe(lambda x: x[["target_shifted"] + [i for i in x.columns if i != "target_shifted"]])
.values
)
- values_target = df["target"].values
- segment = df["segment"].values[0]

def _make(
values_real: np.ndarray,
@@ -159,12 +194,30 @@ class RNNNet(DeepBaseNet):
"encoder_target": list(),
"decoder_target": list(),
"segment": None,
+ "weight": None,
}
total_length = len(values_target)
total_sample_length = encoder_length + decoder_length

if total_sample_length + start_idx > total_length:
return None
+ # if start_idx < 0:
+ # sample["decoder_real"] = values_real[start_idx + encoder_length : start_idx + total_sample_length]
+ #
+ # # Get shifted target and concatenate it with real values features
+ # sample["encoder_real"] = values_real[: start_idx + encoder_length]
+ # sample["encoder_real"] = sample["encoder_real"][1:]
+ #
+ # target = values_target[: start_idx + total_sample_length].reshape(-1, 1)
+ # sample["encoder_target"] = target[1 : start_idx + encoder_length]
+ # sample["decoder_target"] = target[start_idx + encoder_length :]
+ #
+ # sample["encoder_real"] = np.pad(
+ # sample["encoder_real"], ((-start_idx, 0), (0, 0)), "constant", constant_values=0
+ # )
+ # sample["encoder_target"] = np.pad(
+ # sample["encoder_target"], ((-start_idx, 0), (0, 0)), "constant", constant_values=0
+ # )

# Get shifted target and concatenate it with real values features
sample["decoder_real"] = values_real[start_idx + encoder_length : start_idx + total_sample_length]
@@ -173,14 +226,16 @@ class RNNNet(DeepBaseNet):
sample["encoder_real"] = values_real[start_idx : start_idx + encoder_length]
sample["encoder_real"] = sample["encoder_real"][1:]

- target = values_target[start_idx : start_idx + encoder_length + decoder_length].reshape(-1, 1)
+ target = values_target[start_idx : start_idx + total_sample_length].reshape(-1, 1)
sample["encoder_target"] = target[1:encoder_length]
sample["decoder_target"] = target[encoder_length:]

sample["segment"] = segment

+ sample["weight"] = 1 + sample["encoder_real"].mean() if self.scale else float("nan")
return sample

+ # start_idx = -(encoder_length - 2) # TODO is good?
start_idx = 0
while True:
batch = _make(
@@ -202,8 +257,8 @@ class RNNNet(DeepBaseNet):
return optimizer


-class RNNModel(DeepBaseModel):
- """RNN based model on LSTM cell.
+class DeepARModelNew(DeepBaseModel):
+ """DeepAR based model on LSTM cell.

Note
----
@@ -219,6 +274,8 @@ class RNNModel(DeepBaseModel):
num_layers: int = 2,
hidden_size: int = 16,
lr: float = 1e-3,
+ scale: bool = True,
+ n_samples: Optional[int] = None,
loss: Optional["torch.nn.Module"] = None,
train_batch_size: int = 16,
test_batch_size: int = 16,
@@ -228,6 +285,7 @@ class RNNModel(DeepBaseModel):
test_dataloader_params: Optional[dict] = None,
val_dataloader_params: Optional[dict] = None,
split_params: Optional[dict] = None,
+ sampler: Optional[str] = None,
):
"""Init RNN model based on LSTM cell.

@@ -245,8 +303,6 @@ class RNNModel(DeepBaseModel):
size of the hidden state
lr:
learning rate
- loss:
- loss function, MSELoss by default
train_batch_size:
batch size for training
test_batch_size:
@@ -273,24 +329,38 @@ class RNNModel(DeepBaseModel):
self.num_layers = num_layers
self.hidden_size = hidden_size
self.lr = lr
- self.loss = loss
+ self.scale = scale
+ self.n_samples = n_samples
self.optimizer_params = optimizer_params
+ self.loss = loss
+ self.sampler = sampler
+ if sampler == "weighted":
+ sampler = SamplerWrapper(WeightedRandomSampler)
+ else:
+ sampler = SamplerWrapper(RandomSampler)
+ self.train_dataloader_params = (
+ train_dataloader_params if train_dataloader_params is not None else {"sampler": sampler}
+ )
+ self.val_dataloader_params = val_dataloader_params if val_dataloader_params is not None else {}
+ self.test_dataloader_params = test_dataloader_params if test_dataloader_params is not None else {}
super().__init__(
- net=RNNNet(
+ net=DeepARNetNew(
input_size=input_size,
num_layers=num_layers,
hidden_size=hidden_size,
lr=lr,
- loss=nn.MSELoss() if loss is None else loss,
+ scale=scale,
+ n_samples=n_samples,
optimizer_params=optimizer_params,
+ loss=GaussianLoss() if loss is None else loss,
),
decoder_length=decoder_length,
encoder_length=encoder_length,
train_batch_size=train_batch_size,
test_batch_size=test_batch_size,
- train_dataloader_params=train_dataloader_params,
- test_dataloader_params=test_dataloader_params,
- val_dataloader_params=val_dataloader_params,
+ train_dataloader_params=self.train_dataloader_params,
+ test_dataloader_params=self.test_dataloader_params,
+ val_dataloader_params=self.val_dataloader_params,
trainer_params=trainer_params,
split_params=split_params,
)
18 changes: 10 additions & 8 deletions etna/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,9 +527,9 @@
self.decoder_length = decoder_length
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.train_dataloader_params = {} if train_dataloader_params is None else train_dataloader_params
self.test_dataloader_params = {} if test_dataloader_params is None else test_dataloader_params
self.val_dataloader_params = {} if val_dataloader_params is None else val_dataloader_params
self.train_dataloader_params = {"shuffle": True} if train_dataloader_params is None else train_dataloader_params
self.test_dataloader_params = {"shuffle": False} if test_dataloader_params is None else test_dataloader_params
self.val_dataloader_params = {"shuffle": False} if val_dataloader_params is None else val_dataloader_params

Check warning on line 532 in etna/models/base.py

View check run for this annotation

Codecov / codecov/patch

etna/models/base.py#L530-L532

Added lines #L530 - L532 were not covered by tests
self.trainer_params = {} if trainer_params is None else trainer_params
self.split_params = {} if split_params is None else split_params
self.trainer: Optional[Trainer] = None
Expand Down Expand Up @@ -593,15 +593,19 @@
],
generator=self.split_params.get("generator"),
)
if "sampler" in self.train_dataloader_params:
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
self.train_dataloader_params["sampler"] = self.train_dataloader_params["sampler"](train_dataset)

Check warning on line 597 in etna/models/base.py

View check run for this annotation

Codecov / codecov/patch

etna/models/base.py#L596-L597

Added lines #L596 - L597 were not covered by tests
train_dataloader = DataLoader(
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
train_dataset, batch_size=self.train_batch_size, shuffle=True, **self.train_dataloader_params
train_dataset, batch_size=self.train_batch_size, **self.train_dataloader_params
)
val_dataloader: Optional[DataLoader] = DataLoader(
val_dataset, batch_size=self.test_batch_size, shuffle=False, **self.val_dataloader_params
)
else:
if "sampler" in self.train_dataloader_params:
self.train_dataloader_params["sampler"] = self.train_dataloader_params["sampler"](torch_dataset)

Check warning on line 606 in etna/models/base.py

View check run for this annotation

Codecov / codecov/patch

etna/models/base.py#L605-L606

Added lines #L605 - L606 were not covered by tests
train_dataloader = DataLoader(
torch_dataset, batch_size=self.train_batch_size, shuffle=True, **self.train_dataloader_params
torch_dataset, batch_size=self.train_batch_size, **self.train_dataloader_params
)
val_dataloader = None

Expand All @@ -627,9 +631,7 @@
:
Dictionary with predictions
"""
test_dataloader = DataLoader(
torch_dataset, batch_size=self.test_batch_size, shuffle=False, **self.test_dataloader_params
)
test_dataloader = DataLoader(torch_dataset, batch_size=self.test_batch_size, **self.test_dataloader_params)

Check warning on line 634 in etna/models/base.py

View check run for this annotation

Codecov / codecov/patch

etna/models/base.py#L634

Added line #L634 was not covered by tests

predictions_dict = dict()
self.net.eval()
Expand Down
1 change: 1 addition & 0 deletions etna/models/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

if SETTINGS.torch_required:
from etna.models.nn.deepar import DeepARModel
from etna.models.nn.deepar_new import DeepARModelNew

Check warning on line 5 in etna/models/nn/__init__.py

View check run for this annotation

Codecov / codecov/patch

etna/models/nn/__init__.py#L5

Added line #L5 was not covered by tests
from etna.models.nn.deepstate.deepstate import DeepStateModel
from etna.models.nn.mlp import MLPModel
from etna.models.nn.nbeats import NBeatsGenericModel
Expand Down
4 changes: 4 additions & 0 deletions etna/models/nn/deepar_new/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from etna import SETTINGS

Check warning on line 1 in etna/models/nn/deepar_new/__init__.py

View check run for this annotation

Codecov / codecov/patch

etna/models/nn/deepar_new/__init__.py#L1

Added line #L1 was not covered by tests

if SETTINGS.torch_required:
from etna.models.nn.deepar_new.deepar import DeepARModelNew

Check warning on line 4 in etna/models/nn/deepar_new/__init__.py

View check run for this annotation

Codecov / codecov/patch

etna/models/nn/deepar_new/__init__.py#L3-L4

Added lines #L3 - L4 were not covered by tests