A notebook where we synthesize regression models across conditions.  We examine the performance of the models when they are fit together and individually. 

Specifically, we synthesize regression models that predict only one variable.  The models project data down to a single variable, which is then passed through a complicated non-linear function.  When we say we synthesize models "across conditions" we mean that we simulate data so that data for each model projects into a given range of the low-s space, so we only see the behavior of the non-linear function in a certain range for each individual. 

Some details:

1) Each input variable in the model is associated with a 2-d position (position is the measurable property in this example).  Variable positions are sampled uniformly from the unit square. 

2) We pull the weights for each model from a prior where the mean and std. conditioned on position are sums of bump functions (so they are truly smooth). When fitting models, we use hypercube functions for the mean and std. (so there
is model mismatch here).  This can make correctly learning the std functions tricky, because if the true mean functions vary alot within a single hypercube, we will learn a standard deviation that is elevated in that region

3) We simulate data with scales and biases which are very close to 1 and 0, respectively.  This keeps things simpler.  When fitting, we use priors which assume scales and biases are concentrated near these values. 

4) We pull noise variances from a Gamma distribution when simulating data, and we also fit models with priors and posteriors which are also Gamma distributions over noise variances.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import copy
import math

import matplotlib
from matplotlib import cm
import matplotlib.pyplot as plt
import numpy as np
import numpy.linalg
import torch

from janelia_core.math.basic_functions import list_grid_pts
from janelia_core.ml.extra_torch_modules import DenseLNLNet
from janelia_core.ml.extra_torch_modules import QuadSurf
from janelia_core.ml.fitting import match_torch_module
from janelia_core.ml.utils import list_torch_devices


from probabilistic_model_synthesis.gaussian_nonlinear_regression import compare_weight_prior_dists
from probabilistic_model_synthesis.gaussian_nonlinear_regression import align_low_d_spaces
from probabilistic_model_synthesis.gaussian_nonlinear_regression import fit_with_hypercube_priors
from probabilistic_model_synthesis.gaussian_nonlinear_regression import Fitter
from probabilistic_model_synthesis.gaussian_nonlinear_regression import generate_hypercube_prior_collection
from probabilistic_model_synthesis.gaussian_nonlinear_regression import GNLRMdl
from probabilistic_model_synthesis.simulation import generate_sum_of_bump_fcns_dist
from probabilistic_model_synthesis.simulation import efficient_cone_and_projected_interval_sample
from probabilistic_model_synthesis.simulation import sample_proj_data_from_interval
from probabilistic_model_synthesis.simulation import sample_proj_data_from_interval
from probabilistic_model_synthesis.visualization import assign_colors_to_pts
from probabilistic_model_synthesis.visualization import plot_three_dim_pts

In [3]:
%matplotlib notebook
plt.style.use('dark_background')
matplotlib.rcParams.update({'font.size': 20})

## Parameters go here

In [72]:
# The number of example systems we generate data for - must be >= 4
n_ex_systems = 100

# Range of the number of input variables we observe from each example system - the actual number of variables we 
# observe from an example system will be pulled uniformly from this range (inclusive)
n_input_var_range = [10000, 11000] #[5000, 5100]

# Range of the number of samples we observe from each example system - the actual number we observe from each 
# example system will be unformly from this range (inclusive)
n_smps_range = [10000, 11000] #[12000, 15000]


# ===============================================================================================
# Parameters for the true priors 

# Options for the prior distribution on weights
#true_w_prior_opts = {'n_bump_fcns': 50, 'd_in': 2, 'p': 1, 'mn_m_std': 1.0, 'std_m_std': .1, 'bump_w': .2}
true_w_prior_opts = {'n_bump_fcns': 50, 'd_in': 2, 'p': 1, 'mn_m_std': 1.0, 'std_m_std': .1, 'bump_w': .2}


# Options for the prior distributions on the scales, biases and psi
#true_s_in_prior_opts = {'mn_mn': 1.0/np.sqrt(n_input_var_range[0]), 'mn_std': .00000001, 
#                     'std_iv': .00001, 'std_lb': .0000001}

true_s_in_prior_opts = {'mn_mn': 1.0/np.sqrt(n_input_var_range[0]), 'mn_std': .00000001, 
                        'std_iv': .0000001, 'std_ub': .0001, 'std_lb': .00000001}

true_s_out_prior_opts = {'mn_mn': 1.0, 'mn_std': .00000001, 
                     'std_iv': .00001, 'std_ub': .0001, 'std_lb': .0000001}

true_b_prior_opts = {'mn_mn': 0.0, 'mn_std': .00000001, 'std_iv': .001}
true_psi_prior_opts = {'conc_iv': 10.0, 'rate_iv': 1000.0, 'rate_ub': 10000.0}

# ===============================================================================================
# Parameters for the fit models

# The full options for setting up the prior on weights
fit_hc_params = {'n_divisions_per_dim': [100, 100], 
                 'dim_ranges': np.asarray([[-.1, 1.1],
                                       [-.1, 1.1]]),
                 'n_div_per_hc_side_per_dim': [2, 2]}

# The full options for setting up the prior on weights
fit_w_prior_opts = {'mn_hc_params': fit_hc_params, 'std_hc_params': fit_hc_params, 
                     'min_std': .000001, 'mn_init': 0.0, 'std_init': .3} #1

# Options for the prior distributions on the scales, biases and psi
fit_s_in_prior_opts = true_s_in_prior_opts
fit_s_out_prior_opts = true_s_out_prior_opts
fit_b_prior_opts = true_b_prior_opts
fit_psi_prior_opts = true_psi_prior_opts

# Options for posterior distribtions 
#s_in_post_opts = {'mn_mn': 1.0/np.sqrt(n_input_var_range[0]), 'mn_std': .000001, 
#               'std_iv': .00001, 'std_lb': .0000001}
s_in_post_opts = {'mn_mn': 1.0/np.sqrt(n_input_var_range[0]), 'mn_std': .00000001, 
                        'std_iv': .0000001, 'std_ub': .0001, 'std_lb': .00000001}
s_out_post_opts = {'mn_mn': 1.0, 'mn_std': .00000001, 
                     'std_iv': .00001, 'std_ub': .0001, 'std_lb': .0000001}

b_post_opts = {'mn_mn': 0.0, 'mn_std': .01}
psi_post_opts = {'conc_iv': 10.0, 'rate_iv': 1.0, 'rate_ub': 100000.0}

# Options for the densenet which makes up the shared-m module
#dense_net_opts = {'n_layers': 10, 'growth_rate': 5, 'bias': True} # These work well 
dense_net_opts = {'n_layers': 2, 'growth_rate': 40, 'bias': True}

# ======================================================================================================
# Parameters for fitting - should be entered as lists, each entry corresponding to one round of fitting

# Parameters when fitting combined models
comb_sp_fit_opts = [{'n_epochs': 500, 'milestones': None, 'update_int': 100, 'init_lr': .01, 'n_batches': 2}]
comb_ip_fit_opts = [{'n_epochs': 1000, 'milestones': [1000], 'update_int': 100, 'init_lr': .1, 'n_batches': 2},
                    {'n_epochs': 1000, 'milestones': [1000], 'update_int': 100, 'init_lr': .01, 'n_batches': 2},
                    {'n_epochs': 1000, 'milestones': [1000], 'update_int': 100, 'init_lr': .001, 'n_batches': 2}]
                    
                    

single_sp_fit_opts = comb_sp_fit_opts
single_ip_fit_opts = comb_ip_fit_opts


# ======================================================================================================
# Specify the number of intermediate variables and number of variables we predict
p = 1
d_pred = 1

