Here we fit models across fish, looking at the latent representations of different conditions for all fish in the same 
low-d space

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import copy
from itertools import chain
from pathlib import Path

from matplotlib import cm
import matplotlib.pyplot as plt
import numpy as np
from numpy.linalg import svd
import torch

import sklearn.decomposition as decomposition

from ahrens_wbo.annotations import label_subperiods
from ahrens_wbo.data_processing import load_and_preprocess_data
from janelia_core.ml.utils import list_torch_devices
from janelia_core.visualization.image_generation import max_project_pts
from janelia_core.visualization.image_generation import generate_dot_image_3d
from janelia_core.visualization.image_generation import scalar_3d_max_project
from janelia_core.visualization.matrix_visualization import cmp_n_mats
from probabilistic_model_synthesis.fa import FAMdl
from probabilistic_model_synthesis.fa import Fitter
from probabilistic_model_synthesis.fa import generate_basic_posteriors
from probabilistic_model_synthesis.fa import generate_hypercube_prior_collection
from probabilistic_model_synthesis.fa import generate_simple_prior_collection
from probabilistic_model_synthesis.fa import initialize_basic_posteriors
from probabilistic_model_synthesis.fa import orthonormalize
from probabilistic_model_synthesis.fa import PosteriorCollection
from probabilistic_model_synthesis.fa import VICollection

In [3]:
import warnings
warnings.resetwarnings()

In [4]:
%matplotlib notebook

  and should_run_async(code)


## Parameters go here

In [5]:
ps = dict()

# Specify the directory where all the raw data is found
ps['data_dir'] = r'/groups/bishop/bishoplab/projects/ahrens_wbo/data'

# Here we list the subject we want to fit and the conditions we want to fit on for each subject
ps['fit_specs'] = {8: ['omr_forward'], 
                   9: ['omr_left'], 
                   11: ['omr_right']} 
                   
# Here we specify if we fit in conditions with shock or not
ps['shock'] = False

# Specify the number of latent variables in the FA models
ps['n_latent_vars'] = 10

## See which devices are available for fitting

In [6]:
devices, _ = list_torch_devices()

Found 1 GPUs


## Load the data for each subject

In [7]:
subjects = list(ps['fit_specs'].keys())

In [8]:
datasets, neuron_locs = load_and_preprocess_data(data_folder=ps['data_dir'], 
                                                 subjects=subjects, 
                                                 neural_gain=10)

Done loading data for subject subject_8.
Done loading data for subject subject_9.
Done loading data for subject subject_11.


## Form the fitting data for each subject

In [9]:
all_subperiods = set(chain(*[v for v in ps['fit_specs'].values()]))
label_map = {sp: sp_i for sp_i, sp in enumerate(all_subperiods)}

In [10]:
fit_data = dict()
fit_labels = dict()
for s_n, dataset in datasets.items():
    
    data_n = datasets[s_n].ts_data['dff']['vls'][:]
    
    # Label the subperiods for this subject
    subperiods = label_subperiods(dataset.ts_data['stim']['vls'][:])
    
    # Down select to only the subperiods we want to fit on for this subject
    subperiods = {k:v for k, v in subperiods.items() if k in ps['fit_specs'][s_n]}
    
    # Down select to the shock condition we want to fit
    subperiods = {k: [sp_i for sp_i in v if sp_i['shock'] == ps['shock']] for k, v in subperiods.items()} 
    
    # Pull out the fitting data for this subject
    fit_data[s_n] = {k:np.concatenate([data_n[sl['slice'], :] for sl in v], axis=0) for k, v in subperiods.items()}
    
    # Generate numerical labels for each data point
    fit_labels[s_n] = {k:label_map[k]*np.ones(np.sum([sl_i['slice'].stop - sl_i['slice'].start for sl_i in v])) 
                   for k, v in subperiods.items()}
    

