Notebook for development and testing of code for the second version of fitting latent regression models across multiple subjects with variational inference.  The main advance in version 2.0 of the code is the ability to support distributions across additional model parameters (not just the modes). 

In particular we generate models of how one neural population drives another as follows:

1) The user specified a number of subjects and how many neurons are in each population for each of those subjects. Neuron locations for each subject are than randomly drawn from a uniform distribution on the unit square. 

2) Our models include only neural dynamics (no stimulus input or behavioral output) and we use an identity mapping in 
the low d space

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import copy
import matplotlib.pyplot as plt
import numpy as np
import torch

from janelia_core.ml.datasets import TimeSeriesDataset
from janelia_core.ml.latent_regression.group_maps import IdentityMap
from janelia_core.ml.latent_regression.subject_models import LatentRegModel
from janelia_core.ml.latent_regression.vi import MultiSubjectVIFitter
from janelia_core.ml.latent_regression.vi import PriorCollection
from janelia_core.ml.latent_regression.vi import SubjectVICollection
from janelia_core.ml.latent_regression.vi import predict_with_truth
from janelia_core.ml.torch_distributions import CondGaussianDistribution
from janelia_core.ml.torch_distributions import CondMatrixHypercubePrior
from janelia_core.ml.torch_distributions import CondMatrixProductDistribution
from janelia_core.ml.torch_distributions import MatrixGaussianProductDistribution
from janelia_core.ml.utils import torch_mod_to_fcn
from janelia_core.ml.utils import list_torch_devices
from janelia_core.visualization.image_generation import generate_dot_image_3d
from janelia_core.visualization.image_visualization import visualize_2d_function


In [3]:
%matplotlib notebook

## Parameters and model specification goes here

In [4]:
# Here we specify the number of subjects (by the length of the list) and number of neurons that will be present
# each population for each subject

n_subj_neurons = [(10000, 10000),
                  (9000, 9000),
                  (11000, 11000)]

# Number of samples of data to generate for each subject
n_smps = 20000

# True if we should used shared posteriors among subjects
use_shared_posts = False 


### Parameters for creating hypercube functions

In [5]:
hc_fcn_params = {'n_divisions_per_dim': [50, 50], 
                 'dim_ranges': np.asarray([[-.1, 1.1], [-.1, 1.1]]), 
                 'n_div_per_hc_side_per_dim': [1, 1]}

### Here we specify the mean and standard deviation functions for the different parameters of the models

#### Specify some helper functions

In [6]:
class exp2d(torch.nn.Module):
    def __init__(self, ctr, std, gain, offset):
        #assert(ctr.shape == [1, 2])
        #assert(std.shape == [1,2])
        
        super().__init__()
        self.ctr = torch.nn.Parameter(ctr)
        self.std = torch.nn.Parameter(std)
        self.gain = torch.nn.Parameter(gain)
        self.offset = torch.nn.Parameter(offset)
        
    def forward(self, x):
        return (self.gain*torch.exp(-1*torch.sum((x - self.ctr)**2/self.std, dim=1)) + self.offset).unsqueeze(1)  

class constantF(torch.nn.Module):
    def __init__(self, vl):
        super().__init__()
        self.vl = vl
        
    def forward(self, x):
        return self.vl*torch.ones([x.shape[0], 1])

#### Specify the distributions over p and u modes

Here we implicitly define the number of modes by the number of distributions we define

In [7]:
p_ctrs = [torch.tensor([.1, .1]), torch.tensor([.9, .9])]
true_p_dists = CondMatrixProductDistribution([CondGaussianDistribution(mn_f=exp2d(ctr = c, 
                                                                                  std = torch.tensor([1.0, 1.0]),
                                                                                  gain = torch.tensor(1.0), 
                                                                                  offset = torch.tensor(0.0)),
                                                                        std_f=constantF(.1)) 
                                              for c in p_ctrs])


 
