Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions nowcasting_dataset/dataset/split/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,35 @@ def split_method(
test = datetimes[datetimes_period.isin(test_periods)]

return train, validation, test


def split_by_dates(
datetimes: pd.DatetimeIndex,
train_validation_datetime_split: pd.Timestamp,
validation_test_datetime_split: pd.Timestamp,
) -> (List[pd.Timestamp], List[pd.Timestamp], List[pd.Timestamp]):
"""
Split datetimes into train, validation and test by two specific datetime splits
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please mention that the train_validation_datetime_split must come before the validation_test_datetime_split


Note that the 'train_validation_datetime_split' should be less than the 'validation_test_datetime_split'

Args:
datetimes: list of datetimes
train_validation_datetime_split: the datetime which will split the train and validation datetimes.
For example if this is '2021-01-01' then the train datetimes will end by '2021-01-01' and the
validation datetimes will start at '2021-01-01'.
validation_test_datetime_split: the datetime which will split the validation and test datetimes

Returns: train, validation and test datetimes

"""
assert train_validation_datetime_split <= validation_test_datetime_split

train = datetimes[datetimes < train_validation_datetime_split]
validation = datetimes[
(datetimes >= train_validation_datetime_split)
& (datetimes < validation_test_datetime_split)
]
test = datetimes[datetimes >= validation_test_datetime_split]

return train, validation, test
67 changes: 64 additions & 3 deletions nowcasting_dataset/dataset/split/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import logging
from enum import Enum
from typing import List, Tuple, Union
from typing import List, Tuple, Union, Optional

import pandas as pd

from nowcasting_dataset.dataset.split.method import split_method
from nowcasting_dataset.dataset.split.method import split_method, split_by_dates
from nowcasting_dataset.dataset.split.model import (
TrainValidationTestSpecific,
default_train_test_validation_specific,
Expand All @@ -16,15 +16,18 @@


class SplitMethod(Enum):
""" Different split methods """
"""Different split methods"""

DATE = "date"
DAY = "day"
DAY_RANDOM = "day_random"
DAY_SPECIFIC = "day_specific"
WEEK = "week"
WEEK_RANDOM = "week_random"
YEAR_SPECIFIC = "year_specific"
SAME = "same"
DAY_RANDOM_TEST_YEAR = "day_random_test_year"
DAY_RANDOM_TEST_DATE = "day_random_test_date"


def split_data(
Expand All @@ -34,6 +37,7 @@ def split_data(
train_test_validation_specific: TrainValidationTestSpecific = (
default_train_test_validation_specific
),
train_validation_test_datetime_split: Optional[List[pd.Timestamp]] = None,
seed: int = 1234,
) -> (List[pd.Timestamp], List[pd.Timestamp], List[pd.Timestamp]):
"""
Expand All @@ -46,6 +50,7 @@ def split_data(
seed: random seed used to permutate the data for the 'random' method
train_test_validation_specific: pydandic class of 'train', 'validation' and 'test'.
These specify which data goes into which dataset.
train_validation_test_datetime_split: split train, validation based on specific dates.

Returns: train, validation and test dataset
"""
Expand Down Expand Up @@ -101,6 +106,62 @@ def split_data(
seed=seed,
train_test_validation_specific=train_test_validation_specific,
)

elif method == SplitMethod.DATE:
train_datetimes, validation_datetimes, test_datetimes = split_by_dates(
datetimes=datetimes,
train_validation_datetime_split=train_validation_test_datetime_split[0],
validation_test_datetime_split=train_validation_test_datetime_split[1],
)

elif method in [SplitMethod.DAY_RANDOM_TEST_YEAR, SplitMethod.DAY_RANDOM_TEST_DATE]:
if method == SplitMethod.DAY_RANDOM_TEST_YEAR:
# This method splits
# 1. test set to be in one year, using 'train_test_validation_specific'
# 2. train and validation by random day, using 'train_test_validation_split' on ratio how to split it
#
# This allows us to create a test set for 2021, and train and validation for random days not in 2021

# create test set
train_datetimes, validation_datetimes, test_datetimes = split_method(
datetimes=datetimes,
train_test_validation_split=train_test_validation_split,
method="specific",
freq="Y",
seed=seed,
train_test_validation_specific=train_test_validation_specific,
)
elif method == SplitMethod.DAY_RANDOM_TEST_DATE:
# This method splits
# 1. test set from one date onwards
# 2. train and validation by random day, using 'train_test_validation_split' on ratio how to split it
#
# This allows us to create a test set from a specfic date e.g. 2020-07-01, and train and validation
# for random days before that date

# create test set
train_datetimes, validation_datetimes, test_datetimes = split_by_dates(
datetimes=datetimes,
train_validation_datetime_split=train_validation_test_datetime_split[0],
validation_test_datetime_split=train_validation_test_datetime_split[1],
)

# join train and validation together, so they can then be split by random day.
train_and_validation_datetimes = train_datetimes.append(validation_datetimes)

# set split ratio to only be on train and validation
train_validation_split = list(train_test_validation_split)
train_validation_split[2] = 0
train_validation_split = tuple(train_validation_split)

# get train and validation methods
train_datetimes, validation_datetimes, _ = split_method(
datetimes=train_and_validation_datetimes,
train_test_validation_split=train_validation_split,
method="random",
seed=seed,
)

else:
raise ValueError(f"{method} for splitting day is not implemented")

Expand Down
122 changes: 110 additions & 12 deletions tests/dataset/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def test_split_day_random():
validation_df = pd.DatetimeIndex(validation)
test_df = pd.DatetimeIndex(test)

train_validation_overlap = [t for t in train_df if t in validation_df]
train_test_overlap = [t for t in train_df if t in test_df]
validation_test_overlap = [t for t in validation_df if t in test_df]
train_validation_overlap = train_df.join(validation_df, how="inner")
train_test_overlap = train_df.join(test_df, how="inner")
validation_test_overlap = validation_df.join(test_df, how="inner")

assert len(train_validation_overlap) == 0
assert len(train_test_overlap) == 0
Expand All @@ -116,9 +116,9 @@ def test_split_year():
validation_df = pd.DatetimeIndex(validation)
test_df = pd.DatetimeIndex(test)

train_validation_overlap = [t for t in train_df if t in validation_df]
train_test_overlap = [t for t in train_df if t in test_df]
validation_test_overlap = [t for t in validation_df if t in test_df]
train_validation_overlap = train_df.join(validation_df, how="inner")
train_test_overlap = train_df.join(test_df, how="inner")
validation_test_overlap = validation_df.join(test_df, how="inner")

assert len(train_validation_overlap) == 0
assert len(train_test_overlap) == 0
Expand Down Expand Up @@ -154,9 +154,9 @@ def test_split_day_specific():
validation_df = pd.DatetimeIndex(validation)
test_df = pd.DatetimeIndex(test)

train_validation_overlap = [t for t in train_df if t in validation_df]
train_test_overlap = [t for t in train_df if t in test_df]
validation_test_overlap = [t for t in validation_df if t in test_df]
train_validation_overlap = train_df.join(validation_df, how="inner")
train_test_overlap = train_df.join(test_df, how="inner")
validation_test_overlap = validation_df.join(test_df, how="inner")

assert len(train_validation_overlap) == 0
assert len(train_test_overlap) == 0
Expand Down Expand Up @@ -215,9 +215,9 @@ def test_split_week_random():
validation_df = pd.DatetimeIndex(validation)
test_df = pd.DatetimeIndex(test)

train_validation_overlap = [t for t in train_df if t in validation_df]
train_test_overlap = [t for t in train_df if t in test_df]
validation_test_overlap = [t for t in validation_df if t in test_df]
train_validation_overlap = train_df.join(validation_df, how="inner")
train_test_overlap = train_df.join(test_df, how="inner")
validation_test_overlap = validation_df.join(test_df, how="inner")

assert len(train_validation_overlap) == 0
assert len(train_test_overlap) == 0
Expand All @@ -228,3 +228,101 @@ def test_split_week_random():
week = train[0].week
for t in train[0:3]:
assert t.week == week


def test_split_random_day_test_specific():

datetimes = pd.date_range("2020-01-01", "2022-01-01", freq="1D")

train, validation, test = split_data(
datetimes=datetimes, method=SplitMethod.DAY_RANDOM_TEST_YEAR
)

assert len(train) == 274 # 75% of days of 2020
assert len(validation) == 92 # 25% of days of 2020
assert len(test) == 365 # % of days in 2021

train_df = pd.DatetimeIndex(train)
validation_df = pd.DatetimeIndex(validation)
test_df = pd.DatetimeIndex(test)

train_validation_overlap = train_df.join(validation_df, how="inner")
train_test_overlap = train_df.join(test_df, how="inner")
validation_test_overlap = validation_df.join(test_df, how="inner")

assert len(train_validation_overlap) == 0
assert len(train_test_overlap) == 0
assert len(validation_test_overlap) == 0

# check all train and validation are in 2020
assert (train_df.year == 2020).sum() == len(train_df)
assert (validation_df.year == 2020).sum() == len(validation_df)
assert (test.year == 2021).sum() == len(test)


def test_split_date():

datetimes = pd.date_range("2020-01-01", "2022-01-01", freq="1D")
train_validation_test_datetime_split = [pd.Timestamp("2020-07-01"), pd.Timestamp("2021-01-01")]

train, validation, test = split_data(
datetimes=datetimes,
method=SplitMethod.DATE,
train_validation_test_datetime_split=train_validation_test_datetime_split,
)

assert len(train) == 182 # first half of 2020
assert len(validation) == 184 # second half of 2020
assert len(test) == 366 # all of days in 2021

train_df = pd.DatetimeIndex(train)
validation_df = pd.DatetimeIndex(validation)
test_df = pd.DatetimeIndex(test)

train_validation_overlap = train_df.join(validation_df, how="inner")
train_test_overlap = train_df.join(test_df, how="inner")
validation_test_overlap = validation_df.join(test_df, how="inner")

assert len(train_validation_overlap) == 0
assert len(train_test_overlap) == 0
assert len(validation_test_overlap) == 0

# check datetimes are in the correct sections
assert (train_df < pd.Timestamp("2020-07-01")).sum() == len(train_df)
assert (
(validation_df >= pd.Timestamp("2020-07-01")) & (validation_df < pd.Timestamp("2021-01-01"))
).sum() == len(validation_df)
assert (test >= pd.Timestamp("2021-01-01")).sum() == len(test)


def test_split_day_random_test_date():

datetimes = pd.date_range("2020-01-01", "2022-01-01", freq="1D")
train_validation_test_datetime_split = [pd.Timestamp("2020-07-01"), pd.Timestamp("2021-07-01")]

train, validation, test = split_data(
datetimes=datetimes,
method=SplitMethod.DAY_RANDOM_TEST_DATE,
train_validation_test_datetime_split=train_validation_test_datetime_split,
)

assert len(train) == 410 # 75% of days of 2020 and half of 2021 (~365*1.5*0.75)
assert len(validation) == 137 # 25% of days of 2020 and half of 2021 (~365*1.5*0.25)
assert len(test) == 185 # and second half of 2021 of days in 2021

train_df = pd.DatetimeIndex(train)
validation_df = pd.DatetimeIndex(validation)
test_df = pd.DatetimeIndex(test)

train_validation_overlap = train_df.join(validation_df, how="inner")
train_test_overlap = train_df.join(test_df, how="inner")
validation_test_overlap = validation_df.join(test_df, how="inner")

assert len(train_validation_overlap) == 0
assert len(train_test_overlap) == 0
assert len(validation_test_overlap) == 0

# check datetimes are in the correct sections
assert (train_df < pd.Timestamp("2021-07-01")).sum() == len(train_df)
assert (validation_df < pd.Timestamp("2021-07-01")).sum() == len(validation_df)
assert (test >= pd.Timestamp("2021-07-01")).sum() == len(test)