In [11]:
fit_data_conc = {k: np.concatenate([data for data in v.values()], axis=0) for k, v in fit_data.items()}
fit_labels_conc = {k: np.concatenate([lbls for lbls in v.values()], axis=0) for k, v in fit_labels.items()}

## Prepare things we will need for eventual plotting

In [12]:
n_subperiods = len(label_map)
subperiod_clrs = {k: cm.Set1(k_i/(n_subperiods-1)) for k_i, k in enumerate(label_map.keys())}
subject_markers = {8: 'o', 9: 'v', 10: 's', 11:'X'}

## Setup everything for fitting FA models with shared posteriors

In [13]:
sp_priors = generate_hypercube_prior_collection(n_latent_vars=ps['n_latent_vars'], 
                                                hc_params={'n_divisions_per_dim': [100, 100, 25], 
                                                           'dim_ranges': np.asarray([[0, 990.0], 
                                                                                      [0, 610.0], 
                                                                                      [0, 350.0]]),
                                                           'n_div_per_hc_side_per_dim': [1, 1, 1]}, 
                                                min_gaussian_std=.0000001,
                                                lm_mn_init=0.0,
                                                lm_std_init=1.0, 
                                                mn_mn_init=0.0, 
                                                mn_std_init=1.0, 
                                                psi_conc_vl_init=.2,
                                                psi_rate_vl_init=.1,
                                                min_gamma_conc_vl=.001,
                                                min_gamma_rate_vl=.000001)
                                                 
    
sp_posteriors = generate_basic_posteriors(n_obs_vars=[data_n.shape[1] for data_n in fit_data_conc.values()],
                                           n_smps=[data_n.shape[0] for data_n in fit_data_conc.values()],
                                           n_latent_vars=ps['n_latent_vars'])
sp_posteriors = {s_n: sp_posteriors[s_i] for s_i, s_n in enumerate(subjects)}

sp_fit_mdls = {k:FAMdl(lm=None, mn=None, psi=None) for k, data_n in fit_data_conc.items()}
                    
                    
                   
sp_vi_collections = [VICollection(data=torch.tensor(fit_data_conc[s_n]), 
                                  props=neuron_locs[s_n],
                                  mdl = sp_fit_mdls[s_n],
                                  posteriors = sp_posteriors[s_n])
                     for s_i, s_n in enumerate(fit_data_conc.keys())]


## Tie the posteriors to the priors

In [14]:
for vi_coll in sp_vi_collections:
    vi_coll.posteriors.lm_post = sp_priors.lm_prior
    vi_coll.posteriors.mn_post = sp_priors.mn_prior
    #vi_coll.posteriors.psi_post = sp_priors.psi_prior

## Fit the sp models

In [15]:
sp_fitter = Fitter(vi_collections=sp_vi_collections, priors=sp_priors)

In [16]:
sp_fitter.distribute(distribute_data=True, devices=devices)
sp_log = sp_fitter.fit(1000, milestones=[300, 500, 700, 1200], update_int=100, init_lr=.1, 
                       skip_lm_kl=False, skip_mn_kl=False, skip_psi_kl=False)


Obj: 4.82e+08
----------------------------------------
NELL: 2.79e+08, 8.64e+07, 1.16e+08
Latent KL: 3.22e+02, 9.54e+01, 9.67e+01
LM KL: 0.00e+00, 0.00e+00, 0.00e+00
Mn KL: 0.00e+00, 0.00e+00, 0.00e+00
Psi KL: 9.62e+04, 1.11e+05, 1.41e+05
----------------------------------------
LR: 0.1
Elapsed time (secs): 1.3523216247558594
----------------------------------------
CPU cur memory used (GB): 2.67e+01
GPU_0 cur memory used (GB): 3.86e-01, max memory used (GB): 3.86e-01

