We simulate a small amount of data from many individual FA models and then we fit the models together and alone and measure the benefit (in terms of the likelihood of held-out test data) of fitting the models together vs. fitting them together. 


For comparison, when fitting FA models individually we use a standard FA fitting package to estimate point estimates for model parameters.  When evaluating models that have been fit together, we use the modes of posterior distributions as point estimates. 


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sklearn.decomposition
import matplotlib.pyplot as plt
import numpy as np
import torch

from janelia_core.ml.utils import list_torch_devices

from probabilistic_model_synthesis.fa import FAMdl
from probabilistic_model_synthesis.fa import Fitter
from probabilistic_model_synthesis.fa import generate_basic_posteriors
from probabilistic_model_synthesis.fa import generate_simple_prior_collection
from probabilistic_model_synthesis.fa import VICollection

## Parameters go here

In [3]:
# Number of individuals we simulate observing data from 
n_individuals = 10

# Range of the number of variables we observe from each individual - the actual number of variables we observe from an
# individual will be pulled uniformly from this range (inclusive)
n_var_range = [100, 120]

# Range of the number of samples we observe for fittig from each individual - the actual number we observe 
# from each individual will be unformly from this range (inclusive)
n_fitting_smps_range = [10, 15]

# Number of latent variables in the model
n_latent_vars = 3

# Number of samples we generate when testing each model
n_test_smps = 1000

## Create the true prior distributions that relate parameters in the model to variable (e.g., neuron) properties

In [4]:
true_priors = generate_simple_prior_collection(n_prop_vars=2, n_latent_vars=n_latent_vars, 
                                               lm_mn_w_init_std=1.0, lm_std_w_init_std=.1,
                                               mn_mn_w_init_std=1.0, mn_std_w_init_std=1.0,
                                               psi_conc_f_w_init_std=2.0, psi_rate_f_w_init_std=1.0, 
                                               psi_conc_bias_mn=10.0, psi_rate_bias_mn=5.0)

## Generate properties

In [5]:
ind_n_vars = np.random.randint(n_var_range[0], n_var_range[1]+1, n_individuals)
ind_n_smps = np.random.randint(n_fitting_smps_range[0], n_fitting_smps_range[1]+1, n_individuals)
ind_props = [torch.rand(size=[n_vars,2]) for n_vars in ind_n_vars]

## Generate true FA models

In [6]:
with torch.no_grad():
    ind_true_fa_mdls = [FAMdl(lm=true_priors.lm_prior.sample(props), mn=true_priors.mn_prior.sample(props).squeeze(), 
                           psi=(true_priors.psi_prior.sample(props).squeeze()))
                        for props in ind_props]

## Generate data for fitting from each model

In [7]:
with torch.no_grad():
    ind_train_data = [mdl.sample(n_smps) for n_smps, mdl in zip(ind_n_smps, ind_true_fa_mdls)]

## Fit FA models together

In [8]:
devices, _ = list_torch_devices()

Found 3 GPUs


In [9]:
fit_priors = generate_simple_prior_collection(n_prop_vars=2, n_latent_vars=n_latent_vars)

fit_posteriors = generate_basic_posteriors(n_obs_vars=ind_n_vars, n_smps=ind_n_smps, n_latent_vars=n_latent_vars)

fit_mdls = [FAMdl(lm=None, mn=None, psi=None) for i in range(n_individuals)]

vi_collections = [VICollection(data=data_i[1], props=props_i, mdl=mdl_i, posteriors=posteriors_i) 
                  for data_i, props_i,mdl_i, posteriors_i in zip(ind_train_data, ind_props, fit_mdls, fit_posteriors)]

In [10]:
fitter = Fitter(vi_collections=vi_collections, priors=fit_priors)

In [11]:
fitter.distribute(distribute_data=True, devices=devices)

In [12]:
logs = [fitter.fit(1000, milestones=[100], update_int=100, init_lr=.1, skip_lm_kl=False, 
                 skip_mn_kl=False, skip_psi_kl=False) for fit_r in range(1)]


Obj: 1.12e+05
----------------------------------------
NELL: 6.63e+03, 9.74e+03, 1.61e+04, 9.90e+03, 1.53e+04, 1.03e+04, 8.47e+03, 1.80e+04, 1.11e+04, 7.76e+03
Latent KL: 5.10e-01, 4.41e-01, 2.22e-01, 3.23e-01, 2.73e-01, 3.23e-01, 6.84e-01, 5.46e-01, 3.78e-01, 2.62e-01
LM KL: 5.75e+02, 6.19e+02, 6.04e+02, 6.24e+02, 6.72e+02, 5.80e+02, 5.52e+02, 5.10e+02, 6.01e+02, 6.37e+02
Mn KL: 2.22e+02, 2.20e+02, 2.06e+02, 2.02e+02, 2.09e+02, 2.01e+02, 1.92e+02, 1.71e+02, 2.12e+02, 2.17e+02
Psi KL: 2.37e+01, 2.45e+01, 2.46e+01, 2.38e+01, 2.48e+01, 2.24e+01, 2.18e+01, 2.12e+01, 2.44e+01, 2.52e+01
----------------------------------------
LR: 0.1
Elapsed time (secs): 2.4620072841644287
----------------------------------------
CPU cur memory used (GB): 4.53e+00
GPU_0 cur memory used (GB): 1.79e-04, max memory used (GB): 1.79e-04
GPU_1 cur memory used (GB): 1.09e-04, max memory used (GB): 1.09e-04
GPU_2 cur memory used (GB): 1.11e-04, max memory used (GB): 1.11e-04

