A notebook to generate plots from the post processed results of an across condition analysis

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
import pickle
import re

import numpy as np

from ahrens_wbo.annotations import label_subperiods
from ahrens_wbo.data_processing import load_and_preprocess_data
from janelia_core.stats.regression import r_squared
from probabilistic_model_synthesis.gnlr_ahrens_tools import find_period_time_points

## Parameters go here

In [3]:
# Top-level directory holding the results
base_dir = r'/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/real_data/gnlr/across_cond_transfer_analysis/v6'

# Name used for all files containing post-processed results
pp_file = r'pp_test_results.pkl'

# Specify the type of models we assess performance for - can be 'sp' or 'ip'
mdl_type = 'ip'

# Specify conditions we test on 
test_conds = ['omr_forward', 'omr_right', 'omr_left']



## Define helper functions

In [4]:
def rmse(x0, x1): 
    return np.sqrt(np.mean((x0 - x1)**2, axis=1))

## Load all results

In [5]:
train_conds = [cond.name for cond in Path(base_dir).iterdir() if cond.is_dir()]
test_subjs = np.sort([int(re.search('.*_(\d+)', subj.name)[1]) 
                      for subj in (Path(base_dir) / train_conds[0]).iterdir()])

In [6]:
FIT_TYPES = ['single_cond', 'multi_cond']

In [7]:
rs = dict()
for cond in train_conds:
    rs[cond] = dict()
    for subj in test_subjs:
        rs[cond][subj] = dict()
        for fit_type in FIT_TYPES:
            fit_type_dir = Path(base_dir) / cond / ('subj_' + str(subj)) / fit_type
            fit_type_file = fit_type_dir / pp_file
            
            with open(fit_type_file, 'rb') as f:
                rs[cond][subj][fit_type] = pickle.load(f)
            
            # Print some diagonstic information 
            print('Results for subject ' + str(subj) + ', train condition: ' + cond + 
                  ', fit type: ' + fit_type + ' ***')
            print('Best CP Ind: ' + str(rs[cond][subj][fit_type][mdl_type]['early_stopping']['best_cp_ind']))
            
      

Results for subject 8, train condition: omr_f_ns, fit type: single_cond ***
Best CP Ind: 2
Results for subject 8, train condition: omr_f_ns, fit type: multi_cond ***
Best CP Ind: 1
Results for subject 9, train condition: omr_f_ns, fit type: single_cond ***
Best CP Ind: 28
Results for subject 9, train condition: omr_f_ns, fit type: multi_cond ***
Best CP Ind: 12
Results for subject 11, train condition: omr_f_ns, fit type: single_cond ***
Best CP Ind: 0
Results for subject 11, train condition: omr_f_ns, fit type: multi_cond ***
Best CP Ind: 12
Results for subject 8, train condition: omr_r_ns, fit type: single_cond ***
Best CP Ind: 1
Results for subject 8, train condition: omr_r_ns, fit type: multi_cond ***
Best CP Ind: 8
Results for subject 9, train condition: omr_r_ns, fit type: single_cond ***
Best CP Ind: 8
Results for subject 9, train condition: omr_r_ns, fit type: multi_cond ***
Best CP Ind: 19
Results for subject 11, train condition: omr_r_ns, fit type: single_cond ***
Best CP Ind:

## Load raw data - we need this to label the different periods of test data 

In [8]:
ex_fit_ps = rs[train_conds[0]][test_subjs[0]][FIT_TYPES[0]]['fit_ps']
subject_order = rs[train_conds[0]][test_subjs[0]][FIT_TYPES[0]]['subject_order']

subject_data, subject_neuron_locs = load_and_preprocess_data(data_folder=ex_fit_ps['data_dir'], 
                                                             subjects=subject_order,
                                                             neural_gain=ex_fit_ps['neural_gain'], 
                                                             z_ratio=ex_fit_ps['z_ratio'])

Done loading data for subject subject_8.
Done loading data for subject subject_9.
Done loading data for subject subject_11.


## Get labels for all moments in time

In [9]:
labels = {s_n: {'ts': subject_data[s_n].ts_data['stim']['ts'],
                'labels': label_subperiods(subject_data[s_n].ts_data['stim']['vls'][:])} for s_n in subject_order}

## Now assess test performance 

In [10]:
metric = r_squared

