In [2]:
%load_ext autoreload
%autoreload 2

import scipy.special as sps
import scipy.stats as scstats
import numpy as np

import torch

import matplotlib.pyplot as plt
from matplotlib import patches
from matplotlib import transforms
import daft

import sys
sys.path.append("../../neuroprob")
sys.path.append("../scripts/") # access to scripts


import os
if not os.path.exists('./output'):
    os.makedirs('./output')
    

from neuroprob import utils

import model_utils

import pickle


plt.style.use(['paper.mplstyle'])

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
def render(pgm):
    """
    Wrapper for rendering PGM via daft
    """
    for plate in pgm._plates:
        plate.render(pgm._ctx)

    for edge in pgm._edges:
        edge.render(pgm._ctx)

    for name in pgm._nodes:
        pgm._nodes[name].render(pgm._ctx)



def init_figax(pgm, fig, ax):
    """
    Wrapper for initializing PGM via daft
    """
    pgm._ctx._figure = fig
    ax.axis('off')

    # Set the bounds.
    l0 = pgm._ctx.convert(*pgm._ctx.origin)
    l1 = pgm._ctx.convert(*(pgm._ctx.origin + pgm._ctx.shape))
    ax.set_xlim(l0[0], l1[0])
    ax.set_ylim(l0[1], l1[1])
    ax.set_aspect(1)

    pgm._ctx._ax = ax

### Stochastic process over point process conditionals

In [None]:
a = np.random.randn(10, 100)*5.+10.0
dt = 0.001

In [None]:
Tl = 3000
sample_bin = 0.001


l = sample_bin*np.array([30.0])[None, :]
dn = l.shape[1]
v = np.ones(dn)



# generate GP trajectories
kernel_tuples = [('variance', v), 
                 ('RBF', 'euclid', l)]

with torch.no_grad():
    kernel, _, _ = GP.kernels.create_kernel(kernel_tuples, 'softplus', torch.double)

    inp = torch.arange(Tl)[None, None, :, None]*sample_bin
    K = kernel(inp, inp)[0, ...]
    K.view(dn, -1)[:, ::Tl+1] += 1e-6


L = torch.cholesky(K)
mc = 10
xi = np.random.randn(mc, dn, Tl)
x = (L[None, ...] * xi[..., None, :]).sum(-1)



# linear Poisson model
neurons = 100
w_len = dn
GPFA = mdl.parametrics.GLM(w_len, neurons, w_len, 'exp', bias=True)

w = np.random.randn(neurons, w_len)
bias = np.random.randn(neurons)
GPFA.set_params(w, bias)


likelihood = mdl.likelihoods.Poisson(sample_bin, neurons, 'exp')
#likelihood.set_Y(rc_t, batch_size=5000, filter_len=1) 

In [None]:
plt.plot(x.numpy()[1, 0, :])

In [None]:
x = np.exp(x)
p = x.numpy()[:, 0, :]*np.exp(-np.cumsum(x.numpy()[:, 0, :], axis=1)*dt) # natural time

In [None]:
tau = np.arange(Tl)*dt
tau_0 = 0.001
t = tau_0*(np.exp(tau) - 1)
dtau_dt = 1/(t+tau_0)

In [None]:
plt.plot(tau, p.mean(0).T)
plt.plot(tau[:, None].repeat(p.shape[0], axis=1), p.T, alpha=0.3)

In [None]:
plt.plot(t, p.mean(0).T*dtau_dt)
plt.plot(t[:, None].repeat(p.shape[0], axis=1), p.T*dtau_dt[:, None], alpha=0.3)

In [None]:
(p*dt).sum(1)

In [None]:
neurons = 1
hist_couple = None
shape_t = 5.0*np.ones(neurons)
likelihood = mdl.likelihoods.Gamma(sample_bin, neurons, rate_model.inv_link, shape_t)
#sigma_t = 0.5*np.ones(neurons)
#likelihood = mdl.likelihoods.logNormal(sample_bin, neurons, inv_link, sigma_t)

#mu_t = 5.0*np.ones(neurons)
#likelihood = mdl.likelihoods.invGaussian(neurons, inv_link, mu_t)

#hist_len = 99 # 100 steps of spiketrain, no instantaneous element
#hist_couple = mdl.filters.raised_cosine_bumps(a=1., c=1., phi=phi_h, w=w_h, timesteps=hist_len)
#likelihood = mdl.likelihoods.Bernoulli(neurons, inv_link)

#input_group = mdl.inference.input_group(3, [(None, None, None, 1)]*3)
#input_group.set_XZ(covariates, track_samples, batch_size=track_samples, trials=trials)

#glm = mdl.inference.VI_optimized(input_group, gauss_rate, likelihood)
#glm.to(dev)
    

bb_isi = np.linspace(0.001, 5.0, 100) # ISI evaluate