## Create true distributions that govern how systems under study are generated

In [5]:
# Because we do not use the hypercube priors on weights, we provide some default paramaters for these before
# creating the priors

temp_hc_params = {'n_divisions_per_dim': [100, 100], 
                  'dim_ranges': np.asarray([[-.1, 1.1],
                                       [-.1, 1.1]]),
                   'n_div_per_hc_side_per_dim': [2, 2]}

temp_w_prior_opts = {'mn_hc_params': temp_hc_params, 'std_hc_params': temp_hc_params, 
                     'min_std': .000001, 'mn_init': 0.0, 'std_init': .3}

true_priors = generate_hypercube_prior_collection(p=p, d_pred=d_pred, 
                                                  w_prior_opts=temp_w_prior_opts, 
                                                  s_in_prior_opts=true_s_in_prior_opts, 
                                                  b_in_prior_opts=true_b_prior_opts, 
                                                  s_out_prior_opts=true_s_out_prior_opts,
                                                  b_out_prior_opts=true_b_prior_opts, 
                                                  psi_prior_opts=true_psi_prior_opts)

In [6]:
true_priors.w_prior.dists[0].mn_f.to('cpu')
true_priors.w_prior = generate_sum_of_bump_fcns_dist(**true_w_prior_opts)

## Define the true non-linear function relating projections of input variables to the mean of output variables

In [7]:
class Quad(torch.nn.Module):
    
    def forward(self, x):
        return x + torch.sin(3*x) 

In [8]:
m_true = Quad()

## Generate data

### Generate properties

In [9]:
ind_n_vars = np.random.randint(n_input_var_range[0], n_input_var_range[1]+1, n_ex_systems)
ind_props = [torch.rand(size=[n_vars,2]) for n_vars in ind_n_vars]

### Generate true models for each individual

In [10]:
with torch.no_grad():
    ind_true_mdls = [GNLRMdl(m=m_true, 
                     w=true_priors.w_prior.form_standard_sample(true_priors.w_prior.sample(props)),
                     s_in=true_priors.s_in_prior.form_standard_sample(true_priors.s_in_prior.sample(props)).squeeze(axis=1),
                     b_in=true_priors.b_in_prior.form_standard_sample(true_priors.b_in_prior.sample(props)).squeeze(axis=1),
                     s_out=true_priors.s_out_prior.form_standard_sample(true_priors.s_out_prior.sample(props)).squeeze(axis=1),
                     b_out=true_priors.b_out_prior.form_standard_sample(true_priors.b_out_prior.sample(props)).squeeze(axis=1),
                     psi=true_priors.psi_prior.form_standard_sample(true_priors.psi_prior.sample(props)).squeeze(axis=1))
                     for props in ind_props]
    
    if d_pred > 1:
        for mdl in ind_true_mdls:
            mdl.s_out.data = mdl.s_out.data.squeeze()
            mdl.b_out.data = mdl.b_out.data.squeeze()
            mdl.psi.data = mdl.psi.data.squeeze()
                             

### Generate observations from true models

In [11]:
ang_ranges = np.ones(n_ex_systems)
ang_ranges[0] = 0
ang_ranges *= .5*np.pi
ang_ranges = np.cumsum(ang_ranges)
ang_ranges = [[a, a+np.pi] for a in ang_ranges]

In [12]:
ind_n_smps = np.random.randint(n_smps_range[0], n_smps_range[1]+1, n_ex_systems)
ind_data = [None]*n_ex_systems

min_proj_vl = -1.96*np.sqrt(n_input_var_range[0])
max_proj_vl = 1.96*np.sqrt(n_input_var_range[0])
interval_span = (max_proj_vl - min_proj_vl)/4

ind_intervals = [None]*n_ex_systems
for i in range(n_ex_systems):
    
    if i < 4:
        start_interval = min_proj_vl + interval_span*i
    else:
        start_interval = np.random.uniform(low=min_proj_vl, high=max_proj_vl-interval_span)
        
    stop_interval = start_interval + interval_span
     
    cur_interval = [start_interval, stop_interval]
    ind_intervals[i] = cur_interval
    x_i_new = efficient_cone_and_projected_interval_sample(n_smps=ind_n_smps[i], 
                                             locs=ind_props[i],
                                             ctr=torch.tensor([.5, .5]), 
                                             ang_range=ang_ranges[i],
                                             w=ind_true_mdls[i].w.detach(),
                                             interval=cur_interval, 
                                             big_std=1.0, 
                                             small_std=0,
                                             device=torch.device('cuda'))
    x_i_new = x_i_new.cpu()
    
    with torch.no_grad():
        y_i_new = ind_true_mdls[i].sample(x=x_i_new)
    ind_data[i] = (x_i_new, y_i_new)

In [13]:
i = 5
stds = np.std(np.asarray(ind_data[i][0]), axis=0)

min_smp_std = np.min(stds)
max_smp_std = np.max(stds)
smp_std_span = max_smp_std - min_smp_std

smp_stds_scaled = (stds - min_smp_std)/smp_std_span

In [14]:
plot_std_clrs = cm.viridis(smp_stds_scaled)

#plt.figure(figsize=(7, 7))
#ax = plt.subplot(1,1,1)
#ax.scatter(ind_props[i][:,0], ind_props[i][:,1], marker='.', color=plot_std_clrs)

## Synthesize models (fit models to the data together)

In [73]:
comb_fit_rs = fit_with_hypercube_priors(data=ind_data[0:100], props=ind_props[0:100], p=p, 
                                   w_prior_opts=fit_w_prior_opts, 
                                   s_in_prior_opts=fit_s_in_prior_opts,
                                   b_in_prior_opts=fit_b_prior_opts,
                                   s_out_prior_opts=fit_s_out_prior_opts,
                                   b_out_prior_opts=fit_b_prior_opts,
                                   psi_prior_opts=fit_psi_prior_opts,
                                   s_in_post_opts=s_in_post_opts,
                                   b_in_post_opts=b_post_opts,
                                   s_out_post_opts=s_out_post_opts,
                                   b_out_post_opts=b_post_opts,
                                   psi_post_opts=psi_post_opts,
                                   dense_net_opts=dense_net_opts, 
                                   sp_fit_opts=comb_sp_fit_opts, 
                                   ip_fit_opts=comb_ip_fit_opts, 
                                   sp_fixed_var=True)
                        

Found 1 GPUs
Fixing variance of sp w_prior.
Beginning SP fitting.

Obj: 2.50e+11
----------------------------------------
NELL: 2.15e+04, 2.31e+04, 2.23e+04, 2.14e+04, 2.15e+04, 2.28e+04, 2.38e+04, 2.28e+04, 1.97e+04, 2.16e+04, 2.27e+04, 2.27e+04, 2.15e+04, 2.12e+04, 2.35e+04, 2.22e+04, 2.18e+04, 2.14e+04, 2.28e+04, 2.12e+04, 2.16e+04, 2.10e+04, 2.12e+04, 2.19e+04, 2.03e+04, 2.04e+04, 2.09e+04, 2.17e+04, 2.16e+04, 2.28e+04, 2.20e+04, 1.96e+04, 2.28e+04, 2.35e+04, 2.30e+04, 2.09e+04, 2.22e+04, 2.25e+04, 2.01e+04, 2.24e+04, 2.33e+04, 2.15e+04, 2.08e+04, 2.10e+04, 2.12e+04, 2.19e+04, 2.06e+04, 2.17e+04, 2.14e+04, 2.24e+04, 2.24e+04, 2.19e+04, 2.26e+04, 2.20e+04, 2.28e+04, 2.18e+04, 2.30e+04, 2.19e+04, 2.18e+04, 2.20e+04, 2.40e+04, 2.10e+04, 2.05e+04, 2.19e+04, 2.14e+04, 2.07e+04, 2.31e+04, 2.35e+04, 2.17e+04, 2.10e+04, 2.25e+04, 2.29e+04, 2.22e+04, 2.06e+04, 2.13e+04, 2.17e+04, 2.42e+04, 2.26e+04, 2.17e+04, 2.29e+04, 2.19e+04, 2.24e+04, 2.23e+04, 2.11e+04, 2.09e+04, 2.06e+04, 2.14e+04, 2.


