# Train/Validation/Test Splits

This notebooks is used to generate the train, validation, and test splits from the full dataset.

The key here is to ensure that all intervals generated for each dataset are distinct and contain no overlap with intervals in the other datasets. This ensures we are not at risk for data leakage when training and testing our model.

In [2]:
%matplotlib widget
from collections import defaultdict, Counter
import datetime as dt
import glob 

import matplotlib.dates as mdates
import matplotlib.pyplot as plt
plt.style.use('ggplot')
import numpy as np
np.random.seed(12345)
import pandas as pd
from tqdm.notebook import tqdm

In [3]:
df = pd.read_csv('../utils/sta_dataset_labels.txt', header=0, parse_dates=['start_time', 'stop_time'])

In [4]:
df.label.value_counts() / len(df)

0    0.868661
1    0.131339
Name: label, dtype: float64

In [5]:
icmes = df[df.label.eq(1)]

In [6]:
validation_size = 0.15
nval = round(validation_size * icmes.shape[0])
train_size = 0.70
ntrain = round(train_size * icmes.shape[0])
test_size = 0.1
ntest = round(test_size * icmes.shape[0])

In [7]:
sum([nval, ntrain, ntest]) == icmes.shape[0]

False

In [68]:
sum([ntrain, ntest, nval])

340

In [31]:
def plot_interval_span(
    ax, 
    df, 
    c='r',
    hatch='*', 
    label='',
    
):
    j = 0
    for i, row in df.iterrows():
        if j == 0:
            ax.axvspan(row['start_time'], row['stop_time'], facecolor=c, alpha=0.3, label=label, hatch=hatch)
        else:
            ax.axvspan(row['start_time'], row['stop_time'], facecolor=c, alpha=0.3, hatch=hatch)
        j+=1
    return ax


def plot_train_test_val_intervals(
    val_df, 
    test_df, 
    train_df, 
    xlim=(dt.datetime(2009,1,1), dt.datetime(2010,1, 1))
):
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 3))
    plot_interval_span(ax, val_df, c='r', label='val', hatch=None)
    plot_interval_span(ax, test_df, c='b', label='test', hatch=None)
    plot_interval_span(ax, train_df, c='g', label='train', hatch=None)
    ax.tick_params(which='both', axis='y', labelleft=False, left=False)
    ax.xaxis.set_major_formatter(
        mdates.ConciseDateFormatter(
            locator=mdates.MonthLocator(interval=2),
            formats=['%Y', '%b', '%d', '%H:%M', '%H:%M', '%S.%f'],
            offset_formats=[
                '%Y',
                '%b',
                '%b %d, %Y',
                '%b %d, %Y',
                '%b %d, %Y',
                '%b %d, %Y'
            ],
            zero_formats=['', '%Y', '%b', '%b-%d', '%H:%M', '%H:%M'],
            show_offset=False
        )
    )
    ax.legend(loc='upper left', bbox_to_anchor=(1.01, 0.9), edgecolor='k')
    ax.set_xlim(xlim)
    fig.savefig('interval_check.jpg', format='jpg', dpi=250, bbox_inches='tight')

In [32]:
def get_pos_samples(positive_class, df, nsamples, used_intervals=None, for_train=False):
    dout = defaultdict(list)
    if used_intervals is not None:
        intervals_to_exclude = list(
            zip(used_intervals['start_time'], used_intervals['stop_time'])
        )
    else:
        intervals_to_exclude=[]

    while len(dout['fname']) < nsamples:

        pos_sample = positive_class.sample(1)
        if intervals_to_exclude is not None:
            distinct = check_used_intervals(intervals_to_exclude, pos_sample)
        else:
            distinct = True
        if not distinct:
            positive_class = positive_class.drop(index=pos_sample.index)

            continue
        # Get this sample and the following one from the full df
        pos_cut = df.loc[pos_sample.index[0]: pos_sample.index[0]+1]

        try:
            tdiff = (pos_cut['start_time'].iloc[1] - pos_cut['start_time'].iloc[0]) / np.timedelta64(1,'D')
        except IndexError:
            continue
        # if these samples are consecutive and both contain ICMEs continue
        if tdiff  == 0.75 and all(pos_cut.label.eq(1)):
            # consecutive interval and we are goo
            for i, row in pos_cut.iterrows():
                dout['fname'].append(row['fname'])
                dout['label'].append(row['label'])
                dout['start_time'].append(row['start_time'])
                dout['stop_time'].append(row['stop_time'])
                dout['index'].append(i)
#                 intervals_to_exclude.append((row['start_time'], row['stop_time']))
        else:
            # try the current sample and the previous one
            pos_cut = df.loc[pos_sample.index[0] - 1 : pos_sample.index[0]]
            try:
                tdiff = (pos_cut['start_time'].iloc[1] - pos_cut['start_time'].iloc[0]) / np.timedelta64(1,'D')
            except IndexError:
                continue
#             print(tdiff, pos_cut.label.eq(1).iloc[0])
            if tdiff == 0.75 and all(pos_cut.label.eq(1)):
                # consecutive interval and we are good
                for i, row in pos_cut.iterrows():
                    dout['fname'].append(row['fname'])
                    dout['label'].append(row['label'])
                    dout['start_time'].append(row['start_time'])
                    dout['stop_time'].append(row['stop_time'])
                    dout['index'].append(i)
            else:
#                 print('Resampling...')
                continue   

    return dout

def get_neg_samples(negative_class, df, nsamples, used_intervals):
    dout = defaultdict(list)
    intervals_to_exclude = list(zip(used_intervals['start_time'], used_intervals['stop_time']))
    while len(dout['fname']) < nsamples:
        neg_sample = negative_class.sample(1)
        distinct = check_used_intervals(intervals_to_exclude, neg_sample)
        if distinct:
            for i, row in neg_sample.iterrows():
                dout['fname'].append(row['fname'])
                dout['label'].append(row['label'])
                dout['start_time'].append(row['start_time'])
                dout['stop_time'].append(row['stop_time'])
                dout['index'].append(i)
        else:
            continue

    return dout

def check_used_intervals(intervals_to_exclude, sample):
    distinct = []
    for start, stop in intervals_to_exclude:
        if sample['start_time'].iloc[0] == start and sample['stop_time'].iloc[0] == stop:
            distinct.append(False)
        elif sample['stop_time'].iloc[0] < start or sample['start_time'].iloc[0] > stop:
            distinct.append(True)
        else:
            distinct.append(False)
    return all(distinct)


def trim_df(df, indices_to_drop, N):
    print(f'Found {N} samples... trimming original df')
    orig_shape = df.shape
    df = df.drop(index=indices_to_drop)
    new_shape = df.shape
    print(f"{orig_shape} --> {new_shape}")
    return df


def generate_test_train_split(df, ntest=20, ntrain=20, nval=20):
    positive_class = df[df.label.eq(1)]
    negative_class = df[df.label.eq(0)]
    
    train = defaultdict(list)
    test = defaultdict(list)
    pos_val = defaultdict(list)
    neg_val = defaultdict(list)
    used_intervals = defaultdict(list)
    # build the validation set
    print('Finding positive validation set')
    pos_val = get_pos_samples(positive_class, df, nval)
    used_intervals['start_time'] += pos_val['start_time']
    used_intervals['stop_time'] += pos_val['stop_time']

    print('Finding negative validation set')
    neg_val = get_neg_samples(negative_class, df, nval, used_intervals)
    used_intervals['start_time'] += neg_val['start_time']
    used_intervals['stop_time'] += neg_val['stop_time']
    
    print('Finding positive test set')
    pos_test = get_pos_samples(positive_class, df, ntest, used_intervals)
    used_intervals['start_time'] += pos_test['start_time']
    used_intervals['stop_time'] += pos_test['stop_time']
    
    print('Finding negative test set')
    neg_test = get_neg_samples(negative_class, df, ntest, used_intervals)
    used_intervals['start_time'] += neg_test['start_time']
    used_intervals['stop_time'] += neg_test['stop_time']
    
    print('Finding positive train set')
    pos_train = get_pos_samples(positive_class, df, ntrain, used_intervals, for_train=True)
    used_intervals['start_time'] += pos_train['start_time']
    used_intervals['stop_time'] += pos_train['stop_time']
    
    print('Finding negative train set')
    neg_train = get_neg_samples(negative_class, df, ntrain, used_intervals)
    used_intervals['start_time'] += neg_train['start_time']
    used_intervals['stop_time'] += neg_train['stop_time']
    

    validation_set = defaultdict(list)
    testing_set = defaultdict(list)
    training_set = defaultdict(list)
    
    # Combine the positive and negative classes for each set into single
    # dictionary and convert that to a dataframe
    for key in pos_val.keys():
        validation_set[key] += pos_val[key]
        validation_set[key] += neg_val[key]

    for key in pos_test.keys():
        testing_set[key] += pos_test[key]
        testing_set[key] += neg_test[key]
        
    for key in pos_train.keys():
        training_set[key] += pos_train[key]
        training_set[key] += neg_train[key]


    val_df = pd.DataFrame(validation_set, index=validation_set['index'])
    test_df = pd.DataFrame(testing_set, index=testing_set['index'])
    train_df = pd.DataFrame(training_set, index=training_set['index'])
    return val_df, test_df, train_df
    
    

In [33]:
val_df, test_df, train_df = generate_test_train_split(
    df, nval=nval, ntrain=ntrain, ntest=ntest
)

Finding positive validation set
Finding negative validation set
Finding positive test set
Finding negative test set
Finding positive train set
Finding negative train set


Check to make sure there is no overlap between the train/validation/test sets

In [34]:
train_f = set(train_df['fname'])
test_f = set(test_df['fname'])
val_f = set(val_df['fname'])

In [35]:
print(train_f.intersection(test_f))
print(train_f.intersection(val_f))
print(val_f.intersection(test_f))

set()
set()
set()


In [36]:
fig = plot_train_test_val_intervals(val_df, test_df, train_df)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [255]:
for df in [val_df, test_df, train_df]:
    df['fname_img'] = df.fname.str.replace('ts_interval','img_interval').str.replace('.txt','.npy')

  


In [None]:
val_df.to_csv('../data/sta_validation_set.txt', header=True, index=False)

In [None]:
test_df.to_csv('../data/sta_test_set.txt', header=True, index=False)

In [None]:
train_df.to_csv('../data/sta_train_set.txt', header=True, index=False)