Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add electricity to internal datasets #60

Merged
merged 15 commits into from
Aug 29, 2023
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased
### Added
-
- Add `electricity` to internal datasets ([#60](https://github.com/etna-team/etna/pull/60))
-
-
-
Expand Down
135 changes: 135 additions & 0 deletions etna/datasets/internal_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import tempfile
import urllib.request
import warnings
import zipfile
from pathlib import Path
from typing import Dict

import pandas as pd

from etna.datasets import TSDataset

_DOWNLOAD_PATH = Path.home() / ".etna"


def _check_dataset_local(dataset_path: Path) -> bool:
"""
Check dataset is local.

Parameters
----------
dataset_path:
path to dataset
"""
return dataset_path.exists()


def _download_dataset_zip(url: str, file_name: str, **kwargs) -> pd.DataFrame:
"""
Download zipped csv file.

Parameters
----------
url:
url of the dataset
file_name:
csv file name in zip archive

Returns
-------
result:
dataframe with data

Raises
------
Exception:
any error during downloading, saving and reading dataset from url
"""
try:
with tempfile.TemporaryDirectory() as td:
temp_path = Path(td) / "temp.zip"
urllib.request.urlretrieve(url, temp_path)
with zipfile.ZipFile(temp_path) as f:
f.extractall(td)
df = pd.read_csv(Path(td) / file_name, **kwargs)
except Exception as err:
raise Exception(f"Error during downloading and reading dataset. Reason: {repr(err)}")
return df


def load_dataset(name: str, download_path: Path = _DOWNLOAD_PATH, rebuild_dataset: bool = False) -> TSDataset:
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved
"""
Load internal dataset.

Parameters
----------
name:
Name of the dataset.
download_path:
The path for saving dataset locally.
rebuild_dataset:
Whether to rebuild the dataset from the original source. If rebuild_dataset = False and the dataset was saved
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved
locally, then it would be loaded from disk.
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
result:
internal dataset

Raises
------
NotImplementedError:
if name not from available list of dataset names
"""
if name not in datasets_dict:
raise NotImplementedError(f"Dataset {name} is not available. You can use one from: {sorted(datasets_dict)}.")

dataset_dir = download_path / name
dataset_path = dataset_dir / f"{name}.csv"

get_dataset_function, freq = datasets_dict[name].values()
if not _check_dataset_local(dataset_path) or rebuild_dataset:
ts = get_dataset_function(dataset_dir)
else:
data = pd.read_csv(dataset_path)
ts = TSDataset(TSDataset.to_dataset(data), freq=freq)
return ts


def get_electricity_dataset(dataset_dir) -> TSDataset:
"""
Download save and load electricity dataset.
The electricity dataset is a 15 minutes time series of electricity consumption (in kW)
of 370 customers.

Parameters
----------
dataset_dir:
The path for saving dataset locally.

Returns
-------
result:
electricity dataset in TSDataset format

References
----------
.. [1] https://archive.ics.uci.edu/ml/datasets/ElectricityLoadDiagrams20112014
"""
url = "https://archive.ics.uci.edu/static/public/321/electricityloaddiagrams20112014.zip"
dataset_dir.mkdir(exist_ok=True, parents=True)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
data = _download_dataset_zip(url=url, file_name="LD2011_2014.txt", sep=";")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What kind of warnings could you meet there?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we load LD2011_2014.txt via pd.read_csv, it returns huge warning message like this:

DtypeWarning: Columns (1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18, ... ,368,369,370) have mixed types. Specify dtype option on import or set low_memory=False.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we specify dtypes somehow to avoid this?

data = data.rename({"Unnamed: 0": "timestamp"}, axis=1)
data["timestamp"] = pd.to_datetime(data["timestamp"])
data.loc[:, data.columns != "timestamp"] = (
data.loc[:, data.columns != "timestamp"].replace(",", ".", regex=True).astype(float)
)
data = data.melt("timestamp", var_name="segment", value_name="target")
data.to_csv(dataset_dir / "electricity.csv", index=False)
ts = TSDataset(TSDataset.to_dataset(data), freq="15T")
return ts


datasets_dict: Dict[str, Dict] = {"electricity": {"get_dataset_function": get_electricity_dataset, "freq": "15T"}}
54 changes: 54 additions & 0 deletions tests/test_datasets/test_internal_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import shutil

import numpy as np
import pandas as pd
import pytest

from etna.datasets import TSDataset
from etna.datasets.internal_datasets import _DOWNLOAD_PATH
from etna.datasets.internal_datasets import datasets_dict
from etna.datasets.internal_datasets import load_dataset


def get_custom_dataset(dataset_dir):
np.random.seed(1)
dataset_dir.mkdir(exist_ok=True, parents=True)
df = pd.DataFrame(np.random.normal(0, 10, size=(30, 5)))
df["timestamp"] = pd.date_range("2021-01-01", periods=30, freq="D")
df = df.melt("timestamp", var_name="segment", value_name="target")
df.to_csv(dataset_dir / "custom_internal_dataset.csv", index=False)
ts = TSDataset(TSDataset.to_dataset(df), freq="D")
return ts


def update_dataset_dict(dataset_name, get_dataset_function, freq):
datasets_dict[dataset_name] = {"get_dataset_function": get_dataset_function, "freq": freq}


def test_not_present_dataset():
with pytest.raises(NotImplementedError, match="is not available."):
_ = load_dataset(name="not_implemented_dataset")


def test_load_custom_dataset():
update_dataset_dict(dataset_name="custom_internal_dataset", get_dataset_function=get_custom_dataset, freq="D")
dataset_path = _DOWNLOAD_PATH / "custom_internal_dataset"
if dataset_path.exists():
shutil.rmtree(dataset_path)
ts_init = load_dataset("custom_internal_dataset", rebuild_dataset=False)
ts_local = load_dataset("custom_internal_dataset", rebuild_dataset=False)
ts_rebuild = load_dataset("custom_internal_dataset", rebuild_dataset=True)
shutil.rmtree(dataset_path)
pd.util.testing.assert_frame_equal(ts_init.to_pandas(), ts_local.to_pandas())
pd.util.testing.assert_frame_equal(ts_init.to_pandas(), ts_rebuild.to_pandas())


@pytest.mark.parametrize(
"dataset_name, expected_shape, expected_min_date, expected_max_date",
[("electricity", (140256, 370), pd.to_datetime("2011-01-01 00:15:00"), pd.to_datetime("2015-01-01 00:00:00"))],
)
def test_dataset_statistics(dataset_name, expected_shape, expected_min_date, expected_max_date):
ts = load_dataset(dataset_name)
assert ts.df.shape == expected_shape
assert ts.index.min() == expected_min_date
assert ts.index.max() == expected_max_date