Skip to content

Commit

Permalink
Add size method to TSDataset class (#238)
Browse files Browse the repository at this point in the history
* size_functional

* size_functional

* minor changes

* update changelog and fix codestyle

* final fixes

* fixed changelog
  • Loading branch information
yellowssnake committed Feb 12, 2024
1 parent e978ad0 commit 491de74
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add dataset integrity check using hash for internal datasets ([#151](https://github.com/etna-team/etna/pull/151))
- Create page about internal datasets in documentation ([#175](https://github.com/etna-team/etna/pull/175))
- Add usage example of internal datasets in `101-get_started.ipynb` and `305-classification.ipynb` tutorials ([#202](https://github.com/etna-team/etna/pull/202))
-
- Add size method to `TSDataset` class ([#238](https://github.com/etna-team/etna/pull/238))

### Changed
- Add `relevance_aggregation_mode` and `redundancy_aggregation_mode` into `MRMRFeatureSelectionTransform.params_to_tune` ([#212](https://github.com/etna-team/etna/pull/212))
Expand Down
19 changes: 19 additions & 0 deletions etna/datasets/tsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1651,3 +1651,22 @@ def to_torch_dataset(
ts_samples = [samples for df_segment in ts_segments for samples in make_samples(df_segment)]

return _TorchDataset(ts_samples=ts_samples)

def size(self) -> Tuple[int, int, Optional[int]]:
"""Return size of TSDataset.
The order of sizes is (number of time series, number of segments,
and number of features (if their amounts are equal in each segment; otherwise, returns None)).
Returns
-------
:
Tuple of TSDataset sizes
"""
current_number_of_features = 0
for segment in self.segments:
cur_seg_features = self.df[segment].columns.get_level_values("feature").unique()
if current_number_of_features != 0 and current_number_of_features != len(cur_seg_features):
return len(self.index), len(self.segments), None
current_number_of_features = len(cur_seg_features)
return len(self.index), len(self.segments), current_number_of_features
29 changes: 29 additions & 0 deletions tests/test_datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,35 @@ def test_dataset_segment_conversion_during_init(df_segments_int):
assert np.all(ts.columns.get_level_values("segment") == ["1", "2"])


def test_size_with_diff_number_of_features():
df_temp = generate_ar_df(start_time="2023-01-01", periods=30, n_segments=2, freq="D")
df_exog_temp = generate_ar_df(start_time="2023-01-01", periods=30, n_segments=1, freq="D")
df_exog_temp = df_exog_temp.rename({"target": "target_exog"}, axis=1)
ts_temp = TSDataset(df=TSDataset.to_dataset(df_temp), df_exog=TSDataset.to_dataset(df_exog_temp), freq="D")
assert ts_temp.size()[0] == len(df_exog_temp)
assert ts_temp.size()[1] == 2
assert ts_temp.size()[2] is None


def test_size_target_only():
df_temp = generate_ar_df(start_time="2023-01-01", periods=40, n_segments=3, freq="D")
ts_temp = TSDataset(df=TSDataset.to_dataset(df_temp), freq="D")
assert ts_temp.size()[0] == len(df_temp) / 3
assert ts_temp.size()[1] == 3
assert ts_temp.size()[2] == 1


def simple_test_size_():
df_temp = generate_ar_df(start_time="2023-01-01", periods=30, n_segments=2, freq="D")
df_exog_temp = generate_ar_df(start_time="2023-01-01", periods=30, n_segments=2, freq="D")
df_exog_temp = df_exog_temp.rename({"target": "target_exog"}, axis=1)
df_exog_temp["other_feature"] = 1
ts_temp = TSDataset(df=TSDataset.to_dataset(df_temp), df_exog=TSDataset.to_dataset(df_exog_temp), freq="D")
assert ts_temp.size()[0] == len(df_exog_temp) / 2
assert ts_temp.size()[1] == 2
assert ts_temp.size()[2] == 3


@pytest.mark.xfail
def test_make_future_raise_error_on_diff_endings(ts_diff_endings):
with pytest.raises(ValueError, match="All segments should end at the same timestamp"):
Expand Down

0 comments on commit 491de74

Please sign in to comment.