# Testing different versions of the two-parameter DSO rule, WIP

inference results analysis cells below.

In [None]:
import numpy as np
import pickle

from consbi.simulators import (
    DistanceRuleSimulator, 
    RuleSimulator, 
    poisson_glm, 
    peters_rule_subcellular, 
    two_param_poisson_glm,
    default_rule_linear, 
    two_param_rule_dependent,
    dso_linear_two_param,
)
from consbi.simulators.poisson_glm import *
from consbi.simulators.rule_simulator import *
from consbi.simulators import default_rule
from torch.distributions import MultivariateNormal

from sbi.inference import prepare_for_sbi, simulate_for_sbi
from sbi.utils import BoxUniform

import torch
import matplotlib.pyplot as plt
plt.style.use('plotting_settings.mplstyle')
%matplotlib inline

In [None]:
# The prior will be Gaussian for most rules.
num_dim = 3
prior_scale = 3
# prior = MultivariateNormal(torch.ones(num_dim), 0.5 * torch.eye(num_dim))
prior = BoxUniform(0.0 * torch.ones(num_dim), 3 * torch.ones(num_dim))
xo = np.array([[0.4300, 0.4300, 0.4200, 0.6400, 0.1700, 0.4400, 0.0900]])

## Poisson GLM

The problem is reduced to just 10 neuron-pair-subvolume combinations withtout summary statistics, i.e., the simulator returns the raw synapse counts sampled from a Poisson distribution. Thus, it is fast and tractable. 

In [None]:
theta = prior.sample((10000,))
x = poisson_glm(theta)

In [None]:
plt.figure(figsize=(18, 5))
plt.hist(x.numpy(), density=0)
plt.xlabel("Counts")
# plt.legend([f"neuron-pair-volume combi {ii}" for ii in range(10)]);

In [None]:
path_to_model = "../data/structural_model"
# set number of neuron pairs sampled from the connectome to mimick experimental settings, e.g., 50
num_subsampling_pairs = 50
simulator = RuleSimulator(path_to_model, 
                          rule=default_rule_linear, 
                          verbose=True, 
                          num_subsampling_pairs=num_subsampling_pairs,
                          prelocate_postall_offset=True,
                         )

In [None]:
num_simulations = 1000
num_dim = 3
scaling = 0.5
# prior = BoxUniform(0.7 * torch.ones(num_dim), 1.6 * torch.ones(num_dim))
prior = MultivariateNormal(torch.ones(num_dim), scaling * torch.eye(num_dim))

simulator.rule = default_rule

def batch_loop_simulator(theta):
        """Return a batch of simulations by looping over a batch of parameters."""
        assert theta.ndim > 1, "Theta must have a batch dimension."
        # Simulate in loop
        xs = list(map(simulator, theta))
        # Stack over batch to keep x_shape
        return torch.stack(xs).reshape(theta.shape[0], -1)

# wrap the simulator to handle batches of parameters
batch_simulator, prior = prepare_for_sbi(batch_loop_simulator, prior)

In [None]:
theta, x = simulate_for_sbi(batch_simulator, prior, num_simulations=num_simulations, 
                            num_workers=20, 
                            simulation_batch_size=50)


In [None]:
fig, ax = plt.subplots(1, 7, figsize=(12, 3), gridspec_kw=dict(wspace=0.15))
x = x.squeeze()
colors = ['#377eb8', '#ff7f00', '#4daf4a',
                  '#f781bf', '#a65628', '#984ea3',
                  '#999999', '#e41a1c', '#dede00']
num_bins = 15
data_labels = [r"L4", r"L4SEP", r"L4SP", r"L4SS", r"L5IT", r"L5PT", r"L6"]

for ii in range(7):
    plt.sca(ax[ii])
    plt.hist(x[:, ii].numpy(), color=colors[6], bins=np.linspace(0, 1, num_bins), alpha=0.9)
    plt.axvline(xo[0, ii], color="k")
    plt.xticks(np.linspace(0,1,2))
    plt.xlim(0, 1)
    ax[ii].spines["right"].set_visible(False)
    ax[ii].spines["top"].set_visible(False)
    ax[ii].spines["left"].set_visible(False)
    plt.yticks([])
    plt.xlabel(data_labels[ii])
plt.legend(["measured", "simulated \nfrom prior"], bbox_to_anchor=(.11, .69), handlelength=1.0)

In [None]:
marginal_plot(prior.sample((10000,)), figsize=(18, 5));

In [None]:
# save prior predictive samples
with open(f"../results/prior_predictive_samples_dso_two_param_N{num_simulations}.p", "wb") as fh: 
    pickle.dump(dict(theta=theta, x=x), fh)

In [None]:
from sbi.analysis import pairplot, marginal_plot
pairplot(x.numpy()[:10000,:], points=xo[0], upper="scatter", figsize=(18, 12), 
         points_offdiag=dict(marker="+"), 
         points_colors=["k"]
         
        );

