Here we fit models across fish, looking at the latent representations of different conditions for all fish in the same 
low-d space.  We initialize with shared posteriors, use these for initialization when fitting with the individual posteriors and then use early stopping to prevent overfitting when fitting with the individual posteriors.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import copy
from itertools import chain
from pathlib import Path

from matplotlib import cm
import matplotlib.pyplot as plt
import numpy as np
from numpy.linalg import svd
import torch

import sklearn.decomposition as decomposition

from ahrens_wbo.annotations import label_subperiods
from ahrens_wbo.data_processing import load_and_preprocess_data
from janelia_core.ml.utils import list_torch_devices
from janelia_core.visualization.image_generation import max_project_pts
from janelia_core.visualization.image_generation import generate_dot_image_3d
from janelia_core.visualization.image_generation import scalar_3d_max_project
from janelia_core.visualization.matrix_visualization import cmp_n_mats
from probabilistic_model_synthesis.fa import FAMdl
from probabilistic_model_synthesis.fa import Fitter
from probabilistic_model_synthesis.fa import evaluate_check_points
from probabilistic_model_synthesis.fa import generate_basic_posteriors
from probabilistic_model_synthesis.fa import generate_hypercube_prior_collection
from probabilistic_model_synthesis.fa import generate_simple_prior_collection
from probabilistic_model_synthesis.fa import infer_latents
from probabilistic_model_synthesis.fa import initialize_basic_posteriors
from probabilistic_model_synthesis.fa import orthonormalize
from probabilistic_model_synthesis.fa import PosteriorCollection
from probabilistic_model_synthesis.fa import PriorCollection
from probabilistic_model_synthesis.fa import VICollection

In [3]:
%matplotlib notebook

## Parameters go here

In [4]:
ps = dict()

# Specify the directory where all the raw data is found
ps['data_dir'] = r'/groups/bishop/bishoplab/projects/ahrens_wbo/data'

# Here we list the subject we want to fit and the conditions we want to fit on for each subject
ps['fit_specs'] = {8: ['omr_forward'], 
                   9: ['omr_left'], 
                   11: ['omr_right']} 
                   
# Here we specify if we fit in conditions with shock or not
ps['shock'] = False

# Specify the number of latent variables in the FA models
ps['n_latent_vars'] = 10

# Number of inidividual slices of each type of subperiod we use for validation
ps['n_validation_slices'] = 2

# True if we should use early stopping
ps['use_early_stopping'] = True

# Folder where checkpoints will be saved
ps['cp_save_folder'] = r'/groups/bishop/bishoplab/projects/probabilistic_model_synthesis/results/testing'

## See which devices are available for fitting

In [5]:
devices, _ = list_torch_devices()

Found 2 GPUs


## Load the data for each subject

In [6]:
subjects = list(ps['fit_specs'].keys())

In [7]:
datasets, neuron_locs = load_and_preprocess_data(data_folder=ps['data_dir'], 
                                                 subjects=subjects, 
                                                 neural_gain=10)

Done loading data for subject subject_8.
Done loading data for subject subject_9.
Done loading data for subject subject_11.


## Form the fitting and validation data for each subject

In [8]:
all_subperiods = set(chain(*[v for v in ps['fit_specs'].values()]))
label_map = {sp: sp_i for sp_i, sp in enumerate(all_subperiods)}

In [9]:
fit_data = dict()
validation_data = dict()
fit_labels = dict()
validation_labels = dict()

