This notebook is designed to be used in conjunction with the notebook "segment_periods_for_transfer_analysis." That notebook breaks data for each subject into different sets, which can then be combined to form training, validation and testing data in this notebook.

The idea behind the way we break up data is we want to do multiple analyses, each of the following form: 

    1) We identify a "target" fish - this is a fish we want to transfer model structure we learn from other fish to.  We also identify a "target condition" we observe this fish under. 
    
    2) We identify a number of "transfer" fish - these are fish we will observe in conditions different from those we observe the target fish in. We want to transfer what we learn about model structure under these conditions to the target fish. 
    
    3) We will from our training and validation data in two ways.  In the first way, the training and validation data for each fish consists of a different condition.  In the second way, the training and validation data for all fish consists of the same condition as the condition for the target fish.  We make sure the total amount of training and validation data used in both cases in the same. 
    
    4) We then test the performance of the models for the target fish on the conditions outside of its training data (as well as on the condition in its training data). What we hope to see is that model performance improves on the conditions outside of the target fish's training condition when we synthesize a model for the target fish when the training data for the other fish is of the other conditions relative to when all fish have the same condition.  This would show our framework is able to transfer model structure across fish even when those fish are observed in different conditions. 
    
To this end, this script will generate and save two assignments of data to train, test and validation for each target fish and target fish condition.   We generate such an assignment for each type of condition we observe in the target fish (e.g., if we have OMR L, R and F data for our fish, we can do three seperate analyses where we assume we only observe one of these conditions in the target fish and then the other two in the transfer fish).  The two assignments correspond to when (1) we observe different conditions across fish and (2) when we observe the same condition in all fish. 


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import copy

from pathlib import Path
import pickle

import numpy as np
import numpy.random as random

from ahrens_wbo.data_processing import SegmentTable

## Parameters go here

In [3]:
ps = dict()

# Specify where the segment table created by segment_ahrens_data_for_across_cond_analsysis.ipynb is saved
ps['segment_table_folder'] = r'/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/real_data/'
ps['segment_table_file'] = r'omr_l_r_f_ns_across_cond_segments_8_9_10_11.pkl'

# Specify the different conditions we want to test - there should be three of these
ps['test_groups'] = ['omr_l_ns', 'omr_r_ns', 'omr_f_ns']

# Specify the different subjects we use in the analysis - there should be three of these
ps['subjects'] = [8, 9, 11]

# Specify the percentage of data for each target subject we use for training.  (Note because we 
# balance data across fish, we may not use this much, so this is the max we can use.)
ps['tgt_subj_train_percentage'] = .7

# Specify the percentage of data for each target subject train condition we use for validation. (Note because we 
# balance data across fish, we may not use this much, so this is the max we can use.)
ps['tgt_subj_validation_percentage'] = .15

# Specify the folder we should save the fold structures in
ps['save_folder'] = r'' #r'/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/real_data/'

# Specify a base string to prepend to file names with the fold structures
ps['save_str'] = 'ac_an'



## Load the segment tables and get basic information we need

In [4]:
segment_table_path = Path(ps['segment_table_folder']) / ps['segment_table_file']
with open(segment_table_path, 'rb') as f:
    seg_table_data = pickle.load(f)
    
seg_tables = {s_n: SegmentTable.from_dict(seg_table) for s_n, seg_table in seg_table_data['segment_tables'].items()}

## Define helper functions here

In [59]:
def randomly_assign_chunks(n_max_chunks, n_train_chunks, n_validation_chunks):
    chunk_order = random.permutation(n_max_chunks)
    train_chunks = chunk_order[:n_train_chunks]
    validation_chunks = chunk_order[n_train_chunks:n_train_chunks+n_validation_chunks]
    test_chunks = chunk_order[n_train_chunks+n_validation_chunks:]
    
    return train_chunks, validation_chunks, test_chunks

In [57]:
tgt_subj = 8
tgt_subj_train_group = 'omr_l_ns'
train_percentage = .7
validation_percentage = .15
trans_subjs = [9, 11]
trans_train_groups = ['omr_r_ns', 'omr_f_ns']

In [61]:
n_transfer_subjs = len(trans_subjs)

all_subjs = [tgt_subj] + trans_subjs
all_across_cond_train_conds = [tgt_subj_train_group] + trans_train_groups