Obj: 3.14e+05
----------------------------------------
NELL: -8.93e+02, 2.14e+03, 3.41e+03, 9.09e+02, 3.04e+02, -1.09e+03, 5.92e+03, -2.39e+02, 1.67e+03, -4.58e+02, 5.49e+03, -5.30e+02, 4.75e+01, 3.45e+03, -1.02e+03, 4.23e+03, -1.52e+03, 8.54e+04, 1.96e+03, 4.22e+03, 3.59e+03, 8.97e+03, 4.51e+02, 2.82e+03, 2.76e+03, 3.67e+03, 5.55e+03, 4.28e+03, 2.47e+03, -1.10e+03, 4.96e+03, 5.67e+03, -1.56e+03, -1.53e+03, 2.54e+03, -9.18e+02, 3.73e+03, 9.37e+03, 8.17e+02, 2.06e+04, -6.14e+02, 3.36e+03, 8.54e+02, 4.70e+03, -3.51e+03, 3.06e+03, -7.23e+01, 7.15e+03, -1.65e+03, 5.28e+03, -5.63e+02, 4.86e+03, 1.20e+03, -3.38e+02, 5.05e+03, -3.49e+02, -8.25e+02, -2.08e+03, 5.67e+03, 4.76e+03, -2.97e+02, 3.94e+03, 4.29e+03, 6.79e+02, 3.99e+03, 2.42e+03, -6.86e+02, 5.34e+03, 1.59e+03, -1.68e+03, -5.55e+02, 1.48e+03, 7.71e+02, -1.05e+02, 3.20e+03, 3.85e+02, 3.80e+03, 4.43e+03, 5.51e+03, 5.60e+03, 3.42e+03, 2.24e+03, 5.52e+02, 4.91e+03, -1.73e+03, 1.91e+03, 2.82e+03, 5.73e+03, 3.43e+03, -1.66e+03, 1.83e+03, 2


Obj: 3.93e+04
----------------------------------------
NELL: -1.04e+03, -6.14e+01, 3.03e+03, -1.18e+03, -1.23e+03, -2.07e+03, 3.37e+03, -2.53e+03, -3.51e+03, -4.11e+03, 4.91e+03, 4.57e+02, -1.85e+03, 1.12e+03, -3.19e+03, 5.39e+03, -1.96e+03, 1.76e+03, 1.60e+03, 4.43e+03, 2.57e+03, 2.02e+03, -8.67e+02, 5.99e+02, 2.39e+03, 8.55e+02, -2.15e+03, 3.97e+03, 1.92e+03, -1.53e+03, 4.51e+03, 3.92e+03, -3.13e+03, 1.85e+02, 1.45e+03, -6.17e+02, 1.71e+03, -1.85e+03, -1.70e+03, -8.95e+02, -5.62e+01, 3.08e+03, 1.04e+03, 2.84e+03, -3.28e+03, 2.24e+03, -1.29e+03, 4.89e+03, -4.71e+03, 2.72e+03, -2.14e+03, 5.43e+03, -4.46e+03, -2.17e+03, 4.14e+03, -1.18e+03, -2.82e+03, -1.69e+03, 3.89e+03, 3.49e+03, -1.34e+03, 3.69e+03, 3.19e+03, -2.46e+02, 1.04e+03, -3.32e+03, -2.23e+03, 2.80e+03, -3.56e+03, -1.93e+03, -1.58e+01, 8.23e+02, -1.30e+02, -3.24e+03, 2.15e+03, -1.34e+03, 6.74e+03, -2.31e+03, 2.36e+03, 5.57e+03, 1.09e+03, 1.88e+03, -1.95e+03, 4.38e+03, -4.35e+03, 1.82e+03, 2.16e+03, 5.53e+02, 2.75e+03, -2.48e


Obj: 1.62e+06
----------------------------------------
NELL: 7.76e+03, 4.02e+03, 1.14e+04, 1.00e+04, 1.00e+04, 7.94e+03, 1.20e+04, 3.79e+03, 6.64e+03, 3.45e+03, 1.41e+04, 3.34e+03, 2.66e+03, 1.02e+04, 6.84e+03, 1.23e+04, 4.42e+03, 8.86e+03, 1.33e+04, 1.74e+04, 1.25e+04, 1.31e+04, 1.45e+04, 1.08e+04, 1.66e+04, 1.21e+04, 1.32e+04, 9.82e+03, 1.80e+04, 9.70e+03, 1.89e+04, 1.72e+04, 9.93e+03, 1.32e+04, 1.91e+04, 5.58e+03, 1.44e+04, 8.47e+03, 1.00e+04, 3.54e+03, 1.18e+04, 9.45e+03, 1.60e+04, 1.81e+04, 5.31e+03, 1.52e+04, 1.22e+04, 1.37e+04, 3.77e+03, 1.28e+04, 1.12e+04, 1.13e+04, 1.03e+04, 4.77e+03, 1.82e+04, 4.42e+03, 1.32e+04, 8.89e+03, 1.10e+04, 1.18e+04, 5.65e+03, 1.72e+04, 1.48e+04, 5.29e+03, 1.33e+04, 7.34e+03, 8.52e+03, 1.23e+04, 1.46e+04, 8.24e+03, 1.62e+04, 1.22e+04, 5.97e+03, 5.41e+03, 9.67e+03, 4.50e+03, 1.44e+04, 3.58e+03, 7.88e+03, 1.61e+04, 1.16e+04, 9.71e+03, 4.84e+03, 2.15e+04, 5.91e+03, 1.51e+04, 8.53e+03, 1.48e+04, 1.18e+04, 3.26e+03, 1.30e+04, 7.20e+03, 5.14e+03, 9.57e+03


Obj: 1.20e+06
----------------------------------------
NELL: 9.14e+03, 5.11e+03, 7.46e+03, 1.03e+04, 3.86e+03, 7.03e+03, 9.36e+03, 3.57e+03, 4.29e+03, 2.53e+03, 1.26e+04, 4.80e+03, 5.13e+03, 7.49e+03, 5.15e+03, 1.07e+04, 4.86e+03, 5.58e+03, 1.36e+04, 1.39e+04, 9.51e+03, 8.67e+03, 1.38e+04, 8.95e+03, 1.24e+04, 1.01e+04, 1.43e+04, 1.06e+04, 1.37e+04, 4.79e+03, 1.53e+04, 1.42e+04, 8.36e+03, 1.48e+04, 1.75e+04, 4.79e+03, 1.13e+04, 7.87e+03, 8.59e+03, 4.76e+03, 1.08e+04, 9.98e+03, 1.66e+04, 1.61e+04, 6.38e+03, 1.13e+04, 1.42e+04, 1.12e+04, 3.26e+03, 8.30e+03, 8.95e+03, 9.02e+03, 7.91e+03, 4.83e+03, 9.79e+03, 4.54e+03, 1.32e+04, 6.27e+03, 8.40e+03, 9.00e+03, 6.87e+03, 1.14e+04, 1.20e+04, 6.34e+03, 1.29e+04, 8.27e+03, 6.77e+03, 9.80e+03, 1.18e+04, 6.89e+03, 1.50e+04, 1.21e+04, 6.73e+03, 4.66e+03, 7.73e+03, 6.83e+03, 1.14e+04, 2.66e+03, 8.48e+03, 1.18e+04, 9.29e+03, 8.16e+03, 5.53e+03, 9.30e+03, 6.09e+03, 1.39e+04, 7.02e+03, 1.43e+04, 9.57e+03, 3.31e+03, 1.12e+04, 8.58e+03, 6.79e+03, 7.65e+03