Obj: 7.66e+07
----------------------------------------
NELL: 4.63e+07, 1.38e+07, 1.63e+07
Latent KL: 2.01e+04, 5.52e+03, 4.91e+03
LM KL: 0.00e+00, 0.00e+00, 0.00e+00
Mn KL: 0.00e+00, 0.00e+00, 0.00e+00
Psi KL: 3.63e+04, 3.93e+04, 4.28e+04
----------------------------------------
LR: 0.1
Elapsed time (secs): 40.873656272888184
----------------------------------------
CPU cur memory used (GB): 2.67e+01
GPU_0 cur memory used (GB): 3.86e-01, max memory used (GB): 3.86e-01

Obj: 8.62e+07
-------------------------------------

In [17]:
sp_fitter.distribute(devices=[torch.device('cpu')])

## Look at the log for fitting the sp model

In [18]:
sp_fitter.plot_log(sp_log)

[True, True, True, True, True, True]


<IPython.core.display.Javascript object>

## View latents estimated with the sp models

In [19]:
props_conc = torch.cat([locs for locs in neuron_locs.values()], dim=0)
sp_prior_lm_conc = sp_priors.lm_prior(props_conc).detach().numpy()

  and should_run_async(code)


In [20]:
sp_latents = dict()
for s_n, posteriors_n in sp_posteriors.items():

    latents_n = posteriors_n.latent_post.mns.detach().numpy()
    
    #mdl = decomposition.FactorAnalysis()
    #mdl.mean_ = posteriors_n.mn_post(neuron_locs[s_n]).detach().numpy().squeeze()
    #mdl.noise_variance_ = posteriors_n.psi_post.mode(neuron_locs[s_n]).detach().numpy().squeeze()
    #mdl.components_ = posteriors_n.lm_post(neuron_locs[s_n]).detach().numpy().transpose()    
    #latents_n = mdl.transform(fit_data_conc[s_n])
    
    _, latents_o = orthonormalize(sp_prior_lm_conc, latents_n, unit_len_columns=False )
        
    sp_latents[s_n] = {'latents': latents_n, 'latents_o': latents_o}

In [21]:
start_dim = 0

plt.figure()
ax = plt.axes(projection='3d')

for s_i, s_n in enumerate(sp_latents.keys()):
    #ax = plt.subplot(2,2, s_i+1, projection='3d')
    plot_latents = sp_latents[s_n]['latents_o']
    plot_labels = fit_labels_conc[s_n]
    
    for sp, sp_lbl in label_map.items():
        sp_clr = subperiod_clrs[sp]
        sp_inds = plot_labels == sp_lbl
        ax.scatter(plot_latents[sp_inds,start_dim+0], plot_latents[sp_inds,start_dim+1], 
                   plot_latents[sp_inds,start_dim+2], color=sp_clr, marker=subject_markers[s_n])
    plt.title(str(s_n))

<IPython.core.display.Javascript object>

## Visualize loading matrices across space for the sp models

In [57]:
vis_comp = 2

In [58]:
_, _, o_vis = svd(sp_prior_lm_conc, full_matrices=False)

mode_imgs = dict()
for s_n in subjects:
    neuron_locs_n = neuron_locs[s_n]
    
    lm_n = sp_posteriors[s_n].lm_post(neuron_locs_n).detach().numpy()
    lm_n = np.matmul(lm_n, o_vis.transpose())
    
    mode_imgs[s_n], _ = max_project_pts(dot_positions=neuron_locs_n.numpy()[:,[0,1]].astype('float'),
                                        dot_vls=lm_n[:,vis_comp], 
                                        box_position=np.asarray([[0, 0], [990, 610]]), 
                                        n_divisions=np.asarray([990, 610]), 
                                        dot_dim_width=np.asarray([5,5]))
    
plt.figure()
cmp_n_mats([mode_imgs[k] for k in mode_imgs.keys()])   

<IPython.core.display.Javascript object>

[<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>]

## Setup everything for fitting models with individual posteriors

In [36]:
ip_priors = copy.deepcopy(sp_priors)

