In [None]:
from numpyro import optim
import numpyro

import jax.numpy as jnp
import jax
import yaml
import os
import sys


sys.path.insert(1,'..')
import loader

%load_ext autoreload
%autoreload 2

In [None]:

dataset_params = {
    'file': '/mnt/home/anejatbakhsh/ceph/ibl',
    'tag': '2022_Q2_IBL_et_al_RepeatedSite',
    'probe': 'probe00',
    'sessions': [0,1,2],
    'areas': ['CA1','DG','LP','PO','VISa'],
    'props':{'train':.6,'test':.2,'validation':.2},
    'seeds':{'train':0,'test':1,'validation':2},
    'n_neurons': None, # all neurons
    'n_trials': None, # all trials
    'pre_time':0,
    'post_time':.2,
    'align_to': 'responses',
    'train_trial_prop':.9, 
    'train_condition_prop':1, 
    'seed':0
}

dataloader = loader.IBLDataLoader(
    dataset_params
)

xs,ys,rs,cs = dataloader.load_train_data()

In [None]:
[print('Trials:{}, Conditions:{}, Neurons:{}'.format(y.shape[0],y.shape[1],y.shape[2])) for y in ys]

In [None]:
seed = 0
save=False
file='../results/wishart/'

model_params = {
    'prior': 'WishartLRDProcess',
    'seed': 0,
    'nu': 0,

    'gp_kernel_diag': 0.001,
    'gp_kernel': [{
        'type': 'RBF',
        'scale': 1,
        'sigma': 20.,
        'normalizer': 100
        }
    ],
    
    'wp_sample_diag': 1.,
  
    'optimize_L': False,
    'wp_kernel_diag': 0.001,
    'wp_kernel': [{
        'type': 'RBF',
        'scale': 1,
        'sigma': 50.,
        'normalizer': 100
        }
    ],
    'likelihood': 'NormalConditionalLikelihood' 
}

variational_params = {
    'guide': 'VariationalNormal',
    'num_particles': 1,
    'n_iter': 50000,
    'optimizer':{
        'type': 'Adam',
        'step_size': 0.001
    }
}

In [None]:
# Needs Wishart installed and added to the path
# https://github.com/neurostatslab/wishart-process
path = '/mnt/home/anejatbakhsh/Desktop/Projects/'
sys.path.insert(1,path+'Wishart-Process/codes/')

import models
import inference
import utils
import evaluation
import visualizations

x,y,r,c = xs[0],ys[0],rs[0],cs[0]

gp_kernel = utils.get_kernel(
    model_params['gp_kernel'],
    model_params['gp_kernel_diag']
)
D = y.shape[2]

print('Trials, Conditions, Neurons: ', y.shape)


gp = models.GaussianProcess(kernel=gp_kernel,num_dims=D)
empirical = jnp.cov((y - y.mean(0)[None]).reshape(y.shape[0]*y.shape[1],y.shape[2]).T)

wp_kernel = utils.get_kernel(
    model_params['wp_kernel'],
    model_params['wp_kernel_diag']
)


V = empirical+model_params['wp_sample_diag']*jnp.eye(D)

wp = models.WishartLRDProcess(
    kernel=wp_kernel,nu=model_params['nu'] ,
    V=V,optimize_L=model_params['optimize_L'],
    diag_scale=wp_sample_diag
)

likelihood = eval('models.'+model_params['likelihood'])(D)
joint = models.JointGaussianWishartProcess(gp,wp,likelihood) 

print(gp.evaluate_kernel(x,x).max())
print(wp.evaluate_kernel(x,x).max())

compared = evaluation.compare(y)
compared['grand-empirical'] = jnp.repeat(empirical[:,:,None],y.shape[1],2)


# %%

mu_empirical = y.mean(0)
sigma_empirical = compared['empirical'].transpose(2,0,1)
visualizations.visualize_pc(
    mu_empirical[:,None],
    .1*sigma_empirical,
    pc=y.reshape(y.shape[0]*y.shape[1],-1),
    dotsize=500,
    linewidth=2,
    fontsize=30
)

# %%
init = {'G':y.mean(0).T[:,None]}

varfam = eval('inference.'+variational_params['guide'])(
    joint.model,init=init
)
optimizer = eval('optim.'+variational_params['optimizer']['type'])(
    variational_params['optimizer']['step_size']
)
key = jax.random.PRNGKey(seed)

varfam.infer(
    optimizer,x,y,
    n_iter=variational_params['n_iter'],key=key,
    num_particles=variational_params['num_particles']
)
joint.update_params(varfam.posterior)



In [None]:
# %%
posterior = models.NormalGaussianWishartPosterior(joint,varfam,x)
with numpyro.handlers.seed(rng_seed=seed):
    mu_hat, sigma_hat, F_hat = posterior.mode(x)
    mu_prime, sigma_prime = posterior.derivative(x)
    log_pf, log_pg = posterior.posterior.log_prob(F=F_hat,G=mu_hat.T[:,None])

mus[index] = mu_hat.copy()
sigmas[index] = sigma_hat.copy()

# %%
visualizations.plot_loss(
    [varfam.losses],xlabel='Iteration',ylabel='ELBO',
    titlestr='Training Loss',colors=['k'],
)
# %%
visualizations.visualize_pc(
    mu_hat[:,None],.1*sigma_hat,
    pc_test=y_test.reshape(y_test.shape[0]*y_test.shape[1],-1),
    dotsize=500,
    linewidth=2,
    fontsize=30
)
# %%
var_bootstrap = jnp.array([jnp.concatenate((y[:i],y[i+1:])).var(0) for i in range(y.shape[0])])

# %%
visualizations.plot_variance_smoothness(
    x[:,None],[
        y.var(0),
        jnp.array([jnp.diag(compared['lw'][:,:,i]) for i in range(len(x))]),
        jnp.array([jnp.diag(compared['grand-empirical'][:,:,i]) for i in range(len(x))]),
        jnp.array([jnp.diag(sigma_hat[i]) for i in range(len(x))]),
    ],
    yerr=var_bootstrap,
    methods=['empirical','lw','grand-empirical','wishart']
)
# %%
compared['wishart'] = sigma_hat.transpose(1,2,0)
lpp = {}
mu_empirical = y.mean(0)

for key in compared.keys():
    lpp[key] = likelihood.log_prob(y_test,mu_empirical,compared[key].transpose(2,0,1)).flatten()


visualizations.plot_box(
    lpp,titlestr='Log Posterior Predictive',
)

[print(key, jnp.median(lpp[key])) for key in lpp.keys()]
print(jnp.median(jnp.exp(lpp['wishart']-lpp['grand-empirical'])))

In [None]:
dist_neural = utils.ssd([
    [[mus[i],sigmas[i]], 
     [mus[j],sigmas[j]]]
     for i in range(len(ys)) 
     for j in range(len(ys))],
    alpha=2.,
    niter=1000
)

dist_rt = utils.ssd([
    [[rts[i].mean(0)[:,None], rts[i].var(0)[:,None,None]],
     [rts[j].mean(0)[:,None], rts[j].var(0)[:,None,None]]]
    for i in range(len(ys)) 
    for j in range(len(ys))],
    alpha=2.,
    niter=1000
)

dist_cc = utils.ssd([
    [[ccs[i].mean(0)[:,None], ccs[i].var(0)[:,None,None]],
     [ccs[j].mean(0)[:,None], ccs[j].var(0)[:,None,None]]]
    for i in range(len(ys)) 
    for j in range(len(ys))],
    alpha=2.,
    niter=1000
)