#gt_field.append(compute_rate(model.rate_model[0], [0])[0])
ISI = [torch.tensor(bb_isi[:, None], device='cpu')]#dev)]
scale = likelihood.shape.data
#scale = torch.exp(-model.likelihood.sigma.data**2/2.)
#scale = 1./model.likelihood.mu
likelihood.trials = 1
p = likelihood.nll(torch.log(scale)*torch.ones((len(bb_isi), 1), device='cpu'), [[ISI[0]*scale]], [0])#.data[:, 0].cpu().numpy()
a = np.exp(-p)

b = likelihood.ISI_dist([0]).prob_dens(bb_isi)

plt.plot(bb_isi, a)
plt.plot(bb_isi, b, 'r--')
plt.show()

In [None]:
### monotonic RQ splines ###
def searchsorted(bin_locations, inputs, eps=1e-6):
    bin_locations[..., -1] += eps
    return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1


def RQS(
    inputs,
    unnormalized_widths,
    unnormalized_heights,
    unnormalized_derivatives,
    inverse=False,
    left=0.0,
    right=1.0,
    bottom=0.0,
    top=1.0,
    min_bin_width=1e-3,
    min_bin_height=1e-3,
    min_derivative=1e-3,
):
    """
    last dimension of params is number of bins
    first dimensions are (batch,) or (batch, dims)

    Based on implementation in https://github.com/bayesiains/nsf
    """
    if torch.min(inputs) < left or torch.max(inputs) > right:
        raise ValueError("Input outside domain")

    num_bins = unnormalized_widths.shape[-1]

    if min_bin_width * num_bins > 1.0:
        raise ValueError("Minimal bin width too large for the number of bins")
    if min_bin_height * num_bins > 1.0:
        raise ValueError("Minimal bin height too large for the number of bins")

    widths = nn.functional.softmax(unnormalized_widths, dim=-1)
    widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
    cumwidths = torch.cumsum(widths, dim=-1)
    cumwidths = nn.functional.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
    cumwidths = (right - left) * cumwidths + left
    cumwidths[..., 0] = left
    cumwidths[..., -1] = right
    widths = cumwidths[..., 1:] - cumwidths[..., :-1]

    derivatives = min_derivative + nn.functional.softplus(unnormalized_derivatives)

    heights = nn.functional.softmax(unnormalized_heights, dim=-1)
    heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
    cumheights = torch.cumsum(heights, dim=-1)
    cumheights = nn.functional.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
    cumheights = (top - bottom) * cumheights + bottom
    cumheights[..., 0] = bottom
    cumheights[..., -1] = top
    heights = cumheights[..., 1:] - cumheights[..., :-1]

    if inverse:
        bin_idx = searchsorted(cumheights, inputs)[..., None]
    else:
        bin_idx = searchsorted(cumwidths, inputs)[..., None]

    input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
    input_bin_widths = widths.gather(-1, bin_idx)[..., 0]

    input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
    delta = heights / widths
    input_delta = delta.gather(-1, bin_idx)[..., 0]

    input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
    input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)
    input_derivatives_plus_one = input_derivatives_plus_one[..., 0]

    input_heights = heights.gather(-1, bin_idx)[..., 0]

    if inverse:
        a = (inputs - input_cumheights) * (
            input_derivatives + input_derivatives_plus_one - 2 * input_delta
        ) + input_heights * (input_delta - input_derivatives)
        b = input_heights * input_derivatives - (inputs - input_cumheights) * (
            input_derivatives + input_derivatives_plus_one - 2 * input_delta
        )
        c = -input_delta * (inputs - input_cumheights)

        discriminant = b.pow(2) - 4 * a * c
        assert (discriminant >= 0).all()

        root = (2 * c) / (-b - torch.sqrt(discriminant))
        outputs = root * input_bin_widths + input_cumwidths

        theta_one_minus_theta = root * (1 - root)
        denominator = input_delta + (
            (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
            * theta_one_minus_theta
        )
        derivative_numerator = input_delta.pow(2) * (
            input_derivatives_plus_one * root.pow(2)
            + 2 * input_delta * theta_one_minus_theta
            + input_derivatives * (1 - root).pow(2)
        )
        logdetjac = torch.log(derivative_numerator) - 2 * torch.log(denominator)
        return outputs, -logdetjac

    else:
        theta = (inputs - input_cumwidths) / input_bin_widths
        theta_one_minus_theta = theta * (1 - theta)

        numerator = input_heights * (
            input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
        )
        denominator = input_delta + (
            (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
            * theta_one_minus_theta
        )
        outputs = input_cumheights + numerator / denominator

        derivative_numerator = input_delta.pow(2) * (
            input_derivatives_plus_one * theta.pow(2)
            + 2 * input_delta * theta_one_minus_theta
            + input_derivatives * (1 - theta).pow(2)
        )
        logdetjac = torch.log(derivative_numerator) - 2 * torch.log(denominator)
        return outputs, logdetjac


def unconstrained_RQS(
    inputs,
    unnormalized_widths,
    unnormalized_heights,
    unnormalized_derivatives,
    inverse=False,
    tail_bound=1.0,
    min_bin_width=1e-3,
    min_bin_height=1e-3,
    min_derivative=1e-3,
):
    """
    Based on implementation in https://github.com/bayesiains/nsf
    """
    outputs = torch.zeros_like(inputs)
    logdetjac = torch.zeros_like(inputs)

    unnormalized_derivatives = nn.functional.pad(unnormalized_derivatives, pad=(1, 1))
    constant = math.log(math.exp(1 - min_derivative) - 1)
    unnormalized_derivatives[..., 0] = constant
    unnormalized_derivatives[..., -1] = constant

    inside_intvl_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
    outside_interval_mask = ~inside_intvl_mask

    if outside_interval_mask.sum() > 0:
        outputs[outside_interval_mask] = inputs[outside_interval_mask]
        logdetjac[outside_interval_mask] = 0

    if inside_intvl_mask.sum() > 0:
        outputs[inside_intvl_mask], logdetjac[inside_intvl_mask] = RQS(
            inputs=inputs[inside_intvl_mask],
            unnormalized_widths=unnormalized_widths[inside_intvl_mask, :],
            unnormalized_heights=unnormalized_heights[inside_intvl_mask, :],
            unnormalized_derivatives=unnormalized_derivatives[inside_intvl_mask, :],
            inverse=inverse,
            left=-tail_bound,
            right=tail_bound,
            bottom=-tail_bound,
            top=tail_bound,
            min_bin_width=min_bin_width,
            min_bin_height=min_bin_height,
            min_derivative=min_derivative,
        )

    return outputs, logdetjac


class NRQS(nn.Module):
    """
    Neural rational quadratic spline flow, coupling layer

    [Durkan et al. 2019]
    """

    def __init__(self, mask1, f1, f2, K, B):
        """
        f1 and f2 take in full input x and map out flattened (3 * K - 1) * dims
        """
        super().__init__()
        self.dim1 = mask1.sum()
        self.dim2 = np.prod(mask1.shape) - self.dim1
        self.K = K
        self.B = B

        self.register_buffer("mask1", mask1.type(torch.bool))  # add batch dimension
        self.register_buffer("mask2", ~mask1.type(torch.bool))
        self.f1 = f1  # output (3 * K - 1) * dim2
        self.f2 = f2

    def _compute_RQS(self, x, W, H, D, inverse):
        W = 2 * self.B * torch.softmax(W, dim=2)
        H = 2 * self.B * torch.softmax(H, dim=2)
        D = nn.functional.softplus(D)
        x, ld = unconstrained_RQS(x, W, H, D, inverse, tail_bound=self.B)
        return x, ld

    def forward(self, x, reverse=False, log_px=0):
        """
        :param torch.tensor x: input of shape (batch, dims)
        """
        x1, x2 = x[:, self.mask1], x[:, self.mask2]
        x = torch.clone(x)  # copy to avoid overwriting input

        if reverse:
            x2_ = x * self.mask2
            out = self.f2(x2_).reshape(-1, self.dim1, 3 * self.K - 1)
            W, H, D = torch.split(out, self.K, dim=2)
            x1, ld = self._compute_RQS(x1, W, H, D, inverse=True)
            x[:, self.mask1] = x1  # update x1 part
            log_px = log_px + torch.sum(ld, dim=1)

            x1_ = x * self.mask1
            out = self.f1(x1_).reshape(-1, self.dim2, 3 * self.K - 1)
            W, H, D = torch.split(out, self.K, dim=2)
            x2, ld = self._compute_RQS(x2, W, H, D, inverse=True)
            x[:, self.mask2] = x2  # update x2 part
            log_px = log_px + torch.sum(ld, dim=1)

        else:
            x1_ = x * self.mask1
            out = self.f1(x1_).reshape(-1, self.dim2, 3 * self.K - 1)
            W, H, D = torch.split(out, self.K, dim=2)
            x2, ld = self._compute_RQS(x2, W, H, D, inverse=False)
            x[:, self.mask2] = x2  # update x2 part
            log_px = log_px + torch.sum(ld, dim=1)

            x2_ = x * self.mask2
            out = self.f2(x2_).reshape(-1, self.dim1, 3 * self.K - 1)
            W, H, D = torch.split(out, self.K, dim=2)
            x1, ld = self._compute_RQS(x1, W, H, D, inverse=False)
            x[:, self.mask1] = x1  # update x1 part
            log_px = log_px + torch.sum(ld, dim=1)

        return x, log_px


### Schematic

In [None]:
fig = plt.figure(figsize=(8,5)) # plot fits
fig.text(-0.21, 1.16, 'A', transform=ax.transAxes, size=15)

time_bins = I_ext[0].shape[0]
tt = np.arange(time_bins)*dt



widths = [1]
heights = [1, 1, 1, 2]
spec = fig.add_gridspec(ncols=1, nrows=4, width_ratios=widths, height_ratios=heights, 
                        left=0., right=0.7, bottom=0., top=1.0)
