Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepAR native implementation #114

Merged
merged 30 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
329 changes: 329 additions & 0 deletions diff.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,329 @@
diff --git a/etna/models/nn/rnn.py b/etna/models/nn/deepar_new.py
index 49bc37f..34551e6 100644
--- a/etna/models/nn/rnn.py
+++ b/etna/models/nn/deepar_new.py
@@ -1,7 +1,9 @@
+from collections import Counter
from typing import Any
from typing import Dict
from typing import Iterator
from typing import Optional
+from typing import Type

import numpy as np
import pandas as pd
@@ -17,20 +19,42 @@ from etna.models.base import DeepBaseNet
if SETTINGS.torch_required:
import torch
import torch.nn as nn
+ from torch.distributions import NegativeBinomial
+ from torch.distributions import Normal
+ from torch.utils.data.sampler import Sampler


-class RNNBatch(TypedDict):
- """Batch specification for RNN."""
+class DeepARSampler(Sampler):
+ """Select samples with probabilities 1 / number of appearance of given segment."""
+
+ def __init__(self, data):
+ self.data = data
+
+ def __iter__(self):
+ segments = [d["segment"] for d in self.data]
+ count_segments = Counter(segments)
+ p = torch.tensor([1 / count_segments[segment] for segment in segments])
+ num_samples = len(self.data) // len(set(segments)) # TODO is good?
+ idx = torch.multinomial(p, num_samples=num_samples)
+ return iter(idx)
+
+ def __len__(self):
+ return len(self.data)
+
+
+class DeepARBatchNew(TypedDict):
+ """Batch specification for DeepAR."""

encoder_real: "torch.Tensor"
decoder_real: "torch.Tensor"
encoder_target: "torch.Tensor"
decoder_target: "torch.Tensor"
segment: "torch.Tensor"
+ weight: "torch.Tensor"


-class RNNNet(DeepBaseNet):
- """RNN based Lightning module with LSTM cell."""
+class DeepARNetNew(DeepBaseNet):
+ """DeepAR based Lightning module with LSTM cell."""

def __init__(
self,
@@ -38,10 +62,10 @@ class RNNNet(DeepBaseNet):
num_layers: int,
hidden_size: int,
lr: float,
- loss: "torch.nn.Module",
+ loss: Type[DeepARSampler],
optimizer_params: Optional[dict],
) -> None:
- """Init RNN based on LSTM cell.
+ """Init DeepAR.

Parameters
----------
@@ -53,8 +77,6 @@ class RNNNet(DeepBaseNet):
size of the hidden state
lr:
learning rate
- loss:
- loss function
optimizer_params:
parameters for optimizer for Adam optimizer (api reference :py:class:`torch.optim.Adam`)
"""
@@ -63,15 +85,16 @@ class RNNNet(DeepBaseNet):
self.num_layers = num_layers
self.input_size = input_size
self.hidden_size = hidden_size
- self.loss = torch.nn.MSELoss() if loss is None else loss
self.rnn = nn.LSTM(
num_layers=self.num_layers, hidden_size=self.hidden_size, input_size=self.input_size, batch_first=True
)
- self.projection = nn.Linear(in_features=self.hidden_size, out_features=1)
+ self.loc = nn.Linear(in_features=self.hidden_size, out_features=1)
+ self.scale = nn.Linear(in_features=self.hidden_size, out_features=1)
self.lr = lr
self.optimizer_params = {} if optimizer_params is None else optimizer_params
+ self.loss = loss

- def forward(self, x: RNNBatch, *args, **kwargs): # type: ignore
+ def forward(self, x: DeepARBatchNew, *args, **kwargs): # type: ignore
"""Forward pass.

Parameters
@@ -88,22 +111,46 @@ class RNNNet(DeepBaseNet):
decoder_real = x["decoder_real"].float() # (batch_size, decoder_length, input_size)
decoder_target = x["decoder_target"].float() # (batch_size, decoder_length, 1)
decoder_length = decoder_real.shape[1]
- output, (h_n, c_n) = self.rnn(encoder_real)
+ weights = x["weight"]
+ _, (h_n, c_n) = self.rnn(encoder_real)
forecast = torch.zeros_like(decoder_target) # (batch_size, decoder_length, 1)

for i in range(decoder_length - 1):
output, (h_n, c_n) = self.rnn(decoder_real[:, i, None], (h_n, c_n))
- forecast_point = self.projection(output[:, -1]).flatten()
+ distibution_class = self._count_distr_params(output[:, -1], weights)
+ forecast_point = distibution_class.sample().flatten()
forecast[:, i, 0] = forecast_point
- decoder_real[:, i + 1, 0] = forecast_point
+ decoder_real[:, i + 1, 0] = forecast_point # TODO можно через if

# Last point is computed out of the loop because `decoder_real[:, i + 1, 0]` would cause index error
- output, (h_n, c_n) = self.rnn(decoder_real[:, decoder_length - 1, None], (h_n, c_n))
- forecast_point = self.projection(output[:, -1]).flatten()
+ output, (_, _) = self.rnn(decoder_real[:, decoder_length - 1, None], (h_n, c_n))
+ distibution_class = self._count_distr_params(output[:, -1], weights)
+ forecast_point = distibution_class.sample().flatten()
forecast[:, decoder_length - 1, 0] = forecast_point
return forecast

