Skip to content

Commit

Permalink
Fix PatchTSModel fails when using additional features (#376)
Browse files Browse the repository at this point in the history
* fit bug

* chore: update changelog

---------

Co-authored-by: Egor Baturin <egoriyaa@github.com>
  • Loading branch information
egoriyaa and Egor Baturin committed May 31, 2024
1 parent d08e6fa commit b1f6877
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 37 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix FordA download url in classification notebook ([#309](https://github.com/etna-team/etna/pull/309))
- Allow `seaborn` dependency to have higher version ([#319](https://github.com/etna-team/etna/pull/319))
- Fix `MRMRFeatureSelectionTransform` to correctly handle less-is-better `relevance_table` ([#308](https://github.com/etna-team/etna/issues/308))
-
- Fix `PatchTSModel` fails when using additional features ([#376](https://github.com/etna-team/etna/issues/376))
-
-
- Fix `101-get-started` notebook to be rendered correctly ([#340](https://github.com/etna-team/etna/pull/340))
Expand Down
48 changes: 17 additions & 31 deletions etna/models/nn/patchts.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
class PatchTSBatch(TypedDict):
"""Batch specification for PatchTS."""

encoder_real: "torch.Tensor"
decoder_real: "torch.Tensor"
encoder_target: "torch.Tensor"
decoder_target: "torch.Tensor"
segment: "torch.Tensor"
Expand All @@ -46,10 +44,10 @@ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
self.register_buffer("pe", pe)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""x: Tensor, shape [batch_size, input_size, patch_num, embedding_dim]."""
"""x: Tensor, shape [batch_size, 1, patch_num, embedding_dim]."""
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
# x.shape == (batch_size * input_size, patch_num, embedding_dim)
x = x.permute(1, 0, 2) # (patch_num, batch_size * input_size, embedding_dim)
# x.shape == (batch_size * 1, patch_num, embedding_dim)
x = x.permute(1, 0, 2) # (patch_num, batch_size * 1, embedding_dim)
x = x + self.pe[: x.size(0)] # type: ignore
return self.dropout(x)

Expand Down Expand Up @@ -135,11 +133,10 @@ def forward(self, x: PatchTSBatch, *args, **kwargs): # type: ignore
:
forecast with shape (batch_size, decoder_length, 1)
"""
encoder_real = x["encoder_real"].float() # (batch_size, encoder_length, input_size)
decoder_real = x["decoder_real"].float() # (batch_size, decoder_length, input_size)
decoder_length = decoder_real.shape[1]
encoder_target = x["encoder_target"].float() # (batch_size, encoder_length, 1)
decoder_length = x["decoder_target"].shape[1]
outputs = []
current_input = encoder_real
current_input = encoder_target
for _ in range(decoder_length):
pred = self._get_prediction(current_input)
outputs.append(pred)
Expand All @@ -151,14 +148,12 @@ def forward(self, x: PatchTSBatch, *args, **kwargs): # type: ignore
return forecast

def _get_prediction(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 1) # (batch_size, input_size, encoder_length)
x = x.permute(0, 2, 1) # (batch_size, 1, encoder_length)
# do patching
x = x.unfold(
dimension=-1, size=self.patch_len, step=self.stride
) # (batch_size, input_size, patch_num, patch_len)
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) # (batch_size, 1, patch_num, patch_len)

y = self.model(x)
y = y.permute(1, 0, 2) # (batch_size, hidden_size, patch_num)
y = self.model(x) # (patch_num, batch_size, hidden_size)
y = y.permute(1, 0, 2) # (batch_size, patch_num, hidden_size)

return self.projection(y) # (batch_size, 1)

Expand All @@ -175,19 +170,16 @@ def step(self, batch: PatchTSBatch, *args, **kwargs): # type: ignore
:
loss, true_target, prediction_target
"""
encoder_real = batch["encoder_real"].float() # (batch_size, encoder_length, input_size)
decoder_real = batch["decoder_real"].float() # (batch_size, decoder_length, input_size)

encoder_target = batch["encoder_target"].float() # (batch_size, encoder_length, 1)
decoder_target = batch["decoder_target"].float() # (batch_size, decoder_length, 1)

decoder_length = decoder_real.shape[1]
decoder_length = decoder_target.shape[1]

outputs = []
x = encoder_real
x = encoder_target
for i in range(decoder_length):
pred = self._get_prediction(x)
outputs.append(pred)
x = torch.cat((x[:, 1:, :], torch.unsqueeze(decoder_real[:, i, :], dim=1)), dim=1)
x = torch.cat((x[:, 1:, :], torch.unsqueeze(decoder_target[:, i, :], dim=1)), dim=1)

target_prediction = torch.cat(outputs, dim=1)
target_prediction = torch.unsqueeze(target_prediction, dim=2)
Expand All @@ -197,12 +189,10 @@ def step(self, batch: PatchTSBatch, *args, **kwargs): # type: ignore

def make_samples(self, df: pd.DataFrame, encoder_length: int, decoder_length: int) -> Iterator[dict]:
"""Make samples from segment DataFrame."""
values_real = df.drop(["segment", "timestamp"], axis=1).select_dtypes(include=[np.number]).values
values_target = df["target"].values
segment = df["segment"].values[0]

def _make(
values_real: np.ndarray,
values_target: np.ndarray,
segment: str,
start_idx: int,
Expand All @@ -211,8 +201,6 @@ def _make(
) -> Optional[dict]:

sample: Dict[str, Any] = {
"encoder_real": list(),
"decoder_real": list(),
"encoder_target": list(),
"decoder_target": list(),
"segment": None,
Expand All @@ -223,9 +211,6 @@ def _make(
if total_sample_length + start_idx > total_length:
return None

sample["decoder_real"] = values_real[start_idx + encoder_length : start_idx + total_sample_length]
sample["encoder_real"] = values_real[start_idx : start_idx + encoder_length]

target = values_target[start_idx : start_idx + encoder_length + decoder_length].reshape(-1, 1)
sample["encoder_target"] = target[:encoder_length]
sample["decoder_target"] = target[encoder_length:]
Expand All @@ -238,7 +223,6 @@ def _make(
while True:
batch = _make(
values_target=values_target,
values_real=values_real,
segment=segment,
start_idx=start_idx,
encoder_length=encoder_length,
Expand All @@ -256,7 +240,9 @@ def configure_optimizers(self) -> "torch.optim.Optimizer":


class PatchTSModel(DeepBaseModel):
"""PatchTS model using PyTorch layers.
"""PatchTS model using PyTorch layers. For more details read the `paper <https://arxiv.org/abs/2211.14730>`_.
Model uses only `target` column, other columns will be ignored.
Note
----
Expand Down
6 changes: 1 addition & 5 deletions tests/test_models/test_nn/test_patchts.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,11 @@ def test_patchts_make_samples(df_name, request):
num_samples_check = 2
for i in range(num_samples_check):
expected_sample = {
"encoder_real": df[["target", "regressor_float", "regressor_int"]].iloc[i : encoder_length + i].values,
"decoder_real": df[["target", "regressor_float", "regressor_int"]]
.iloc[encoder_length + i : encoder_length + decoder_length + i]
.values,
"encoder_target": df[["target"]].iloc[i : encoder_length + i].values,
"decoder_target": df[["target"]].iloc[encoder_length + i : encoder_length + decoder_length + i].values,
}

assert ts_samples[i].keys() == {"encoder_real", "decoder_real", "encoder_target", "decoder_target", "segment"}
assert ts_samples[i].keys() == {"encoder_target", "decoder_target", "segment"}
assert ts_samples[i]["segment"] == "segment_1"
for key in expected_sample:
np.testing.assert_equal(ts_samples[i][key], expected_sample[key])
Expand Down

0 comments on commit b1f6877

Please sign in to comment.