Obj: 9.55e+05
----------------------------------------
NELL: 7.98e+03, 4.19e+03, 9.01e+03, 1.07e+04, 4.08e+03, 8.63e+03, 7.21e+03, 3.58e+03, 4.32e+03, 5.67e+02, 9.52e+03, 5.99e+03, 3.05e+03, 7.40e+03, 5.12e+03, 8.00e+03, 3.97e+03, 6.41e+03, 9.00e+03, 7.92e+03, 7.67e+03, 6.28e+03, 1.27e+04, 8.71e+03, 8.70e+03, 7.38e+03, 1.26e+04, 9.73e+03, 7.73e+03, 3.73e+03, 8.74e+03, 7.72e+03, 6.44e+03, 1.24e+04, 1.51e+04, 6.02e+03, 9.93e+03, 7.45e+03, 7.37e+03, 5.15e+03, 8.25e+03, 8.81e+03, 1.66e+04, 1.31e+04, 7.50e+03, 7.53e+03, 1.44e+04, 8.24e+03, 4.18e+03, 3.20e+03, 9.78e+03, 8.03e+03, 8.62e+03, 1.56e+03, 2.49e+03, 4.73e+03, 1.33e+04, 3.42e+03, 8.30e+03, 7.81e+03, 2.71e+03, 5.61e+03, 7.87e+03, 6.93e+03, 8.55e+03, 6.64e+03, 4.46e+03, 1.04e+04, 1.08e+04, 8.80e+03, 1.60e+04, 1.15e+04, 6.09e+03, 1.71e+03, 6.53e+03, 5.99e+03, 8.17e+03, 1.87e+03, 7.11e+03, 7.07e+03, 6.68e+03, 8.40e+03, 4.49e+03, 4.19e+00, 7.49e+03, 8.07e+03, 7.54e+03, 1.18e+04, 6.96e+03, 3.96e+03, 9.60e+03, 8.08e+03, 5.31e+03, 8.33e+03


Obj: 1.35e+06
----------------------------------------
NELL: 7.46e+03, 4.77e+03, 8.74e+03, 1.02e+04, 3.27e+03, 9.01e+03, 8.43e+03, 3.04e+03, 3.88e+03, 1.24e+03, 6.13e+03, 7.40e+03, 3.60e+03, 4.80e+03, 2.91e+03, 4.29e+03, 4.85e+03, 7.18e+03, 9.34e+03, 4.26e+03, 6.35e+03, 1.19e+03, 9.33e+03, 6.75e+03, 2.81e+03, 3.55e+03, 1.09e+04, 7.82e+03, 2.89e+03, 2.98e+03, 3.55e+03, 3.53e+03, 6.18e+03, 1.33e+04, 1.35e+04, 4.98e+03, 7.52e+03, 7.26e+03, 8.02e+03, 4.19e+03, 7.03e+03, 6.77e+03, 1.38e+04, 1.27e+04, 5.91e+03, 2.95e+02, 1.25e+04, 4.39e+03, 5.41e+03, 1.08e+03, 8.63e+03, 4.01e+03, 8.03e+03, 3.36e+03, -1.73e+03, 5.71e+03, 1.13e+04, 5.08e+03, 5.05e+03, 5.90e+03, 3.85e+03, -6.85e+02, 2.06e+03, 5.34e+03, 7.67e+03, 5.82e+03, 4.34e+03, 7.83e+03, 1.06e+04, 6.72e+03, 1.37e+04, 9.62e+03, 5.71e+03, 3.76e+03, 6.91e+03, 6.92e+03, 6.63e+03, 3.00e+03, 5.81e+03, 1.12e+03, 3.98e+03, 6.48e+03, 4.31e+03, -3.33e+03, 3.74e+03, 6.52e+03, 6.93e+03, 9.66e+03, 4.87e+03, 1.92e+03, 1.01e+04, 7.66e+03, 7.21e+03, 7.81e


Obj: 8.66e+05
----------------------------------------
NELL: 6.14e+03, 4.23e+03, 8.78e+03, 9.14e+03, 3.90e+03, 6.20e+03, 5.35e+03, 4.70e+03, 4.58e+03, 1.46e+03, 5.08e+03, 6.84e+03, 4.49e+03, 4.61e+03, 3.35e+03, 4.61e+02, 4.48e+03, 7.66e+03, 5.16e+03, -4.95e+02, 2.88e+03, -9.21e+02, 9.78e+03, 4.40e+03, -3.80e+02, 2.45e+02, 1.18e+04, 6.80e+03, -5.87e+02, 4.55e+03, -9.09e+01, -5.06e+02, 7.37e+03, 1.41e+04, 8.88e+03, 5.40e+03, 5.01e+03, 7.67e+03, 8.26e+03, 7.00e+03, 6.39e+03, 5.47e+03, 1.50e+04, 9.25e+03, 5.41e+03, -2.59e+03, 1.14e+04, 1.90e+03, 2.34e+03, -2.33e+03, 8.40e+03, 2.43e+03, 7.64e+03, 3.38e+03, -4.03e+03, 5.60e+03, 1.05e+04, 5.29e+03, 3.46e+03, 4.69e+03, 6.43e+03, -4.18e+03, -6.42e+02, 5.77e+03, 4.39e+03, 6.14e+03, 3.48e+03, 5.31e+03, 1.09e+04, 5.95e+03, 1.35e+04, 8.16e+03, 4.60e+03, 4.20e+03, 5.34e+03, 4.36e+03, 3.11e+03, 1.95e+03, 5.81e+03, -1.24e+03, 2.84e+03, 6.66e+03, 5.36e+03, -4.83e+03, 6.18e+03, -7.80e+02, 7.90e+03, 5.92e+03, 2.63e+03, 4.74e+03, 6.74e+03, 6.06e+03, 6.74


Obj: -5.86e+04
----------------------------------------
NELL: 2.54e+03, -7.13e+03, -2.88e+03, 3.40e+03, -4.86e+03, -4.46e+03, -3.19e+03, -1.20e+03, -2.79e+02, -5.43e+03, -2.41e+03, 2.54e+03, 5.24e+03, -3.85e+03, -2.41e+03, -3.00e+03, -4.33e+03, -3.63e+03, -3.40e+03, -2.70e+03, -1.01e+03, -4.21e+03, -8.43e+02, -1.94e+03, -2.68e+03, -1.95e+03, -4.42e+02, -3.64e+03, -3.36e+03, -2.58e+03, -2.85e+03, -2.92e+03, -2.67e+03, -3.93e+03, -2.77e+03, -1.59e+03, -8.52e+02, 4.12e+03, 1.28e+03, -3.39e+03, -5.28e+03, -3.57e+03, 2.31e+03, -6.02e+03, 7.84e+02, -2.40e+03, -2.40e+03, -3.10e+03, -4.71e+03, -3.40e+03, -2.59e+03, -3.35e+03, -3.12e+03, 4.16e+03, -6.04e+03, -1.19e+03, -2.19e+03, -1.52e+03, -3.81e+03, -1.21e+03, 1.85e+03, -4.94e+03, -3.55e+03, -2.72e+03, -3.68e+03, 1.80e+03, 6.78e+02, -2.51e+03, 2.54e+02, -6.07e+03, 4.80e+03, -2.14e+03, -3.89e+03, -1.96e+03, -1.65e+03, 8.50e+01, -1.31e+03, -8.74e+02, -1.84e+03, -2.62e+03, -3.82e+03, -6.50e+03, 2.50e+03, -4.40e+03, -4.39e+03, -2.61e+03, -3.63e+


