Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add electricity to internal datasets #60

Merged
merged 15 commits into from
Aug 29, 2023
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased
### Added
-
- Add `electricity` to internal datasets ([#60](https://github.com/etna-team/etna/pull/60))
-
-
-
Expand Down
1 change: 1 addition & 0 deletions etna/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from etna.datasets.datasets_generation import generate_hierarchical_df
from etna.datasets.datasets_generation import generate_periodic_df
from etna.datasets.hierarchical_structure import HierarchicalStructure
from etna.datasets.internal_datasets import load_dataset
from etna.datasets.tsdataset import TSDataset
from etna.datasets.utils import duplicate_data
from etna.datasets.utils import set_columns_wide
134 changes: 134 additions & 0 deletions etna/datasets/internal_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import tempfile
import urllib.request
import zipfile
from pathlib import Path
from typing import Dict

import pandas as pd

from etna.datasets.tsdataset import TSDataset

_DOWNLOAD_PATH = Path.home() / ".etna"


def _check_dataset_local(dataset_path: Path) -> bool:
"""
Check dataset is local.

Parameters
----------
dataset_path:
path to dataset
"""
return dataset_path.exists()


def _download_dataset_zip(url: str, file_name: str, **kwargs) -> pd.DataFrame:
"""
Download zipped csv file.

Parameters
----------
url:
url of the dataset
file_name:
csv file name in zip archive

Returns
-------
result:
dataframe with data

Raises
------
Exception:
any error during downloading, saving and reading dataset from url
"""
try:
with tempfile.TemporaryDirectory() as td:
temp_path = Path(td) / "temp.zip"
urllib.request.urlretrieve(url, temp_path)
with zipfile.ZipFile(temp_path) as f:
f.extractall(td)
df = pd.read_csv(Path(td) / file_name, **kwargs)
except Exception as err:
raise Exception(f"Error during downloading and reading dataset. Reason: {repr(err)}")
return df


def load_dataset(name: str, download_path: Path = _DOWNLOAD_PATH, rebuild_dataset: bool = False) -> TSDataset:
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved
"""
Load internal dataset.

Parameters
----------
name:
Name of the dataset.
download_path:
The path for saving dataset locally.
rebuild_dataset:
Whether to rebuild the dataset from the original source. If ``rebuild_dataset=False`` and the dataset was saved
locally, then it would be loaded from disk. If ``rebuild_dataset=True``, then the dataset will be downloaded and
saved locally.

Returns
-------
result:
internal dataset

Raises
------
NotImplementedError:
if name not from available list of dataset names
"""
if name not in datasets_dict:
raise NotImplementedError(f"Dataset {name} is not available. You can use one from: {sorted(datasets_dict)}.")

dataset_dir = download_path / name
dataset_path = dataset_dir / f"{name}.csv"

get_dataset_function, freq = datasets_dict[name].values()
if not _check_dataset_local(dataset_path) or rebuild_dataset:
ts = get_dataset_function(dataset_dir)
else:
data = pd.read_csv(dataset_path)
ts = TSDataset(TSDataset.to_dataset(data), freq=freq)
return ts


def get_electricity_dataset(dataset_dir) -> TSDataset:
"""
Download save and load electricity dataset.

The electricity dataset is a 15 minutes time series of electricity consumption (in kW)
of 370 customers.

Parameters
----------
dataset_dir:
The path for saving dataset locally.

Returns
-------
result:
electricity dataset in TSDataset format

References
----------
.. [1] https://archive.ics.uci.edu/ml/datasets/ElectricityLoadDiagrams20112014
"""
url = "https://archive.ics.uci.edu/static/public/321/electricityloaddiagrams20112014.zip"
dataset_dir.mkdir(exist_ok=True, parents=True)
data = _download_dataset_zip(url=url, file_name="LD2011_2014.txt", sep=";", dtype=str)
data = data.rename({"Unnamed: 0": "timestamp"}, axis=1)
data["timestamp"] = pd.to_datetime(data["timestamp"])
data.loc[:, data.columns != "timestamp"] = (
data.loc[:, data.columns != "timestamp"].replace(",", ".", regex=True).astype(float)
)
data = data.melt("timestamp", var_name="segment", value_name="target")
data.to_csv(dataset_dir / "electricity.csv", index=False)
ts = TSDataset(TSDataset.to_dataset(data), freq="15T")
return ts


datasets_dict: Dict[str, Dict] = {"electricity": {"get_dataset_function": get_electricity_dataset, "freq": "15T"}}
55 changes: 55 additions & 0 deletions tests/test_datasets/test_internal_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import shutil

import numpy as np
import pandas as pd
import pytest

from etna.datasets import TSDataset
from etna.datasets import load_dataset
from etna.datasets.internal_datasets import _DOWNLOAD_PATH
from etna.datasets.internal_datasets import datasets_dict


def get_custom_dataset(dataset_dir):
np.random.seed(1)
dataset_dir.mkdir(exist_ok=True, parents=True)
df = pd.DataFrame(np.random.normal(0, 10, size=(30, 5)))
df["timestamp"] = pd.date_range("2021-01-01", periods=30, freq="D")
df = df.melt("timestamp", var_name="segment", value_name="target")
df.to_csv(dataset_dir / "custom_internal_dataset.csv", index=False)
ts = TSDataset(TSDataset.to_dataset(df), freq="D")
return ts


def update_dataset_dict(dataset_name, get_dataset_function, freq):
datasets_dict[dataset_name] = {"get_dataset_function": get_dataset_function, "freq": freq}


def test_not_present_dataset():
with pytest.raises(NotImplementedError, match="is not available."):
_ = load_dataset(name="not_implemented_dataset")


def test_load_custom_dataset():
update_dataset_dict(dataset_name="custom_internal_dataset", get_dataset_function=get_custom_dataset, freq="D")
dataset_path = _DOWNLOAD_PATH / "custom_internal_dataset"
if dataset_path.exists():
shutil.rmtree(dataset_path)
ts_init = load_dataset("custom_internal_dataset", rebuild_dataset=False)
ts_local = load_dataset("custom_internal_dataset", rebuild_dataset=False)
ts_rebuild = load_dataset("custom_internal_dataset", rebuild_dataset=True)
shutil.rmtree(dataset_path)
pd.util.testing.assert_frame_equal(ts_init.to_pandas(), ts_local.to_pandas())
pd.util.testing.assert_frame_equal(ts_init.to_pandas(), ts_rebuild.to_pandas())


@pytest.mark.skip(reason="Dataset is too large for testing in GitHub.")
@pytest.mark.parametrize(
"dataset_name, expected_shape, expected_min_date, expected_max_date",
[("electricity", (140256, 370), pd.to_datetime("2011-01-01 00:15:00"), pd.to_datetime("2015-01-01 00:00:00"))],
)
def test_dataset_statistics(dataset_name, expected_shape, expected_min_date, expected_max_date):
ts = load_dataset(dataset_name)
assert ts.df.shape == expected_shape
assert ts.index.min() == expected_min_date
assert ts.index.max() == expected_max_date