n_max_across_cond_train_chunks = [np.floor(train_percentage*seg_tables[s_n].n_group_segments(grp)) 
                                  for s_n, grp in zip(all_subjs, all_across_cond_train_conds)]

n_max_across_cond_validation_chunks = [np.floor(validation_percentage*seg_tables[s_n].n_group_segments(grp)) 
                                  for s_n, grp in zip(all_subjs, all_across_cond_train_conds)]

n_max_within_cond_train_chunks = [np.floor(train_percentage*seg_tables[s_n].n_group_segments(tgt_subj_train_group)) 
                                  for s_n in all_subjs]

n_max_within_cond_validation_chunks = [np.floor(validation_percentage*seg_tables[s_n].n_group_segments(tgt_subj_train_group)) 
                                  for s_n in all_subjs]

n_train_chunks = int(np.min(np.stack([n_max_across_cond_train_chunks, n_max_within_cond_train_chunks])))
n_validation_chunks = int(np.min(np.stack([n_max_across_cond_validation_chunks, n_max_within_cond_validation_chunks])))

# Form across condition assignments
assignments = dict()
for s_n, grp in zip(all_subjs, all_across_cond_train_conds):
    num_segs_n = seg_tables[s_n].n_group_segments(grp)
    
    train_chunks_n, validation_chunks_n, test_chunks_n = randomly_assign_chunks(n_max_chunks=num_segs_n, 
                                                                                n_train_chunks=n_train_chunks,
                                                                                n_validation_chunks=n_validation_chunks)
    
    assignments[s_n] = {'train': {grp: ['set_' + str(i) for i in train_chunks_n]},
                     'validation': {grp: ['set_' + str(i) for i in validation_chunks_n]},
                     'test': {grp: ['set_' + str(i) for i in test_chunks_n] }}
    
    # TODO: Test data needs to include other conditions as well



# Form within condition assignments


In [62]:
assignments