Obj: -3.08e+05
----------------------------------------
NELL: -3.47e+03, -1.10e+04, -5.57e+03, -4.01e+03, -7.33e+03, -6.25e+03, -6.92e+03, -5.38e+03, -3.30e+03, -1.02e+04, -5.50e+03, -2.95e+03, -2.75e+03, -7.63e+03, -5.77e+03, -6.88e+03, -6.38e+03, -6.21e+03, -8.04e+03, -6.61e+03, -4.90e+03, -7.65e+03, -5.11e+03, -4.60e+03, -3.74e+03, -5.28e+03, -4.07e+03, -8.64e+03, -6.69e+03, -6.21e+03, -6.91e+03, -6.13e+03, -5.11e+03, -5.66e+03, -8.29e+03, -7.48e+03, -3.69e+03, -3.90e+03, -4.33e+03, -6.70e+03, -6.75e+03, -4.58e+03, -3.23e+03, -1.05e+04, -5.74e+03, -7.14e+03, -4.85e+03, -7.47e+03, -9.31e+03, -5.57e+03, -4.37e+03, -7.73e+03, -5.00e+03, -2.61e+03, -9.47e+03, -5.59e+03, -2.91e+03, -6.32e+03, -8.18e+03, -7.18e+03, -3.58e+03, -6.58e+03, -6.70e+03, -7.10e+03, -7.50e+03, -5.24e+03, -3.65e+03, -7.38e+03, -4.74e+03, -7.72e+03, -4.92e+03, -4.55e+03, -7.41e+03, 4.71e+03, -6.40e+03, -5.56e+03, -7.41e+03, -7.25e+03, -6.29e+03, -6.09e+03, -7.09e+03, -7.96e+03, -4.44e+03, -7.15e+03, -9.64e+03, -7.


Obj: -4.16e+05
----------------------------------------
NELL: -6.60e+03, -1.15e+04, -6.45e+03, -7.03e+03, -8.32e+03, -7.42e+03, -7.71e+03, -6.85e+03, -5.81e+03, -1.05e+04, -6.15e+03, -6.60e+03, -6.03e+03, -8.36e+03, -9.26e+03, -7.03e+03, -6.57e+03, -7.97e+03, -9.51e+03, -7.75e+03, -5.75e+03, -9.65e+03, -6.67e+03, -6.32e+03, -5.10e+03, -7.10e+03, -6.80e+03, -8.87e+03, -7.75e+03, -8.17e+03, -8.44e+03, -5.96e+03, -6.69e+03, -6.72e+03, -8.73e+03, -8.19e+03, -4.50e+03, -6.18e+03, -6.16e+03, -6.94e+03, -9.95e+03, -5.19e+03, -5.48e+03, -1.14e+04, -7.71e+03, -7.29e+03, -5.67e+03, -8.03e+03, -9.47e+03, -3.86e+03, -6.67e+03, -7.88e+03, -8.47e+03, -6.07e+03, -9.00e+03, -5.68e+03, -9.22e+03, -8.54e+03, -9.64e+03, -8.93e+03, -5.00e+03, -7.18e+03, -8.31e+03, -7.23e+03, -8.26e+03, -7.11e+03, -7.77e+03, -8.49e+03, -7.90e+03, -8.76e+03, -5.59e+03, -5.10e+03, -8.08e+03, -8.59e+03, -6.94e+03, -5.75e+03, -7.15e+03, -8.08e+03, -6.67e+03, -7.39e+03, -7.86e+03, -9.07e+03, -5.88e+03, -7.10e+03, -1.07e+04, -7


Obj: -4.28e+05
----------------------------------------
NELL: -7.18e+03, -1.23e+04, -6.60e+03, -8.36e+03, -9.21e+03, -7.06e+03, -8.16e+03, -7.01e+03, -6.33e+03, -1.09e+04, -6.24e+03, -8.26e+03, -6.54e+03, -8.49e+03, -9.27e+03, -7.54e+03, -6.62e+03, -8.72e+03, -9.76e+03, -8.26e+03, -5.70e+03, -1.04e+04, -6.83e+03, -6.01e+03, -5.13e+03, -7.27e+03, -7.58e+03, -9.52e+03, -7.85e+03, -8.44e+03, -8.62e+03, -7.32e+03, -6.52e+03, -6.81e+03, -9.04e+03, -7.82e+03, -5.78e+03, -6.40e+03, -7.22e+03, -7.09e+03, -9.95e+03, -4.90e+03, -5.41e+03, -1.20e+04, -7.97e+03, -8.08e+03, -6.30e+03, -8.49e+03, -9.65e+03, -6.10e+03, -6.81e+03, -8.74e+03, -7.26e+03, -6.08e+03, -1.12e+04, -6.31e+03, -9.20e+03, -8.62e+03, -9.78e+03, -9.22e+03, -6.32e+03, -6.87e+03, -7.90e+03, -7.98e+03, -8.56e+03, -7.21e+03, -9.36e+03, -8.83e+03, -7.85e+03, -8.70e+03, -5.75e+03, -5.73e+03, -8.45e+03, -8.71e+03, -6.92e+03, -5.58e+03, -7.71e+03, -8.23e+03, -6.90e+03, -6.71e+03, -7.99e+03, -7.33e+03, -6.66e+03, -8.55e+03, -1.03e+04, -7


Obj: -4.43e+05
----------------------------------------
NELL: -7.28e+03, -1.21e+04, -6.66e+03, -7.99e+03, -8.94e+03, -8.01e+03, -7.84e+03, -7.22e+03, -6.51e+03, -1.11e+04, -6.23e+03, -9.14e+03, -7.24e+03, -8.73e+03, -9.17e+03, -7.65e+03, -7.00e+03, -7.83e+03, -9.83e+03, -8.17e+03, -5.61e+03, -1.02e+04, -6.97e+03, -6.18e+03, -5.52e+03, -7.44e+03, -7.49e+03, -9.55e+03, -8.14e+03, -8.40e+03, -8.84e+03, -7.54e+03, -6.54e+03, -7.29e+03, -9.41e+03, -8.94e+03, -5.69e+03, -7.03e+03, -7.15e+03, -7.79e+03, -1.00e+04, -5.13e+03, -5.81e+03, -1.21e+04, -8.71e+03, -8.00e+03, -6.21e+03, -8.09e+03, -1.01e+04, -6.93e+03, -6.95e+03, -8.61e+03, -8.92e+03, -7.02e+03, -1.10e+04, -6.56e+03, -8.28e+03, -8.28e+03, -9.77e+03, -9.52e+03, -7.64e+03, -7.49e+03, -8.77e+03, -8.13e+03, -8.41e+03, -8.54e+03, -8.95e+03, -8.69e+03, -7.03e+03, -8.92e+03, -6.27e+03, -5.66e+03, -8.23e+03, -8.63e+03, -7.24e+03, -6.58e+03, -8.20e+03, -8.70e+03, -7.05e+03, -7.58e+03, -7.77e+03, -9.19e+03, -6.27e+03, -8.92e+03, -1.08e+04, -8