In [11]:
n_subjs = len(test_subjs)
n_train_conds = len(train_conds)
n_test_conds = len(test_conds)

In [12]:
single_perf = np.zeros([n_subjs, n_train_conds, n_test_conds])
multi_perf = np.zeros([n_subjs, n_train_conds, n_test_conds])
single_perf[:] = np.nan
multi_perf[:] = np.nan

perf_arrays = [single_perf, multi_perf]

for s_i, subj in enumerate(test_subjs):
    for tr_i, train_cond in enumerate(train_conds):
        for ft_i, fit_type in enumerate(FIT_TYPES):
            for ts_i, test_cond in enumerate(test_conds):
        
                y = rs[train_cond][subj][fit_type][mdl_type]['preds'][subj]['test']['y']
                y_hat = rs[train_cond][subj][fit_type][mdl_type]['preds'][subj]['test']['y_hat']
                t = rs[train_cond][subj][fit_type][mdl_type]['preds'][subj]['test']['t']
                test_cond_inds = find_period_time_points(cand_ts=t, period=test_cond, 
                                                         shock=False, labels=labels[subj])
                
                if len(test_cond_inds) > 0: 
                    perf_arrays[ft_i][s_i][tr_i][ts_i] = np.mean(metric(y[test_cond_inds,:], y_hat[test_cond_inds,:]))

## Now look at raw performance values 

In [13]:
for ts_i, test_cond in enumerate(test_conds):
    print('**** Test Condition: ' + test_cond + ' ****')
    for s_i, subj in enumerate(test_subjs):
        print('Subject: ' + str(subj))
        print('Single Performance, Multi Performance, Delta, Training Condition')
        for tr_i, train_cond in enumerate(train_conds):
            if not np.isnan(single_perf[s_i][tr_i][ts_i]):
                print(str(single_perf[s_i][tr_i][ts_i]) + ', ' + str(multi_perf[s_i][tr_i][ts_i]) + ', ' + 
                      str(single_perf[s_i][tr_i][ts_i] - multi_perf[s_i][tr_i][ts_i]) + ', '
                      ', trained on: ' + train_cond)

**** Test Condition: omr_forward ****
Subject: 8
Single Performance, Multi Performance, Delta, Training Condition
-0.6557765007019043, -1.0993976593017578, 0.4436211585998535, , trained on: omr_r_ns
0.2537192106246948, 0.29001662135124207, -0.03629741072654724, , trained on: omr_l_ns
Subject: 9
Single Performance, Multi Performance, Delta, Training Condition
-2.9953720569610596, -1.3695883750915527, -1.6257836818695068, , trained on: omr_r_ns
-0.6418294906616211, -1.2445290088653564, 0.6026995182037354, , trained on: omr_l_ns
Subject: 11
Single Performance, Multi Performance, Delta, Training Condition
0.10378637909889221, 0.024032682180404663, 0.07975369691848755, , trained on: omr_r_ns
0.179833322763443, 0.21798691153526306, -0.03815358877182007, , trained on: omr_l_ns
**** Test Condition: omr_right ****
Subject: 8
Single Performance, Multi Performance, Delta, Training Condition
-3.684431552886963, -9.254469871520996, 5.570038318634033, , trained on: omr_f_ns
-2.078763008117676, -1.05

## Debug code goes here 

In [14]:
rs[train_conds[0]][test_subjs[0]]['single_cond']['fit_ps']

{'note': 'Initial testing. p=20, less hypercubes',
 'param_filename': 'transfer_params.pkl',
 'param_save_dir': '/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/real_data/gnlr/across_cond_transfer_analysis/v6',
 'results_dir': '/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/real_data/gnlr/across_cond_transfer_analysis/v6/omr_f_ns/subj_8/single_cond',
 'save_file': 'test_results.pt',
 'sp_cp_dir': '/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/real_data/gnlr/across_cond_transfer_analysis/v6/omr_f_ns/subj_8/single_cond/sp_cp',
 'ip_cp_dir': '/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/real_data/gnlr/across_cond_transfer_analysis/v6/omr_f_ns/subj_8/single_cond/ip_cp',
 'data_dir': '/groups/bishop/bishoplab/projects/ahrens_wbo/data',
 'segment_table_dir': '/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/real_data',
 'segment_table_file': 'omr_l_r_f_ns_across_cond_s