- def step(self, batch: RNNBatch, *args, **kwargs): # type: ignore
+ def _count_distr_params(self, output, weight):
+ if issubclass(self.loss, Normal):
+ loc = self.loc(output)
+ scale = nn.Softplus()(self.scale(output))
+ reshaped = [-1] + [1] * (output.dim() - 1)
+ weight = weight.reshape(reshaped).expand(loc.shape)
+ loc = loc * weight
+ scale = scale * weight.abs()
+ distibution_class = self.loss(loc=loc, scale=scale)
+ elif issubclass(self.loss, NegativeBinomial):
+ mean = nn.Softplus()(self.loc(output))
+ alpha = nn.Softplus()(self.scale(output))
+ reshaped = [-1] + [1] * (output.dim() - 1)
+ weight = weight.reshape(reshaped).expand(mean.shape)
+ total_count = 1 / (torch.sqrt(torch.tensor(weight)) * alpha)
+ probs = 1 / (torch.sqrt(torch.tensor(weight)) * alpha * mean + 1)
+ distibution_class = self.loss(total_count=total_count, probs=probs)
+ else:
+ raise NotImplementedError()
+ return distibution_class
+
+ def step(self, batch: DeepARBatchNew, *args, **kwargs): # type: ignore
"""Step for loss computation for training or validation.

Parameters
@@ -121,19 +168,21 @@ class RNNNet(DeepBaseNet):

encoder_target = batch["encoder_target"].float() # (batch_size, encoder_length-1, 1)
decoder_target = batch["decoder_target"].float() # (batch_size, decoder_length, 1)
-
- decoder_length = decoder_real.shape[1]
+ weights = batch["weight"]
+ target = torch.cat((encoder_target, decoder_target), dim=1)

output, (_, _) = self.rnn(torch.cat((encoder_real, decoder_real), dim=1))
-
- target_prediction = output[:, -decoder_length:]
- target_prediction = self.projection(target_prediction) # (batch_size, decoder_length, 1)
-
- loss = self.loss(target_prediction, decoder_target)
- return loss, decoder_target, target_prediction
+ distibution_class = self._count_distr_params(output, weights)
+ target_prediction = distibution_class.sample()
+ loss = distibution_class.log_prob(target).sum()
+ return -loss, target, target_prediction

def make_samples(self, df: pd.DataFrame, encoder_length: int, decoder_length: int) -> Iterator[dict]:
"""Make samples from segment DataFrame."""
+ segment = df["segment"].values[0]
+ values_target = df["target"].values
+ weight = df["target"].mean()
+ df["target"] = df["target"] / weight
values_real = (
df.select_dtypes(include=[np.number])
.assign(target_shifted=df["target"].shift(1))
@@ -141,8 +190,6 @@ class RNNNet(DeepBaseNet):
.pipe(lambda x: x[["target_shifted"] + [i for i in x.columns if i != "target_shifted"]])
.values
)
- values_target = df["target"].values
- segment = df["segment"].values[0]

def _make(
values_real: np.ndarray,
@@ -151,6 +198,7 @@ class RNNNet(DeepBaseNet):
start_idx: int,
encoder_length: int,
decoder_length: int,
+ weight: float,
) -> Optional[dict]:

sample: Dict[str, Any] = {
@@ -159,29 +207,48 @@ class RNNNet(DeepBaseNet):
"encoder_target": list(),
"decoder_target": list(),
"segment": None,
+ "weight": None,
}
total_length = len(values_target)
total_sample_length = encoder_length + decoder_length

if total_sample_length + start_idx > total_length:
return None
+ if start_idx < 0:
+ sample["decoder_real"] = values_real[start_idx + encoder_length : start_idx + total_sample_length]

- # Get shifted target and concatenate it with real values features
- sample["decoder_real"] = values_real[start_idx + encoder_length : start_idx + total_sample_length]
+ # Get shifted target and concatenate it with real values features
+ sample["encoder_real"] = values_real[: start_idx + encoder_length]
+ sample["encoder_real"] = sample["encoder_real"][1:]

- # Get shifted target and concatenate it with real values features
- sample["encoder_real"] = values_real[start_idx : start_idx + encoder_length]
- sample["encoder_real"] = sample["encoder_real"][1:]
+ target = values_target[: start_idx + total_sample_length].reshape(-1, 1)
+ sample["encoder_target"] = target[1 : start_idx + encoder_length]
+ sample["decoder_target"] = target[start_idx + encoder_length :]

- target = values_target[start_idx : start_idx + encoder_length + decoder_length].reshape(-1, 1)
- sample["encoder_target"] = target[1:encoder_length]
- sample["decoder_target"] = target[encoder_length:]
+ sample["encoder_real"] = np.pad(
+ sample["encoder_real"], ((-start_idx, 0), (0, 0)), "constant", constant_values=0
+ )
+ sample["encoder_target"] = np.pad(
+ sample["encoder_target"], ((-start_idx, 0), (0, 0)), "constant", constant_values=0
+ )

- sample["segment"] = segment
+ else:
+ # Get shifted target and concatenate it with real values features
+ sample["decoder_real"] = values_real[start_idx + encoder_length : start_idx + total_sample_length]
+
+ # Get shifted target and concatenate it with real values features
+ sample["encoder_real"] = values_real[start_idx : start_idx + encoder_length]
+ sample["encoder_real"] = sample["encoder_real"][1:]

+ target = values_target[start_idx : start_idx + total_sample_length].reshape(-1, 1)
+ sample["encoder_target"] = target[1:encoder_length]
+ sample["decoder_target"] = target[encoder_length:]
+
+ sample["segment"] = segment
+ sample["weight"] = weight
return sample

- start_idx = 0
+ start_idx = -(encoder_length - 2) # TODO is good?
while True:
batch = _make(
values_target=values_target,
@@ -190,6 +257,7 @@ class RNNNet(DeepBaseNet):
start_idx=start_idx,
encoder_length=encoder_length,
decoder_length=decoder_length,
+ weight=weight,
)
if batch is None:
break
@@ -202,8 +270,8 @@ class RNNNet(DeepBaseNet):
return optimizer


-class RNNModel(DeepBaseModel):
- """RNN based model on LSTM cell.
+class DeepARModelNew(DeepBaseModel):
+ """DeepAR based model on LSTM cell.

Note
----
@@ -219,7 +287,7 @@ class RNNModel(DeepBaseModel):
num_layers: int = 2,
hidden_size: int = 16,
lr: float = 1e-3,
- loss: Optional["torch.nn.Module"] = None,
+ loss: Optional[Any] = None,
train_batch_size: int = 16,
test_batch_size: int = 16,
optimizer_params: Optional[dict] = None,
@@ -245,8 +313,6 @@ class RNNModel(DeepBaseModel):
size of the hidden state
lr:
learning rate
- loss:
- loss function, MSELoss by default
train_batch_size:
batch size for training
test_batch_size:
@@ -273,22 +339,25 @@ class RNNModel(DeepBaseModel):
self.num_layers = num_layers
self.hidden_size = hidden_size
self.lr = lr
- self.loss = loss
self.optimizer_params = optimizer_params
+ self.loss = loss
+ self.train_dataloader_params = (
+ train_dataloader_params if train_dataloader_params is not None else {"sampler": DeepARSampler}
+ )
super().__init__(
- net=RNNNet(
+ net=DeepARNetNew(
input_size=input_size,
num_layers=num_layers,
hidden_size=hidden_size,
lr=lr,
- loss=nn.MSELoss() if loss is None else loss,
optimizer_params=optimizer_params,
+ loss=Normal if loss is None else loss,
),
decoder_length=decoder_length,
encoder_length=encoder_length,
train_batch_size=train_batch_size,
test_batch_size=test_batch_size,
- train_dataloader_params=train_dataloader_params,
+ train_dataloader_params=self.train_dataloader_params, # TODO fix
test_dataloader_params=test_dataloader_params,
val_dataloader_params=val_dataloader_params,
trainer_params=trainer_params,
13 changes: 11 additions & 2 deletions etna/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,15 +593,24 @@
],
generator=self.split_params.get("generator"),
)
if "sampler" in self.train_dataloader_params:
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
self.train_dataloader_params["sampler"] = self.train_dataloader_params["sampler"](train_dataset)

Check warning on line 597 in etna/models/base.py

View check run for this annotation

Codecov / codecov/patch

etna/models/base.py#L597

Added line #L597 was not covered by tests
else:
self.train_dataloader_params["shuffle"] = True
train_dataloader = DataLoader(
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
train_dataset, batch_size=self.train_batch_size, shuffle=True, **self.train_dataloader_params
train_dataset, batch_size=self.train_batch_size, **self.train_dataloader_params
)
val_dataloader: Optional[DataLoader] = DataLoader(
val_dataset, batch_size=self.test_batch_size, shuffle=False, **self.val_dataloader_params
)
else:
if "sampler" in self.train_dataloader_params:
self.train_dataloader_params["sampler"] = self.train_dataloader_params["sampler"](torch_dataset)
else:
self.train_dataloader_params["shuffle"] = True

train_dataloader = DataLoader(
torch_dataset, batch_size=self.train_batch_size, shuffle=True, **self.train_dataloader_params
torch_dataset, batch_size=self.train_batch_size, **self.train_dataloader_params
)
val_dataloader = None

Expand Down
1 change: 1 addition & 0 deletions etna/models/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

if SETTINGS.torch_required:
from etna.models.nn.deepar import DeepARModel
from etna.models.nn.deepar_new import DeepARModelNew
from etna.models.nn.deepstate.deepstate import DeepStateModel
from etna.models.nn.mlp import MLPModel
from etna.models.nn.nbeats import NBeatsGenericModel
Expand Down