Notebook for initial development and testing for synthesizing FA models.

The user specifies a number of individuals we observe data from.  For each of these individuals, we create a random number of observed variables (e.g., neurons) and associate with each of these variables some random properties (e.g., position or genetic information) generated uniformly from the unit square. 

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%matplotlib notebook

In [5]:
import matplotlib.pyplot as plt
import numpy as np
import sklearn.decomposition
import torch

from janelia_core.math.basic_functions import optimal_orthonormal_transform
from janelia_core.ml.torch_distributions import CondGaussianDistribution
from janelia_core.ml.torch_distributions import CondGammaDistribution
from janelia_core.ml.extra_torch_modules import ConstantBoundedFcn
from janelia_core.ml.extra_torch_modules import Tanh
from janelia_core.visualization.image_generation import generate_2d_fcn_image
from janelia_core.visualization.matrix_visualization import cmp_n_mats
from janelia_core.ml.utils import list_torch_devices
from janelia_core.ml.utils import torch_mod_to_fcn

from probabilistic_model_synthesis.fa import FAMdl
from probabilistic_model_synthesis.fa import Fitter
from probabilistic_model_synthesis.fa import generate_basic_posteriors
from probabilistic_model_synthesis.fa import generate_simple_prior_collection
from probabilistic_model_synthesis.fa import VICollection
from probabilistic_model_synthesis.math import MeanFcnTransformer
from probabilistic_model_synthesis.math import StdFcnTransformer
from probabilistic_model_synthesis.visualization import plot_torch_dist

## Parameters go here

In [6]:
# Number of individuals we simulate observing data from 
n_individuals = 5

# Range of the number of variables we observe from each individual - the actual number of variables we observe from an
# individual will be pulled uniformly from this range (inclusive)
n_var_range = [1000, 1200]

# Range of the number of samples we observe from each individual - the actual number we observe from each individual
# will be unformly from this range (inclusive)
n_smps_range = [1000, 1500]

# Number of latent variables in the model
n_latent_vars = 3

# True if we should use GPUs for fitting if they are available
use_gpus = True

## Create the true prior distributions that relate parameters in the model to variable (e.g., neuron) properties

In [7]:
true_priors = generate_simple_prior_collection(n_prop_vars=2, n_latent_vars=n_latent_vars, 
                                               lm_mn_w_init_std=1.0, lm_std_w_init_std=1.0,
                                               mn_mn_w_init_std=1.0, mn_std_w_init_std=1.0,
                                               psi_conc_f_w_init_std=2.0, psi_rate_f_w_init_std=1.0, 
                                               psi_conc_bias_mn=10.0, psi_rate_bias_mn=3.0)

## Generate properties

In [8]:
ind_n_vars = np.random.randint(n_var_range[0], n_var_range[1]+1, n_individuals)
ind_n_smps = np.random.randint(n_smps_range[0], n_smps_range[1]+1, n_individuals)
ind_props = [torch.rand(size=[n_vars,2]) for n_vars in ind_n_vars]

## Generate true FA models

In [9]:
with torch.no_grad():
    ind_true_fa_mdls = [FAMdl(lm=true_priors.lm_prior.sample(props), mn=true_priors.mn_prior.sample(props).squeeze(), 
                           psi=(true_priors.psi_prior.sample(props).squeeze()))
                        for props in ind_props]

## Generate data from each model

In [10]:
with torch.no_grad():
    ind_data = [mdl.sample(n_smps) for n_smps, mdl in zip(ind_n_smps, ind_true_fa_mdls)]

## Fit new models together

In [11]:
fit_priors = generate_simple_prior_collection(n_prop_vars=2, n_latent_vars=n_latent_vars)
fit_posteriors = generate_basic_posteriors(n_obs_vars=ind_n_vars, n_smps=ind_n_smps, n_latent_vars=n_latent_vars)

fit_mdls = [FAMdl(lm=None, mn=None, psi=None) for i in range(n_individuals)]

vi_collections = [VICollection(data=data_i[1], props=props_i, mdl=mdl_i, posteriors=posteriors_i) 
                  for data_i, props_i,mdl_i, posteriors_i in zip(ind_data, ind_props, fit_mdls, fit_posteriors)]

## Set initial values of posteriors

In [12]:
if use_gpus:
    devices, _ = list_torch_devices()
else:
    devices = [torch.device('cpu')]
    
fitter = Fitter(vi_collections=vi_collections, priors=fit_priors, devices=devices)

Found 1 GPUs


In [13]:
fitter.distribute(distribute_data=True, devices=devices)
logs = [fitter.fit(1000, milestones=[100], update_int=100, init_lr=.1, skip_lm_kl=False, 
                 skip_mn_kl=False, skip_psi_kl=False) for fit_r in range(1)]


Obj: 6.87e+07
----------------------------------------
NELL: 1.14e+07, 1.47e+07, 1.58e+07, 1.35e+07, 1.33e+07
Latent KL: 3.93e+01, 5.44e+01, 5.46e+01, 4.76e+01, 1.43e+01
LM KL: 5.47e+03, 5.52e+03, 5.90e+03, 6.19e+03, 6.45e+03
Mn KL: 2.04e+03, 1.99e+03, 2.04e+03, 2.05e+03, 2.24e+03
Psi KL: 2.16e+02, 2.05e+02, 2.13e+02, 2.17e+02, 2.33e+02
----------------------------------------
LR: 0.1
Elapsed time (secs): 32.95114707946777
----------------------------------------
CPU cur memory used (GB): 3.04e+00
GPU_0 cur memory used (GB): 3.06e-02, max memory used (GB): 3.06e-02

