This notebook is designed to be used in conjunction with the notebook "segment_periods_for_transfer_analysis." That notebook breaks data for each subject into different sets, which can then be combined to form training, validation and testing data in this notebook.

The idea behind the way we break up data is we want to do multiple analyses, each of the following form: 

    1) We identify a "target" fish - this is a fish we want to transfer model structure we learn from other fish to.  We also identify a "target condition" we observe this fish under. 
    
    2) We identify a number of "transfer" fish - these are fish we will observe in conditions different from those we observe the target fish in. We want to transfer what we learn about model structure under these conditions to the target fish. 
    
    3) We will from our training data in two ways.  In the first way, the training data for each fish consists of a different condition.  In the second way, the training data for all fish consists of the same condition as the condition for the target fish.  We make sure the total amount of training data used in both cases in the same. 
    
    4) We then test the performance of the models for the target fish on the conditions outside of its training data. What we hope to see is that model performance improves when we synthesize a model for the target fish when the training data for the other fish is of the other conditions vs. when all fish have the same condition.  This would show our framework is transfer model structure across fish even when those fish are observed in different conditions. 
    
To this end, this script will generate and save two fold structures for each target fish.  "Folds" here is used loosely, and should just be understood as an assignment of training, validation and testing data.  We generate such an assignment (such a "fold") for each type of condition we  observe in the target fish (e.g., if we have OMR L, R and F data for our fish, we can do three seperate analyses where we assume we only observe one of these conditions in the target fish and then the other two in the transfer fish). The keys of the folds correspond to the condition we observe
in the target fish (so the different folds are not refered to by number but by string). We generate a set of folds for each target fish when (1) we observe different conditions across fish and (2) when we observe the same condition in all fish. 

See documentaiton of the function form_subj_group_fold() below for more details of how we assign data within a fold. 


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import copy

from pathlib import Path
import pickle

import numpy as np
import numpy.random as random

from ahrens_wbo.data_processing import SegmentTable

## Parameters go here

In [3]:
ps = dict()

# Specify where the segment table created by segment_ahrens_data_for_across_cond_analsysis.ipynb is saved
ps['segment_table_folder'] = r'/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/real_data/'
ps['segment_table_file'] = r'omr_l_r_f_ns_across_cond_segments_8_9_10_11.pkl'

# Specify the different conditions we want to test - there should be three of these
ps['test_groups'] = ['omr_l_ns', 'omr_r_ns', 'omr_f_ns']

# Specify the different subjects we use in the analysis - there should be three of these
ps['subjects'] = [8, 9, 11]

# Specify the percentage of data for each target subject we use for training.  (Note because we 
# balance data across fish, we may not use this much, so this is the max we can use.)
ps['tgt_subj_train_percentage'] = .8

# Specify the folder we should save the fold structures in
ps['save_folder'] = r'/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/real_data/'

# Specify a base string to prepend to file names with the fold structures
ps['save_str'] = 'ac_an'



## Load the segment tables and get basic information we need

In [4]:
segment_table_path = Path(ps['segment_table_folder']) / ps['segment_table_file']
with open(segment_table_path, 'rb') as f:
    seg_table_data = pickle.load(f)
    
seg_tables = {s_n: SegmentTable.from_dict(seg_table) for s_n, seg_table in seg_table_data['segment_tables'].items()}

## Define helper functions here