for s_n, dataset in datasets.items():
    
    data_n = datasets[s_n].ts_data['dff']['vls'][:]
    
    # Label the subperiods for this subject
    subperiods = label_subperiods(dataset.ts_data['stim']['vls'][:])
    
    # Down select to only the subperiods we want to fit on for this subject
    subperiods = {k:v for k, v in subperiods.items() if k in ps['fit_specs'][s_n]}
    
    # Down select to the shock condition we want to fit
    subperiods = {k: [sp_i for sp_i in v if sp_i['shock'] == ps['shock']] for k, v in subperiods.items()} 
    
    # Randomly select subperiods for training and validation
    fit_subperiods = dict()
    validation_subperiods = dict()
    for sp_key, sp_slices in subperiods.items():
        
        n_slices = len(sp_slices)
        validation_inds = np.random.choice(n_slices, ps['n_validation_slices'], replace=False)
        
        validation_slices = [sp_slices[s_i] for s_i in range(n_slices) if s_i in validation_inds]
        fit_slices = [sp_slices[s_i] for s_i in range(n_slices) if s_i not in validation_inds]
        
        validation_subperiods[sp_key] = validation_slices
        fit_subperiods[sp_key] = fit_slices

    # Pull out the fitting data for this subject
    fit_data[s_n] = {k:np.concatenate([data_n[sl['slice'], :] for sl in v], axis=0) 
                     for k, v in fit_subperiods.items()}
    
    validation_data[s_n] = {k:np.concatenate([data_n[sl['slice'], :] for sl in v], axis=0) 
                            for k, v in validation_subperiods.items()}
    
    # Generate numerical labels for each data point
    fit_labels[s_n] = {k:label_map[k]*np.ones(np.sum([sl_i['slice'].stop - sl_i['slice'].start for sl_i in v])) 
                       for k, v in fit_subperiods.items()}
    validation_labels[s_n] = {k:label_map[k]*np.ones(np.sum([sl_i['slice'].stop - sl_i['slice'].start for sl_i in v])) 
                              for k, v in validation_subperiods.items()}
    

In [10]:
fit_data_conc = {k: np.concatenate([data for data in v.values()], axis=0) 
                 for k, v in fit_data.items()}
fit_labels_conc = {k: np.concatenate([lbls for lbls in v.values()], axis=0) 
                   for k, v in fit_labels.items()}

validation_data_conc = {k: np.concatenate([data for data in v.values()], axis=0) 
                        for k, v in validation_data.items()}
validation_labels_conc = {k: np.concatenate([lbls for lbls in v.values()], axis=0) 
                          for k, v in validation_labels.items()}

## Prepare things we will need for eventual plotting

In [11]:
n_subperiods = len(label_map)
subperiod_clrs = {k: cm.Set1(k_i/(n_subperiods-1)) for k_i, k in enumerate(label_map.keys())}
subject_markers = {8: 'o', 9: 'v', 10: 's', 11:'X'}

## Setup everything for fitting FA models with shared posteriors

In [12]:
sp_priors = generate_hypercube_prior_collection(n_latent_vars=ps['n_latent_vars'], 
                                                hc_params={'n_divisions_per_dim': [100, 60, 35], 
                                                           'dim_ranges': np.asarray([[0, 990.0], 
                                                                                      [0, 610.0], 
                                                                                      [0, 350.0]]),
                                                           'n_div_per_hc_side_per_dim': [1, 1, 1]}, 
                                                min_gaussian_std=.0000001,
                                                lm_mn_init=0.0,
                                                lm_std_init=1.0, 
                                                mn_mn_init=0.0, 
                                                mn_std_init=1.0, 
                                                psi_conc_vl_init=.2,
                                                psi_rate_vl_init=.1,
                                                min_gamma_conc_vl=.001,
                                                min_gamma_rate_vl=.000001)
                                                 
    
sp_posteriors = generate_basic_posteriors(n_obs_vars=[data_n.shape[1] for data_n in fit_data_conc.values()],
                                           n_smps=[data_n.shape[0] for data_n in fit_data_conc.values()],
                                           n_latent_vars=ps['n_latent_vars'])
sp_posteriors = {s_n: sp_posteriors[s_i] for s_i, s_n in enumerate(subjects)}

sp_fit_mdls = {k:FAMdl(lm=None, mn=None, psi=None) for k, data_n in fit_data_conc.items()}
                    
                    
                   
sp_vi_collections = [VICollection(data=torch.tensor(fit_data_conc[s_n]), 
                                  props=neuron_locs[s_n],
                                  mdl = sp_fit_mdls[s_n],
                                  posteriors = sp_posteriors[s_n])
                     for s_i, s_n in enumerate(fit_data_conc.keys())]


## Tie the posteriors to the priors

In [13]:
for vi_coll in sp_vi_collections:
    vi_coll.posteriors.lm_post = sp_priors.lm_prior
    vi_coll.posteriors.mn_post = sp_priors.mn_prior

## Fit the sp models