Obj: 1.39e+07
----------------------------------------
NELL: 2.25e+06, 2.96e+06, 3.11e+06, 2.77e+06, 2.65e+06
Latent KL: 1.19e+04, 1.63e+04, 1.62e+04, 1.35e+04, 1.26e+04
LM KL: 1.19e+04, 1.20e+04, 1.26e+04, 1.28e+04, 1.33e+04
Mn KL: 3.20e+03, 3.21e+03, 3.31e+03, 3.22e+03, 3.48e+03
Psi KL: 2.90e+02, 2.77e+02, 2.97e+02, 2.96e+02, 3.32e+02
----------------------------------------
LR: 0.010000000000000002
Elapsed time (secs): 

## Move the VI collections and priors back to cpu

In [14]:
fitter.distribute(devices=[torch.device('cpu')])

## Examine lots of fitting performance

In [15]:
fitter.plot_log(logs[0])

[True, True, True, True, True, True]


<IPython.core.display.Javascript object>

## Look at model fits

In [16]:
exam_mdl = 0

In [17]:
fit_lm = vi_collections[exam_mdl].posteriors.lm_post(ind_props[exam_mdl]).detach().squeeze()
fit_mn = vi_collections[exam_mdl].posteriors.mn_post(ind_props[exam_mdl]).detach().squeeze()
fit_psi = vi_collections[exam_mdl].posteriors.psi_post.mode(ind_props[exam_mdl]).detach().squeeze()
#fit_psi = vi_collections[exam_mdl].mdl.psi

cmp_mdl = FAMdl(lm=fit_lm, mn=fit_mn, psi=fit_psi)
true_mdl = ind_true_fa_mdls[exam_mdl]

plt.figure()
true_mdl.compare_models(true_mdl, cmp_mdl)

<IPython.core.display.Javascript object>

(1088, 3)


### Visualize paraemters of the true prior distributions over the loading matrices

In [18]:
for d in range(n_latent_vars):
    plt.figure(figsize=(9,3))
    plot_torch_dist(mn_f=true_priors.lm_prior.mn_f, std_f=true_priors.lm_prior.std_f, vis_dim=d, 
                    extra_title_str = ', d=' + str(d))

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

### Visualize paraemters of the fit prior distributions over the loading matrices

In [19]:
rnd_vls = torch.rand(1000,2)
o = optimal_orthonormal_transform(true_priors.lm_prior(rnd_vls).detach().numpy(), 
                                  fit_priors.lm_prior(rnd_vls).detach().numpy())
fit_lm_mn_fcn = MeanFcnTransformer(o=o.transpose(), f=fit_priors.lm_prior.mn_f)
fit_lm_std_fcn = StdFcnTransformer(o=o.transpose(), f=fit_priors.lm_prior.std_f)

In [20]:
for d in range(n_latent_vars):
    plt.figure(figsize=(9,3))
    plot_torch_dist(mn_f=fit_lm_mn_fcn, std_f=fit_lm_std_fcn, vis_dim=d, 
                    extra_title_str = ', d=' + str(d))

<IPython.core.display.Javascript object>

torch.Size([1000000, 3])


<IPython.core.display.Javascript object>

torch.Size([1000000, 3])


<IPython.core.display.Javascript object>

torch.Size([1000000, 3])


### Visualize parameters of the true prior distribution over the means

In [21]:
plt.figure(figsize=(9,3))
plot_torch_dist(mn_f=true_priors.mn_prior.mn_f, std_f=true_priors.mn_prior.std_f)

<IPython.core.display.Javascript object>

### Visualize parameters of the fit prior distribution over the means

In [22]:
plt.figure(figsize=(9,3))
plot_torch_dist(mn_f=fit_priors.mn_prior.mn_f, std_f=fit_priors.mn_prior.std_f)

<IPython.core.display.Javascript object>

### Visualize parameters of the true prior distribution over private variances

In [23]:
plt.figure(figsize=(9,3))
plot_torch_dist(mn_f=true_priors.psi_prior.forward, std_f=true_priors.psi_prior.std)

<IPython.core.display.Javascript object>

### Visualize parameters of the fit prior distribution over private variances

In [24]:
plt.figure(figsize=(9,3))
plot_torch_dist(mn_f=fit_priors.psi_prior.forward, std_f=fit_priors.psi_prior.std)

<IPython.core.display.Javascript object>

## Visualize latent estimates for an example model

In [25]:
ex_s = 0

In [26]:
# Learn transformation to put estimated latents into same space as true latents
with torch.no_grad():
    true_lm = ind_true_fa_mdls[ex_s].lm.numpy()
    est_lm = fit_posteriors[ex_s].lm_post(ind_props[ex_s]).numpy()
    o = optimal_orthonormal_transform(true_lm, est_lm)

In [27]:
# Get estimated latents in the right space
est_latents = np.matmul(fit_posteriors[ex_s].latent_post.mns.detach().numpy(), o)

In [28]:
# Visualize latents

In [29]:
true_latents = ind_data[ex_s][0].numpy()

In [30]:
plt.figure()
for l_i in range(n_latent_vars):
    ax = plt.subplot(n_latent_vars, 1, l_i+1)
    plt.plot(true_latents[:, l_i], 'bo')
    plt.plot(est_latents[:, l_i], 'r.')
    

<IPython.core.display.Javascript object>