ip_posteriors = generate_basic_posteriors(n_obs_vars=[data_n.shape[1] for data_n in fit_data_conc.values()],
                                           n_smps=[data_n.shape[0] for data_n in fit_data_conc.values()],
                                           n_latent_vars=ps['n_latent_vars'], 
                                           mn_opts={'std_lb': .00000001, 'std_ub': 10.0}, 
                                           lm_opts={'std_lb': .00000001, 'std_ub': 10.0},
                                           psi_opts={'alpha_lb': .1, 'beta_lb': .000001})
ip_posteriors = {s_n: ip_posteriors[s_i] for s_i, s_n in enumerate(subjects)}

In [37]:
for s_n, posteriors in ip_posteriors.items():
    
    # Initialize the poseteriors for the mean vectors
    with torch.no_grad():
        mn_prior_mn = sp_priors.mn_prior(neuron_locs[s_n]).squeeze()
        mn_prior_std = sp_priors.mn_prior.std_f(neuron_locs[s_n]).squeeze()
    
        posteriors.mn_post.dists[0].mn_f.f.vl.data = copy.deepcopy(mn_prior_mn)
        posteriors.mn_post.dists[0].std_f.f.set_value(copy.deepcopy(mn_prior_std.numpy()))
        
    # Initialize the posteriors for the loading matrices
    with torch.no_grad():
        for d_i in range(ps['n_latent_vars']):
            lm_prior_mn = sp_priors.lm_prior.dists[d_i](neuron_locs[s_n]).squeeze()
            lm_prior_std = sp_priors.lm_prior.dists[d_i].std_f(neuron_locs[s_n]).squeeze()
            
            posteriors.lm_post.dists[d_i].mn_f.f.vl.data = copy.deepcopy(lm_prior_mn)
            posteriors.lm_post.dists[d_i].std_f.f.set_value(copy.deepcopy(lm_prior_std.numpy()))
        
    # Initialize the posteriors for the private variances
    posteriors.psi_post = copy.deepcopy(sp_posteriors[s_n].psi_post)
    
    # Initialize the posteriors for the latents
    with torch.no_grad():
        posteriors.latent_post = copy.deepcopy(sp_posteriors[s_n].latent_post)

  init_v = np.arctanh(y)


In [38]:
ip_fit_mdls = {k:FAMdl(lm=None, mn=None, psi=None) for k, data_n in fit_data_conc.items()}
                    
                    
                   
ip_vi_collections = [VICollection(data=torch.tensor(fit_data_conc[s_n]), 
                                  props=neuron_locs[s_n],
                                  mdl = ip_fit_mdls[s_n],
                                  posteriors = ip_posteriors[s_n])
                     for s_i, s_n in enumerate(fit_data_conc.keys())]


## Fit the ip models

In [39]:
ip_fitter = Fitter(vi_collections=ip_vi_collections, priors=ip_priors)

In [40]:
ip_fitter.distribute(distribute_data=True, devices=devices)
ip_log = ip_fitter.fit(1000, milestones=[300, 500, 700], update_int=100, init_lr=.01, 
                       skip_lm_kl=False, skip_mn_kl=False, skip_psi_kl=False)


Obj: 8.78e+07
----------------------------------------
NELL: 4.41e+07, 1.28e+07, 1.45e+07
Latent KL: 2.14e+04, 6.91e+03, 7.02e+03
LM KL: 8.01e+03, 1.00e+04, 1.45e+04
Mn KL: 4.80e+05, 4.94e+05, 1.51e+07
Psi KL: 4.13e+04, 4.09e+04, 4.77e+04
----------------------------------------
LR: 0.01
Elapsed time (secs): 0.4161238670349121
----------------------------------------
CPU cur memory used (GB): 2.74e+01
GPU_0 cur memory used (GB): 4.89e-01, max memory used (GB): 4.89e-01