u_ctrs = [torch.tensor([.1, .1]), torch.tensor([.9, .9])]
true_u_dists = CondMatrixProductDistribution([CondGaussianDistribution(mn_f=exp2d(ctr = c, 
                                                                                  std = torch.tensor([1.0, 1.0]),
                                                                                  gain = torch.tensor(1.0), 
                                                                                  offset = torch.tensor(0.0)),
                                                                        std_f=constantF(.1)) 
                                              for c in u_ctrs])


#### Specify the distributions over scales and offsets and direct connections

In [8]:
true_scale_dist = CondGaussianDistribution(mn_f=exp2d(ctr = torch.tensor([.5, .5]), 
                                                    std = torch.tensor([.5, .5]),
                                                    gain = torch.tensor(10.0), 
                                                    offset = torch.tensor(0.0)),
                                        std_f=constantF(.1))

true_offset_dist = CondGaussianDistribution(mn_f=exp2d(ctr = torch.tensor([.5, .5]), 
                                                    std = torch.tensor([1.0, 1.0]),
                                                    gain = torch.tensor(10.0), 
                                                    offset = torch.tensor(0.0)),
                                        std_f=constantF(.1))

true_psi_dist = CondGaussianDistribution(mn_f=exp2d(ctr = torch.tensor([.5, .5]), 
                                                    std = torch.tensor([6.0, 6.0]),
                                                    gain = torch.tensor(.2), 
                                                    offset = torch.tensor(.1)),
                                        std_f=constantF(.01))

true_direct_map_dist = CondGaussianDistribution(mn_f=exp2d(ctr = torch.tensor([.8, .5]), 
                                                    std = torch.tensor([6.0, 6.0]),
                                                    gain = torch.tensor(.2), 
                                                    offset = torch.tensor(.1)),
                                        std_f=constantF(.01))

## Here we generate our true subject models and data

In [9]:
n_modes = len(true_p_dists.dists)
n_subjs = len(n_subj_neurons)
true_subj_models = [None]*n_subjs
true_data = [None]*n_subjs

for s_i in range(n_subjs):
    
    with torch.no_grad():
        # Generate neuron locations
        p_neuron_locs = torch.rand(size=[n_subj_neurons[s_i][0], 2])
        u_neuron_locs = torch.rand(size=[n_subj_neurons[s_i][1], 2])
    
        # Generate modes
        p_modes = true_p_dists.form_standard_sample(true_p_dists.sample(p_neuron_locs))
        u_modes = true_u_dists.form_standard_sample(true_u_dists.sample(u_neuron_locs))
        
        # Generate scales and offsets
        scales = true_scale_dist.form_standard_sample(true_scale_dist.sample(u_neuron_locs)).squeeze()
        offsets = true_offset_dist.form_standard_sample(true_offset_dist.sample(u_neuron_locs)).squeeze()
        
        # Generate direct maps
        direct_mappings = true_direct_map_dist.form_standard_sample(true_direct_map_dist.sample(u_neuron_locs)).squeeze()
        
        # Generate psi
        psi = true_psi_dist.form_standard_sample(true_psi_dist.sample(u_neuron_locs)).squeeze()
        assert(torch.all(psi > 0))
    
        s_mdl = LatentRegModel(d_in = [n_subj_neurons[s_i][0]], d_out = [n_subj_neurons[s_i][1]], 
                               d_proj=[n_modes], d_trans=[n_modes], 
                               m=IdentityMap(),
                               s=[torch.nn.Identity()], 
                               use_scales=True,
                               use_offsets=True,
                               direct_pairs=[(0,0)], 
                               assign_direct_pair_mappings=True)
    
        s_mdl.u[0].data = u_modes
        s_mdl.p[0].data = p_modes
        s_mdl.offsets[0].data = offsets
        s_mdl.scales[0].data = scales
        s_mdl.psi[0].data = psi
        s_mdl.direct_mappings[0].data = direct_mappings
    
        true_subj_models[s_i] = {'mdl': s_mdl, 'p_neuron_locs': p_neuron_locs, 'u_neuron_locs': u_neuron_locs}
    
        
        p_data = [torch.randn(size=[n_smps, n_subj_neurons[s_i][0]])]
        u_data = s_mdl.generate(p_data)
        
        # Delay u data with respect to u data (since we model u_{t+1} as a function of p_t)
        p_data[0] = p_data[0][1:,:]
        u_data[0] = u_data[0][0:-1, :]
        
        
        
        true_data[s_i] = (p_data, u_data)