In [None]:
plt.savefig("prior_predictive.png", dpi=300, bbox_inches='tight')

## Inference

In [None]:
filename = "../results/npe_dso_constrained_2p_uniform0-3_n500000r2x100k.p"
with open(filename, "rb") as fh: 
#     prior, de, *_ = pickle.load(fh).values()
    prior, de, posteriors, thos, xos, kwargs, seed = pickle.load(fh).values()

In [None]:
prior = BoxUniform(0.2 * torch.ones(2), 1.6 * torch.ones(2))

In [None]:
# posteriors = _[0]
sampless = [p.sample((10000,), show_progress_bars=False) for p in posteriors]

In [None]:
from sbi.inference import SNPE
from sbi.analysis import pairplot, sbc
from sbi.simulators.simutils import simulate_in_batches
posterior = SNPE(prior).build_posterior(de)
samples = posterior.sample((10000,), x=xo)

In [None]:
fig, ax = pairplot([
    prior.sample((10000,)),
#     sampless[-1][:10000],
#     samples
    ] + sampless
    ,
                   diag="hist", 
                   upper="contour", 
                   contour_offdiag=dict(levels=[0.99]),
                   hist_diag=dict(bins=20, density=True), 
                   samples_colors=[
                       "gray", 
                       "C0", "C1", "C2", "C3", "C4"], 
                   kde_offdiag=dict(bw_method=0.3, num_bins=100),
                   scatter_offdiag=dict(s=1, alpha=0.5),
                   points=posteriors[-1].map(),
#                    limits=[[low, high] for low, high in zip(prior.base_dist.low, prior.base_dist.high)],
                   limits=[[-1, 4]*samples.shape[1]],
                   figsize=(8, 8));
# fig, ax = pairplot([    
# #     prior.sample((10000,)),
# #     sampless[0][:10000],
#     samples
#     ],
#                    diag="kde", 
#                    upper="kde", 
#                    contour_offdiag=dict(levels=[0.99]),
#                    hist_diag=dict(bins=25, density=True), 
#                    samples_colors=[
# #                        "gray", 
#                                    "C0", "C1", "C2", "C3", "C4"], 
#                    kde_offdiag=dict(bw_method=0.3, num_bins=100),
#     scatter_offdiag=dict(s=1, alpha=0.5),
#                    fig=fig, axes=ax, 
#                    limits=[[0.2, 1.6]*samples.shape[1]],
#                   );

In [None]:
posteriors[-1].map()

In [None]:
np.corrcoef(sampless[-1], rowvar=False)

In [None]:
from sbi.analysis import run_sbc, sbc_rank_plot, marginal_plot

In [None]:
with open("../data/presimulated_dso_constrained_2p_uniform0-3_n500000.p", "rb") as fh:
    _, thetas, xs = pickle.load(fh).values()
    xs = xs.squeeze()

In [None]:
# posterior predictive samples
simulator.rule = default_rule_constrained_two_param
# simulator.rule = default_rule

def batch_loop_simulator(theta):
        """Return a batch of simulations by looping over a batch of parameters."""
        assert theta.ndim > 1, "Theta must have a batch dimension."
        # Simulate in loop
        xs = list(map(simulator, theta))
        # Stack over batch to keep x_shape
        return torch.stack(xs)


xos = simulate_in_batches(batch_simulator, 
                          sampless[-1][:1000], 
#                           samples[:1000],
                          sim_batch_size=50, num_workers=20)

In [None]:
marginal_plot([
        xs[:1000],
        xos.squeeze(), 
              ], 
    hist_diag=dict(bins=10, histtype="step",),
    samples_colors=["gray", "C0"],
    points_colors=["k"],
    labels=data_labels,
    density=True,
    limits=[[0, 1]*7],
    points=xo.squeeze(), figsize=(18, 4));

## Save with posterior and predictive samples

In [None]:
_[1]

In [None]:
with open(filename, "wb") as fh:
    pickle.dump(dict(prior=prior, 
                     de = de, 
                     posteriors=posteriors, 
                     sampless = sampless,
#                      samples = samples,
                     xos = xos,
#                      kwargs=_[0],
                     kwargs=_[1],
                    ), fh)

In [None]:
num_sbc_samples = 1000
ranks, daps = run_sbc(thetas[:num_sbc_samples], xs[:num_sbc_samples,], posterior=posterior,
                      num_workers=1,
                      sbc_batch_size=1,
                      reduce_fns=posterior.log_prob
                     )

In [None]:
plt.hist(ranks.numpy(), bins=20);

In [None]:
# fig, ax = plt.subplots(1, 1, figsize=(12, 5))
sbc_rank_plot(
    ranks=ranks,
    num_posterior_samples=1000,
    plot_type="hist",
    num_bins=20,
    kwargs=dict(
#         fig=fig, 
#         axes=ax,
        figsize=(10, 5),
    )

)