This notebook is designed to be used in conjunction with the notebook "segment_periods_for_transfer_analysis." That notebook breaks data for each subject into different sets, which can then be combined to form training, validation and testing data in this notebook.

The idea behind the way we break up data is we want to do multiple analyses, each of the following form: 

    1) We identify a "target" fish - this is a fish we want to transfer model structure we learn from other fish to.  We also identify a "target condition" we observe this fish under. 
    
    2) We identify a number of "transfer" fish - these are fish we will observe in conditions different from those we observe the target fish in. We want to transfer what we learn about model structure under these conditions to the target fish. 
    
    3) We will from our training and validation data in two ways.  In the first way, the training and validation data for each fish consists of a different condition.  In the second way, the training and validation data for all fish consists of the same condition as the condition for the target fish.  We make sure the total amount of training and validation data used in both cases in the same. 
    
    4) We then test the performance of the models for the target fish on the conditions outside of its training data (as well as on the condition in its training data). What we hope to see is that model performance improves on the conditions outside of the target fish's training condition when we synthesize a model for the target fish when the training data for the other fish is of the other conditions relative to when all fish have the same condition.  This would show our framework is able to transfer model structure across fish even when those fish are observed in different conditions. 
    
To this end, this script will generate and save two assignments of data to train, test and validation for each target fish and target fish condition.   We generate such an assignment for each type of condition we observe in the target fish (e.g., if we have OMR L, R and F data for our fish, we can do three seperate analyses where we assume we only observe one of these conditions in the target fish and then the other two in the transfer fish).  The two assignments correspond to when (1) we observe different conditions across fish and (2) when we observe the same condition in all fish. 


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import copy

from pathlib import Path
import pickle

import numpy as np
import numpy.random as random

from ahrens_wbo.data_processing import SegmentTable

## Parameters go here

In [3]:
ps = dict()

# Specify where the segment table created by segment_ahrens_data_for_across_cond_analsysis.ipynb is saved
ps['segment_table_folder'] = r'/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/real_data/'
ps['segment_table_file'] = r'omr_l_r_f_ns_across_cond_segments_8_9_10_11.pkl'

# Specify the different conditions we want to test - there should be three of these
ps['test_groups'] = ['omr_l_ns', 'omr_r_ns', 'omr_f_ns']

# Specify the different subjects we use in the analysis - there should be three of these
ps['subjects'] = [8, 9, 11]

# Specify the percentage of data for each target subject we use for training.  (Note because we 
# balance data across fish, we may not use this much, so this is the max we can use.)
ps['train_percentage'] = .7

# Specify the percentage of data for each target subject train condition we use for validation. (Note because we 
# balance data across fish, we may not use this much, so this is the max we can use.)
ps['validation_percentage'] = .15

# Specify the folder we should save the fold structures in
ps['save_folder'] = r'/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/real_data/'

# Specify a base string to prepend to file names with the fold structures
ps['save_str'] = 'ac_an'



## Load the segment tables and get basic information we need

In [4]:
segment_table_path = Path(ps['segment_table_folder']) / ps['segment_table_file']
with open(segment_table_path, 'rb') as f:
    seg_table_data = pickle.load(f)
    
seg_tables = {s_n: SegmentTable.from_dict(seg_table) for s_n, seg_table in seg_table_data['segment_tables'].items()}

## Define helper functions here

In [5]:
def randomly_assign_chunks(n_max_chunks, n_train_chunks, n_validation_chunks):
    chunk_order = random.permutation(n_max_chunks)
    train_chunks = chunk_order[:n_train_chunks]
    validation_chunks = chunk_order[n_train_chunks:n_train_chunks+n_validation_chunks]
    test_chunks = chunk_order[n_train_chunks+n_validation_chunks:]
    
    return train_chunks, validation_chunks, test_chunks