## Now we set things up for fitting with variational inference

### Define prior distributions

In [10]:
p_prior = CondMatrixHypercubePrior(n_cols=n_modes, mn_hc_params=hc_fcn_params, std_hc_params=hc_fcn_params, 
                                   min_std=.00001)

u_prior = CondMatrixHypercubePrior(n_cols=n_modes, mn_hc_params=hc_fcn_params, std_hc_params=hc_fcn_params, 
                                   min_std=.00001)

scales_prior = CondMatrixHypercubePrior(n_cols=1, mn_hc_params=hc_fcn_params, std_hc_params=hc_fcn_params, 
                                   min_std=.00001, mn_init=1.0)

offsets_prior = CondMatrixHypercubePrior(n_cols=1, mn_hc_params=hc_fcn_params, std_hc_params=hc_fcn_params, 
                                   min_std=.00001, mn_init=0.0)

direct_mappings_prior = CondMatrixHypercubePrior(n_cols=1, mn_hc_params=hc_fcn_params, std_hc_params=hc_fcn_params, 
                                   min_std=.00001, mn_init=0.0)

prior_collection = PriorCollection(p_dists=[p_prior], u_dists=[u_prior], psi_dists=[None], 
                                   scale_dists=[scales_prior], offset_dists=[offsets_prior], 
                                   direct_mapping_dists=[direct_mappings_prior])

### Define subject models and posteriors for each subject

In [11]:
vi_collections = [None]*n_subjs
for s_i in range(n_subjs):
    
    # Create subject model for fitting
    s_mdl = LatentRegModel(d_in = [n_subj_neurons[s_i][0]], d_out = [n_subj_neurons[s_i][1]], 
                           d_proj=[n_modes], d_trans=[n_modes], m=IdentityMap(), s=[torch.nn.Identity()],
                           use_scales=True, use_offsets=True, direct_pairs=[(0,0)],
                           assign_p_modes=False, assign_u_modes=False, assign_scales=False, assign_offsets=False,
                           assign_direct_pair_mappings=False,
                           assign_psi=True) # We will fit point estimates for psi (and not distributions)    
    
    # Create posterior distributions 
    if use_shared_posts:
        if s_i == 0:
            p_post = CondMatrixHypercubePrior(n_cols=n_modes, mn_hc_params=hc_fcn_params, 
                                              std_hc_params=hc_fcn_params, min_std=.00001, 
                                              mn_init=.1)
            u_post = CondMatrixHypercubePrior(n_cols=n_modes, mn_hc_params=hc_fcn_params, 
                                              std_hc_params=hc_fcn_params, min_std=.00001,
                                              mn_init=.1)
            scale_post = CondMatrixHypercubePrior(n_cols=n_modes, mn_hc_params=hc_fcn_params, 
                                                  std_hc_params=hc_fcn_params, min_std=.00001, 
                                                  mn_init=1.0)
            offset_post = CondMatrixHypercubePrior(n_cols=n_modes, mn_hc_params=hc_fcn_params, 
                                                   std_hc_params=hc_fcn_params, min_std=.00001,
                                                   mn_init=0.0)
            direct_mappings_post = CondMatrixHypercubePrior(n_cols=n_modes, mn_hc_params=hc_fcn_params, 
                                                            std_hc_params=hc_fcn_params, min_std=.00001,
                                                            mn_init=0.0)
        else:
            pass # Do nothing, we can just keep using the posteriors we already created for subject 1
    else:
        p_post = MatrixGaussianProductDistribution(shape=[n_subj_neurons[s_i][0], n_modes], mn_mn=.01, mn_std=.001)
        u_post = MatrixGaussianProductDistribution(shape=[n_subj_neurons[s_i][1], n_modes], mn_mn=.01, mn_std=.001)
        scale_post = MatrixGaussianProductDistribution(shape=[n_subj_neurons[s_i][1], 1], mn_mn=1.0, mn_std=.001)
        offset_post = MatrixGaussianProductDistribution(shape=[n_subj_neurons[s_i][1], 1], mn_mn=0.0, mn_std=.001)
        direct_mappings_post = MatrixGaussianProductDistribution(shape=[n_subj_neurons[s_i][1], 1], mn_mn=0.0, mn_std=.001)
    
    # Package data
    data = TimeSeriesDataset([true_data[s_i][0][0], true_data[s_i][1][0]])[:]
    
    vi_collections[s_i] = SubjectVICollection(s_mdl=s_mdl, p_dists=[p_post], u_dists=[u_post], psi_dists=[None],
                                        scale_dists=[scale_post], 
                                        offset_dists=[offset_post],
                                        direct_mappings_dists=[direct_mappings_post],
                                        data=data, input_grps=[0], output_grps=[1], 
                                        props=[true_subj_models[s_i]['p_neuron_locs'], 
                                               true_subj_models[s_i]['u_neuron_locs']],
                                        p_props = [0], u_props=[1], psi_props=[None], 
                                        scale_props=[1], offset_props=[1], 
                                        direct_mapping_props=[1], min_var=[.0001])