Obj: 6.89e+07
----------------------------------------
NELL: 4.22e+07, 1.21e+07, 1.34e+07
Latent KL: 2.39e+04, 7.59e+03, 8.07e+03
LM KL: 4.13e+05, 1.93e+05, 2.61e+05
Mn KL: 8.09e+04, 7.09e+04, 9.39e+04
Psi KL: 3.45e+04, 3.59e+04, 3.99e+04
----------------------------------------
LR: 0.01
Elapsed time (secs): 36.85405611991882
----------------------------------------
CPU cur memory used (GB): 2.74e+01
GPU_0 cur memory used (GB): 4.89e-01, max memory used (GB): 4.89e-01

Obj: 6.78e+07
------------------------------------

In [41]:
ip_fitter.distribute(devices=[torch.device('cpu')])

## View fitting log for ip models

In [48]:
ip_fitter.plot_log(ip_log)

[True, True, True, True, True, True]


<IPython.core.display.Javascript object>

## View latents estimated with the ip models

In [49]:
ip_prior_lm_conc = ip_priors.lm_prior(props_conc).detach().numpy()

  and should_run_async(code)


In [50]:
ip_latents = dict()
for s_n, posteriors_n in ip_posteriors.items():

    latents_n = posteriors_n.latent_post.mns.detach().numpy()
    
        
    mdl = decomposition.FactorAnalysis()
    mdl.mean_ = posteriors_n.mn_post(neuron_locs[s_n]).detach().numpy().squeeze()
    mdl.noise_variance_ = posteriors_n.psi_post.mode(neuron_locs[s_n]).detach().numpy().squeeze()
    mdl.components_ = posteriors_n.lm_post(neuron_locs[s_n]).detach().numpy().transpose()    
    latents_n = mdl.transform(fit_data_conc[s_n])
    
    _, latents_o = orthonormalize(ip_prior_lm_conc, latents_n, unit_len_columns=False )
    
    ip_latents[s_n] = {'latents': latents_n, 'latents_o': latents_o}

In [51]:
start_dim = 0

plt.figure()
ax = plt.axes(projection='3d')

for s_i, s_n in enumerate(sp_latents.keys()):
    #ax = plt.subplot(2,2, s_i+1, projection='3d')
    plot_latents = ip_latents[s_n]['latents_o']
    plot_labels = fit_labels_conc[s_n]
    
    for sp, sp_lbl in label_map.items():
        sp_clr = subperiod_clrs[sp]
        sp_inds = plot_labels == sp_lbl
        ax.scatter(plot_latents[sp_inds,start_dim+0], plot_latents[sp_inds,start_dim+1], 
                   plot_latents[sp_inds,start_dim+2], color=sp_clr, marker=subject_markers[s_n])
    plt.title(str(s_n))

  and should_run_async(code)


<IPython.core.display.Javascript object>

## Visualize loading matrices across space for the ip models

In [55]:
vis_comp = 2

In [66]:
_, _, o_vis = svd(ip_prior_lm_conc, full_matrices=False)

mode_imgs = dict()
for s_n in subjects:
    neuron_locs_n = neuron_locs[s_n]
    
    #lm_n = ip_posteriors[s_n].lm_post.(neuron_locs_n).detach().numpy()
    lm_n = ip_priors.lm_prior(neuron_locs_n).detach().numpy()
    #lm_n = ip_priors.lm_prior(neuron_locs_n).detach().numpy()
    lm_n = np.matmul(lm_n, o_vis.transpose())
    
    mode_imgs[s_n], _ = max_project_pts(dot_positions=neuron_locs_n.numpy()[:,[0,1]].astype('float'),
                                        dot_vls=lm_n[:,vis_comp], 
                                        box_position=np.asarray([[0, 0], [990, 610]]), 
                                        n_divisions=np.asarray([990, 610]), 
                                        dot_dim_width=np.asarray([5,5]))
    
plt.figure()
cmp_n_mats([mode_imgs[k] for k in mode_imgs.keys()])  

<IPython.core.display.Javascript object>

[<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>]