In [14]:
sp_fitter = Fitter(vi_collections=sp_vi_collections, priors=sp_priors)

In [15]:
sp_fitter.distribute(distribute_data=True, devices=devices)
sp_log = sp_fitter.fit(100, milestones=[300, 500, 700, 1200], update_int=100, init_lr=.1, 
                       skip_lm_kl=False, skip_mn_kl=False, skip_psi_kl=False)
sp_fitter.distribute(devices=[torch.device('cpu')])


Obj: 4.06e+08
----------------------------------------
NELL: 2.52e+08, 6.85e+07, 8.49e+07
Latent KL: 2.83e+02, 6.94e+01, 7.43e+01
LM KL: 0.00e+00, 0.00e+00, 0.00e+00
Mn KL: 0.00e+00, 0.00e+00, 0.00e+00
Psi KL: 9.62e+04, 1.11e+05, 1.41e+05
----------------------------------------
LR: 0.1
Elapsed time (secs): 2.569316864013672
----------------------------------------
CPU cur memory used (GB): 2.74e+01
GPU_0 cur memory used (GB): 2.52e-01, max memory used (GB): 2.52e-01
GPU_1 cur memory used (GB): 3.89e-02, max memory used (GB): 3.89e-02


## Infer latents on validation data with the sp models

In [16]:
sp_validation_latent_posts ={s_n: infer_latents(vi_collection=sp_vi_collections[s_i],
                                                data=torch.tensor(validation_data_conc[s_n]),
                                                fit_opts={'n_epochs': 1000, 'init_lr':.1, 'update_int':None, 
                                                          'milestones': [500]}, 
                                                device=devices[0])[0]
                         for s_i, s_n in enumerate(subjects)}

## Look at the log for fitting the sp model

In [17]:
sp_fitter.plot_log(sp_log)

[True, True, True, True, True, True]


<IPython.core.display.Javascript object>

## View latents estimated with the sp models

In [18]:
props_conc = torch.cat([locs for locs in neuron_locs.values()], dim=0)
sp_prior_lm_conc = sp_priors.lm_prior(props_conc).detach().numpy()
view_validaton = False # True if we should view validation latents

In [19]:
sp_latents = dict()
if view_validaton:
    sp_view_latents_post = sp_validation_latent_posts
else:
    sp_view_latents_post = {k: sp_posteriors[k].latent_post for k in sp_posteriors.keys()}
    
for s_n, latent_post_n in sp_view_latents_post.items():

    latents_n = latent_post_n.mns.detach().numpy()
    _, latents_o = orthonormalize(sp_prior_lm_conc, latents_n, unit_len_columns=False )
        
    sp_latents[s_n] = {'latents': latents_n, 'latents_o': latents_o}

In [20]:
start_dim = 0

plt.figure()
ax = plt.axes(projection='3d')

for s_i, s_n in enumerate(sp_latents.keys()):
    plot_latents = sp_latents[s_n]['latents_o']
    if view_validaton:
        plot_labels = validation_labels_conc[s_n]
    else:
        plot_labels = fit_labels_conc[s_n]
    
    
    for sp, sp_lbl in label_map.items():
        sp_clr = subperiod_clrs[sp]
        sp_inds = plot_labels == sp_lbl
        ax.scatter(plot_latents[sp_inds,start_dim+0], plot_latents[sp_inds,start_dim+1], 
                   plot_latents[sp_inds,start_dim+2], color=sp_clr, marker=subject_markers[s_n])
    plt.title(str(s_n))

<IPython.core.display.Javascript object>

## Visualize loading matrices across space for the sp models

In [22]:
vis_comp = 2

In [23]:
_, _, o_vis = svd(sp_prior_lm_conc, full_matrices=False)

mode_imgs = dict()
for s_n in subjects:
    neuron_locs_n = neuron_locs[s_n]
    
    lm_n = sp_posteriors[s_n].lm_post(neuron_locs_n).detach().numpy()
    lm_n = np.matmul(lm_n, o_vis.transpose())
    
    mode_imgs[s_n], _ = max_project_pts(dot_positions=neuron_locs_n.numpy()[:,[0,1]].astype('float'),
                                        dot_vls=lm_n[:,vis_comp], 
                                        box_position=np.asarray([[0, 0], [990, 610]]), 
                                        n_divisions=np.asarray([990, 610]), 
                                        dot_dim_width=np.asarray([5,5]))
    