## Create the fitter 

In [12]:
fitter = MultiSubjectVIFitter(s_collections=vi_collections, prior_collection=prior_collection)

## Fit the model

In [13]:
devices, _ = list_torch_devices()

Found 1 GPUs


In [14]:
fitter.distribute(devices, distribute_data=True)

In [15]:
logs0 = fitter.fit(n_epochs=1000, n_batches=2, update_int=100, learning_rates=.1, 
                  enforce_priors=(use_shared_posts==False))
logs1 = fitter.fit(n_epochs=1000, n_batches=2, update_int=100, learning_rates=.01, 
                  enforce_priors=(use_shared_posts==False))

*****************************************************
Epoch 0 complete.  Obj: 1.60e+14, LR: [0.1]
Model NLLs:  s_0: 5.16e+13, s_1: 4.27e+13, s_2: 6.54e+13
Subj P KLs:  s_0: 5.24e+05, s_1: 4.25e+05, s_2: 5.97e+05
Subj U KLs:  s_0: 9.15e+03, s_1: 8.22e+03, s_2: 1.00e+04
Subj Psi KLs:  s_0: 0.00e+00, s_1: 0.00e+00, s_2: 0.00e+00
Subj Scale KLs:  s_0: 8.13e+05, s_1: 7.29e+05, s_2: 8.94e+05
Subj Offsets KLs:  s_0: 7.65e+05, s_1: 6.96e+05, s_2: 8.81e+05
Subj Direct Mappings KLs:  s_0: 8.50e+05, s_1: 7.76e+05, s_2: 9.38e+05
Device memory allocated:  d_0: 7.46e+09
Device max memory allocated:  d_0: 9.42e+09
Elapsed time: 0.5644152164459229
*****************************************************
Epoch 100 complete.  Obj: 1.24e+10, LR: [0.1]
Model NLLs:  s_0: 4.00e+09, s_1: 3.47e+09, s_2: 4.96e+09
Subj P KLs:  s_0: 9.01e+04, s_1: 8.16e+04, s_2: 1.14e+05
Subj U KLs:  s_0: 9.11e+04, s_1: 8.20e+04, s_2: 1.06e+05
Subj Psi KLs:  s_0: 0.00e+00, s_1: 0.00e+00, s_2: 0.00e+00
Subj Scale KLs:  s_0: 4.03e+04

