Segments data for for fitting models to different data across fish. 


The way we segment data is very basic.  The user first specified a set of groups (e.g., omr left) and then for each fish we break the data for each group into a given set of chunks of a fixed size (e.g., 5 time poitns).  We order these chunks according to the power of swimming in each chunk and then assign each chunk to it's own segment (so in this special case, segments and chunks are the same thing).  By ordering the chunks by swim power, we can then try to balance swim power when breaking data up into train/validation sets in the notebook form_folds_for_across_cond_analysis.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from collections import OrderedDict
from pathlib import Path
import pickle

import numpy as np

from ahrens_wbo.annotations import label_subperiods
from ahrens_wbo.data_processing import SegmentTable
from ahrens_wbo.data_processing import segment_dataset_with_constant_segment_sizes
from ahrens_wbo.raw_data_processing import load_processed_data

## Parameters go here

In [3]:
ps = dict()
ps['data_dir'] = r'/groups/bishop/bishoplab/projects/ahrens_wbo/data'

# Specify subjects
ps['subjects'] =  [8, 9, 10, 11]

# Specify size of chunks data up into - we form sets out of chunks
ps['chunk_size'] = 5

# Specify which behavioral channels we will use for calculating values to associate with each sample point
ps['value_chs'] = [3, 4]

# Specify how we will group the data
ps['groups'] = OrderedDict([('omr_l_ns', [{'period': 'omr_left', 'shock': False}]),
                            ('omr_r_ns', [{'period': 'omr_right', 'shock': False}]),
                            ('omr_f_ns', [{'period': 'omr_forward', 'shock': False}])])                                                                                                      
# Specify value function
ps['value_fnc'] = 'max'
    
# Specify where we should save the segment information
ps['save_folder'] = r'/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/real_data/'
ps['save_name'] = r'omr_l_r_f_ns_across_cond_segments_8_9_10_11.pkl'

## Get list of labels for each dataset and swim values 

In [4]:
n_subjects = len(ps['subjects'])
labels = dict()
values = dict()
for subject_id in ps['subjects']:
    dataset = load_processed_data(Path(ps['data_dir']) / ('subject_' + str(subject_id)), subject_id)
    labels[subject_id] = label_subperiods(dataset.ts_data['stim']['vls'][:])
    values[subject_id] = dataset.ts_data['behavior']['vls'][:, ps['value_chs']]
    print('Done loading subject ' + str(subject_id))

Done loading subject 8
Done loading subject 9
Done loading subject 10
Done loading subject 11


## Segment datasets

In [5]:
if ps['value_fnc'] == 'mean':
    value_fnc = lambda x: np.mean(x)
elif ps['value_fnc'] == 'max':
    value_fnc = lambda x: np.max(x)
else:
    raise(ValueError('value_fcn is not recogonized'))

In [15]:
segment_tables = OrderedDict()
for s_n in ps['subjects']:
    subj_values = np.mean(values[s_n], axis=1)
    segment_tables[s_n] = segment_dataset_with_constant_segment_sizes(period_lbls=labels[s_n], 
                                                                      groups=ps['groups'], 
                                                                      chunk_size=ps['chunk_size'],
                                                                      n_segment_chunks=1, 
                                                                      vls=subj_values, vl_fnc=value_fnc, 
                                                                      random_vl_assignment=False)

## Save the segment tables

In [7]:
for s_n in ps['subjects']:
    segment_tables[s_n] = segment_tables[s_n].to_dict()

In [8]:
save_path = Path(Path(ps['save_folder']) / ps['save_name'])
rs = {'ps': ps, 'segment_tables': segment_tables}
with open(save_path, 'wb') as f:
    pickle.dump(rs, f)

In [9]:
print('Segment tables saved to: ' + str(save_path))

Segment tables saved to: /groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/real_data/omr_l_r_f_ns_across_cond_segments_8_9_10_11.pkl


## Debug code goes here

In [55]:
labels[8]['omr_right'][]

[{'slice': slice(1030, 1060, None), 'shock': False},
 {'slice': slice(1180, 1210, None), 'shock': False},
 {'slice': slice(1330, 1360, None), 'shock': False},
 {'slice': slice(1480, 1510, None), 'shock': False},
 {'slice': slice(1630, 1660, None), 'shock': False},
 {'slice': slice(1780, 1810, None), 'shock': False},
 {'slice': slice(1930, 1960, None), 'shock': False},
 {'slice': slice(2080, 2110, None), 'shock': False},
 {'slice': slice(2230, 2260, None), 'shock': False},
 {'slice': slice(3580, 3610, None), 'shock': True},
 {'slice': slice(3730, 3760, None), 'shock': True},
 {'slice': slice(3880, 3910, None), 'shock': True},
 {'slice': slice(4030, 4060, None), 'shock': True},
 {'slice': slice(4180, 4210, None), 'shock': True},
 {'slice': slice(4330, 4360, None), 'shock': True},
 {'slice': slice(4480, 4510, None), 'shock': True},
 {'slice': slice(4630, 4660, None), 'shock': True},
 {'slice': slice(4780, 4810, None), 'shock': True},
 {'slice': slice(6250, 6280, None), 'shock': False},
 {

In [51]:
segment_tables[10].find_all({'omr_l_ns': segment_tables[10].segments})

[slice(600, 620, 1),
 slice(720, 740, 1),
 slice(840, 860, 1),
 slice(960, 980, 1),
 slice(1080, 1100, 1),
 slice(1200, 1220, 1)]

In [58]:
segment_tables[10].grp_segment_slices[0][0] = None

In [60]:
segment_tables[10].grp_segment_slices[0][1]

[slice(685, 690, None)]