In [6]:
def assign_for_one_tgt_subj(seg_tables, tgt_subj, tgt_subj_train_group, train_percentage, validation_percentage, trans_subjs, 
                            trans_train_groups):
    """ Assigns train, validation & test data for a given target subject and target condition. 
    
    This function will generate two assignments: one for across condition fitting and one for within condition fitting.
    The across condition fitting is when the conditions in the training data for each subject are as listed in
    the tgt_subj_train_group and trans_train_groups inputs, which can be different.  The within condition fitting
    is when the training data for each subject is all of the tgt_subj_train_group condition. 
    
    This function will ensure the total amount of training and validation data for each subject is the same in
    both types of assignments.  It does this by finding the subject and condition with the smallest number of
    chunks, and then:
        1) Setting train_percentage of that min number of chunks as the number of training chunks 
           for all conditions and subjects
        2) Setting validation_percentage of that min number of chunks as the number of validation chunks
           for all conditions and subjects
    
    For each subject, the testing data will always be all chunks across all conditions not in its train and validation
    data. 
    
    To ensure a rough balance of swim vigor, we randomly pick the particular chunks assigned to train, validation and 
    test.
    
    Args:
    
        seg_tables: The segment tables for all subjects
        
        tgt_subj: The target subject
        
        tgt_subj_train_group: The target condition
        
        train_percentage: The max percentage of training data we can use for any subject 
        
        validation_percentage: The max percentage of validation data we can use for any subject
        
        trans_subjs: The base or transfer subjects
        
        trans_train_groups: The conditions in the training data for the transfer subjects
        
    Returns: 
    
        across_assignments: The assignments for the across condition analyses.  A dictionary with keys 
        corresponding to subjects.  Each entry is itself a dictionary with the keys, 'train', 'validation' 
        and 'test'.  Each of these is another dictionary with keys corresponding to groups and values correspoding
        to the sets (which are the same as chunks in this case) that should be pulled from that group. 
    
    """

    n_transfer_subjs = len(trans_subjs)

    all_subjs = [tgt_subj] + trans_subjs
    all_across_cond_train_conds = [tgt_subj_train_group] + trans_train_groups

    n_max_across_cond_train_chunks = [np.floor(train_percentage*seg_tables[s_n].n_group_segments(grp)) 
                                      for s_n, grp in zip(all_subjs, all_across_cond_train_conds)]

    n_max_across_cond_validation_chunks = [np.floor(validation_percentage*seg_tables[s_n].n_group_segments(grp)) 
                                           for s_n, grp in zip(all_subjs, all_across_cond_train_conds)]

    n_max_within_cond_train_chunks = [np.floor(train_percentage*seg_tables[s_n].n_group_segments(tgt_subj_train_group)) 
                                      for s_n in all_subjs]

    n_max_within_cond_validation_chunks = [np.floor(validation_percentage*seg_tables[s_n].n_group_segments(tgt_subj_train_group)) 
                                           for s_n in all_subjs]

    n_train_chunks = int(np.min(np.stack([n_max_across_cond_train_chunks, n_max_within_cond_train_chunks])))
    n_validation_chunks = int(np.min(np.stack([n_max_across_cond_validation_chunks, 
                                               n_max_within_cond_validation_chunks])))

    # Form across condition assignments
    across_assignments = dict()
    for s_n, grp in zip(all_subjs, all_across_cond_train_conds):
        num_segs_n = seg_tables[s_n].n_group_segments(grp)
    
        train_chunks_n, validation_chunks_n, test_chunks_n = randomly_assign_chunks(n_max_chunks=num_segs_n, 
                                                                            n_train_chunks=n_train_chunks,
                                                                            n_validation_chunks=n_validation_chunks)
    
        across_assignments[s_n] = {'train': {grp: ['set_' + str(i) for i in train_chunks_n]},
                                   'validation': {grp: ['set_' + str(i) for i in validation_chunks_n]},
                                   'test': {grp: ['set_' + str(i) for i in test_chunks_n] }}
    
        other_test_grps = list(set(all_across_cond_train_conds) - set([grp]))
        for other_grp in other_test_grps:
            other_grp_n_chunks = seg_tables[s_n].n_group_segments(other_grp)
            across_assignments[s_n]['test'][other_grp] = ['set_' + str(i) for i in range(other_grp_n_chunks)]
    
    # Form within condition assignments
    within_assignments = dict()
    within_assignments[tgt_subj] = copy.deepcopy(across_assignments[tgt_subj])
    for s_n in trans_subjs:
        num_segs_n = seg_tables[s_n].n_group_segments(tgt_subj_train_group)
        train_chunks_n, validation_chunks_n, test_chunks_n = randomly_assign_chunks(n_max_chunks=num_segs_n, 
                                                                                n_train_chunks=n_train_chunks,
                                                                                n_validation_chunks=n_validation_chunks)
    
        within_assignments[s_n] = {'train': {tgt_subj_train_group: ['set_' + str(i) for i in train_chunks_n]},
                                   'validation': {tgt_subj_train_group: ['set_' + str(i) for i in validation_chunks_n]},
                                   'test': {tgt_subj_train_group: ['set_' + str(i) for i in test_chunks_n] }}
    
        for other_grp in trans_train_groups:
            other_grp_n_chunks = seg_tables[s_n].n_group_segments(other_grp)
            within_assignments[s_n]['test'][other_grp] = ['set_' + str(i) for i in range(other_grp_n_chunks)]
            
    return across_assignments, within_assignments


## Form assignments

