A notebook for the development of initial GNLDR models

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import copy

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.nn import Identity

from janelia_core.math.basic_functions import optimal_orthonormal_transform
from janelia_core.ml.extra_torch_modules import ConstantRealFcn
from janelia_core.ml.extra_torch_modules import DenseLNLNet
from janelia_core.ml.extra_torch_modules import PWLNNFcn
from janelia_core.ml.extra_torch_modules import QuadSurf
from janelia_core.ml.utils import list_torch_devices
from janelia_core.ml.utils import torch_mod_to_fcn
from janelia_core.visualization.image_generation import generate_2d_fcn_image
from janelia_core.visualization.matrix_visualization import cmp_n_mats

from probabilistic_model_synthesis.gaussian_nonlinear_dim_reduction import align_intermediate_spaces
from probabilistic_model_synthesis.gaussian_nonlinear_dim_reduction import Fitter
from probabilistic_model_synthesis.gaussian_nonlinear_dim_reduction import GNLDRMdl
from probabilistic_model_synthesis.gaussian_nonlinear_dim_reduction import generate_basic_posteriors
from probabilistic_model_synthesis.gaussian_nonlinear_dim_reduction import generate_simple_prior_collection
from probabilistic_model_synthesis.gaussian_nonlinear_dim_reduction import VICollection
from probabilistic_model_synthesis.math import MeanFcnTransformer
from probabilistic_model_synthesis.math import StdFcnTransformer
from probabilistic_model_synthesis.visualization import assign_colors_to_pts
from probabilistic_model_synthesis.visualization import plot_three_dim_pts
from probabilistic_model_synthesis.visualization import plot_torch_dist

In [3]:
%matplotlib notebook

## Parameters go here

In [4]:
# Number of individuals we simulate observing data from 
n_individuals = 5

# Range of the number of variables we observe from each individual - the actual number of variables we observe from an
# individual will be pulled uniformly from this range (inclusive)
n_var_range = [100, 120]

# Range of the numbe0 of samples we observe from each individual - the actual number we observe from each individual
# will be unformly from this range (inclusive)
n_smps_range = [10000, 15000]

# Number of latent variables in the model
n_latent_vars = 2

# True if we should use GPUs for fitting if they are available
use_gpus = True

# Parameters for the true scales
s_mn = 1.0
s_std = .0001

# Parameters for generating true shared m-module
m_n_layers = 2
m_growth_rate = 5
n_intermediate_latent_vars = 3

In [5]:
## Determine which devices we use for fitting

In [6]:
if use_gpus:
    devices, _ = list_torch_devices()
else:
    devices = [torch.device('cpu')]

No GPUs found.


## Create the true prior distributions that relate parameters in the model to variable (e.g., neuron) properties

In [7]:
m_true = QuadSurf(torch.tensor([.5, .5]), torch.tensor([.2, -.2]))

In [8]:
plt.figure()
im, _, _ = generate_2d_fcn_image(torch_mod_to_fcn(m_true), dim_0_range=[-20, 20], dim_1_range=[-20, 20], vis_dim=2)
plt.imshow(im)
plt.colorbar()

<IPython.core.display.Javascript object>

<matplotlib.colorbar.Colorbar at 0x7f9c789d5950>

In [9]:
true_priors = generate_simple_prior_collection(n_prop_vars=2, n_intermediate_latent_vars=n_intermediate_latent_vars,
                                               lm_mn_w_init_std=1.0, lm_std_w_init_std=.001,
                                               mn_mn_w_init_std=2.0, mn_std_w_init_std=1.0,
                                               psi_conc_f_w_init_std=.001, psi_rate_f_w_init_std=.001, 
                                               psi_conc_bias_mn=1.0, psi_rate_bias_mn=10.0, 
                                               s_mn=s_mn, s_std=s_std)

## Generate properties

In [10]:
ind_n_vars = np.random.randint(n_var_range[0], n_var_range[1]+1, n_individuals)
ind_n_smps = np.random.randint(n_smps_range[0], n_smps_range[1]+1, n_individuals)
ind_props = [torch.rand(size=[n_vars,2]) for n_vars in ind_n_vars]

## Generate true models

