In [16]:
import pandas as pd
import numpy as np
import datetime
from dateutil.relativedelta import relativedelta

status_df = pd.read_csv('./data/status.csv')

status_df['date'] = (
                    status_df['year'].astype(str) +
                    '/' + status_df['month'].astype(str).str.zfill(2).astype(str) +
                    '/' + status_df['day'].astype(str).str.zfill(2).astype(str)
                    )
status_df['date'] = pd.to_datetime(status_df['date'])
status_df.head()

Unnamed: 0,id,year,month,day,hour,station_id,bikes_available,predict,date
0,0,2013,9,1,0,0,11.0,0,2013-09-01
1,1,2013,9,1,1,0,11.0,0,2013-09-01
2,2,2013,9,1,2,0,11.0,0,2013-09-01
3,3,2013,9,1,3,0,11.0,0,2013-09-01
4,4,2013,9,1,4,0,11.0,0,2013-09-01


In [92]:
class TimeSeriesSplitGenerator:
    def __init__(self, n_split = 12, test_day_after = "2014-09-01", slide = False):
        self.test_day_after = pd.to_datetime(test_day_after)
        self.n_split = n_split
        self.test_month_period = 12
        self.month = relativedelta(months = 1)
        # self.month = relativedelta(months = self.test_month_period // n_split)

        self.slide = slide

    def split(self, X):
        for m in range(self.test_month_period):
            test_month = self.test_day_after + relativedelta(months=m)
            test_index = (test_month <= X.date) & (X.date < test_month + self.month)
            valid_index = (test_month - self.month <= X.date) & (X.date < test_month)
            train_index = (X.date < test_month - self.month)
            if self.slide:
                train_index = train_index &(
                    test_month - self.month - relativedelta(months =12) <= X.date
                )
            yield train_index, valid_index, test_index
            #yield X[train_index], X[valid_index], X[test_index]


In [93]:
for train_index, valid_index, test_index in CustomTimeSeriesSplitter(slide = True).split(status_df):
    train = status_df[train_index]
    test = status_df[test_index]
    valid = status_df[valid_index]
    print("-------------------------------------------")
    print("train:", train.date.min(), train.date.max())
    print("valid:", valid.date.min(), valid.date.max())
    print("test: ", test.date.min(), test.date.max())

-------------------------------------------
train: 2013-09-01 00:00:00 2014-07-31 00:00:00
valid: 2014-08-01 00:00:00 2014-08-31 00:00:00
test:  2014-09-01 00:00:00 2014-09-30 00:00:00
-------------------------------------------
train: 2013-09-01 00:00:00 2014-08-31 00:00:00
valid: 2014-09-01 00:00:00 2014-09-30 00:00:00
test:  2014-10-01 00:00:00 2014-10-31 00:00:00
-------------------------------------------
train: 2013-10-01 00:00:00 2014-09-30 00:00:00
valid: 2014-10-01 00:00:00 2014-10-31 00:00:00
test:  2014-11-01 00:00:00 2014-11-30 00:00:00
-------------------------------------------
train: 2013-11-01 00:00:00 2014-10-31 00:00:00
valid: 2014-11-01 00:00:00 2014-11-30 00:00:00
test:  2014-12-01 00:00:00 2014-12-31 00:00:00
-------------------------------------------
train: 2013-12-01 00:00:00 2014-11-30 00:00:00
valid: 2014-12-01 00:00:00 2014-12-31 00:00:00
test:  2015-01-01 00:00:00 2015-01-31 00:00:00
-------------------------------------------
train: 2014-01-01 00:00:00 2014