{8: {'train': {'omr_l_ns': ['set_6',
    'set_9',
    'set_80',
    'set_35',
    'set_23',
    'set_71',
    'set_43',
    'set_85',
    'set_34',
    'set_102',
    'set_20',
    'set_45',
    'set_97',
    'set_37',
    'set_31',
    'set_1',
    'set_32',
    'set_63',
    'set_91',
    'set_25',
    'set_98']},
  'validation': {'omr_l_ns': ['set_4', 'set_60', 'set_87', 'set_92']},
  'test': {'omr_l_ns': ['set_0',
    'set_100',
    'set_11',
    'set_49',
    'set_82',
    'set_83',
    'set_99',
    'set_107',
    'set_22',
    'set_13',
    'set_21',
    'set_77',
    'set_33',
    'set_104',
    'set_28',
    'set_106',
    'set_66',
    'set_12',
    'set_74',
    'set_94',
    'set_86',
    'set_24',
    'set_89',
    'set_105',
    'set_2',
    'set_40',
    'set_3',
    'set_70',
    'set_8',
    'set_73',
    'set_52',
    'set_57',
    'set_75',
    'set_78',
    'set_46',
    'set_42',
    'set_27',
    'set_26',
    'set_103',
    'set_48',
    'set_54',
    'set_64',
 

In [44]:
n_validation_chunks

4.0

In [5]:
def form_subj_group_fold(seg_tables, tgt_subj, tgt_subj_train_group, tgt_subj_train_percentage,  
                         tgt_subj_validation_percentage, trans_subjs, trans_train_groups):
    """ Assigns training, validation and testing data for a transfer analysis for one target fish. 
    
    Here, the user specifies a "target" fish (a fish we want to transfer model structure to) and the condition
    that we get to see in the training data for that fish (e.g., OMR Left).  The user also specifies "transfer"
    fish and the conditions we get to see in each of those fish (e.g., OMR Right, OMR Forward).  This function 
    then assigns training data so that we see a different condition in each fish. In particular, we see the condition 
    for the training fish and each of the specified conditions in each of the transfer fish. 
           
    When assigning the training data, this function will ensure the amount of training data is the same for all
    training conditions and fish. 
    
    This function will also assign validation and test data for the target fish (but it will not 
    assign validation and test data for the transfer fish, since we expect to be doing early stopping and testing 
    based only on the target fish).  Test data consists of all the data in each of the train conditions for the
    transfer fish and a percentage (given by 1 - tgt_subj_train_percentage - tgt_subj_validation_percentage) of the 
    data for the train condition for the target fish.  The validation data validation data is the same condition as 
    the train data, and the amount of data will be roughly equal to tgt_subj_validation_percentage (only roughly equal
    due to the discrete amount of data that is available) of the amount of data available for the condition.  We try 
    to roughly ensure balance in the swimming strength in training, testing and validation for the target condition 
    data by randomly assigning the top sets in the segment table to the train, test and validation 
    data (to see how we sort sets in the segment tables by swimming strength, see the notes below).
    
    A couple of important final notes: 
    
    1) The segment table input is expected to be sorted by swimming strength (the notebook, 
    segment_ahrens_data_for_across_cond_analysis does this).  In many cases, we may not be able to use all the 
    data of a condition for training, testing and validation (see below).  In that case, we use the data with the 
    strongest swimming signals (specifically, we use the top sets in the segment table, when sets are sorted by
    swimming strength).
    
    2) Due to the need to balance the amount of training data across fish, we may not use all the available 
    training data for a condition in a fish.  
    
    3) In some cases, there may be much more data for the target condition and target fish than there is for the 
    transfer conditions and transfer fish. In this case, we can't use the full tgt_subj_train_percentage percent
    of data for the training data, and we assign the unused data to testing. 
       
    """
    
    n_transfer_subjs = len(trans_subjs)
    
    # See how many segments are available for training and validation in each subject
    
    n_tgt_subj_segs = seg_tables[tgt_subj].n_group_segments(tgt_subj_train_group)
    n_tgt_subj_train_segs = int(np.floor(n_tgt_subj_segs*tgt_subj_train_percentage))
    n_tgt_subj_validation_segs = int(np.floor(n_tgt_subj_segs*tgt_subj_validation_percentage))

    n_trans_subj_train_segs = [seg_tables[s_n].n_group_segments(grp) 
                               for s_n, grp in zip(trans_subjs, trans_train_groups)]
    
    # Determine the number of segments used for training - this is the min number available across all subjects
    n_train_segs = np.min(n_trans_subj_train_segs + [n_tgt_subj_train_segs])
    
    # Determine number of segments for testing the train condition in the target fish
    n_tgt_subj_train_cond_test_segs = n_tgt_subj_segs - n_train_segs - n_tgt_subj_validation_segs
    
    # Form our fold structure for the target fish here
    #n_tgt_fish_segs = n_train_segs + n_tgt_subj_validation_segs
    tgt_seg_nums = random.permutation(n_tgt_subj_segs)
    tgt_train_seg_nums = tgt_seg_nums[0:n_train_segs]
    tgt_validation_seg_nums = tgt_seg_nums[n_train_segs:n_train_segs+n_tgt_subj_validation_segs]
    tgt_train_cond_test_seg_nums = tgt_seg_nums[n_train_segs+n_tgt_subj_validation_segs:]
      
    tgt_fish_fold = {'train': {tgt_subj_train_group: ['set_' + str(n) for n in tgt_train_seg_nums]},
                     'validation': {tgt_subj_train_group: ['set_' + str(n) for n in tgt_validation_seg_nums]},
                     'test': {tgt_subj_train_group: ['set_' + str(n) for n in tgt_train_cond_test_seg_nums] }}
    
    for s_i in range(n_transfer_subjs):
        tgt_fish_fold['test'][trans_train_groups[s_i]] = ['set_' + str(n) 
                                    for n in range(seg_tables[tgt_subj].n_group_segments(trans_train_groups[s_i]))]
                     
    # Form our fold structure for the transfer fish
    transfer_fish_folds = [{'train': {trans_train_groups[s_i]: ['set_' + str(n) for n in range(n_train_segs)]},
                            'validation': None, 
                            'test': None}
                           for s_i, s_n in enumerate(trans_subjs)]

    return {tgt_subj: tgt_fish_fold, 
            trans_subjs[0]: transfer_fish_folds[0], 
            trans_subjs[1]: transfer_fish_folds[1]}

    

    

## Form fold structures

In [6]:
multi_cond_fold_strs = dict()
single_cond_fold_strs = dict()
for tgt_subj in ps['subjects']:
    fish_multi_cond_folds = dict()
    fish_single_cond_folds = dict()
    for tgt_cond in ps['test_groups']:
        trans_subjs = [s_n for s_n in ps['subjects'] if s_n != tgt_subj]
        trans_conds = [cond for cond in ps['test_groups']  if cond != tgt_cond]
        
        multi_cond_folds = form_subj_group_fold(seg_tables=seg_tables, 
                                                tgt_subj=tgt_subj, 
                                                tgt_subj_train_group=tgt_cond, 
                                                tgt_subj_train_percentage=ps['tgt_subj_train_percentage'], 
                                                tgt_subj_validation_percentage=ps['tgt_subj_validation_percentage'],
                                                trans_subjs=trans_subjs, 
                                                trans_train_groups=trans_conds)
        
        transfer_subjs = set(ps['subjects']) - set([tgt_subj])
        single_cond_folds = copy.deepcopy(multi_cond_folds)
        for transfer_subj in transfer_subjs:
            n_train_segs = len(single_cond_folds[tgt_subj]['train'][tgt_cond])
            single_cond_folds[transfer_subj]['train'] = {tgt_cond: ['set_' + str(i) for i in range(n_train_segs)]}

        fish_multi_cond_folds[tgt_cond] = multi_cond_folds
        fish_single_cond_folds[tgt_cond] = single_cond_folds
    
    multi_cond_fold_strs[tgt_subj] = fish_multi_cond_folds
    single_cond_fold_strs[tgt_subj] = fish_single_cond_folds

Rearrange the fold structures so they are organized so that the fish is the first level and then fold is the second level - this will then allow the fold structures to be used seamlessly with our standard fitting code

In [7]:
new_multi_cond_fold_strs = dict()
new_single_cond_fold_strs = dict()
for tgt_subj in ps['subjects']:
    new_multi_cond_fold_strs[tgt_subj] = dict()
    new_single_cond_fold_strs[tgt_subj] = dict()
    for s_n in ps['subjects']:
        new_multi_cond_fold_strs[tgt_subj][s_n] = dict()
        new_single_cond_fold_strs[tgt_subj][s_n] = dict()
        for tgt_cond in ps['test_groups']:
            new_multi_cond_fold_strs[tgt_subj][s_n][tgt_cond] = multi_cond_fold_strs[tgt_subj][tgt_cond][s_n]
            new_single_cond_fold_strs[tgt_subj][s_n][tgt_cond] = single_cond_fold_strs[tgt_subj][tgt_cond][s_n]

## Save fold structures 

In [None]:
for tgt_subj in ps['subjects']:
    multi_cond_folds = new_multi_cond_fold_strs[tgt_subj]
    single_cond_folds = new_single_cond_fold_strs[tgt_subj]
    
    multi_cond_file_name = ps['save_str'] + '_tgt_' + str(tgt_subj) + '_multi_cond_folds.pkl'
    single_cond_file_name = ps['save_str'] + '_tgt_' + str(tgt_subj) + '_single_cond_folds.pkl'
    
    multi_cond_path = Path(ps['save_folder']) / multi_cond_file_name
    single_cond_path = Path(ps['save_folder']) / single_cond_file_name
    
    with open(multi_cond_path, 'wb') as f:
        pickle.dump(multi_cond_folds, f)
    with open(single_cond_path, 'wb') as f:
        pickle.dump(single_cond_folds, f)    
    
    

In [13]:
new_multi_cond_fold_strs[8][9]

{'omr_l_ns': {'train': {'omr_r_ns': ['set_0',
    'set_1',
    'set_2',
    'set_3',
    'set_4',
    'set_5',
    'set_6',
    'set_7',
    'set_8',
    'set_9',
    'set_10',
    'set_11',
    'set_12',
    'set_13',
    'set_14',
    'set_15',
    'set_16',
    'set_17',
    'set_18',
    'set_19',
    'set_20',
    'set_21',
    'set_22',
    'set_23',
    'set_24',
    'set_25',
    'set_26',
    'set_27',
    'set_28',
    'set_29',
    'set_30',
    'set_31']},
  'validation': None,
  'test': None},
 'omr_r_ns': {'train': {'omr_l_ns': ['set_0',
    'set_1',
    'set_2',
    'set_3',
    'set_4',
    'set_5',
    'set_6',
    'set_7',
    'set_8',
    'set_9',
    'set_10',
    'set_11',
    'set_12',
    'set_13',
    'set_14',
    'set_15',
    'set_16',
    'set_17',
    'set_18',
    'set_19',
    'set_20',
    'set_21',
    'set_22',
    'set_23',
    'set_24',
    'set_25',
    'set_26',
    'set_27',
    'set_28',
    'set_29',
    'set_30',
    'set_31']},
  'validation':