plt.figure()
cmp_n_mats([mode_imgs[k] for k in mode_imgs.keys()])   

<IPython.core.display.Javascript object>

[<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>]

## Setup everything for fitting models with individual posteriors

In [None]:
ip_priors = copy.deepcopy(sp_priors)

ip_posteriors = generate_basic_posteriors(n_obs_vars=[data_n.shape[1] for data_n in fit_data_conc.values()],
                                           n_smps=[data_n.shape[0] for data_n in fit_data_conc.values()],
                                           n_latent_vars=ps['n_latent_vars'], 
                                           mn_opts={'std_lb': .00000001, 'std_ub': 10.0}, 
                                           lm_opts={'std_lb': .00000001, 'std_ub': 10.0},
                                           psi_opts={'alpha_lb': .1, 'beta_lb': .000001})
ip_posteriors = {s_n: ip_posteriors[s_i] for s_i, s_n in enumerate(subjects)}

In [None]:
for s_n, posteriors in ip_posteriors.items():
    
    # Initialize the poseteriors for the mean vectors
    with torch.no_grad():
        mn_prior_mn = sp_priors.mn_prior(neuron_locs[s_n]).squeeze()
        mn_prior_std = sp_priors.mn_prior.std_f(neuron_locs[s_n]).squeeze()
    
        posteriors.mn_post.dists[0].mn_f.f.vl.data = copy.deepcopy(mn_prior_mn)
        posteriors.mn_post.dists[0].std_f.f.set_value(copy.deepcopy(mn_prior_std.numpy()))
        
    # Initialize the posteriors for the loading matrices
    with torch.no_grad():
        for d_i in range(ps['n_latent_vars']):
            lm_prior_mn = sp_priors.lm_prior.dists[d_i](neuron_locs[s_n]).squeeze()
            lm_prior_std = sp_priors.lm_prior.dists[d_i].std_f(neuron_locs[s_n]).squeeze()
            
            posteriors.lm_post.dists[d_i].mn_f.f.vl.data = copy.deepcopy(lm_prior_mn)
            posteriors.lm_post.dists[d_i].std_f.f.set_value(copy.deepcopy(lm_prior_std.numpy()))
        
    # Initialize the posteriors for the private variances
    posteriors.psi_post = copy.deepcopy(sp_posteriors[s_n].psi_post)
    
    # Initialize the posteriors for the latents
    with torch.no_grad():
        posteriors.latent_post = copy.deepcopy(sp_posteriors[s_n].latent_post)

In [None]:
ip_fit_mdls = {k:FAMdl(lm=None, mn=None, psi=None) for k, data_n in fit_data_conc.items()}
                    
                    
                   
ip_vi_collections = [VICollection(data=torch.tensor(fit_data_conc[s_n]), 
                                  props=neuron_locs[s_n],
                                  mdl = ip_fit_mdls[s_n],
                                  posteriors = ip_posteriors[s_n])
                     for s_i, s_n in enumerate(fit_data_conc.keys())]


## Fit the ip models

In [None]:
ip_fitter = Fitter(vi_collections=ip_vi_collections, priors=ip_priors)

In [None]:
ip_fitter.distribute(distribute_data=True, devices=devices)
ip_log = ip_fitter.fit(200, milestones=[300, 500, 700], update_int=100, init_lr=0.001,
                       skip_lm_kl=False, skip_mn_kl=False, skip_psi_kl=False, 
                       cp_save_folder=ps['cp_save_folder'], cp_epochs=np.arange(0, 1000, 10))

## Evaluate performance across check points on validation performance

In [None]:
cp_epochs, cp_elbo = evaluate_check_points(cp_folder=ps['cp_save_folder'], 
                              data=[torch.tensor(validation_data_conc[k]) for k in validation_data_conc.keys()], 
                              props=[neuron_locs[k] for k in neuron_locs.keys()], 
                              n_smps=100, device=devices[0])

## View validation peformance across checkpoints

In [None]:
plt.figure()
plt.plot(cp_epochs, np.mean(cp_elbo, axis=1))