In [11]:
with torch.no_grad():
    ind_true_mdls = [GNLDRMdl(n_latent_vars=n_latent_vars, m = m_true,
                              lm=true_priors.lm_prior.sample(props), 
                              mn=true_priors.mn_prior.sample(props).squeeze(), 
                              psi=(true_priors.psi_prior.sample(props).squeeze()), 
                              s=true_priors.s_prior.sample(props).squeeze())
                        for props in ind_props]

## Generate data from each model

In [12]:
with torch.no_grad():
    ind_data = [mdl.sample(n_smps) for n_smps, mdl in zip(ind_n_smps, ind_true_mdls)]

## Setup everything for fitting sp models

In [13]:
sp_m_fit = torch.nn.Sequential(DenseLNLNet(nl_class=torch.nn.ReLU, 
                                         d_in=n_latent_vars, 
                                         n_layers=m_n_layers, 
                                         growth_rate=m_growth_rate, 
                                         bias=True), 
                             torch.nn.Linear(in_features=n_latent_vars+m_n_layers*m_growth_rate, 
                                             out_features=n_intermediate_latent_vars, 
                                             bias=True))

#sp_m_fit = QuadSurf(torch.tensor([.1, .1]), torch.tensor([.1, .1]))

#n_fcns = 25
#sp_m_fit = PWLNNFcn(init_centers=torch.randn([n_fcns,n_latent_vars]), 
#                    init_weights=.01*torch.randn([n_fcns,n_latent_vars, n_intermediate_latent_vars]),
#                    init_offsets=torch.zeros([n_fcns,n_intermediate_latent_vars]), 
#                    k=10, n_used_fcns=12)

In [14]:
sp_priors = generate_simple_prior_collection(n_prop_vars=2, n_intermediate_latent_vars=n_intermediate_latent_vars,
                                             lm_mn_w_init_std=1.0, lm_std_w_init_std=1.0,
                                             s_mn=s_mn, s_std=s_std)
                                                 
    
sp_posteriors = generate_basic_posteriors(n_obs_vars=ind_n_vars, n_smps=ind_n_smps, n_latent_vars=n_latent_vars, 
                                          n_intermediate_latent_vars=n_intermediate_latent_vars,
                                           s_opts={'mn_mn': 1.0, 'mn_std': .00000001, 'std_iv': .0001})

sp_fit_mdls = [GNLDRMdl(n_latent_vars=n_latent_vars, m=sp_m_fit, lm=None, mn=None, psi=None, s=None) 
               for i in range(n_individuals)]
                    
                                    
sp_vi_collections = [VICollection(data=ind_data[s_i][1], 
                                  props=ind_props[s_i],
                                  mdl = sp_fit_mdls[s_i],
                                  posteriors = sp_posteriors[s_i]) for s_i in range(n_individuals)]

for vi_coll in sp_vi_collections:
    vi_coll.posteriors.lm_post = sp_priors.lm_prior
    vi_coll.posteriors.mn_post = sp_priors.mn_prior

## Fit the sp models

In [15]:
sp_fitter = Fitter(vi_collections=sp_vi_collections, priors=sp_priors, devices=devices)

In [None]:
sp_fitter.distribute(distribute_data=True, devices=devices)
sp_logs = [sp_fitter.fit(200, milestones=[100], update_int=100, init_lr=.1) for fit_r in range(1)]


Obj: 1.80e+08
----------------------------------------
NELL: 8.68e+06, 6.79e+06, 7.51e+06, 8.07e+06, 6.49e+06
Latent KL: 3.14e+02, 2.60e+02, 3.09e+02, 3.18e+02, 2.43e+02
LM KL: 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00
Mn KL: 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00
Psi KL: 4.01e+01, 4.00e+01, 3.59e+01, 3.79e+01, 4.30e+01
S KL: 2.83e+07, 2.90e+07, 2.68e+07, 2.80e+07, 3.00e+07
----------------------------------------
LR: 0.1
Elapsed time (secs): 0.2905123233795166
----------------------------------------
CPU cur memory used (GB): 3.27e-01