In [5]:
def form_subj_group_fold(seg_tables, tgt_subj, tgt_subj_train_group, tgt_subj_train_percentage, 
                         trans_subjs, trans_train_groups, fold_types = 'multi_cond'):
    """ Assigns training, validation and testing data for a transfer analysis for one target fish. 
    
    
    Here, the user specifies a "target" fish (a fish we want to transfer model structure to) and the condition
    that we get to see in the training data for that fish (e.g., OMR Left).  The user also specifies "transfer"
    fish and the conditions we get to see in each of those fish (e.g., OMR Right, OMR Forward).  This function 
    then assigns training data in one of two ways: 
    
        'multi_cond': We see a different condition in each fish. In particular, we see the condition for the training
        fish and each of the specified conditions in each of the transfer fish. 
        
        'single_cond': We see the same condition (the conditing for the training fish) in all fish. 
        
    When assigning the training data, this function will ensure the amount of training data in the multi_cond and
    single_cond cases is the same, to enable comparison of fit model performance. 
    
    This function will also assign validation data for the target fish (but it will not assign validation data for the 
    transfer fish, since we expect to be doing early stopping based only on the target fish).  The validation
    data is the same condition as the train data, and the amount of data will be roughly equal to 
    1-tgt_subj_train_percentage (only roughly equal due to the discrete amount of data that is available) of
    the amount of data available for the condition.  We try to roughly ensure balance in the swimming strength in
    training and validation data by randomly assigning the top sets in the segment table to the train and validation 
    data (to see how we sort sets in the segment tables by swimming strength, see the notes below).
    
    Finally, this function will assign test data for the target fish.  The test data is simply all the data in each of
    the train conditions for the transfer fish. 
    
    A couple of important final notes: 
    
    1) The segment table input is expected to be sorted by swimming strength (the notebook, 
    segment_ahrens_data_for_across_cond_analysis does this).  In many cases, we may not be able to use all the 
    data of a condition for training or validation (see below).  In that case, we use the data with the strongest 
    swimming signals (specifically, we use the top sets in the segment table, when sets are sorted by swimming
    strength).
    
    2) Due to the need to balance the amount of training data across fish and the multi and single conditions, we
    may not use all the availabel training data for a condition in a fish.  
    
    3) In some cases, there may be much more data for the target condition and target fish than there is for the 
    transfer conditions and transfer fish. In this case, we do not adjust the amount of validation data used (it is
    always 1-tgt_subj_train_percentage percent of the original amount of available data.   We could have decided to 
    instead to increase the amount of validation data by simply assigning all data not used for training to validation.
    However, it would then become much harder to make sure that behavior is balanced between the train and validation
    data; thus we decided not to do this. 
    
    """
    
    n_transfer_subjs = len(trans_subjs)
    
    # See how many segments are available for training in each subject, accounting for the data we need
    # for validation in our target subject
    n_tgt_subj_segs = seg_tables[tgt_subj].n_group_segments(tgt_subj_train_group)
    n_tgt_subj_train_segs = int(np.floor(n_tgt_subj_segs*tgt_subj_train_percentage))
    n_tgt_subj_validation_segs = n_tgt_subj_segs - n_tgt_subj_train_segs
    
    n_trans_subj_train_segs = [seg_tables[s_n].n_group_segments(grp) 
                               for s_n, grp in zip(trans_subjs, trans_train_groups)]
    
    # Determine the number of segments used for training - this is the min number available across all subjects
    n_train_segs = np.min(n_trans_subj_train_segs + [n_tgt_subj_train_segs])
    
    # Form our fold structure for the target fish here
    n_tgt_fish_segs = n_train_segs + n_tgt_subj_validation_segs
    tgt_seg_nums = random.permutation(n_tgt_fish_segs)
    tgt_train_seg_nums = tgt_seg_nums[0:n_train_segs]
    tgt_validation_seg_nums = tgt_seg_nums[n_train_segs:]
    
    tgt_fish_fold = {'train': {tgt_subj_train_group: ['set_' + str(n) for n in tgt_train_seg_nums]},
                     'validation': {tgt_subj_train_group: ['set_' + str(n) for n in tgt_validation_seg_nums]},
                     'test': {trans_train_groups[s_i]: ['set_' + str(n) 
                                                        for n in range(seg_tables[tgt_subj].n_group_segments(trans_train_groups[s_i]))]
                              for s_i in range(n_transfer_subjs)}}
    
    # Form our fold structure for the transfer fish
    if fold_types == 'multi_cond':
        transfer_fish_folds = [{'train': {trans_train_groups[s_i]: ['set_' + str(n) for n in range(n_train_segs)]},
                                'validation': None, 
                                'test': None}
                                for s_i, s_n in enumerate(trans_subjs)]
    elif fold_types == 'single_cond':
        transfer_fish_folds = [{'train': {tgt_subj_train_group: ['set_' + str(n) for n in range(n_train_segs)]},
                                'validation': None, 
                                'test': None} for _ in trans_subjs]
    else:
        raise(ValueError('fold_types ' + fold_types + ' is not recognized'))
    
    return {tgt_subj: tgt_fish_fold, 
            trans_subjs[0]: transfer_fish_folds[0], 
            trans_subjs[1]: transfer_fish_folds[1]}

    

    