Obj: -4.65e+05
----------------------------------------
NELL: -8.04e+03, -1.26e+04, -6.88e+03, -9.03e+03, -9.30e+03, -7.96e+03, -8.23e+03, -7.45e+03, -6.78e+03, -1.12e+04, -6.27e+03, -9.37e+03, -6.83e+03, -8.59e+03, -9.73e+03, -7.51e+03, -7.03e+03, -8.84e+03, -1.00e+04, -8.54e+03, -5.69e+03, -1.02e+04, -6.96e+03, -6.39e+03, -5.51e+03, -7.46e+03, -8.26e+03, -9.61e+03, -8.42e+03, -8.59e+03, -8.44e+03, -6.44e+03, -7.14e+03, -7.11e+03, -9.38e+03, -8.80e+03, -5.79e+03, -7.50e+03, -7.47e+03, -7.73e+03, -1.03e+04, -5.23e+03, -5.91e+03, -1.25e+04, -8.05e+03, -8.53e+03, -6.39e+03, -8.58e+03, -1.03e+04, -6.85e+03, -7.06e+03, -8.99e+03, -9.58e+03, -7.80e+03, -1.06e+04, -6.67e+03, -9.31e+03, -8.89e+03, -9.92e+03, -9.66e+03, -7.49e+03, -7.59e+03, -8.71e+03, -8.07e+03, -8.61e+03, -8.98e+03, -9.59e+03, -8.93e+03, -8.16e+03, -9.09e+03, -6.51e+03, -5.94e+03, -8.44e+03, -8.88e+03, -7.27e+03, -6.51e+03, -8.28e+03, -8.80e+03, -7.20e+03, -7.49e+03, -7.77e+03, -9.16e+03, -6.01e+03, -9.19e+03, -1.09e+04, -8


Obj: -4.70e+05
----------------------------------------
NELL: -7.86e+03, -1.25e+04, -6.79e+03, -9.10e+03, -8.00e+03, -7.75e+03, -8.36e+03, -7.10e+03, -6.87e+03, -1.11e+04, -6.43e+03, -9.14e+03, -6.82e+03, -8.75e+03, -9.86e+03, -7.80e+03, -7.07e+03, -8.65e+03, -1.01e+04, -8.97e+03, -5.52e+03, -1.04e+04, -7.65e+03, -6.46e+03, -5.65e+03, -7.62e+03, -8.13e+03, -9.84e+03, -8.43e+03, -8.57e+03, -9.01e+03, -7.72e+03, -6.98e+03, -7.26e+03, -9.32e+03, -8.95e+03, -5.66e+03, -7.56e+03, -7.76e+03, -7.65e+03, -1.02e+04, -5.23e+03, -5.95e+03, -1.24e+04, -9.07e+03, -8.22e+03, -6.39e+03, -8.50e+03, -1.03e+04, -7.12e+03, -6.81e+03, -8.98e+03, -9.19e+03, -7.73e+03, -1.16e+04, -6.26e+03, -9.50e+03, -8.97e+03, -9.85e+03, -9.81e+03, -7.59e+03, -7.75e+03, -9.09e+03, -8.38e+03, -8.89e+03, -9.06e+03, -9.54e+03, -9.01e+03, -8.25e+03, -9.20e+03, -6.38e+03, -5.67e+03, -8.46e+03, -9.00e+03, -7.02e+03, -6.47e+03, -8.36e+03, -8.79e+03, -7.09e+03, -7.90e+03, -8.20e+03, -9.23e+03, -6.74e+03, -8.88e+03, -1.11e+04, -8


Obj: -4.68e+05
----------------------------------------
NELL: -7.90e+03, -1.23e+04, -6.69e+03, -9.25e+03, -9.35e+03, -7.93e+03, -8.44e+03, -7.51e+03, -7.04e+03, -1.15e+04, -6.44e+03, -8.79e+03, -7.63e+03, -8.69e+03, -9.73e+03, -7.69e+03, -6.86e+03, -8.81e+03, -1.01e+04, -8.48e+03, -5.81e+03, -1.06e+04, -7.65e+03, -6.52e+03, -5.55e+03, -7.48e+03, -7.89e+03, -9.90e+03, -8.40e+03, -8.66e+03, -9.07e+03, -7.59e+03, -7.19e+03, -7.54e+03, -9.50e+03, -8.86e+03, -5.59e+03, -7.58e+03, -7.58e+03, -8.06e+03, -1.03e+04, -5.40e+03, -5.91e+03, -1.24e+04, -8.55e+03, -8.35e+03, -6.82e+03, -8.63e+03, -1.01e+04, -7.14e+03, -6.99e+03, -9.05e+03, -9.50e+03, -6.69e+03, -1.16e+04, -6.53e+03, -9.56e+03, -8.89e+03, -1.01e+04, -9.66e+03, -7.99e+03, -7.55e+03, -8.74e+03, -8.42e+03, -8.80e+03, -9.00e+03, -9.74e+03, -9.07e+03, -8.46e+03, -9.07e+03, -6.58e+03, -5.95e+03, -8.42e+03, -9.41e+03, -7.03e+03, -6.28e+03, -8.51e+03, -8.85e+03, -7.21e+03, -7.59e+03, -8.16e+03, -9.39e+03, -6.59e+03, -8.95e+03, -1.08e+04, -8


Obj: -4.66e+05
----------------------------------------
NELL: -8.09e+03, -1.27e+04, -6.98e+03, -9.12e+03, -9.04e+03, -7.83e+03, -8.36e+03, -7.58e+03, -6.81e+03, -1.14e+04, -6.61e+03, -8.29e+03, -7.72e+03, -8.87e+03, -9.89e+03, -7.55e+03, -6.98e+03, -8.75e+03, -1.03e+04, -8.91e+03, -5.94e+03, -1.05e+04, -7.57e+03, -6.60e+03, -3.73e+03, -7.54e+03, -8.23e+03, -9.73e+03, -8.04e+03, -8.79e+03, -9.27e+03, -7.58e+03, -7.07e+03, -7.48e+03, -9.52e+03, -9.11e+03, -5.86e+03, -7.63e+03, -7.90e+03, -7.95e+03, -1.04e+04, -5.45e+03, -5.49e+03, -1.25e+04, -8.48e+03, -8.55e+03, -5.98e+03, -8.60e+03, -1.02e+04, -6.83e+03, -7.08e+03, -8.78e+03, -9.72e+03, -7.88e+03, -1.17e+04, -6.79e+03, -9.66e+03, -8.97e+03, -1.02e+04, -9.83e+03, -5.73e+03, -7.51e+03, -9.13e+03, -8.15e+03, -8.92e+03, -9.10e+03, -9.70e+03, -8.99e+03, -8.45e+03, -9.16e+03, -6.32e+03, -4.56e+03, -8.63e+03, -9.50e+03, -7.17e+03, -6.67e+03, -8.37e+03, -8.74e+03, -7.44e+03, -7.63e+03, -8.16e+03, -9.42e+03, -6.70e+03, -9.00e+03, -1.11e+04, -8