## Pick the best ip model based on early stopping

In [None]:
if ps['use_early_stopping']:
    best_cp_epoch = cp_epochs[np.argmax(np.mean(cp_elbo, axis=1))].astype('int')
    print('Best epoch is ' + str(best_cp_epoch) + '.')

    best_cp = torch.load(Path(ps['cp_save_folder']) / ('cp_' + str(best_cp_epoch) + '.pt'))    
    ip_eval_vi_collections = [VICollection.from_checkpoint(coll_d) for coll_d in best_cp['vi_collections']]
    for c_i, coll in enumerate(ip_eval_vi_collections):
        coll.props = ip_vi_collections[c_i].props
    ip_eval_priors = PriorCollection.from_checkpoint(best_cp['priors'])
    
else:
    ip_eval_vi_collections = ip_vi_collections
    ip_eval_priors = ip_priors
    
ip_eval_posteriors = {s_n: ip_eval_vi_collections[s_i].posteriors for s_i, s_n in enumerate(subjects)}

## Infer latents with the ip models for the validation data

In [None]:
ip_validation_latent_posts ={s_n: infer_latents(vi_collection=ip_eval_vi_collections[s_i],
                                                data=torch.tensor(validation_data_conc[s_n]),
                                                fit_opts={'n_epochs': 1000, 'init_lr':.1, 'update_int':None, 
                                                          'milestones': [500]}, 
                                                device=devices[0])[0]
                         for s_i, s_n in enumerate(subjects)}

In [None]:
ip_fitter.distribute(devices=[torch.device('cpu')])

## View fitting log for ip models

In [None]:
ip_fitter.plot_log(ip_log)

## View latents estimated with the ip models

In [None]:
ip_prior_lm_conc = ip_eval_priors.lm_prior(props_conc).detach().numpy()

In [None]:
view_validaton = False

In [None]:
ip_latents = dict()
if view_validaton:
    ip_view_latents_post = ip_validation_latent_posts
else:
    ip_view_latents_post = {k: ip_eval_posteriors[k].latent_post for k in ip_eval_posteriors.keys()}
    
for s_n, latent_post_n in ip_view_latents_post.items():

    latents_n = latent_post_n.mns.detach().numpy()
    _, latents_o = orthonormalize(ip_prior_lm_conc, latents_n, unit_len_columns=False )
        
    ip_latents[s_n] = {'latents': latents_n, 'latents_o': latents_o}

In [None]:
start_dim = 0

plt.figure()
ax = plt.axes(projection='3d')

for s_i, s_n in enumerate(sp_latents.keys()):
    plot_latents = ip_latents[s_n]['latents_o']
    if view_validaton:
        plot_labels = validation_labels_conc[s_n]
    else:
        plot_labels = fit_labels_conc[s_n]
    
    for sp, sp_lbl in label_map.items():
        sp_clr = subperiod_clrs[sp]
        sp_inds = plot_labels == sp_lbl
        ax.scatter(plot_latents[sp_inds,start_dim+0], plot_latents[sp_inds,start_dim+1], 
                   plot_latents[sp_inds,start_dim+2], color=sp_clr, marker=subject_markers[s_n])
    plt.title(str(s_n))

## Visualize loading matrices across space for the ip models

In [None]:
vis_comp = 2

In [None]:
_, _, o_vis = svd(ip_prior_lm_conc, full_matrices=False)

mode_imgs = dict()
for s_n in subjects:
    neuron_locs_n = neuron_locs[s_n]
    
    lm_n = ip_eval_posteriors[s_n].lm_post(neuron_locs_n).detach().numpy()
    lm_n = np.matmul(lm_n, o_vis.transpose())
    
    mode_imgs[s_n], _ = max_project_pts(dot_positions=neuron_locs_n.numpy()[:,[0,1]].astype('float'),
                                        dot_vls=lm_n[:,vis_comp], 
                                        box_position=np.asarray([[0, 0], [990, 610]]), 
                                        n_divisions=np.asarray([990, 610]), 
                                        dot_dim_width=np.asarray([5,5]))
    
plt.figure()
cmp_n_mats([mode_imgs[k] for k in mode_imgs.keys()])  