This notebook is designed to be used in conjunction with the notebook "segment_periods_for_transfer_analysis." That notebook breaks data for each subject into different sets, which can then be combined to form training, validation and testing data.  This notebook specifies how these sets are in fact combined once we know how many folds we want and how much data we want to use for training in each of these folds.

The notebook allows us to specify things on a per-subject basis.  This can be useful when wanting to use a lot of training data for some subjects but only a little training data for other subjects. In particular, we can specify how much training and validation data is used on a per-subject basis.  However, we always assume we use all possible testing data per fold for each subject. 


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np

from pathlib import Path
import pickle

from ahrens_wbo.data_processing import SegmentTable

## Parameters go here

In [3]:
ps = dict()

# Specify where the segment table created by segment_subperiods_for_transfer_analysis.ipynb is saved
ps['segment_table_folder'] = r'/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/real_data/'
ps['segment_table_file'] = r'phototaxis_ns_subjects_1_2_5_6_8_9_10_11.pkl'

# Number of folds we want
ps['n_folds'] = 3

# Number of sets of data (as defined by the notebook segment_subperiods_for_transfer_analysis) 
# we use for training for each subject
ps['n_train_sets'] = {1: 14, 
                      2: 14, 
                      5: 14, 
                      6: 14,
                      8: 4, 
                      9: 4, 
                     10: 4,
                     11: 4}

# Number of sets of data we use for validation for each subject
ps['n_validation_sets'] = {1: 14, 
                           2: 14, 
                           5: 14, 
                           6: 14,
                           8: 4, 
                           9: 4, 
                          10: 4,
                          11: 4}

# Specify where we should save the fold stucture
ps['save_folder'] = r'/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/real_data'
ps['save_name'] = 'fold_str_base_14_tgt_4.pkl'

## Load the segment tables and get basic information we need

In [4]:
segment_table_path = Path(ps['segment_table_folder']) / ps['segment_table_file']
with open(segment_table_path, 'rb') as f:
    seg_table_data = pickle.load(f)
    
n_seg_table_sets = seg_table_data['ps']['n_sets']
seg_tables = [SegmentTable.from_dict(seg_table) for seg_table in seg_table_data['segment_tables'].values()]
groups = seg_tables[0].groups

## Specify the fold structure

 Do some basic checks to make sure the parameters the user has entered in this notebook are compatable with the way the sgement tables were created

In [5]:
if n_seg_table_sets % ps['n_folds'] != 0:
    raise(ValueError('Number of folds does not evenly divide into the number of sets in the segment table.'))
    
n_test_sets = int(n_seg_table_sets / ps['n_folds'])

for s_n, n_sets in ps['n_train_sets'].items():
    if n_sets > n_test_sets:
        raise(ValueError('Subject ' + str(s_n) + 'requests more training sets than are available.'))
        
for s_n, n_sets in ps['n_validation_sets'].items():
    if n_sets > n_test_sets:
        raise(ValueError('Subject ' + str(s_n) + 'requests more validation sets than are available.'))

Specify the test sets in each fold, as well as the train and validation sets that can be used in each fold

In [6]:
# Specify which sets are in the testing data for each fold

test_sets = [list(range(i, i + n_test_sets)) for i in np.linspace(0, n_seg_table_sets, ps['n_folds']+1, dtype=int)[0:-1]]

# Now specify which sets can be used for training and validation data for each fold
def rotate(l, n):
    return l[n:] + l[:n]

possible_train_sets = rotate(test_sets, 1)
possible_validation_sets = rotate(test_sets,2)


Form the dictionaries for each fold that can be used to directly get the slices from the segement table

In [7]:
fold_groups = dict()
for s_n in ps['n_train_sets'].keys():

    s_fold_groups = dict()
    for f_i in range(ps['n_folds']):
        
        fold_test_sets = ['set_' + str(s) for s in test_sets[f_i]]
        fold_train_sets = ['set_' + str(s) for s in possible_train_sets[f_i][:ps['n_train_sets'][s_n]]]
        fold_validation_sets = ['set_' + str(s) for s in possible_validation_sets[f_i][:ps['n_validation_sets'][s_n]]]
        
        test_dict = {grp: fold_test_sets for grp in groups}
        train_dict = {grp: fold_train_sets for grp in groups}
        validation_dict = {grp: fold_validation_sets for grp in groups}
    
        s_fold_groups[f_i] = {'test': test_dict, 'train': train_dict, 'validation': validation_dict}
        
    fold_groups[s_n] = s_fold_groups
    

## Save the fold structure

In [None]:
save_path = Path(ps['save_folder']) / ps['save_name']
with open(save_path, 'wb') as f:
    pickle.dump(fold_groups, f)
    
print('Saved fold structure to: ' + str(save_path))

In [8]:
fold_groups

{1: {0: {'test': {'phototaxis_ns': ['set_0',
     'set_1',
     'set_2',
     'set_3',
     'set_4',
     'set_5',
     'set_6',
     'set_7',
     'set_8',
     'set_9',
     'set_10',
     'set_11',
     'set_12',
     'set_13']},
   'train': {'phototaxis_ns': ['set_14',
     'set_15',
     'set_16',
     'set_17',
     'set_18',
     'set_19',
     'set_20',
     'set_21',
     'set_22',
     'set_23',
     'set_24',
     'set_25',
     'set_26',
     'set_27']},
   'validation': {'phototaxis_ns': ['set_28',
     'set_29',
     'set_30',
     'set_31',
     'set_32',
     'set_33',
     'set_34',
     'set_35',
     'set_36',
     'set_37',
     'set_38',
     'set_39',
     'set_40',
     'set_41']}},
  1: {'test': {'phototaxis_ns': ['set_14',
     'set_15',
     'set_16',
     'set_17',
     'set_18',
     'set_19',
     'set_20',
     'set_21',
     'set_22',
     'set_23',
     'set_24',
     'set_25',
     'set_26',
     'set_27']},
   'train': {'phototaxis_ns': ['set_28',
    