Obj: -4.66e+05
----------------------------------------
NELL: -8.26e+03, -1.26e+04, -6.79e+03, -9.31e+03, -9.49e+03, -7.88e+03, -8.31e+03, -7.50e+03, -7.03e+03, -1.13e+04, -6.47e+03, -9.15e+03, -7.30e+03, -8.89e+03, -9.60e+03, -7.36e+03, -6.93e+03, -8.76e+03, -9.90e+03, -8.92e+03, -5.53e+03, -1.05e+04, -7.40e+03, -6.19e+03, -5.75e+03, -7.61e+03, -8.43e+03, -9.87e+03, -8.11e+03, -8.73e+03, -9.15e+03, -7.75e+03, -7.18e+03, -7.50e+03, -9.45e+03, -9.19e+03, -5.95e+03, -7.70e+03, -7.52e+03, -8.17e+03, -1.03e+04, -5.32e+03, -6.52e+03, -1.24e+04, -9.51e+03, -8.51e+03, -6.80e+03, -8.79e+03, -9.97e+03, -7.15e+03, -7.00e+03, -9.06e+03, -9.38e+03, -8.22e+03, -1.17e+04, -6.75e+03, -9.56e+03, -8.83e+03, -1.01e+04, -9.93e+03, -7.13e+03, -7.91e+03, -9.29e+03, -8.40e+03, -8.95e+03, -9.19e+03, -9.78e+03, -9.19e+03, -8.36e+03, -9.20e+03, -6.28e+03, -5.85e+03, -8.37e+03, -9.53e+03, -7.38e+03, -6.57e+03, -8.53e+03, -8.96e+03, -7.44e+03, -7.54e+03, -8.11e+03, -9.46e+03, -6.63e+03, -9.27e+03, -1.11e+04, -8

## Fit a model to one example individual

In [74]:
#single_fit_ind = 0

In [75]:
#single_fit_rs = fit_with_hypercube_priors(data=[ind_data_new[single_fit_ind]], 
#                                          props=[ind_props[single_fit_ind]], p=p, 
#                                          w_prior_opts=fit_w_prior_opts, 
#                                          s_in_prior_opts=fit_s_in_prior_opts,
#                                          b_in_prior_opts=fit_b_prior_opts,
#                                          s_out_prior_opts=fit_s_out_prior_opts,
#                                          b_out_prior_opts=fit_b_prior_opts,
#                                          psi_prior_opts=fit_psi_prior_opts,
#                                          s_in_post_opts=s_in_post_opts,
#                                          b_in_post_opts=b_post_opts,
#                                          s_out_post_opts=s_out_post_opts,
#                                          b_out_post_opts=b_post_opts,
#                                          psi_post_opts=psi_post_opts,
#                                          dense_net_opts=dense_net_opts, 
#                                          sp_fit_opts=single_sp_fit_opts, 
#                                          ip_fit_opts=single_ip_fit_opts)

## Plot some example data

In [76]:
n_plot_vars = 1000 # Number of variables we plot the location of
ex_ind = 0

In [77]:
# Helper formatting function
def format_box(ax):
    ax.set_aspect('equal', 'box')
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    plt.xlabel('property dim 1 (a.u.)')
    plt.ylabel('property dim 2 (a.u.)')

### Show the location of some example variables in property space

In [78]:
plt.figure(figsize=(7, 7))
ax = plt.subplot(1,1,1)
ax.plot(ind_props[ex_ind][0:n_plot_vars,0], ind_props[ex_ind][0:n_plot_vars,1], 'r.')
format_box(ax)

<IPython.core.display.Javascript object>

### Plot the mean and standard deviation of the true weight distribution

In [79]:
overlay_pts = True # True if we want to overlay neuron locations on plots

In [80]:
pts, dim_pts = list_grid_pts(grid_limits=np.asarray([[0, 1.0], [0, 1.0]]), n_pts_per_dim=[100,100])
pts = torch.tensor(pts, dtype=torch.float)

true_w_mn = true_priors.w_prior(pts).detach().cpu().numpy()
true_w_std = np.concatenate([d.std_f(pts).detach().cpu().numpy() for d in true_priors.w_prior.dists], axis=1)
true_w_mn = true_w_mn.reshape([100,100]).transpose()
true_w_std = true_w_std.reshape([100,100]).transpose()

plt.figure(figsize=(7, 7))
ax = plt.subplot(1,1,1)
im = ax.imshow(true_w_mn, origin='lower', extent=[0, 1.0, 0, 1.0])
mn_vmin, mn_vmax = im.get_clim() # Keep track of color limits
plt.colorbar(im)
format_box(ax)
plt.title('Mean')
if overlay_pts:
    ax.plot(ind_props[ex_ind][0:n_plot_vars,0], ind_props[ex_ind][0:n_plot_vars,1], 'r.')

plt.figure(figsize=(7, 7))
ax = plt.subplot(1,1,1)
im = ax.imshow(true_w_std, origin='lower', extent=[0, 1.0, 0, 1.0])
std_vmin, std_vmax = im.get_clim() # Keep track of color limits
plt.colorbar(im)
format_box(ax)
plt.title('Standard Deviation')
if overlay_pts:
    ax.plot(ind_props[ex_ind][0:n_plot_vars,0], ind_props[ex_ind][0:n_plot_vars,1], 'r.')


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

### Plot the sampled weights for the example individual

In [81]:
plot_w_vls = ind_true_mdls[ex_ind].w.detach().numpy()[0:n_plot_vars]
plow_w_vls_scaled = (plot_w_vls - mn_vmin)/(mn_vmax - mn_vmin)
plot_w_clrs = cm.viridis(plow_w_vls_scaled)

plt.figure(figsize=(7, 7))
ax = plt.subplot(1,1,1)
ax.scatter(ind_props[ex_ind][0:n_plot_vars,0], ind_props[ex_ind][0:n_plot_vars,1], marker='.', color=plot_w_clrs)
format_box(ax)
plt.title('Sampled Weights')

<IPython.core.display.Javascript object>

Text(0.5, 1.0, 'Sampled Weights')

## Plot the true low-d function and the true data from the example individual

In [82]:
n_plot_smps = 100

In [83]:
f_x_vls = torch.linspace(-2, 2, 100)
true_f_vls = ind_true_mdls[0].m(f_x_vls).detach().numpy()
f_x_vls = f_x_vls.detach().numpy()

ex_proj_x_vls = ind_true_mdls[ex_ind].project(x=ind_data[ex_ind][0]).detach().numpy()[0:n_plot_smps]
ex_y_vls = ind_data[ex_ind][1].detach().numpy()[0:n_plot_smps]

In [84]:
plt.figure(figsize=(9, 6.5))
plt.plot(ex_proj_x_vls, ex_y_vls, 'ro')
plt.plot(f_x_vls, true_f_vls, color='w')
plt.xlabel('Projected x (a.u.)')
plt.ylabel('y (a.u.)')

<IPython.core.display.Javascript object>

Text(0, 0.5, 'y (a.u.)')

### Plot the true low-d function and the predictions from the example individual

In [85]:
fit_type = 'comb'  # 'comb' or 'single'
exam_type = 'ip'

In [86]:
if fit_type == 'single':
    fit_rs = single_fit_rs
else:
    fit_rs = comb_fit_rs

if exam_type == 'sp':
    exam_priors = fit_rs['sp']['priors']
    exam_posts = [coll.posteriors for coll in fit_rs['sp']['vi_collections']]
    exam_vi_collections = fit_rs['sp']['vi_collections']
    exam_mdls = [coll.mdl for coll in fit_rs['sp']['vi_collections']]
    exam_logs = fit_rs['sp']['logs']
else:
    exam_priors = fit_rs['ip']['priors']
    exam_posts = [coll.posteriors for coll in fit_rs['ip']['vi_collections']]
    exam_vi_collections = fit_rs['ip']['vi_collections']
    exam_mdls = [coll.mdl for coll in fit_rs['ip']['vi_collections']]
    exam_logs = fit_rs['ip']['logs']