Obj: 9.47e+06
----------------------------------------
NELL: 2.17e+06, 1.62e+06, 1.82e+06, 2.01e+06, 1.62e+06
Latent KL: 3.21e+04, 5.04e+04, 5.12e+04, 5.67e+04, 4.22e+04
LM KL: 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00
Mn KL: 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00
Psi KL: 5.94e+01, 7.00e+01, 6.40e+01, 6.65e+01, 6.52e+01
S KL: 2.10e+02, 2.47e+02, 1.71e+02, 9.93e+01, 1.02e+02
----------------------------------------
LR: 

## Examine logs of sp fitting performance

In [None]:
for log in sp_logs:
    sp_fitter.plot_log(log)

## Look at sp model fits

In [None]:
exam_mdl = 0

In [None]:
fit_lm = sp_vi_collections[exam_mdl].posteriors.lm_post(ind_props[exam_mdl]).detach().squeeze()
fit_mn = sp_vi_collections[exam_mdl].posteriors.mn_post(ind_props[exam_mdl]).detach().squeeze()
fit_psi = sp_vi_collections[exam_mdl].posteriors.psi_post.mode(ind_props[exam_mdl]).detach().squeeze()
fit_s = sp_vi_collections[exam_mdl].posteriors.s_post(ind_props[exam_mdl]).detach().squeeze()

cmp_mdl = GNLDRMdl(n_latent_vars=n_latent_vars, 
                   m=sp_m_fit, lm=fit_lm, mn=fit_mn, psi=fit_psi, s=fit_s)
true_mdl = ind_true_mdls[exam_mdl]

plt.figure()
true_mdl.compare_models(true_mdl, cmp_mdl)

## Setup everything for fitting models with individual posteriors

In [None]:
ip_m_fit = copy.deepcopy(sp_m_fit)

In [None]:
ip_priors = copy.deepcopy(sp_priors)

ip_posteriors = generate_basic_posteriors(n_obs_vars=ind_n_vars, n_smps=ind_n_smps, n_latent_vars=n_latent_vars, 
                                          n_intermediate_latent_vars=n_intermediate_latent_vars,
                                          s_opts={'mn_mn': s_mn, 'mn_std': .00000001, 'std_iv': .01})

In [None]:
for s_i, posteriors in enumerate(ip_posteriors):
    
    # Initialize the poseteriors for the mean vectors
    with torch.no_grad():
        mn_prior_mn = sp_priors.mn_prior(ind_props[s_i]).squeeze()
        mn_prior_std = sp_priors.mn_prior.std_f(ind_props[s_i]).squeeze()
    
        posteriors.mn_post.dists[0].mn_f.f.vl.data = copy.deepcopy(mn_prior_mn)
        posteriors.mn_post.dists[0].std_f.f.set_value(copy.deepcopy(mn_prior_std.numpy()))
        
    # Initialize the posteriors for the loading matrices
    with torch.no_grad():
        lm_prior_mn = sp_priors.lm_prior(ind_props[s_i]).squeeze()
        lm_prior_std = sp_priors.lm_prior.std_f(ind_props[s_i]).squeeze().numpy()
            
        for d_i in range(n_latent_vars):
            posteriors.lm_post.dists[d_i].mn_f.f.vl.data = copy.deepcopy(lm_prior_mn[:, d_i])
            posteriors.lm_post.dists[d_i].std_f.f.set_value(copy.deepcopy(lm_prior_std[:, d_i]))
        
    # Initialize the posteriors for the private variances
    posteriors.psi_post = copy.deepcopy(sp_posteriors[s_i].psi_post)
    
    # Initialize the posteriors for the scales
    #posteriors.s_post = copy.deepcopy(sp_posteriors[s_i].s_post)
    
    # Initialize the posteriors for the latents
    with torch.no_grad():
        posteriors.latent_post = copy.deepcopy(sp_posteriors[s_i].latent_post)

In [None]:
ip_fit_mdls = [GNLDRMdl(n_latent_vars=n_latent_vars, m=ip_m_fit, lm=None, mn=None, psi=None, s=None) 
               for i in range(n_individuals)]
                    
                    
                   
ip_vi_collections = [VICollection(data=ind_data[s_i][1], 
                                  props=ind_props[s_i],
                                  mdl = ip_fit_mdls[s_i],
                                  posteriors = ip_posteriors[s_i]) for s_i in range(n_individuals)]