*****************************************************
Epoch 300 complete.  Obj: 8.79e+08, LR: [0.01]
Model NLLs:  s_0: 2.92e+08, s_1: 2.69e+08, s_2: 3.18e+08
Subj P KLs:  s_0: 1.34e+05, s_1: 1.21e+05, s_2: 1.54e+05
Subj U KLs:  s_0: 1.43e+05, s_1: 1.29e+05, s_2: 1.66e+05
Subj Psi KLs:  s_0: 0.00e+00, s_1: 0.00e+00, s_2: 0.00e+00
Subj Scale KLs:  s_0: 5.84e+04, s_1: 5.26e+04, s_2: 6.41e+04
Subj Offsets KLs:  s_0: 2.74e+04, s_1: 2.48e+04, s_2: 2.95e+04
Subj Direct Mappings KLs:  s_0: 5.37e+03, s_1: 4.97e+03, s_2: 5.66e+03
Device memory allocated:  d_0: 7.46e+09
Device max memory allocated:  d_0: 9.42e+09
Elapsed time: 94.38117504119873
*****************************************************
Epoch 400 complete.  Obj: 8.74e+08, LR: [0.01]
Model NLLs:  s_0: 2.88e+08, s_1: 2.39e+08, s_2: 3.46e+08
Subj P KLs:  s_0: 1.38e+05, s_1: 1.25e+05, s_2: 1.58e+05
Subj U KLs:  s_0: 1.47e+05, s_1: 1.32e+05, s_2: 1.70e+05
Subj Psi KLs:  s_0: 0.00e+00, s_1: 0.00e+00, s_2: 0.00e+00
Subj Scale KLs:  s_0: 6.05e

In [16]:
fitter.plot_log(logs1[0])

<IPython.core.display.Javascript object>

## Move everything to cpu

In [17]:
fitter.to('cpu')

## Look at predictions the models make on training data

In [18]:
s_preds = [predict_with_truth(s_coll, s_coll.data) for s_coll in vi_collections]

In [19]:
plt_s_i = 2
plot_v_i = 3
smp_inds = slice(0, 100)

plt.figure()
plt.plot(s_preds[plt_s_i]['truth'][0][smp_inds, plot_v_i], 'b-')
plt.plot(s_preds[plt_s_i]['pred'][0][smp_inds, plot_v_i], 'r-')

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x2b988f3f7da0>]

## Look at true and fit offset and scale distributions

In [53]:
# Offsets
#true_dist = true_offset_dist 
#fit_dist = offsets_prior.dists[0] 

# Scales
#true_dist = true_scale_dist
#fit_dist = scales_prior.dists[0]

# Direct mapings
true_dist = true_direct_map_dist
fit_dist = direct_mappings_prior.dists[0]

In [55]:
plt.figure()
visualize_2d_function(torch_mod_to_fcn(true_dist.mn_f), ax=plt.subplot(1,2, 1))
#plt.gca().get_images()[0].set_clim(0.0, 10.0)
visualize_2d_function(torch_mod_to_fcn(fit_dist.mn_f), ax=plt.subplot(1,2, 2), 
                      dim_0_range=[0, .99], dim_1_range=[0, .99])
#plt.gca().get_images()[0].set_clim(0.0, 10.0)

<IPython.core.display.Javascript object>

## Look at true and fit offset values compared to posteriors on a single neuron basis

In [37]:
vis_s_i = 2
e_shape = [21, 21, 1]

# Offsets
true_vls = true_subj_models[vis_s_i]['mdl'].offsets[0]
fit_dist = vi_collections[vis_s_i].offset_dists[0]

# Scales
true_vls = true_subj_models[vis_s_i]['mdl'].scales[0]
fit_dist = vi_collections[vis_s_i].scale_dists[0]

# Direct mappings
#true_vls = true_subj_models[vis_s_i]['mdl'].direct_mappings[0]
#fit_dist = vi_collections[vis_s_i].direct_mapping_dists[0]

In [38]:
vis_neuron_locs = true_subj_models[vis_s_i]['u_neuron_locs']
vis_true_offsets = true_vls.detach().numpy()
vis_fit_offsets = fit_dist.dists[0](vis_neuron_locs).detach().numpy()
vis_neuron_locs = 1000*np.concatenate([vis_neuron_locs.numpy(), np.zeros([vis_neuron_locs.shape[0], 1])], axis=1)

In [39]:
true_image = generate_dot_image_3d(image_shape=[1001, 1001, 1], dot_ctrs=vis_neuron_locs, 
                                   dot_vls=vis_true_offsets, 
                     ellipse_shape=e_shape) 