In [87]:
exam_w = exam_posts[ex_ind].w_post(ind_props[ex_ind])
exam_s_in = exam_posts[ex_ind].s_in_post(ind_props[ex_ind]).squeeze(axis=1)
exam_b_in = exam_posts[ex_ind].b_in_post(ind_props[ex_ind]).squeeze(axis=1)
exam_s_out = exam_posts[ex_ind].s_out_post(ind_props[ex_ind]).squeeze(axis=1)
exam_b_out = exam_posts[ex_ind].b_out_post(ind_props[ex_ind]).squeeze(axis=1)
exam_psi = exam_posts[ex_ind].psi_post.dists[0].mode(ind_props[ex_ind]).squeeze(axis=1)

In [88]:
n_exam_smps = 1000

In [89]:
x_exam = torch.tensor(sample_proj_data_from_interval(n_smps=n_exam_smps, 
                                                      w=ind_true_mdls[ex_ind].w.detach().numpy(), 
                                                      interval=[min_proj_vl, max_proj_vl]),
                       dtype=torch.float)

In [90]:
with torch.no_grad():
    
    # Determine which of our test points are in and out of the training distribution for this individual 
    x_true_proj = ind_true_mdls[ex_ind].project(x_exam).numpy()
    x_true_proj_for_int = ind_true_mdls[ex_ind].project(x_exam, apply_scales_and_biases=False).numpy().squeeze()
    
    x_within_train_dist = np.logical_and(x_true_proj_for_int >= ind_intervals[ex_ind][0], 
                                         x_true_proj_for_int < ind_intervals[ex_ind][1])
    x_outside_train_dist = np.logical_not(x_within_train_dist)
    
    
    # Get true mean and predicted mean for each data point for this individual 
    true_mns = ind_true_mdls[ex_ind].m(ind_true_mdls[ex_ind].project(x_exam)).numpy()
    pred_mns = exam_mdls[ex_ind].cond_mean(x=x_exam, w=exam_w, s_in=exam_s_in, b_in=exam_b_in, 
                                             s_out=exam_s_out, b_out=exam_b_out)
    x_pred_proj = exam_mdls[ex_ind].project(x=x_exam, w=exam_w, s_in=exam_s_in, b_in=exam_b_in)

In [91]:
align_vls = align_low_d_spaces(w_0=ind_true_mdls[ex_ind].w.detach().numpy(),
                               s_in_0=ind_true_mdls[ex_ind].s_in.detach().numpy(),
                               b_in_0=ind_true_mdls[ex_ind].b_in.detach().numpy(),
                               w_1=exam_w.detach().numpy(),
                               s_in_1=exam_s_in.detach().numpy(),
                               b_in_1=exam_b_in.detach().numpy(),
                               z_1=x_pred_proj.numpy())

x_pred_proj_aligned = align_vls[-1]

In [92]:
true_clrs = np.zeros([n_exam_smps,4])
true_clrs[:,-1] = 1.0
true_clrs[x_outside_train_dist, 0] = 1.0

pred_clrs = np.zeros([n_exam_smps,4])
pred_clrs[:, -1] = 1.0
pred_clrs[x_outside_train_dist, 0:3] = .5
#pred_clrs[x_outside_train_dist, -1] = 0 
pred_clrs[np.logical_not(x_outside_train_dist), 0] = .7


plt.figure(figsize=(9, 6.5))
a = plt.subplot(1,1,1)
a.scatter(x_true_proj, pred_mns, color=pred_clrs)
plt.ylim([-2.5, 2])

plt.plot(f_x_vls, true_f_vls, color='w')
plt.xlabel('Projected x (a.u.)')
plt.ylabel('y (a.u.)')


<IPython.core.display.Javascript object>

Text(0, 0.5, 'y (a.u.)')

### Plot the learned priors

In [93]:
pred_w_mn = exam_priors.w_prior(pts).detach().cpu().numpy()
pred_w_std = np.concatenate([d.std_f(pts).detach().cpu().numpy() for d in exam_priors.w_prior.dists], axis=1)

true_w_mn = true_priors.w_prior(pts).detach().cpu().numpy()

pred_mn_supp = np.concatenate([pred_w_mn, np.ones([pred_w_mn.shape[0], 1])], axis=1)
t = numpy.linalg.lstsq(pred_mn_supp, true_w_mn, rcond=None)
t = t[0]
pred_mn_al = np.matmul(pred_mn_supp, t)
pred_mn_al = pred_mn_al.reshape([100,100]).transpose()


t_std = t[0]
pred_w_std_al = np.zeros(pred_w_std.shape)
for i, std_i in enumerate(pred_w_std):
    pred_w_std_al[i, :] = np.abs(t_std)*std_i
pred_w_std_al = pred_w_std_al.reshape([100,100]).transpose()

In [94]:
pred_w_mn = pred_w_mn.reshape([100,100]).transpose()

In [95]:
plt.figure(figsize=(7, 7))
ax = plt.subplot(1,1,1)
im = ax.imshow(pred_mn_al, origin='lower', extent=[0, 1.0, 0, 1.0], vmin=mn_vmin, vmax=mn_vmax)
plt.colorbar(im)
format_box(ax)
plt.title('Estimated Mean')

plt.figure(figsize=(7, 7))
ax = plt.subplot(1,1,1)
im = ax.imshow(pred_w_std_al, origin='lower', extent=[0, 1.0, 0, 1.0], vmin=std_vmin, vmax=std_vmax)
plt.colorbar(im)
format_box(ax)
plt.title('Estimated Standard Deviation')


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Text(0.5, 1.0, 'Estimated Standard Deviation')

## Compare true and predicted weights

In [96]:
w_true = ind_true_mdls[ex_ind].w.detach().numpy()
w_est = exam_w.detach().numpy()
plt.figure()
plt.plot(w_true, w_est, 'r.')

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x1489da08c0a0>]

In [69]:
exam_w

tensor([[-8.1492],
        [ 0.2353],
        [-4.4932],
        ...,
        [ 1.6472],
        [ 6.2511],
        [-1.1024]], grad_fn=<CatBackward0>)

In [70]:
exam_priors.w_prior.dists[0].std_f(torch.tensor([[.5, .5]]))

tensor([[0.1225]], grad_fn=<AddBackward0>)

## Look at fitting logs

In [98]:
for log in exam_logs:
    Fitter.plot_log(log)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

## Look at individual model parameters

In [97]:
exam_mdl = GNLRMdl(m = exam_vi_collections[ex_ind].mdl.m, w=exam_w, s_in=exam_s_in, b_in=exam_b_in, 
                  s_out=exam_s_out, b_out=exam_b_out, psi=exam_psi)

plt.figure()
GNLRMdl.compare_mdls(ind_true_mdls[ex_ind], exam_mdl)

<IPython.core.display.Javascript object>

[[0.21239606]]


In [43]:
list(exam_posts[ex_ind].w_post.dists[0].std_f.parameters())

[Parameter containing:
 tensor([-1.9713, -1.8874, -1.2698,  ..., -2.5187, -1.3224, -2.3785],
        requires_grad=True)]

In [44]:
exam_posts[ex_ind].w_post.dists[0].std_f(ind_props[ex_ind])

tensor([[0.1903],
        [0.2243],
        [0.7312],
        ...,
        [0.0645],
        [0.6631],
        [0.0852]], grad_fn=<TBackward0>)

In [45]:
exam_mdl.m[1].eval()

Linear(in_features=81, out_features=1, bias=True)