A notebook to generate plots from the post processed results of an across condition analysis

In [1]:
%load_ext autoreload
%autoreload 2

In [14]:
from pathlib import Path
import pickle
import re

import numpy as np
import torch

from janelia_core.stats.regression import r_squared

## Parameters go here 

In [12]:
# Top-level directory holding the results
base_dir = r'/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/testing/gnldr/across_cond_results/v1'

# Name used for all files containing post-processed results
pp_file = r'pp_test_results.pt'

# Specify the type of models we assess performance for - can be 'sp' or 'ip'
mdl_type = 'ip'

# Specify conditions we test on 
test_conds = ['omr_forward', 'omr_right', 'omr_left']



## Load all results

In [5]:
train_conds = [cond.name for cond in Path(base_dir).iterdir() if cond.is_dir()]
test_subjs = np.sort([int(re.search('.*_(\d+)', subj.name)[1]) 
                      for subj in (Path(base_dir) / train_conds[0]).iterdir()])

In [8]:
FIT_TYPES = ['single_cond', 'multi_cond']

In [13]:
rs = dict()
for cond in train_conds:
    rs[cond] = dict()
    for subj in test_subjs:
        rs[cond][subj] = dict()
        for fit_type in FIT_TYPES:
            fit_type_dir = Path(base_dir) / cond / ('subj_' + str(subj)) / fit_type
            fit_type_file = fit_type_dir / pp_file
            
            #with open(fit_type_file, 'rb') as f:
            rs[cond][subj][fit_type] = torch.load(fit_type_file)
            
            # Print some diagonstic information 
            print('Results for subject ' + str(subj) + ', train condition: ' + cond + 
                  ', fit type: ' + fit_type + ' ***')
            print('Best CP Ind: ' + str(rs[cond][subj][fit_type][mdl_type]['early_stopping']['best_cp_ind']))
            
      

Results for subject 8, train condition: omr_f_ns, fit type: single_cond ***
Best CP Ind: 1
Results for subject 8, train condition: omr_f_ns, fit type: multi_cond ***
Best CP Ind: 2
Results for subject 9, train condition: omr_f_ns, fit type: single_cond ***
Best CP Ind: 1
Results for subject 9, train condition: omr_f_ns, fit type: multi_cond ***
Best CP Ind: 1
Results for subject 11, train condition: omr_f_ns, fit type: single_cond ***
Best CP Ind: 1
Results for subject 11, train condition: omr_f_ns, fit type: multi_cond ***
Best CP Ind: 1
Results for subject 8, train condition: omr_r_ns, fit type: single_cond ***
Best CP Ind: 1
Results for subject 8, train condition: omr_r_ns, fit type: multi_cond ***
Best CP Ind: 2
Results for subject 9, train condition: omr_r_ns, fit type: single_cond ***
Best CP Ind: 1
Results for subject 9, train condition: omr_r_ns, fit type: multi_cond ***
Best CP Ind: 1
Results for subject 11, train condition: omr_r_ns, fit type: single_cond ***
Best CP Ind: 1
R

## Now assess test performance 

In [52]:
for ts_i, test_cond in enumerate(test_conds):
    print('**** Test Condition: ' + test_cond + ' ****')
    for s_i, subj in enumerate(test_subjs):
        print('Subject: ' + str(subj))
        print('Single ELBO, Multi ELBO, Delta, Training Condition')
        for tr_i, train_cond in enumerate(train_conds):
            single_period_elbo_vls = rs[train_cond][subj]['single_cond'][mdl_type]['period_elbo_vls'][subj]
            multi_period_elbo_vls = rs[train_cond][subj]['multi_cond'][mdl_type]['period_elbo_vls'][subj]
            
            if single_period_elbo_vls[test_cond] is not None:
                single_elbo = single_period_elbo_vls[test_cond]['elbo'].item()
                multi_elbo = multi_period_elbo_vls[test_cond]['elbo'].item()
                delta_elbo = multi_elbo - single_elbo
                print('{:.2E}'.format(single_elbo) + ', ' + '{:.2E}'.format(multi_elbo) + ', ' + 
                      '{:.2E}'.format(delta_elbo) + ', ' + train_cond)

**** Test Condition: omr_forward ****
Subject: 8
Single ELBO, Multi ELBO, Delta, Training Condition
-7.13E+07, -6.72E+07, 4.16E+06, omr_r_ns
-6.73E+07, -6.52E+07, 2.06E+06, omr_l_ns
Subject: 9
Single ELBO, Multi ELBO, Delta, Training Condition
-3.81E+07, -3.56E+07, 2.48E+06, omr_r_ns
-3.17E+07, -3.40E+07, -2.24E+06, omr_l_ns
Subject: 11
Single ELBO, Multi ELBO, Delta, Training Condition
-1.81E+07, -1.81E+07, 5.59E+04, omr_r_ns
-1.89E+07, -1.88E+07, 5.57E+04, omr_l_ns
**** Test Condition: omr_right ****
Subject: 8
Single ELBO, Multi ELBO, Delta, Training Condition
-6.65E+07, -6.74E+07, -8.37E+05, omr_f_ns
-8.78E+07, -7.99E+07, 7.91E+06, omr_l_ns
Subject: 9
Single ELBO, Multi ELBO, Delta, Training Condition
-1.94E+07, -1.87E+07, 6.52E+05, omr_f_ns
-2.16E+07, -2.04E+07, 1.21E+06, omr_l_ns
Subject: 11
Single ELBO, Multi ELBO, Delta, Training Condition
-2.17E+07, -2.12E+07, 5.19E+05, omr_f_ns
-2.40E+07, -2.22E+07, 1.80E+06, omr_l_ns
**** Test Condition: omr_left ****
Subject: 8
Single ELBO,

## Debug code goes here 

In [54]:
rs[train_conds[0]][test_subjs[0]]['multi_cond']['fit_ps']

{'note': 'Initial testing.',
 'param_filename': 'transfer_params.pkl',
 'param_save_dir': '/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/testing/gnldr/across_cond_results/v1',
 'results_dir': '/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/testing/gnldr/across_cond_results/v1/omr_f_ns/subj_8/multi_cond',
 'save_file': 'test_results.pt',
 'sp_cp_dir': '/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/testing/gnldr/across_cond_results/v1/omr_f_ns/subj_8/multi_cond/sp_cp',
 'ip_cp_dir': '/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/testing/gnldr/across_cond_results/v1/omr_f_ns/subj_8/multi_cond/ip_cp',
 'data_dir': '/groups/bishop/bishoplab/projects/ahrens_wbo/data',
 'segment_table_dir': '/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/real_data',
 'segment_table_file': 'omr_l_r_f_ns_across_cond_segments_8_9_10_11.pkl',
 'fold_str_dir': '/groups/bishop/bishoplab/pr