Skip to content

Commit

Permalink
Add ETT to internal datasets (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
ostreech1997 committed Nov 9, 2023
1 parent 2b56c37 commit 78a0de1
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add `traffic_2015` to internal datasets ([#100](https://github.com/etna-team/etna/pull/100))
- Add `tourism` to internal datasets ([#120](https://github.com/etna-team/etna/pull/120))
- Add `weather` to internal datasets ([#125](https://github.com/etna-team/etna/pull/125))
- Add `ETT` to internal datasets ([#134](https://github.com/etna-team/etna/pull/134))
-

### Changed
-
Expand Down
63 changes: 63 additions & 0 deletions etna/datasets/internal_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,49 @@ def get_weather_dataset(dataset_dir: Path) -> None:
df_test.to_csv(dataset_dir / f"weather_10T_test.csv.gz", index=True, compression="gzip")


def get_ett_dataset(dataset_dir: Path, dataset_type: str) -> None:
"""
Download and save Electricity Transformer Datasets (small version).
Dataset consists of four parts: ETTh1 (hourly freq), ETTh2 (hourly freq), ETTm1 (15 min freq), ETTm2 (15 min freq).
This dataset is a collection of two years of data from two regions of a province of China. There are one target
column ("oil temperature") and six different types of external power load features. We use the last 720 hours as
prediction horizon.
References
----------
.. [1] https://www.bgc-jena.mpg.de/wetter/
.. [2] https://arxiv.org/abs/2012.07436
"""
url = (
"https://raw.githubusercontent.com/zhouhaoyi/ETDataset/"
"1d16c8f4f943005d613b5bc962e9eeb06058cf07/ETT-small/{name}.csv"
)
dataset_dir.mkdir(exist_ok=True, parents=True)

data = pd.read_csv(url.format(name=dataset_type))
data = data.rename({"date": "timestamp"}, axis=1)
data["timestamp"] = pd.to_datetime(data["timestamp"])
data = data.melt("timestamp", var_name="segment", value_name="target")

df_full = TSDataset.to_dataset(data)
if dataset_type in ("ETTm1", "ETTm2"):
df_test = df_full.tail(720 * 4)
df_train = df_full.head(len(df_full) - 720 * 4)
elif dataset_type in ("ETTh1", "ETTh2"):
df_test = df_full.tail(720)
df_train = df_full.head(len(df_full) - 720)
else:
raise NotImplementedError(
f"ETT dataset does not have '{dataset_type}' dataset_type."
f"You can use one from: ('ETTm1', 'ETTm2', 'ETTh1', 'ETTh2')."
)

df_full.to_csv(dataset_dir / f"{dataset_type}_full.csv.gz", index=True, compression="gzip")
df_train.to_csv(dataset_dir / f"{dataset_type}_train.csv.gz", index=True, compression="gzip")
df_test.to_csv(dataset_dir / f"{dataset_type}_test.csv.gz", index=True, compression="gzip")


datasets_dict: Dict[str, Dict] = {
"electricity_15T": {
"get_dataset_function": get_electricity_dataset_15t,
Expand Down Expand Up @@ -709,4 +752,24 @@ def get_weather_dataset(dataset_dir: Path) -> None:
"parts": ("train", "test", "full"),
},
"weather_10T": {"get_dataset_function": get_weather_dataset, "freq": "10T", "parts": ("train", "test", "full")},
"ETTm1": {
"get_dataset_function": partial(get_ett_dataset, dataset_type="ETTm1"),
"freq": "15T",
"parts": ("train", "test", "full"),
},
"ETTm2": {
"get_dataset_function": partial(get_ett_dataset, dataset_type="ETTm2"),
"freq": "15T",
"parts": ("train", "test", "full"),
},
"ETTh1": {
"get_dataset_function": partial(get_ett_dataset, dataset_type="ETTh1"),
"freq": "H",
"parts": ("train", "test", "full"),
},
"ETTh2": {
"get_dataset_function": partial(get_ett_dataset, dataset_type="ETTh2"),
"freq": "H",
"parts": ("train", "test", "full"),
},
}
28 changes: 28 additions & 0 deletions tests/test_datasets/test_internal_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,34 @@ def test_not_present_part():
pd.to_datetime("2021-01-01 00:00:00"),
("train", "test"),
),
(
"ETTm1",
(66800 + 2880, 7),
pd.to_datetime("2016-07-01 00:00:00"),
pd.to_datetime("2018-06-26 19:45:00"),
("train", "test"),
),
(
"ETTm2",
(66800 + 2880, 7),
pd.to_datetime("2016-07-01 00:00:00"),
pd.to_datetime("2018-06-26 19:45:00"),
("train", "test"),
),
(
"ETTh1",
(16700 + 720, 7),
pd.to_datetime("2016-07-01 00:00:00"),
pd.to_datetime("2018-06-26 19:00:00"),
("train", "test"),
),
(
"ETTh2",
(16700 + 720, 7),
pd.to_datetime("2016-07-01 00:00:00"),
pd.to_datetime("2018-06-26 19:00:00"),
("train", "test"),
),
],
)
def test_dataset_statistics(dataset_name, expected_shape, expected_min_date, expected_max_date, dataset_parts):
Expand Down

0 comments on commit 78a0de1

Please sign in to comment.