Segments data for for fitting models to the same behavior across fish.  Here we break data up into a given number of disjoint sets that can later be combined to form train, validation and test sets.  

We have in mind applications where we do cross-validation, but cross validation is performed in a non-standard way.  In particular, 

1) We break the data up into T disjoint folds for testing. 

2) However, unlike standard cross validation, all of the data which is not in the testing data for a fold, may or may not be used for training and validation.  The reason for this is that we want to look at model performance as the amount of data used for training and validation changes.  For this reason, only some of the data not used for testing may be used for training and validation. 

To facilitate (1) and (2) above, the idea is the user will specifiy a number of disjoint "sets" (e.g., 40) to break the data for a subject into.  Each set will be roughly balanced in the different types of behaviors that are present as well as swim vigor.  At model fitting time, the user can then form the sets into largers test sets for cross validation, and use a subset of the remaining sets for training. 


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from collections import OrderedDict
from pathlib import Path
import pickle

import numpy as np

from ahrens_wbo.annotations import label_periods
from ahrens_wbo.data_processing import SegmentTable
from ahrens_wbo.data_processing import segment_dataset
from ahrens_wbo.raw_data_processing import load_processed_data

## Parameters go here

In [3]:
ps = dict()
ps['data_dir'] = r'/groups/bishop/bishoplab/projects/ahrens_wbo/data'

# Specify subjects
ps['subjects'] =  [1, 2, 5, 6, 8, 9, 10, 11]

# Specify the number of sets we form
ps['n_sets'] = 42

# Specify size of chunks data up into - we form sets out of chunks
ps['chunk_size'] = 5

# Specify which behavioral channels we will use for calculating values to associate with each sample point
ps['value_chs'] = [3, 4]

# Specify how we will group the data
ps['groups'] = OrderedDict([('phototaxis_ns', [{'period': 'phototaxis', 'shock': False}])])
                                                                                                          
# Specify value function
ps['value_fnc'] = 'max'
ps['random_vl_assignment'] = True
    
# Specify where we should save the segment information
ps['save_folder'] = r'/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/real_data/'
ps['save_name'] = r'phototaxis_ns_subjects_1_2_5_6_8_9_10_11_v2.pkl'

In [4]:
# Specify segment ratios
ps['segment_labels'] = ['set_' + str(i) for i in range(ps['n_sets'])]
ps['segment_ratios'] = [1]*ps['n_sets']

## Get list of labels for each dataset and swim values 

In [5]:
n_subjects = len(ps['subjects'])
labels = dict()
values = dict()
for subject_id in ps['subjects']:
    dataset = load_processed_data(Path(ps['data_dir']) / ('subject_' + str(subject_id)), subject_id)
    labels[subject_id] = label_periods(dataset.ts_data['stim']['vls'][:])
    values[subject_id] = dataset.ts_data['behavior']['vls'][:, ps['value_chs']]
    print('Done loading subject ' + str(subject_id))

Done loading subject 1
Done loading subject 2
Done loading subject 5
Done loading subject 6
Done loading subject 8
Done loading subject 9
Done loading subject 10
Done loading subject 11


## Segment datasets

In [6]:
if ps['value_fnc'] == 'mean':
    value_fnc = lambda x: np.mean(x)
elif ps['value_fnc'] == 'max':
    value_fnc = lambda x: np.max(x)
else:
    raise(ValueError('value_fcn is not recogonized'))

In [7]:
segment_tables = OrderedDict()
for s_n in ps['subjects']:
    subj_values = np.mean(values[s_n], axis=1)
    segment_tables[s_n] = segment_dataset(period_lbls=labels[s_n], groups=ps['groups'], 
                                               chunk_size=ps['chunk_size'],
                                               segment_labels=ps['segment_labels'], 
                                               segment_ratios=ps['segment_ratios'], 
                                               vls=subj_values, vl_fnc=value_fnc, 
                                               random_vl_assignment=ps['random_vl_assignment'])

## Save the segment tables

In [8]:
for s_n in ps['subjects']:
    segment_tables[s_n] = segment_tables[s_n].to_dict()

In [9]:
save_path = Path(Path(ps['save_folder']) / ps['save_name'])
rs = {'ps': ps, 'segment_tables': segment_tables}
with open(save_path, 'wb') as f:
    pickle.dump(rs, f)

In [10]:
print('Segment tables saved to: ' + str(save_path))

Segment tables saved to: /groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/real_data/phototaxis_ns_subjects_1_2_5_6_8_9_10_11_v2.pkl


In [12]:
segment_tables[1][0,0]

KeyError: (0, 0)