fit_image = generate_dot_image_3d(image_shape=[1001, 1001, 1], dot_ctrs=vis_neuron_locs, dot_vls=vis_fit_offsets, 
                     ellipse_shape=e_shape) 

In [40]:
plt.figure()
plt.subplot(1,2,1)
plt.imshow(np.squeeze(true_image))
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(np.squeeze(fit_image))
plt.colorbar()

<IPython.core.display.Javascript object>

<matplotlib.colorbar.Colorbar at 0x2b9892eb5fd0>

## Look at true and estimated distributions over modes

In [30]:
vis_m = 1

#### Learn a transformation to align modes

In [31]:
u_neuron_locs = true_subj_models[0]['u_neuron_locs']
true_u_modes = true_u_dists(u_neuron_locs).detach().cpu().numpy()
est_u_modes = prior_collection.u_dists[0](u_neuron_locs).detach().cpu().numpy()
#est_u_modes = vi_collections[0].u_dists[0](u_neuron_locs).detach().cpu().numpy()
u_neuron_locs = 1000*np.concatenate([u_neuron_locs.numpy(), np.zeros([u_neuron_locs.shape[0], 1])], axis=1)

p_neuron_locs = true_subj_models[0]['p_neuron_locs']
true_p_modes = true_p_dists(p_neuron_locs).detach().cpu().numpy()
#true_p_modes = true_subj_models[0]['mdl'].p[0].detach().cpu().numpy()
est_p_modes = prior_collection.p_dists[0](p_neuron_locs).detach().cpu().numpy()
#est_p_modes = vi_collections[0].p_dists[0](p_neuron_locs).detach().cpu().numpy()
#est_p_modes = vi_collections[0].s_mdl.p[0].detach().cpu().numpy()
p_neuron_locs = 1000*np.concatenate([p_neuron_locs.numpy(), np.zeros([p_neuron_locs.shape[0], 1])], axis=1)

In [32]:
mode_t = np.linalg.lstsq(est_u_modes, true_u_modes, rcond=None)
mode_t = mode_t[0]

In [33]:
est_u_modes_t = np.matmul(est_u_modes, mode_t)
est_p_modes_t = np.matmul(est_p_modes, np.linalg.inv(mode_t))

In [34]:
true_p_modes.shape

(10000, 2)

In [35]:
e_shape = [21, 21, 1]

true_u_image = generate_dot_image_3d(image_shape=[1001, 1001, 1], dot_ctrs=u_neuron_locs, 
                                   dot_vls=true_u_modes[:,vis_m], 
                                   ellipse_shape=e_shape) 

fit_u_image = generate_dot_image_3d(image_shape=[1001, 1001, 1], dot_ctrs=u_neuron_locs, 
                                  dot_vls=est_u_modes_t[:,vis_m], 
                                   ellipse_shape=e_shape) 

true_p_image = generate_dot_image_3d(image_shape=[1002, 1002, 1], dot_ctrs=p_neuron_locs, 
                                   dot_vls=true_p_modes[:,vis_m], 
                                   ellipse_shape=e_shape) 

fit_p_image = generate_dot_image_3d(image_shape=[1002, 1002, 1], dot_ctrs=p_neuron_locs, 
                                  dot_vls=est_p_modes[:,vis_m], 
                                  ellipse_shape=e_shape) 

In [44]:
plt.figure()
plt.subplot(1,2,1)
plt.imshow(np.squeeze(true_u_image), clim=[0, 1])
plt.colorbar()
plt.title('True u mode')
plt.subplot(1,2,2)
plt.imshow(np.squeeze(fit_u_image), clim=[0, 1])
plt.title('Est u mode')
plt.colorbar()

plt.figure()
plt.subplot(1,2,1)
plt.imshow(np.squeeze(true_p_image), clim=[0, 1])
plt.colorbar()
plt.title('True p mode')
plt.subplot(1,2,2)
plt.imshow(np.squeeze(fit_p_image), clim=[0, 1.4])
plt.title('Est p mode')
plt.colorbar()


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<matplotlib.colorbar.Colorbar at 0x2b98931ff080>