A notebook to generate plots from the post processed results of an across condition analysis

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
import pickle
import re

import numpy as np
import torch

from janelia_core.stats.regression import r_squared

## Parameters go here 

In [3]:
# Top-level directory holding the results
base_dir = r'/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/real_data/gnldr/across_cond_transfer_analysis/v0'

# Name used for all files containing post-processed results
pp_file = r'pp_test_results.pt'

# Specify the type of models we assess performance for - can be 'sp' or 'ip'
mdl_type = 'ip'

# Specify conditions we test on 
test_conds = ['omr_forward', 'omr_right', 'omr_left']



## Load all results

In [4]:
train_conds = [cond.name for cond in Path(base_dir).iterdir() if cond.is_dir()]
test_subjs = np.sort([int(re.search('.*_(\d+)', subj.name)[1]) 
                      for subj in (Path(base_dir) / train_conds[0]).iterdir()])

In [5]:
FIT_TYPES = ['single_cond', 'multi_cond']

In [6]:
rs = dict()
for cond in train_conds:
    rs[cond] = dict()
    for subj in test_subjs:
        rs[cond][subj] = dict()
        for fit_type in FIT_TYPES:
            fit_type_dir = Path(base_dir) / cond / ('subj_' + str(subj)) / fit_type
            fit_type_file = fit_type_dir / pp_file
            
            #with open(fit_type_file, 'rb') as f:
            rs[cond][subj][fit_type] = torch.load(fit_type_file)
            
            # Print some diagonstic information 
            print('Results for subject ' + str(subj) + ', train condition: ' + cond + 
                  ', fit type: ' + fit_type + ' ***')
            print('Best CP Ind: ' + str(rs[cond][subj][fit_type][mdl_type]['early_stopping']['best_cp_ind']))
            
      

Results for subject 8, train condition: omr_f_ns, fit type: single_cond ***
Best CP Ind: 20
Results for subject 8, train condition: omr_f_ns, fit type: multi_cond ***
Best CP Ind: 20
Results for subject 9, train condition: omr_f_ns, fit type: single_cond ***
Best CP Ind: 2
Results for subject 9, train condition: omr_f_ns, fit type: multi_cond ***
Best CP Ind: 14
Results for subject 11, train condition: omr_f_ns, fit type: single_cond ***
Best CP Ind: 6
Results for subject 11, train condition: omr_f_ns, fit type: multi_cond ***
Best CP Ind: 7
Results for subject 8, train condition: omr_r_ns, fit type: single_cond ***
Best CP Ind: 3
Results for subject 8, train condition: omr_r_ns, fit type: multi_cond ***
Best CP Ind: 20
Results for subject 9, train condition: omr_r_ns, fit type: single_cond ***
Best CP Ind: 2
Results for subject 9, train condition: omr_r_ns, fit type: multi_cond ***
Best CP Ind: 7
Results for subject 11, train condition: omr_r_ns, fit type: single_cond ***
Best CP Ind:

## Now assess test performance 

In [7]:
for ts_i, test_cond in enumerate(test_conds):
    print('**** Test Condition: ' + test_cond + ' ****')
    for s_i, subj in enumerate(test_subjs):
        print('Subject: ' + str(subj))
        print('Single ELBO, Multi ELBO, Delta, Training Condition')
        for tr_i, train_cond in enumerate(train_conds):
            single_period_elbo_vls = rs[train_cond][subj]['single_cond'][mdl_type]['period_elbo_vls'][subj]
            multi_period_elbo_vls = rs[train_cond][subj]['multi_cond'][mdl_type]['period_elbo_vls'][subj]
            
            if single_period_elbo_vls[test_cond] is not None:
                single_elbo = single_period_elbo_vls[test_cond]['elbo'].item()
                multi_elbo = multi_period_elbo_vls[test_cond]['elbo'].item()
                delta_elbo = multi_elbo - single_elbo
                print('{:.2E}'.format(single_elbo) + ', ' + '{:.2E}'.format(multi_elbo) + ', ' + 
                      '{:.2E}'.format(delta_elbo) + ', ' + train_cond)

**** Test Condition: omr_forward ****
Subject: 8
Single ELBO, Multi ELBO, Delta, Training Condition
-7.11E+07, -6.72E+07, 3.91E+06, omr_r_ns
-6.64E+07, -6.45E+07, 1.95E+06, omr_l_ns
Subject: 9
Single ELBO, Multi ELBO, Delta, Training Condition
-3.63E+07, -3.45E+07, 1.82E+06, omr_r_ns
-3.12E+07, -3.32E+07, -1.99E+06, omr_l_ns
Subject: 11
Single ELBO, Multi ELBO, Delta, Training Condition
-1.78E+07, -1.79E+07, -3.79E+04, omr_r_ns
-1.86E+07, -1.83E+07, 3.35E+05, omr_l_ns
**** Test Condition: omr_right ****
Subject: 8
Single ELBO, Multi ELBO, Delta, Training Condition
-6.63E+07, -6.54E+07, 9.24E+05, omr_f_ns
-8.53E+07, -7.77E+07, 7.63E+06, omr_l_ns
Subject: 9
Single ELBO, Multi ELBO, Delta, Training Condition
-1.89E+07, -1.87E+07, 2.03E+05, omr_f_ns
-2.11E+07, -1.99E+07, 1.24E+06, omr_l_ns
Subject: 11
Single ELBO, Multi ELBO, Delta, Training Condition
-2.11E+07, -2.06E+07, 5.78E+05, omr_f_ns
-2.34E+07, -2.14E+07, 2.01E+06, omr_l_ns
**** Test Condition: omr_left ****
Subject: 8
Single ELBO,

## Debug code goes here 

In [None]:
rs[train_conds[0]][test_subjs[0]]['multi_cond']['fit_ps']