In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from datetime import timedelta

In [3]:
from sklego.model_selection import TimeGapSplit

In [4]:
target_set_pd = pd.DataFrame(np.random.randint(0, 30, size=(30, 4)), columns=list('ABCy'))
target_set_pd['date'] = pd.date_range(start='1/1/2018', end='1/30/2018')[::-1]

In [5]:
target_set_pd = target_set_pd.sort_values('date')

In [6]:
target_set_pd.head()

Unnamed: 0,A,B,C,y,date
29,10,23,7,15,2018-01-01
28,10,15,25,27,2018-01-02
27,22,28,10,0,2018-01-03
26,18,11,17,6,2018-01-04
25,3,5,11,2,2018-01-05


In [7]:
train = target_set_pd.head(25)

In [8]:
X_train = train[['A', 'B', 'C']]
y_train = train['y']

In [9]:
cv = TimeGapSplit(df=target_set_pd, date_col='date',
                           train_duration=timedelta(days=5),
                           valid_duration=timedelta(days=3),
                           gap_duration=timedelta(days=1))
def printSplitInfo(X, indicies, org_pd):
    mindate = org_pd.loc[X.iloc[indicies].index]['date'].min()
    maxdate = org_pd.loc[X.iloc[indicies].index]['date'].max()
    dates = org_pd[(org_pd['date'] >= mindate) & (org_pd['date'] <= maxdate)]['date']
    print("{} unique days, {}, nbr_samples: {}".format(
        len(dates.unique()),
        pd.to_datetime(maxdate, format='%Y%m%d') - pd.to_datetime(mindate, format='%Y%m%d'),
        len(indicies)))
    print(mindate)
    print(maxdate)

print("Nbr folds: {}\n".format(len(list(cv.split(X_train, y_train)))))
for i in cv.split(X_train, y_train):
    print("Train:")
    printSplitInfo(X_train, i[0], target_set_pd)
    print("Valid:")
    printSplitInfo(X_train, i[1], target_set_pd)
    print()

Nbr folds: 6

Train:
4 unique days, 3 days 00:00:00, nbr_samples: 4
2018-01-01 00:00:00
2018-01-04 00:00:00
Valid:
3 unique days, 2 days 00:00:00, nbr_samples: 3
2018-01-06 00:00:00
2018-01-08 00:00:00

Train:
4 unique days, 3 days 00:00:00, nbr_samples: 4
2018-01-04 00:00:00
2018-01-07 00:00:00
Valid:
3 unique days, 2 days 00:00:00, nbr_samples: 3
2018-01-09 00:00:00
2018-01-11 00:00:00

Train:
4 unique days, 3 days 00:00:00, nbr_samples: 4
2018-01-07 00:00:00
2018-01-10 00:00:00
Valid:
3 unique days, 2 days 00:00:00, nbr_samples: 3
2018-01-12 00:00:00
2018-01-14 00:00:00

Train:
4 unique days, 3 days 00:00:00, nbr_samples: 4
2018-01-10 00:00:00
2018-01-13 00:00:00
Valid:
3 unique days, 2 days 00:00:00, nbr_samples: 3
2018-01-15 00:00:00
2018-01-17 00:00:00

Train:
4 unique days, 3 days 00:00:00, nbr_samples: 4
2018-01-13 00:00:00
2018-01-16 00:00:00
Valid:
3 unique days, 2 days 00:00:00, nbr_samples: 3
2018-01-18 00:00:00
2018-01-20 00:00:00

Train:
4 unique days, 3 days 00:00:00, nb

In [10]:
import plotly.offline as py
import plotly.figure_factory as ff
py.init_notebook_mode(connected=True)


In [11]:
plot_list = []
for i, indices in enumerate(cv.split(X_train, y_train)):
    train_mindate = target_set_pd.loc[X_train.iloc[indices[0]].index]['date'].min()
    train_maxdate = target_set_pd.loc[X_train.iloc[indices[0]].index]['date'].max() + timedelta(hours=23, minutes=59)
    valid_mindate = target_set_pd.loc[X_train.iloc[indices[1]].index]['date'].min()
    valid_maxdate = target_set_pd.loc[X_train.iloc[indices[1]].index]['date'].max() + timedelta(hours=23, minutes=59)
    
    plot_list.append(dict(Task="CV {}".format(i), y=5-i, Resource='Train',
         Start=train_mindate,
         Finish=train_maxdate))
    plot_list.append(dict(Task="CV {}".format(i), y=5-i, Resource='Gap',
         Start=train_maxdate,
         Finish=valid_mindate))
    plot_list.append(dict(Task="CV {}".format(i), y=5-i, Resource='Valid',
         Start=valid_mindate,
         Finish=valid_maxdate))
    
colors = {'Train': '#2579B2',
          'Gap': '#A9A9A9',
          'Valid': '#F27921',}

fig = ff.create_gantt(plot_list, group_tasks=True, colors=colors, index_col='Resource')
fig['layout']['annotations'] = [dict(x=x['Start'], y=x['y'],
                                     text=x['Resource'], showarrow=False,
                                     font=dict(color='white'), xanchor='left') for x in plot_list]

py.iplot(fig, filename='gantt-group-tasks-together')