## Fit ip model

In [None]:
ip_fitter = Fitter(vi_collections=ip_vi_collections, priors=ip_priors, devices=devices)

In [None]:
ip_fitter.distribute(distribute_data=True, devices=devices)
ip_logs = [ip_fitter.fit(1000, milestones=[100, 500], update_int=100, init_lr=.1) for fit_r in range(1)]

## Look at Aligned ip model fits

In [None]:
exam_mdl = 0

In [None]:
fit_lm = ip_vi_collections[exam_mdl].posteriors.lm_post(ind_props[exam_mdl]).detach().squeeze()
fit_mn = ip_vi_collections[exam_mdl].posteriors.mn_post(ind_props[exam_mdl]).detach().squeeze()
fit_psi = ip_vi_collections[exam_mdl].posteriors.psi_post.mode(ind_props[exam_mdl]).detach().squeeze()
fit_s = ip_vi_collections[exam_mdl].posteriors.s_post(ind_props[exam_mdl]).detach().squeeze()

cmp_mdl = GNLDRMdl(n_latent_vars=n_latent_vars, m=Identity(), lm=fit_lm, mn=fit_mn, psi=fit_psi, s=fit_s)

plt.figure()
true_mdl.compare_models(ind_true_mdls[exam_mdl], cmp_mdl)

## Look at true and estimated intermediate latents

In [None]:
intermediate_z_true = m_true(ind_data[exam_mdl][0]).detach().numpy()
intermediate_z_fit = ip_m_fit(ip_posteriors[exam_mdl].latent_post.mns).detach().numpy()

In [None]:
aligned_mn, aligned_lm, w, aligned_intermediate_z = align_intermediate_spaces(lm0=ind_true_mdls[exam_mdl].lm.detach().numpy(), 
                          mn0=ind_true_mdls[exam_mdl].mn.detach().numpy(),
                          s0=ind_true_mdls[exam_mdl].s.detach().numpy(),
                          lm1=fit_lm.detach().numpy(), 
                          mn1=fit_mn.detach().numpy(),
                          s1=fit_s.detach().numpy(), 
                          int_z0=intermediate_z_true, 
                          int_z1=intermediate_z_fit, 
                          align_by_params=True)

In [None]:
plt.figure()
for i in range(3):
    plt.subplot(3,1,i+1)
    plt.plot(intermediate_z_true[:, i], aligned_intermediate_z[:, i], 'r.')
    plt.plot([-4, 4], [-4, 4], 'k-')

In [None]:
a = plot_three_dim_pts(intermediate_z_true)
plot_three_dim_pts(aligned_intermediate_z, a=a)

## Compare means conditioned on latents

In [None]:
true_mns = ind_true_mdls[exam_mdl].cond_mean(ind_data[exam_mdl][0]).detach().numpy()
fit_mns = ip_vi_collections[exam_mdl].mdl.cond_mean(z=ip_posteriors[exam_mdl].latent_post.mns, 
                                                    lm=fit_lm, 
                                                    mn=fit_mn, 
                                                    s=fit_s, 
                                                    psi=fit_psi).detach().numpy()

In [None]:
plt.figure()
cmp_n_mats([true_mns, fit_mns, fit_mns-true_mns], show_colorbars=True)

## View latents

In [None]:
orig_z = ind_data[exam_mdl][0].numpy()
fit_z = ip_posteriors[exam_mdl].latent_post.mns.detach().numpy()

In [None]:
clrs = assign_colors_to_pts(orig_z, lims=np.asarray([[-2, 2], [-2, 2]]))

In [None]:
plt.figure()
plt.subplot(1,2,1)
for pt_i, pt in enumerate(orig_z):
    plt_x = pt[0]
    plt_y = pt[1]
    
    plt.plot(plt_x, plt_y, 'o', color=clrs[pt_i,:])

plt.subplot(1,2,2)
for pt_i, pt in enumerate(fit_z):
    plt_x = pt[0]
    plt_y = pt[1]
    
    plt.plot(plt_x, plt_y, 'o', color=clrs[pt_i,:])