Obj: 3.16e+04
-----------------------


Obj: 2.85e+04
----------------------------------------
NELL: 2.76e+03, 2.48e+03, 2.80e+03, 2.60e+03, 2.35e+03, 2.40e+03, 2.76e+03, 2.69e+03, 2.98e+03, 2.05e+03
Latent KL: 7.28e+01, 7.46e+01, 9.00e+01, 7.33e+01, 6.99e+01, 7.95e+01, 8.03e+01, 8.26e+01, 8.46e+01, 6.01e+01
LM KL: 3.17e+01, 2.51e+01, 1.14e+02, 3.82e+01, 9.68e+01, 1.88e+01, 3.50e+01, 6.65e+01, 3.01e+01, 1.38e+01
Mn KL: 1.26e+02, 1.26e+02, 1.32e+02, 1.33e+02, 1.32e+02, 1.36e+02, 1.27e+02, 1.12e+02, 1.32e+02, 1.20e+02
Psi KL: 2.04e+01, 2.11e+01, 2.37e+01, 1.84e+01, 1.98e+01, 1.82e+01, 1.97e+01, 2.31e+01, 2.54e+01, 2.50e+01
----------------------------------------
LR: 0.010000000000000002
Elapsed time (secs): 223.74925136566162
----------------------------------------
CPU cur memory used (GB): 4.53e+00
GPU_0 cur memory used (GB): 1.79e-04, max memory used (GB): 1.79e-04
GPU_1 cur memory used (GB): 1.09e-04, max memory used (GB): 1.09e-04
GPU_2 cur memory used (GB): 1.11e-04, max memory used (GB): 1.11e-04


## Move combined models to CPU

In [18]:
fitter.distribute(devices=[torch.device('cpu')])

## Fit FA models individually

In [19]:
alone_models = [None]*n_individuals
for ind_i in range(n_individuals):
    mdl = sklearn.decomposition.FactorAnalysis(n_components=n_latent_vars)
    mdl.fit(ind_train_data[ind_i][1].numpy())
    alone_models[ind_i] = mdl

## Measure performance of the fit models on new test data

In [20]:
with torch.no_grad():
    ind_test_data = [mdl.sample(n_test_smps) for mdl in ind_true_fa_mdls]

In [21]:
ind_test_ll = [None]*n_individuals
with torch.no_grad():
    eval_mdl = FAMdl() # Model object we ust just for evaluation 
    for ind_i in range(n_individuals):
    
        mdl_test_data = ind_test_data[ind_i][1]
        
        # Calculate log-likelihood using model fit alone
        
        alone_lm = torch.tensor(alone_models[ind_i].components_.transpose())
        alone_mn = torch.tensor(alone_models[ind_i].mean_)
        alone_psi = torch.tensor(alone_models[ind_i].noise_variance_)
    
        alone_ll = torch.sum(eval_mdl.log_prob(x=mdl_test_data, lm=alone_lm, mn=alone_mn, psi=alone_psi))
        alone_ll = (alone_ll/n_test_smps).numpy().item()
        
        # Calculate log-likelihood using model fit with the other models
        
        comb_lm = vi_collections[ind_i].posteriors.lm_post(ind_props[ind_i])
        comb_mn = vi_collections[ind_i].posteriors.mn_post(ind_props[ind_i]).squeeze()
        comb_psi = vi_collections[ind_i].posteriors.psi_post.mode(ind_props[ind_i]).squeeze()
                             
        comb_ll = torch.sum(eval_mdl.log_prob(x=mdl_test_data, lm=comb_lm, mn=comb_mn, psi=comb_psi))
        comb_ll = (comb_ll/n_test_smps).numpy().item()
        
        ind_test_ll[ind_i] = {'alone': alone_ll, 'comb': comb_ll}
            

In [22]:
ind_test_ll

[{'alone': -273.84177291126724, 'comb': -209.60745239257812},
 {'alone': -321.52913658438297, 'comb': -222.1905059814453},
 {'alone': -253.67458512634033, 'comb': -211.0919952392578},
 {'alone': -273.76329503560646, 'comb': -214.62393188476562},
 {'alone': -309.88419375405914, 'comb': -225.42991638183594},
 {'alone': -248.20801485929906, 'comb': -198.0729217529297},
 {'alone': -225.41397381433777, 'comb': -195.1164093017578},
 {'alone': -235.4733753351554, 'comb': -189.27928161621094},
 {'alone': -250.53135460611054, 'comb': -213.1034393310547},
 {'alone': -350.83425068866603, 'comb': -218.47979736328125}]