## Form fold structures

In [6]:
multi_cond_fold_strs = dict()
single_cond_fold_strs = dict()
for tgt_subj in ps['subjects']:
    fish_multi_cond_folds = dict()
    fish_single_cond_folds = dict()
    for tgt_cond in ps['test_groups']:
        trans_subjs = [s_n for s_n in ps['subjects'] if s_n != tgt_subj]
        trans_conds = [cond for cond in ps['test_groups']  if cond != tgt_cond]
        
        multi_cond_folds = form_subj_group_fold(seg_tables=seg_tables, 
                                                tgt_subj=tgt_subj, 
                                                tgt_subj_train_group=tgt_cond, 
                                                tgt_subj_train_percentage=ps['tgt_subj_train_percentage'], 
                                                trans_subjs=trans_subjs, 
                                                trans_train_groups=trans_conds, 
                                                fold_types='multi_cond')
        
        single_cond_folds = form_subj_group_fold(seg_tables=seg_tables, 
                                                tgt_subj=tgt_subj, 
                                                tgt_subj_train_group=tgt_cond, 
                                                tgt_subj_train_percentage=ps['tgt_subj_train_percentage'], 
                                                trans_subjs=trans_subjs, 
                                                trans_train_groups=trans_conds, 
                                                fold_types='single_cond')
        
        fish_multi_cond_folds[tgt_cond] = multi_cond_folds
        fish_single_cond_folds[tgt_cond] = single_cond_folds
    
    multi_cond_fold_strs[tgt_subj] = fish_multi_cond_folds
    single_cond_fold_strs[tgt_subj] = fish_single_cond_folds

Rearrange the fold structures so they are organized so that the fish is the first level and then fold is the second level - this will then allow the fold structures to be used seamlessly with our standard fitting code

In [7]:
new_multi_cond_fold_strs = dict()
new_single_cond_fold_strs = dict()
for tgt_subj in ps['subjects']:
    new_multi_cond_fold_strs[tgt_subj] = dict()
    new_single_cond_fold_strs[tgt_subj] = dict()
    for s_n in ps['subjects']:
        new_multi_cond_fold_strs[tgt_subj][s_n] = dict()
        new_single_cond_fold_strs[tgt_subj][s_n] = dict()
        for tgt_cond in ps['test_groups']:
            new_multi_cond_fold_strs[tgt_subj][s_n][tgt_cond] = multi_cond_fold_strs[tgt_subj][tgt_cond][s_n]
            new_single_cond_fold_strs[tgt_subj][s_n][tgt_cond] = single_cond_fold_strs[tgt_subj][tgt_cond][s_n]

## Save fold structures 

In [8]:
for tgt_subj in ps['subjects']:
    multi_cond_folds = new_multi_cond_fold_strs[tgt_subj]
    single_cond_folds = new_single_cond_fold_strs[tgt_subj]
    
    multi_cond_file_name = ps['save_str'] + '_tgt_' + str(tgt_subj) + '_multi_cond_folds.pkl'
    single_cond_file_name = ps['save_str'] + '_tgt_' + str(tgt_subj) + '_single_cond_folds.pkl'
    
    multi_cond_path = Path(ps['save_folder']) / multi_cond_file_name
    single_cond_path = Path(ps['save_folder']) / single_cond_file_name
    
    with open(multi_cond_path, 'wb') as f:
        pickle.dump(multi_cond_folds, f)
    with open(single_cond_path, 'wb') as f:
        pickle.dump(single_cond_folds, f)    
    
    