In [7]:
multi_cond_fold_strs = dict()
single_cond_fold_strs = dict()
for tgt_subj in ps['subjects']:
    fish_multi_cond_folds = dict()
    fish_single_cond_folds = dict()
    for tgt_cond in ps['test_groups']:
        trans_subjs = [s_n for s_n in ps['subjects'] if s_n != tgt_subj]
        trans_conds = [cond for cond in ps['test_groups']  if cond != tgt_cond]
        
        multi_cond_folds, single_cond_folds = assign_for_one_tgt_subj(seg_tables=seg_tables, 
                                                tgt_subj=tgt_subj, 
                                                tgt_subj_train_group=tgt_cond, 
                                                train_percentage=ps['train_percentage'], 
                                                validation_percentage=ps['validation_percentage'],
                                                trans_subjs=trans_subjs, 
                                                trans_train_groups=trans_conds)
        
        fish_multi_cond_folds[tgt_cond] = multi_cond_folds
        fish_single_cond_folds[tgt_cond] = single_cond_folds
    
    multi_cond_fold_strs[tgt_subj] = fish_multi_cond_folds
    single_cond_fold_strs[tgt_subj] = fish_single_cond_folds

Rearrange the fold structures so they are organized so that the fish is the first level and then fold is the second level - this will then allow the fold structures to be used seamlessly with our standard fitting code

In [8]:
new_multi_cond_fold_strs = dict()
new_single_cond_fold_strs = dict()
for tgt_subj in ps['subjects']:
    new_multi_cond_fold_strs[tgt_subj] = dict()
    new_single_cond_fold_strs[tgt_subj] = dict()
    for s_n in ps['subjects']:
        new_multi_cond_fold_strs[tgt_subj][s_n] = dict()
        new_single_cond_fold_strs[tgt_subj][s_n] = dict()
        for tgt_cond in ps['test_groups']:
            new_multi_cond_fold_strs[tgt_subj][s_n][tgt_cond] = multi_cond_fold_strs[tgt_subj][tgt_cond][s_n]
            new_single_cond_fold_strs[tgt_subj][s_n][tgt_cond] = single_cond_fold_strs[tgt_subj][tgt_cond][s_n]

## Save fold structures 

In [9]:
for tgt_subj in ps['subjects']:
    multi_cond_folds = new_multi_cond_fold_strs[tgt_subj]
    single_cond_folds = new_single_cond_fold_strs[tgt_subj]
    
    multi_cond_file_name = ps['save_str'] + '_tgt_' + str(tgt_subj) + '_multi_cond_folds.pkl'
    single_cond_file_name = ps['save_str'] + '_tgt_' + str(tgt_subj) + '_single_cond_folds.pkl'
    
    multi_cond_path = Path(ps['save_folder']) / multi_cond_file_name
    single_cond_path = Path(ps['save_folder']) / single_cond_file_name
    
    with open(multi_cond_path, 'wb') as f:
        pickle.dump(multi_cond_folds, f)
    with open(single_cond_path, 'wb') as f:
        pickle.dump(single_cond_folds, f)    
    
    

## Debug

In [10]:
new_multi_cond_fold_strs[8][8]['omr_f_ns']

{'train': {'omr_f_ns': ['set_13',
   'set_51',
   'set_87',
   'set_48',
   'set_2',
   'set_45',
   'set_4',
   'set_36',
   'set_77',
   'set_73',
   'set_43',
   'set_7',
   'set_24',
   'set_75',
   'set_106',
   'set_97',
   'set_58',
   'set_62',
   'set_11',
   'set_76',
   'set_5',
   'set_81']},
 'validation': {'omr_f_ns': ['set_6', 'set_55', 'set_41', 'set_71']},
 'test': {'omr_f_ns': ['set_100',
   'set_35',
   'set_86',
   'set_64',
   'set_28',
   'set_44',
   'set_70',
   'set_20',
   'set_95',
   'set_56',
   'set_90',
   'set_80',
   'set_84',
   'set_8',
   'set_69',
   'set_57',
   'set_89',
   'set_93',
   'set_38',
   'set_68',
   'set_40',
   'set_9',
   'set_25',
   'set_32',
   'set_82',
   'set_91',
   'set_61',
   'set_0',
   'set_88',
   'set_59',
   'set_66',
   'set_92',
   'set_67',
   'set_98',
   'set_17',
   'set_34',
   'set_33',
   'set_27',
   'set_78',
   'set_63',
   'set_83',
   'set_19',
   'set_30',
   'set_105',
   'set_49',
   'set_47',
   'set

In [11]:
multi_cond_path

PosixPath('/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/real_data/ac_an_tgt_11_multi_cond_folds.pkl')