# Inference.py

In [None]:
# Credit: Functions here are essentially copies of those in the
# SSM package by Scott Linderman et al. https://github.com/lindermanlab/ssm

import numba
import numpy as np
import numpy.random as npr
from scipy.special import logsumexp as logsumexp_scipy
from scipy.special import gammaln


LOG_EPS = 1e-16


@numba.jit(nopython=True, cache=True)
def logsumexp(x):
    N = x.shape[0]

    # find the max
    m = -np.inf
    for i in range(N):
        m = max(m, x[i])

    # sum the exponentials
    out = 0
    for i in range(N):
        out += np.exp(x[i] - m)
        
    return m + np.log(out)


@numba.jit(nopython=True, cache=True)
def dlse(a, out):
    K = a.shape[0]
    lse = logsumexp(a)
    for k in range(K):
        out[k] = np.exp(a[k] - lse)


@numba.jit(nopython=True, cache=True)
def forward_pass(log_pi0,
                 log_Ps,
                 log_likes,
                 alphas):

    T = log_likes.shape[0]  # number of time steps
    K = log_likes.shape[1]  # number of discrete states

    # if Ps.ndim == 2:
    #     Ps = Ps[None, :, :]
    assert log_Ps.shape[0] == T-1 or log_Ps.shape[0] == 1
    assert log_Ps.shape[1] == K
    assert log_Ps.shape[2] == K
    assert alphas.shape[0] == T
    assert alphas.shape[1] == K

    # Check if we have heterogeneous transition matrices.
    # If not, save memory by passing in log_Ps of shape (1, K, K)
    hetero = (log_Ps.shape[0] == T-1)
    alphas[0] = log_pi0 + log_likes[0]
    for t in range(T-1):
        m = np.max(alphas[t])
        # alphas[t+1] = np.log(np.dot(np.exp(alphas[t] - m), Ps[t * hetero])) + m + log_likes[t+1]
        alphas[t+1] = np.log(np.dot(np.exp(alphas[t] - m), np.exp(log_Ps[t * hetero]))) + m + log_likes[t+1]


    return logsumexp(alphas[T-1])
    for t in range(T-1):
        m = np.max(alphas[t])
    return logsumexp(alphas[T-1])


@numba.jit(nopython=True, cache=True)
def backward_pass(log_Ps,
                  log_likes,
                  betas):

    T = log_likes.shape[0]  # number of time steps
    K = log_likes.shape[1]  # number of discrete states

    assert log_Ps.shape[0] == T-1 or log_Ps.shape[0] == 1
    assert log_Ps.shape[1] == K
    assert log_Ps.shape[2] == K
    assert betas.shape[0] == T
    assert betas.shape[1] == K

    # Check if we have heterogeneous transition matrices.
    # If not, save memory by passing in log_Ps of shape (1, K, K)
    hetero = (log_Ps.shape[0] == T-1)
    tmp = np.zeros(K)

    # Initialize the last output
    betas[T-1] = 0
    for t in range(T-2,-1,-1):
        tmp = log_likes[t+1] + betas[t+1]
        m = np.max(tmp)
        # betas[t] = np.log(np.dot(Ps[t * hetero], np.exp(tmp - m))) + m
        betas[t] = np.log(np.dot( np.exp(log_Ps[t * hetero]), np.exp(tmp - m))) + m

def hmm_normalizer(log_pi0, log_Ps, ll):
    T, K = ll.shape
    alphas = np.zeros((T, K))

#     # Make sure everything is C contiguous
#     pi0 = to_c(pi0)
#     Ps = to_c(Ps)
#     ll = to_c(ll)

    forward_pass(log_pi0, log_Ps, ll, alphas)
    out = logsumexp_scipy(alphas[-1])
    return out


def hmm_expected_states(log_pi0, log_Ps, ll, filter=False):
    """
    Calculates the posterior probabilities of HMM states given the observations, implicitly input via
    the matrix of observation log-likelihoods.
    :param log_pi0: shape (K,),  vector of initial state probabilities (log)
    :param log_Ps: shape (K, K): state transition matrix (time-homogeneous case), or:
               shape (T-1, K, K): temporal sequence
    :param ll: shape (T, K): matrix of log-likelihoods (i.e. log observation probabilities
                             evaluated for the actual observations).
    :param filter: False by default. If True the function calculates the so-called "filtered"
                   posterior probabilities which only take into account observations until time t
                   (as opposed to all observations until time T (with Python index T-1), which is what
                   is calculated by default).
    :return:
    expected_states: this is an array of shape (T, K) with the t-th row giving the
                     posterior probabilities of the different Markov states,
                     conditioned on the sequence of observations.
    normalizer: this is the model log-likelihood, i.e. it is the log-probability of the entire sequence
                of observations (given the model parameters, which are implicit here).
    """
    T, K = ll.shape
    
    if log_Ps.ndim == 2:
        log_Ps = log_Ps[None, :, :]
    assert log_Ps.ndim == 3


    alphas = np.zeros((T, K))
    forward_pass(log_pi0, log_Ps, ll, alphas)
    normalizer = logsumexp(alphas[-1])

    betas = np.zeros((T, K))
    if not filter:
        backward_pass(log_Ps, ll, betas)

    # Compute P[x_t | n_{1:T}] for t = 1, ..., T (if filter = True, calculate P[x_t | n_{1:t}] instead).
    expected_states = alphas + betas
    expected_states -= logsumexp_scipy(expected_states, axis=1, keepdims=True)
    expected_states = np.exp(expected_states)
    
    # expected_joints calculation removed

    return expected_states, normalizer


@numba.jit(nopython=True, cache=True)
def backward_sample(log_Ps, log_likes, alphas, us, zs):
    T = log_likes.shape[0]
    K = log_likes.shape[1]
    assert log_Ps.shape[0] == T-1 or log_Ps.shape[0] == 1
    assert log_Ps.shape[1] == K
    assert log_Ps.shape[2] == K
    assert alphas.shape[0] == T
    assert alphas.shape[1] == K
    assert us.shape[0] == T
    assert zs.shape[0] == T

    lpzp1 = np.zeros(K)
    lpz = np.zeros(K)

    # Trick for handling time-varying transition matrices
    hetero = (log_Ps.shape[0] == T-1)

    for t in range(T-1,-1,-1):
        # compute normalized log p(z[t] = k | z[t+1])
        lpz = lpzp1 + alphas[t]
        Z = logsumexp(lpz)

        # sample
        acc = 0
        zs[t] = K-1
        for k in range(K):
            acc += np.exp(lpz[k] - Z)
            if us[t] < acc:
                zs[t] = k
                break

        # set the transition potential
        if t > 0:
            # lpzp1 = np.log(Ps[(t-1) * hetero, :, int(zs[t])] + LOG_EPS)
            lpzp1 = log_Ps[(t-1) * hetero, :, int(zs[t])]


@numba.jit(nopython=True, cache=True)
def _hmm_sample(log_pi0, log_Ps, ll):
    T, K = ll.shape

    # Forward pass gets the predicted state at time t given
    # observations up to and including those from time t
    alphas = np.zeros((T, K))
    forward_pass(log_pi0, log_Ps, ll, alphas)

    # Sample backward
    us = npr.rand(T)
    zs = -1 * np.ones(T)
    backward_sample(log_Ps, ll, alphas, us, zs)
    return zs


def hmm_sample(log_pi0, log_Ps, ll):
    return _hmm_sample(log_pi0, log_Ps, ll).astype(int)


@numba.jit(nopython=True, cache=True)
def _viterbi(log_pi0, log_Ps, ll):
    """
    This is modified from pyhsmm.internals.hmm_states
    by Matthew Johnson.
    """
    T, K = ll.shape

    # Check if the transition matrices are stationary or
    # time-varying (hetero)
    hetero = (log_Ps.shape[0] == T-1)
    if not hetero:
        assert log_Ps.shape[0] == 1

    # Pass max-sum messages backward
    scores = np.zeros((T, K))
    args = np.zeros((T, K))
    for t in range(T-2,-1,-1):
        # vals = np.log(Ps[t * hetero] + LOG_EPS) + scores[t+1] + ll[t+1]
        vals = log_Ps[t * hetero] + scores[t+1] + ll[t+1]

        for k in range(K):
            args[t+1, k] = np.argmax(vals[k])
            scores[t, k] = np.max(vals[k])

    # Now maximize forwards
    z = np.zeros(T)
    # z[0] = (scores[0] + np.log(pi0 + LOG_EPS) + ll[0]).argmax()
    z[0] = (scores[0] + log_pi0 + ll[0]).argmax()
    for t in range(1, T):
        z[t] = args[t, int(z[t-1])]

    return z


def viterbi(log_pi0, log_Ps, ll):
    """
    Find the most likely state sequence
    """
    return _viterbi(log_pi0, log_Ps, ll).astype(int)


def poisson_logpdf(counts, lambdas, mask=None):
    """
    Compute the log probability of a Poisson distribution.
    This will broadcast as long as data and lambdas have the same
    (or at least compatible) leading dimensions.
    Parameters
    ----------
    counts : array_like of shape (Ntrials, T) or (T,)
        array of integer counts for which to evaluate the log probability
    lambdas : array_like of shape (K,)
        The rates (mean counts) of the Poisson distribution(s)
    Returns
    -------
    lps :  lps : array_like with shape (T, K), or (Ntrials, T, K) depending on
          the shape of 'counts'.
        Log probabilities under the Poisson distribution(s).
    """
    assert counts.dtype in (int, np.int8, np.int16, np.int32, np.int64)
    assert counts.ndim == 1 or counts.ndim == 2
    if counts.ndim == 1:
        counts = counts[:, None]
    elif counts.ndim == 2:
        counts = counts[:,:,None]

    # Compute log pdf
    lambdas[lambdas == 0] = 1e-8

    lls = -gammaln(counts + 1) - lambdas + counts * np.log(lambdas)
    return lls


# models.py

In [None]:
from inference import *
import numpy as np
import numpy.random as npr
import scipy.stats as stats
import matplotlib.pyplot as plt

LOG_EPS = 1e-16


def lo_histogram(x, bins):
    """
    Left-open version of np.histogram with left-open bins covering the interval (left_edge, right_edge]
    (np.histogram does the opposite and treats bins as right-open.)
    Input & output behaviour is exactly the same as np.histogram
    """
    out = np.histogram(-x, -bins[::-1])
    return out[0][::-1], out[1:]


def gamma_isi_point_process(rate, shape):
    """
    Simulates (1 trial of) a sub-poisson point process (with underdispersed inter-spike intervals relative to Poisson)
    :param rate: time-series giving the mean spike count (firing rate * dt) in different time bins (= time steps)
    :param shape: shape parameter of the gamma distribution of ISI's
    :return: vector of spike counts with same shape as "rate".
    """
    sum_r_t = np.hstack((0, np.cumsum(rate)))
    gs = np.zeros(2)
    while gs[-1] < sum_r_t[-1]:
        gs = np.cumsum( npr.gamma(shape, 1 / shape, size=(2 + int(2 * sum_r_t[-1]),)) )
    y, _ = lo_histogram(gs, sum_r_t)

    return y

def emit(dt, rate, GammaShape=None):
    """
    emit spikes based on rates
    :param rate: firing rate sequence, r_t, possibly in many trials. Shape: (Ntrials, T)
    :return: spike train, n_t, as an array of shape (Ntrials, T) containing integer spike counts in different
             trials and time bins.
    """
    if GammaShape is None:
        # poisson spike emissions
        y = npr.poisson(rate * dt)
    else:
        # sub-poisson/underdispersed spike emissions
        y = gamma_isi_point_process(rate * dt, GammaShape)

    return y


class StepModel():
    """
    Simulator of the Stepping Model of Latimer et al. Science 2015.
    """
    def __init__(self, m=50, r=10, x0=0.2, Rh=50, isi_gamma_shape=None, Rl=None, dt=None):
        """
        Simulator of the Stepping Model of Latimer et al. Science 2015.
        :param m: mean jump time (in # of time-steps). This is the mean parameter of the Negative Binomial distribution
                  of jump (stepping) time
        :param r: parameter r ("# of successes") of the Negative Binomial (NB) distribution of jump (stepping) time
                  (Note that it is more customary to parametrise the NB distribution by its parameter p and r,
                  instead of m and r, where p is so-called "probability of success" (see Wikipedia). The two
                  parametrisations are equivalent and one can go back-and-forth via: m = r (1-p)/p and p = r / (m + r).)
        :param x0: determines the pre-jump firing rate, via  R_pre = x0 * Rh (see below for Rh)
        :param Rh: firing rate of the "up" state (the same as the post-jump state in most of the project tasks)
        :param isi_gamma_shape: shape parameter of the Gamma distribution of inter-spike intervals.
                            see https://en.wikipedia.org/wiki/Gamma_distribution
        :param Rl: firing rate of the post-jump "down" state (rarely used)
        :param dt: real time duration of time steps in seconds (only used for converting rates to units of inverse time-step)
        """
        self.m = m
        self.r = r
        self.x0 = x0

        self.p = r / (m + r)

        self.Rh = Rh
        if Rl is not None:
            self.Rl = Rl

        self.isi_gamma_shape = isi_gamma_shape
        self.dt = dt


    @property
    def params(self):
        return self.m, self.r, self.x0

    @property
    def fixed_params(self):
        return self.Rh, self.Rl


    def emit(self, rate):
        """
        emit spikes based on rates
        :param rate: firing rate sequence, r_t, possibly in many trials. Shape: (Ntrials, T)
        :return: spike train, n_t, as an array of shape (Ntrials, T) containing integer spike counts in different
                 trials and time bins.
        """
        if self.isi_gamma_shape is None:
            # poisson spike emissions
            y = npr.poisson(rate * self.dt)
        else:
            # sub-poisson/underdispersed spike emissions
            y = gamma_isi_point_process(rate * self.dt, self.isi_gamma_shape)

        return y


    def simulate(self, Ntrials=1, T=100, get_rate=True, GammaShape = None):
        """
        :param Ntrials: (int) number of trials
        :param T: (int) duration of each trial in number of time-steps.
        :param get_rate: whether or not to return the rate time-series
        :return:
        spikes: shape = (Ntrial, T); spikes[j] gives the spike train, n_t, in trial j, as
                an array of spike counts in each time-bin (= time step)
        jumps:  shape = (Ntrials,) ; jumps[j] is the jump time (aka step time), tau, in trial j.
        rates:  shape = (Ntrial, T); rates[j] is the rate time-series, r_t, in trial j (returned only if get_rate=True)
        """
        if GammaShape != None:
            self.isi_gamma_shape = GammaShape
            
        # set dt (time-step duration in seconds) such that trial duration is always 1 second, regardless of T.
        dt = 1 / T
        self.dt = dt

        ts = np.arange(T)

        spikes, jumps, rates = [], [], []
        for tr in range(Ntrials):
            # sample jump time
            jump = npr.negative_binomial(self.r, self.p)
            jumps.append(jump) # (unit: 1/T s )

            # first set rate at all times to pre-step rate
            rate = np.ones(T) * self.x0 * self.Rh #=R0
            # then set rates after jump to self.Rh
            rate[ts >= jump] = self.Rh
            rates.append(rate)

            spikes.append(self.emit(rate))

        if get_rate:
            return np.array(spikes), np.array(jumps), np.array(rates)
        else:
            return np.array(spikes), np.array(jumps)

        
        
        
        
    def simulate_HMM_inhomo(self, Ntrials=1, T=100, get_rate=True, GammaShape = None):
        """
        :param Ntrials: (int) number of trials
        :param T: (int) duration of each trial in number of time-steps.
        :param get_rate: whether or not to return the rate time-series
        :return:
        spikes: shape = (Ntrial, T); spikes[j] gives the spike train, n_t, in trial j, as
                an array of spike counts in each time-bin (= time step)
        jumps:  shape = (Ntrials,) ; jumps[j] is the jump time (aka step time), tau, in trial j.
        rates:  shape = (Ntrial, T); rates[j] is the rate time-series, r_t, in trial j (returned only if get_rate=True)
        """
        if GammaShape != None:
            self.isi_gamma_shape = GammaShape
            
            
        # set dt (time-step duration in seconds) such that trial duration is always 1 second, regardless of T.
        dt = 1 / T
        self.dt = dt

        ts = np.arange(T)
        # Inhomogeneous markov chain
        PMF_jump = stats.nbinom.pmf(ts, self.r, self.p)
        CMF_jump = stats.nbinom.cdf(ts, self.r, self.p)
        
        trans_matrices = np.empty((T-1,2,2)) # inhomogeneous transition matrix

        for t in range(0,T-1):
            if CMF_jump[t]==1:
                trans_matrices[t] = np.array([[0, 1], [0,1]])
            else:
                trans_matrices[t] = np.array([[ (1-CMF_jump[t+1]) / (1-CMF_jump[t]), PMF_jump[t+1]/(1-CMF_jump[t]) ], [0,1]])
            if np.any(trans_matrices[t] < 0):
                print("invalid probability: negative value!")
                print(trans_matrices[t])
                print(f"t={t},m={self.m}, r={self.r}")
        # sample the first state (t=0)
        p0 = np.array([1-PMF_jump[0],PMF_jump[0]]) 
        # logProb of t=0
        log_pi0 = np.log(p0)
        #log trans matrix
        log_trans_matrices = np.log(trans_matrices)
        
        states = np.zeros(T, dtype=int) 
        spikes, jumps, rates = [], [], []
        for tr in range(Ntrials):
            # sample jump time
 
            jump = 0
            # Simulate the chain
            states[0] = np.random.choice(2, size=None, replace=True, p=p0)
            
            # sample all other (T-1) states
            for t in range(1, T):
                # The transition probabilities depend on the current state
                states[t] = np.random.choice(2, size=None, replace=True, p=trans_matrices[t-1, states[t-1]])
                jump = t
                if states[t]==1:
                    jumps.append(jump) # (unit: 1/T s )
                    break
                elif t == T-1:
                    jumps.append(jump+1) # Not jumped during 0:T-1

            # first set rate at all times to pre-step rate
            rate = np.ones(T) * self.x0 * self.Rh #=R0
            # then set rates after jump to self.Rh
            rate[ts >= jump] = self.Rh
            rates.append(rate)

            spikes.append(self.emit(rate))

        if get_rate:
            return np.array(spikes), np.array(jumps), np.array(rates), log_trans_matrices, log_pi0
        else:
            return np.array(spikes), np.array(jumps), log_trans_matrices, log_pi0
  
        
#     def simulate_HMM_homo(self, Ntrials=1, T=100, get_rate=True):
#         """
#         :param Ntrials: (int) number of trials
#         :param T: (int) duration of each trial in number of time-steps.
#         :param get_rate: whether or not to return the rate time-series
#         :return:
#         spikes: shape = (Ntrial, T); spikes[j] gives the spike train, n_t, in trial j, as
#                 an array of spike counts in each time-bin (= time step)
#         jumps:  shape = (Ntrials,) ; jumps[j] is the jump time (aka step time), tau, in trial j.
#         rates:  shape = (Ntrial, T); rates[j] is the rate time-series, r_t, in trial j (returned only if get_rate=True)
#         """
#         # set dt (time-step duration in seconds) such that trial duration is always 1 second, regardless of T.
#         dt = 1 / T
#         self.dt = dt

#         ts = np.arange(T)
        
#         spikes, jumps, rates = [], [], []
#         for tr in range(Ntrials):
#             # sample jump time
#             # homogeneous markov chain
           
#             # Find transition matrix for r and p
#             trans_matrix = np.zeros((self.r+1,self.r+1))
#             trans_matrix[0, 0:2] = [1-self.p, self.p] # Set the first row
#             for i in range(1, self.r): # Create the remaining rows
#                 trans_matrix[i, i:i+2] = [1-self.p, self.p]
#             trans_matrix[self.r , self.r ] = 1 # Set the last row
#             # Find initial distribution of latent state 
#             pi = np.zeros(self.r+1)
#             pi[0] = 1
#             p0 = np.matmul(pi, trans_matrix)
#             # 1xT matrix to record states. Initial state = 0
#             states = np.zeros(self.r+T, dtype=int) 
#             states[0] = np.random.choice((self.r+1), size=None, replace=True, p=p0)

#             # sample all other (T-1) states
#             jump=0
#             for t in range(1, T+self.r):
#                 # The transition probabilities depend on the current state
#                 states[t] = np.random.choice((self.r+1), size=None, replace=True, p=trans_matrix[states[t-1]])
#                 jump = t-self.r
#                 if states[t]==self.r:
#                     jumps.append(jump) #this is the jump time (ms)
#                     states[t:] = self.r
#                     break
#                 elif t == T+self.r-1:
#                     jumps.append(jump+T) # Not jumped during 0:T-1

                    
           
#             # first set rate at all times to pre-step rate
#             rate = np.ones(T) * self.x0 * self.Rh #=R0
#             # then set rates after jump to self.Rh
#             rate[ts >= jump] = self.Rh
#             rates.append(rate)

#             spikes.append(self.emit(rate))
            
#             # Prob of t=0
#             log_pi0 = np.log(p0+LOG_EPS)
#             # log tran matrix
#             log_trans_matrix = np.log(trans_matrix+LOG_EPS)

            
#         if get_rate:
#             return np.array(spikes), np.array(jumps), np.array(rates), log_trans_matrix, log_pi0, states
#         else:
#             return np.array(spikes), np.array(jumps), log_trans_matrix, log_pi0, states
   
    def simulate_HMM_homo(self, Ntrials=1, T=100, get_rate=True, GammaShape = None):
        """
        :param Ntrials: (int) number of trials
        :param T: (int) duration of each trial in number of time-steps.
        :param get_rate: whether or not to return the rate time-series
        :return:
        spikes: shape = (Ntrial, T); spikes[j] gives the spike train, n_t, in trial j, as
                an array of spike counts in each time-bin (= time step)
        jumps:  shape = (Ntrials,) ; jumps[j] is the jump time (aka step time), tau, in trial j.
        rates:  shape = (Ntrial, T); rates[j] is the rate time-series, r_t, in trial j (returned only if get_rate=True)
        """
        if GammaShape != None:
            self.isi_gamma_shape = GammaShape
            
        # set dt (time-step duration in seconds) such that trial duration is always 1 second, regardless of T.
        dt = 1 / T
        self.dt = dt

        ts = np.arange(T)
        p = self.r/(self.m+self.r)
        
        # homogeneous markov chain

        # Find transition matrix for r and p
        trans_matrix = np.zeros((self.r+1,self.r+1))
        for i in range(self.r+1): # axis=0
            for j in range(i, self.r+1): # axis=1
                trans_matrix[i, j] = (p**(j-i)) * (1-p) if j < self.r else p**(self.r-i)

        # Find initial distribution of latent state 
        pi = np.zeros(self.r+1)
        pi[0] = 1
        p0 = np.matmul(pi, trans_matrix)

        # Prob of t=0
        log_pi0 = np.log(p0+LOG_EPS)
        # log tran matrix
        log_trans_matrix = np.log(trans_matrix)
            
        states = np.zeros(T, dtype=int) 
        spikes, jumps, rates = [], [], []
        for tr in range(Ntrials):
            # sample jump time

            # 1xT matrix to record states. Initial state = 0
            states[0] = np.random.choice((self.r+1), size=None, replace=True, p=p0)

            # sample all other (T-1) states
            jump=0
            for t in range(1, T):
                # The transition probabilities depend on the current state
                states[t] = np.random.choice((self.r+1), size=None, replace=True, p=trans_matrix[states[t-1]])
                jump = t
                if states[t]==self.r:
                    jumps.append(jump) #this is the jump time (ms)
                    states[t:] = self.r
                    break
                elif t == T-1:
                    jumps.append(jump+T) # Not jumped during 0:T-1

                    
           
            # first set rate at all times to pre-step rate
            rate = np.ones(T) * self.x0 * self.Rh #=R0
            # then set rates after jump to self.Rh
            rate[ts >= jump] = self.Rh
            rates.append(rate)

            spikes.append(self.emit(rate))
            

            
        if get_rate:
            return np.array(spikes), np.array(jumps), np.array(rates), log_trans_matrix, log_pi0, states
        else:
            return np.array(spikes), np.array(jumps), log_trans_matrix, log_pi0, states
        

    def simulate_HMM_2states(self, Ntrials=1, T=100, get_rate=True, GammaShape = None):
        """
        :param Ntrials: (int) number of trials
        :param T: (int) duration of each trial in number of time-steps.
        :param get_rate: whether or not to return the rate time-series
        :return:
        spikes: shape = (Ntrial, T); spikes[j] gives the spike train, n_t, in trial j, as
                an array of spike counts in each time-bin (= time step)
        jumps:  shape = (Ntrials,) ; jumps[j] is the jump time (aka step time), tau, in trial j.
        rates:  shape = (Ntrial, T); rates[j] is the rate time-series, r_t, in trial j (returned only if get_rate=True)
        """
        
        if GammaShape != None:
            self.isi_gamma_shape = GammaShape
            
        # set dt (time-step duration in seconds) such that trial duration is always 1 second, regardless of T.
        dt = 1 / T
        self.dt = dt

        ts = np.arange(T)
        p = self.r/(self.m+self.r)
        
        # homogeneous markov chain

        # Find transition matrix for r and p
        trans_matrix = np.array([[1- 1/self.m, 1/self.m],[0,1]])
       


        # Find initial distribution of latent state 
        pi = [1,0] # initial state

        p0 = np.matmul(pi, trans_matrix)

        # Prob of t=0
        log_pi0 = np.log(p0+LOG_EPS)
        # log tran matrix
        log_trans_matrix = np.log(trans_matrix)
            
        states = np.zeros(T, dtype=int) 
        spikes, jumps, rates = [], [], []
        for tr in range(Ntrials):
            # sample jump time

            # 1xT matrix to record states. Initial state = 0
            states[0] = np.random.choice(2, size=None, replace=True, p=p0)

            # sample all other (T-1) states
            jump=0
            for t in range(1, T):
                # The transition probabilities depend on the current state
                states[t] = np.random.choice(2, size=None, replace=True, p=trans_matrix[states[t-1]])
                jump = t
                if states[t]==1:
                    jumps.append(jump) #this is the jump time (ms)
                    states[t:] = 1
                    break
                elif t == T-1:
                    jumps.append(jump+T) # Not jumped during 0:T-1

                    
           
            # first set rate at all times to pre-step rate
            rate = np.ones(T) * self.x0 * self.Rh #=R0
            # then set rates after jump to self.Rh
            rate[ts >= jump] = self.Rh
            rates.append(rate)

            spikes.append(self.emit(rate))
            

            
        if get_rate:
            return np.array(spikes), np.array(jumps), np.array(rates), log_trans_matrix, log_pi0, states
        else:
            return np.array(spikes), np.array(jumps), log_trans_matrix, log_pi0, states
         
        
        
        
class RampModel():
    """
    Simulator of the Ramping Model (aka Drift-Diffusion Model) of Latimer et al., Science (2015).
    """
    def __init__(self, beta=0.5, sigma=0.2, x0=.2, Rh=50, isi_gamma_shape=None, Rl=None, dt=None):
        """
        Simulator of the Ramping Model of Latimer et al. Science 2015.
        :param beta: drift rate of the drift-diffusion process
        :param sigma: diffusion strength of the drift-diffusion process.
        :param x0: average initial value of latent variable x[0]
        :param Rh: the maximal firing rate obtained when x_t reaches 1 (corresponding to the same as the post-step
                   state in most of the project tasks)
        :param isi_gamma_shape: shape parameter of the Gamma distribution of inter-spike intervals.
                            see https://en.wikipedia.org/wiki/Gamma_distribution
        :param Rl: Not implemented. Ignore.
        :param dt: real time duration of time steps in seconds (only used for converting rates to units of inverse time-step)
        """
        self.beta = beta
        self.sigma = sigma
        self.x0 = x0

        self.Rh = Rh
        if Rl is not None:
            self.Rl = Rl

        self.isi_gamma_shape = isi_gamma_shape
        self.dt = dt


    @property
    def params(self):
        return self.mu, self.sigma, self.x0

    @property
    def fixed_params(self):
        return self.Rh, self.Rl


    def f_io(self, xs, b=None):
        if b is None:
            return self.Rh * np.maximum(0, xs)
        else:
            return self.Rh * b * np.log(1 + np.exp(xs / b))


    def emit(self, rate):
        """
        emit spikes based on rates
        :param rate: firing rate sequence, r_t, possibly in many trials. Shape: (Ntrials, T)
        :return: spike train, n_t, as an array of shape (Ntrials, T) containing integer spike counts in different
                 trials and time bins.
        """
        if self.isi_gamma_shape is None:
            # poisson spike emissions
            y = npr.poisson(rate * self.dt)
        else:
            # sub-poisson/underdispersed spike emissions
            y = gamma_isi_point_process(rate * self.dt, self.isi_gamma_shape)

        return y


    def simulate(self, Ntrials=1, T=100, get_rate=True, GammaShape = None):
        """
        :param Ntrials: (int) number of trials
        :param T: (int) duration of each trial in number of time-steps.
        :param get_rate: whether or not to return the rate time-series
        :return:
        spikes: shape = (Ntrial, T); spikes[j] gives the spike train, n_t, in trial j, as
                an array of spike counts in each time-bin (= time step)
        xs:     shape = (Ntrial, T); xs[j] is the latent variable time-series x_t in trial j
        rates:  shape = (Ntrial, T); rates[j] is the rate time-series, r_t, in trial j (returned only if get_rate=True)
        """
        
        if GammaShape != None:
            self.isi_gamma_shape = GammaShape
        # set dt (time-step duration in seconds) such that trial duration is always 1 second, regardless of T.
        dt = 1 / T
        self.dt = dt

       # simulate all trials in parallel (using numpy arrays and broadcasting)

        # first, directly integrate/sum the drift-diffusion updates
        # x[t+1] = x[t] + β dt + σ √dt * randn (with initial condition x[0] = x0 + σ √dt * randn)
        # to get xs in shape (Ntrials, T):
        ts = np.arange(T)
        xs = self.x0 + self.beta * dt * ts + self.sigma * np.sqrt(dt) * np.cumsum(npr.randn(Ntrials, T), axis=1)
        # in each trial set x to 1 after 1st passage through 1; padding xs w 1 assures passage does happen, possibly at T+1
        taus = np.argmax(np.hstack((xs, np.ones((xs.shape[0],1)))) >= 1., axis=-1)
        xs = np.where(ts[None,:] >= taus[:,None], 1., xs)
        # # the above 2 lines are equivalent to:
        # for x in xs:
        #     if np.sum(x >= 1) > 0:
        #         tau = np.nonzero(x >= 1)[0][0]
        #         x[tau:] = 1

        rates = self.f_io(xs) # shape = (Ntrials, T)

        spikes = np.array([self.emit(rate) for rate in rates]) # shape = (Ntrial, T)

        if get_rate:
            return spikes, xs, rates
        else:
            return spikes, xs

    def simulate_HMM(self, Ntrials=1, T=100, K=100, get_rate=True, GammaShape = None):
        """
        :param Ntrials: (int) number of trials
        :param T: (int) duration of each trial in number of time-steps.
        :param K: (int) number of states of the HMM
        :param get_rate: whether or not to return the rate time-series
        :return:
        spikes: shape = (Ntrial, T); spikes[j] gives the spike train, n_t, in trial j, as
                an array of spike counts in each time-bin (= time step)
        xs:     shape = (Ntrial, T); xs[j] is the latent variable time-series x_t in trial j
        rates:  shape = (Ntrial, T); rates[j] is the rate time-series, r_t, in trial j (returned only if get_rate=True)
        """
        
        if GammaShape != None:
            self.isi_gamma_shape = GammaShape
        # set dt (time-step duration in seconds) such that trial duration is always 1 second, regardless of T.
        dt = 1 / T
        self.dt = dt
        ts = np.arange(T)
        
        
        ## transition matrix
        # calculate difference between states, forming a K x K matrix.
        st = np.arange(K) # states
        s_grid = st - st.reshape(-1,1) 
        mu = self.beta * dt * (K-1)
        std = self.sigma * np.sqrt(dt) * (K-1)  + LOG_EPS # To avoid error when sigma = 0

        # Suppose trans_matrix is in log space
        log_trans_matrix = np.log(stats.norm.cdf(s_grid+0.5, mu, std) - stats.norm.cdf(s_grid-0.5, mu, std))
        # First column
        log_trans_matrix[:,0] = stats.norm.logcdf(s_grid[:,0]+0.5, mu, std)
        # Last column
        log_trans_matrix[:,-1] = stats.norm.logsf(s_grid[:,-1]-0.5, mu, std)
        # Last row
        log_trans_matrix[-1] = -np.inf
        log_trans_matrix[-1,-1] = 0

        # Normalization is not necessary
        #Compute row sums in log space using logsumexp
        log_row_sums = np.zeros(log_trans_matrix.shape[0]) 
        for i in range(log_trans_matrix.shape[0]):
            log_row_sums[i] = logsumexp(log_trans_matrix[i, :])
            
        #Subtract row sums from each element in log space to normalize
        normalized_log_trans_matrix = log_trans_matrix - log_row_sums[:, np.newaxis]
        trans_matrix = np.exp(normalized_log_trans_matrix)
        
#             # normalise each row
#             row_sums = trans_matrix.sum(axis=1)
#             trans_matrix = trans_matrix / row_sums[:, np.newaxis]

        mu = self.x0 * (K-1)
        std = self.sigma * np.sqrt(dt) * (K-1)  + LOG_EPS # To avoid error when sigma = 0
        #Suppose trans_matrix is in log space
        log_pi = np.log(stats.norm.cdf(st+0.5, mu, std) - stats.norm.cdf(st-0.5, mu, std))
        log_pi[0] = stats.norm.logcdf(st[0]+0.5, mu, std)
        log_pi[-1] = stats.norm.logsf(st[-1]-0.5, mu, std) # sf = 1-cdf
        #Compute row sums in log space using logsumexp
        log_pi_sums = logsumexp(log_pi)
        #Subtract row sums from each element in log space to normalize
        normalized_log_pi = log_pi - log_pi_sums
        # recover the pi
        pi = np.exp(normalized_log_pi)

        normalized_log_pi0 = np.log(np.matmul(pi,trans_matrix))
        # NxT matrix to record states
        states = np.zeros((Ntrials, T), dtype=int)
        for n in range(Ntrials):
            # Draw the initial state from the initial distribution
            # K states, pi initial distribution. => return a scaler with initial values (discrete)
            states[n,0] = np.random.choice(K, p=pi)

            # Simulate the chain
            for t in range(1, T):
                # The transition probabilities depend on the current state
                current_state = states[n,t-1]
                states[n,t] = np.random.choice(K, p=trans_matrix[current_state])

        xs = states / (K-1)
        
        # in each trial set x to 1 after 1st passage through 1; padding xs w 1 assures passage does happen, possibly at T+1
#         taus = np.argmax(np.hstack((xs, np.ones((xs.shape[0],1)))) >= 1., axis=-1)
#         xs = np.where(ts[None,:] >= taus[:,None], 1., xs)

        rates = self.f_io(xs) # shape = (Ntrials, T)

        spikes = np.array([self.emit(rate) for rate in rates]) # shape = (Ntrial, T)

        if get_rate:
            return spikes, xs, rates, normalized_log_trans_matrix, normalized_log_pi0
        else:
            return spikes, xs, normalized_log_trans_matrix, normalized_log_pi0


    def simulate_HMM_ns(self, Ntrials=1, num_ns=30, T=100, K=100, get_rate=True, GammaShape = None):
        """
        :param Ntrials: (int) number of trials
        :param T: (int) duration of each trial in number of time-steps.
        :param K: (int) number of states of the HMM
        :param get_rate: whether or not to return the rate time-series
        :return:
        spikes: shape = (Ntrial, T); spikes[j] gives the spike train, n_t, in trial j, as
                an array of spike counts in each time-bin (= time step)
        xs:     shape = (Ntrial, T); xs[j] is the latent variable time-series x_t in trial j
        rates:  shape = (Ntrial, T); rates[j] is the rate time-series, r_t, in trial j (returned only if get_rate=True)
        """
        
        if GammaShape != None:
            self.isi_gamma_shape = GammaShape
        # set dt (time-step duration in seconds) such that trial duration is always 1 second, regardless of T.
        dt = 1 / T
        self.dt = dt
        ts = np.arange(T)
        
 #        num_ns = 30 number of negative states = num_ns * K
        ## transition matrix
        # calculate difference between states, forming a K x K matrix.
        st = np.arange(-num_ns*K,K) # states
        s_grid = st - st.reshape(-1,1) 
        mu = self.beta * dt * (K-1)
        std = self.sigma * np.sqrt(dt) * (K-1)  + LOG_EPS # To avoid error when sigma = 0
        
        # Suppose trans_matrix is in log space
        log_trans_matrix = np.log(stats.norm.cdf(s_grid+0.5, mu, std) - stats.norm.cdf(s_grid-0.5, mu, std))
        # First column
        log_trans_matrix[:,0] = stats.norm.logcdf(s_grid[:,0]+0.5, mu, std)
        # Last column
        log_trans_matrix[:,-1] = stats.norm.logsf(s_grid[:,-1]-0.5, mu, std)
        # Last row
        log_trans_matrix[-1] = -np.inf
        log_trans_matrix[-1,-1] = 0

        # Normalization is not necessary
        #Compute row sums in log space using logsumexp
        log_row_sums = np.zeros(log_trans_matrix.shape[0]) 
        for i in range(log_trans_matrix.shape[0]):
            log_row_sums[i] = logsumexp(log_trans_matrix[i, :])
            
        #Subtract row sums from each element in log space to normalize
        normalized_log_trans_matrix = log_trans_matrix - log_row_sums[:, np.newaxis]
        trans_matrix = np.exp(normalized_log_trans_matrix)
        
#             # normalise each row
#             row_sums = trans_matrix.sum(axis=1)
#             trans_matrix = trans_matrix / row_sums[:, np.newaxis]

        mu = self.x0 * (K-1)
        std = self.sigma * np.sqrt(dt) * (K-1)  + LOG_EPS # To avoid error when sigma = 0
        #Suppose trans_matrix is in log space
        log_pi = np.log(stats.norm.cdf(st+0.5, mu, std) - stats.norm.cdf(st-0.5, mu, std))
        log_pi[0] = stats.norm.logcdf(st[0]+0.5, mu, std)
        log_pi[-1] = stats.norm.logsf(st[-1]-0.5, mu, std) # sf = 1-cdf

        #Compute normalized_log_pi in log space using logsumexp
        normalized_log_pi = log_pi - logsumexp(log_pi)
        # recover the pi
        pi = np.exp(normalized_log_pi)
        
        normalized_log_pi0 = np.log(np.matmul(pi,trans_matrix))


        # NxT matrix to record states
        states = np.zeros((Ntrials, T), dtype=int)
        for n in range(Ntrials):
            # Draw the initial state from the initial distribution
            # K states, pi initial distribution. => return a scaler with initial values (discrete)
            states[n,0] = np.random.choice(st, p=pi)
            # Simulate the chain
            for t in range(1, T):
                # The transition probabilities depend on the current state
                current_state = states[n,t-1]
                states[n,t] = np.random.choice(st, p=trans_matrix[current_state+num_ns*K])

        xs = states / (K-1)
        
        # in each trial set x to 1 after 1st passage through 1; padding xs w 1 assures passage does happen, possibly at T+1
#         taus = np.argmax(np.hstack((xs, np.ones((xs.shape[0],1)))) >= 1., axis=-1)
#         xs = np.where(ts[None,:] >= taus[:,None], 1., xs)

        rates = self.f_io(xs) # shape = (Ntrials, T)

        spikes = np.array([self.emit(rate) for rate in rates]) # shape = (Ntrial, T)

        if get_rate:
            return spikes, xs, rates, normalized_log_trans_matrix, normalized_log_pi0, states
        else:
            return spikes, xs, normalized_log_trans_matrix, normalized_log_pi0, states
        


# models2.py

In [None]:
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
from models import *
from inference import *
import time
import warnings
warnings.filterwarnings("ignore", "divide by zero encountered in log")

# def MLR_classifier(data_points):
# #     :param data_points: 2*M data points, M data points for each model, where each data point is a (N by T) matrix
    
    
#     M=26
#     K=100
#     logMLR = MLR_calculator(counts_matrix, M = M, K = K, Print = False)
    
#     if


def MLR_classifier(data_points):
    """
    Classify spike trains as being generated by the step model (return 0) or the ramp model (return 1).
    :param data_points: 2*M data points, M data points for each model, where each data point is a (N by T) matrix
    :param m: mean jump time (in # of time-steps) for StepModel
    :param r: parameter r ("# of successes") of the Negative Binomial (NB) distribution of jump (stepping) time for StepModel
    :param sigma: diffusion strength of the drift-diffusion process for RampModel
    :param beta: drift rate of the drift-diffusion process for RampModel
    :param threshold: threshold for variance
    :return: M predictions, each being 0 (step model) or 1 (ramp model)
    """
    predictions = np.empty((data_points.shape[0], data_points.shape[1])) # 2 x M
    logMLRs = np.empty((data_points.shape[0], data_points.shape[1])) # 2 x M
    logMLR=0
    
    for ii in [0,1]:
        # ii = 0 -> STEP spike trains
        # ii = 1 -> RAMP spike trains
        for jj in range(data_points.shape[1]): 
            spike_trains = data_points[ii, jj]; # (N by T) spike train matrix
            # Calculate the PSTH
            counts_matrix= generate_psth(spike_trains, return_counts=True)
            logMLR, _, _ = MLR_calculator(counts_matrix, M = 10, K = 10, Print=False)
            logMLRs[ii,jj] = logMLR
    predictions = np.where(logMLRs > 0, 1, 0)

    return predictions, logMLRs


def MLR_calculator(counts_matrix, M = 26, K = 100, Print = False):
# read counts_matrix (N x T） and return MLR
    start_time = time.time()

    N = counts_matrix.shape[0]
    ## Setup ##

    values_r = np.linspace(1, 10, M)
    values_m = np.linspace(0, 100, M)
    values_logs = np.linspace(np.log(0.04), np.log(4), M) # -3.22 - 1.386
    values_b = np.linspace(0, 4, M)
    values_x0 = np.linspace(0, 1, M)

    T=100
    time_points = np.linspace(1,T,T) # 0,1,2,...
    dt = 1/T
    time_ms = time_points * dt * 1e3

    Rh = 50
    bin_size = 20
    bin_size_2 = 50
    bin_edges = np.arange(0, 1e3+bin_size, bin_size)
    st = np.arange(K) # states
    xt = st/(K-1)

    ## inference marginal ll for both model ##

    # Ramp MLL
    log_prior = np.log(1/M**3)
    model_ll = np.zeros((M,M,M))

    for b_idx in range(M):
        for s_idx in range(M):
            for x_idx in range(M):

                b = values_b[b_idx]
                logs = values_logs[s_idx]
                x0 = values_x0[x_idx]

                ramp = RampModel(beta=b, sigma=np.exp(logs), x0=x0, Rh=50)
                [_, _, _, normalized_log_trans_matrix, normalized_log_pi0] = ramp.simulate_HMM(Ntrials=0, 
                                                                                               T=T, 
                                                                                               K=K, 
                                                                                               get_rate=True)
                normalized_log_trans_matrix = normalized_log_trans_matrix[np.newaxis, :, :] # (1,K,K) for homogeneous MC

                lls = poisson_logpdf(counts=counts_matrix, lambdas= xt*Rh*dt, mask=None) # N x T x K 

                # Model log-likelihood for N trials:
                for n in range(N):
                    model_ll[b_idx, s_idx, x_idx] += hmm_normalizer(log_pi0 = normalized_log_pi0, 
                                                                    log_Ps = normalized_log_trans_matrix, 
                                                                    ll = lls[n])

    # convert all nan to -inf
    model_ll = np.nan_to_num(model_ll, nan=-np.inf)

    # Model log-posterior
    unnormalised_log_poste = model_ll + log_prior
    # print(unnormalised_log_poste)

    ramp_MLL = logsumexp_scipy(unnormalised_log_poste)
    # log_poste = unnormalised_log_poste - ramp_MLL
    # poste = np.exp(log_poste)+1e-16

    if Print:
        end_time = time.time()
        elapsed_time = end_time - start_time
        print("Ramp marginal log-likelihood inferred：", elapsed_time, "s")     
    
    start_time = time.time()
    
    # Step MLL
    log_prior = np.log(1/M**3)
    model_ll = np.zeros((M,M,M))
    for m_idx in range(M):
        for r_idx in range(M):
            for x_idx in range(M):
                m = values_m[m_idx]
                r = values_r[r_idx]
                x0 = values_x0[x_idx]

                step = StepModel(r=r, m=m, x0=x0, Rh=Rh)
                [_, _, _, normalized_log_trans_matrix, normalized_log_pi0] = step.simulate_HMM_inhomo(Ntrials=0, 
                                                                                                      T=T, 
                                                                                                      get_rate=True)

                rt = np.array([x0, 1]) * Rh * dt
                lls = poisson_logpdf(counts=counts_matrix,lambdas= rt, mask=None) # N x T x K 

                if np.isnan(lls).any():
                    print(lls)

                # Model log-likelihood for N trials:
                for n in range(N):
                    model_ll[m_idx, r_idx, x_idx] += hmm_normalizer(log_pi0 = normalized_log_pi0, 
                                                                    log_Ps = normalized_log_trans_matrix, 
                                                                    ll = lls[n])
                if np.isnan(model_ll[m_idx, r_idx, x_idx]):
                    print(normalized_log_trans_matrix)
                    print(lls)

    # convert all nan to -inf
    model_ll = np.nan_to_num(model_ll, nan=-np.inf)

    # Model log-posterior
    unnormalised_log_poste = model_ll + log_prior
    # print(unnormalised_log_poste)

    step_MLL = logsumexp_scipy(unnormalised_log_poste)
    # log_poste = unnormalised_log_poste - step_MLL
    # poste = np.exp(log_poste) + 1e-16
    if Print:
        end_time = time.time()
        elapsed_time = end_time - start_time
        print("Step marginal log-likelihood inferred：", elapsed_time, "s") 


    logMLR = ramp_MLL - step_MLL
    return logMLR, ramp_MLL, step_MLL


def classifier_tester(M=100, N=400, T=100, classifier="mlr", thresholds=None):
    start_time = time.time()

    data_points = generate_test_spike_trains(M=M, N=N, T=T, model="original")
    # r, b, s are controled by exponent
    # m, x are generated linearily
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"2x{M} datasets generated (NxT = {N}x{T})：", elapsed_time, "s")     
    start_time = time.time()


    if classifier == "var":
        predictions, _, _= var_classifier(data_points, thresholds)
    if classifier == "mlr":
        predictions, _ = MLR_classifier(data_points)
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Pridictions generated：", elapsed_time, "s")     

    accuracy = (np.sum(1-predictions[0]) + np.sum(predictions[1]))/(2*M)
    return accuracy

def generate_test_spike_trains(num_grid=10, M=20, N=400, T=100, rmin=1, rmax=100, bmin=0, bmax=4, logsmin=0.04, logsmax=4, mmin=0, mmax=100, xmin=0, xmax=1, model="original", GammaShape=None):
    """
    Generate M data points for both ramp model and step model
    :param M: number of data points for each model
    :param N: number of trls per data point
    :param T: duration of each trial in number of time-steps
    :param m: mean jump time (in # of time-steps) for StepModel
    :param r: parameter r ("# of successes") of the Negative Binomial (NB) distribution of jump (stepping) time for StepModel
    :param sigma: diffusion strength of the drift-diffusion process for RampModel
    :param beta: drift rate of the drift-diffusion process for RampModel
    :return: A matrix with dim (2, M, N, T)
    """
    
       
    # Generate M data points
    data_points = np.empty((2, M, N, T))  # for an n x m array
    for MM in range(M):
  # Initialize random model parameters
        m = npr.uniform(mmin, mmax)
        r = npr.uniform(rmin, rmax)
        b = npr.uniform(bmin,bmax)
        s = np.exp(npr.uniform(logsmin,logsmax))
        xr = npr.uniform(xmin,xmax)
        xs = npr.uniform(xmin,xmax)
        
        
        #initialise models
        step_model = StepModel(m=m, r=r, x0=xs, Rh=50);
        ramp_model = RampModel(beta=b, sigma=s, x0=xr, Rh=50);

        # Generate spike trains
        if model == "original":
            step_spikes, _, _ = step_model.simulate(Ntrials=N, T=T, GammaShape = GammaShape)
            ramp_spikes, _, _ = ramp_model.simulate(Ntrials=N, T=T, GammaShape = GammaShape)
        elif model == "hmm":
            # To be added here
            step_spikes, _, _,_ ,_= step_model.simulate_HMM_inhomo(Ntrials=N, T=T, GammaShape = GammaShape)
            ramp_spikes, _, _,_,_ = ramp_model.simulate_HMM(Ntrials=N, T=T, K=100, GammaShape = GammaShape)

        # Add spike trains to data points
        data_points[0,MM]=step_spikes
        data_points[1,MM]=ramp_spikes

  
    # Convert data_points to integer type to save memory
    data_points = data_points.astype(int)
    return data_points

def generate_spike_trains(M=20, N=400, T=100, m=50, r=10, sigma=0.2, beta=0.5):
    """
    Generate M data points for both ramp model and step model
    :param M: number of data points for each model
    :param N: number of trls per data point
    :param T: duration of each trial in number of time-steps
    :param m: mean jump time (in # of time-steps) for StepModel
    :param r: parameter r ("# of successes") of the Negative Binomial (NB) distribution of jump (stepping) time for StepModel
    :param sigma: diffusion strength of the drift-diffusion process for RampModel
    :param beta: drift rate of the drift-diffusion process for RampModel
    :return: A matrix with dim (2, M, N, T)
    """
    # Initialize models
    step_model = StepModel(m=m, r=r, x0=0.2, Rh=50);
    ramp_model = RampModel(beta=beta, sigma=sigma);

    # Generate M data points
    data_points = np.empty((2, M, N, T))  # for an n x m array
    for MM in range(M):
        # Generate spike trains
        step_spikes, _, _ = step_model.simulate(Ntrials=N, T=T)
        ramp_spikes, _, _ = ramp_model.simulate(Ntrials=N, T=T)

        # Add spike trains to data points
        data_points[0,MM]=step_spikes
        data_points[1,MM]=ramp_spikes


    # Convert data_points to integer type to save memory
    data_points = data_points.astype(int)

    return data_points


def generate_raster_and_timestamps(spike_trains, plot=False):
    """
    Generate a raster plot and timestamps of the given spike trains.
    :param spike_trains: spike trains to plot (N by T matrix)
    :param plot: whether to plot the raster
    :return: spike trains timestamps
    """
    
    
    T = len(spike_trains[0])
    # Record time of spikes in milliseconds
    spike_trains_timestamp = []
    for spike_train in spike_trains:  # for each trial
        timestamp = []
#         print(spike_train)
        for ii in range(len(spike_train)):  # for each time point
#             print(spike_train[ii])
            for jj in range(spike_train[ii]):  # handle multiple spikes in a time stamp
                timestamp.append(ii*1e3/T)
        spike_trains_timestamp.append(timestamp)

    if plot:
        fig, ax = plt.subplots()
        fig.suptitle("Spike Raster Plot")
        colors = ['C{}'.format(i) for i in range(len(spike_trains))]  # different color for each set of neural data
        ax.eventplot(spike_trains_timestamp, colors=colors, linelengths=0.2)
        ax.yaxis.set_tick_params(labelleft=False)
        ax.set_xlabel("time from motion onset (ms)")
        ax.set_ylabel("spike trains")
        plt.show()

    return spike_trains_timestamp


def generate_psth(spike_trains, bin_size=20, bin_size_2=50, plot=False, return_counts=False):
    """
    Generate a Peri-Stimulus Time Histogram (PSTH) from given timestamps.
    :param spike_trains: spike trains to plot (N by T matrix)
    :param bin_size: bin size for the PSTH (in milliseconds)
    :param bin_size2: a larger size to calculate the variance of psth (in milliseconds)
    :param plot: whether to plot the PSTH
    :return: averaged PSTH, smoothed_psth, variance, Fano factor
    """
    
    N = spike_trains.shape[0]; # number of trials
    T = spike_trains.shape[1];
    #print(N)
    spike_trains_timestamp = generate_raster_and_timestamps(spike_trains); # timestamps of spike trains
    
    
    if return_counts == True:
        counts_matrix = np.zeros((N, T)); # (N x T）
        # Calculate the PSTH for each trail
        for ii in range(len(spike_trains_timestamp)):
            bin_edges_for_counts = np.arange(0, 1e3 + 1000/T, 1000/T)
            counts_matrix[ii], _ = np.histogram(np.array(spike_trains_timestamp[ii]), bins=bin_edges_for_counts)
        counts_matrix = counts_matrix.astype(int)
        return counts_matrix
    
    
    # Calculate the PSTH
    bin_edges = np.arange(0, 1e3 + bin_size, bin_size)
    psth, _ = np.histogram(np.concatenate(spike_trains_timestamp), bins=bin_edges)

    averaged_psth = (psth / bin_size * 1e3) / N # spikes per sec per trail

    # Apply Gaussian smoothing
    sigma = 1.5  # Standard deviation of the Gaussian filter
    gaussian_smoothed_psth = gaussian_filter(averaged_psth, sigma)

    # Calculate the PSTH for larger bins
    bin_edges_2 = np.arange(0, 1e3+bin_size_2, bin_size_2)
    psth_2, _ = np.histogram(np.concatenate(spike_trains_timestamp), bins=bin_edges_2)
    averaged_psth_2 = psth_2 / N # spikes per trail

    var_s = np.zeros_like(averaged_psth_2)

    # psth_matrix is a 2D numpy array where each row is a PSTH vector
    psth_matrix = np.zeros((N, len(averaged_psth_2))); # (N x time_bins)

    # Calculate the PSTH for each trail
    for ii in range(len(spike_trains_timestamp)):
        psth_matrix[ii], _ = np.histogram(np.array(spike_trains_timestamp[ii]), bins=bin_edges_2)
    
    
#     print(psth_matrix.shape)
    # Find the variance across trials (i.e., along the rows)
    var_s = np.var(psth_matrix, axis=0);

    ## Calculate Fano Factor ##

    fano_factors = var_s / averaged_psth_2


    if plot:
        fig, (ax1, ax2, ax3) = plt.subplots(3)
        fig.suptitle("PSTH Diagram")

        # Plot the PSTH
        ax1.plot(bin_edges[:-1], averaged_psth,  label='Original')
        ax1.plot(bin_edges[:-1], gaussian_smoothed_psth,  label='Smoothed')
        ax2.plot(bin_edges_2[:-1], var_s,  label='Variance')
        ax3.plot(bin_edges_2[:-1], fano_factors,  label='Fano factor')

        ax1.set_ylabel("spike rate (sp/s)")
        ax2.set_ylabel("Variance")
        ax3.set_ylabel("Fano factor")
        ax3.set_xlabel("time from motion onset (ms)")
        ax1.legend()
        ax2.legend()
        plt.show()
        
    return averaged_psth, gaussian_smoothed_psth, var_s, fano_factors




def var_classifier(data_points, thresholds):
    """
    Classify spike trains as being generated by the step model (return 0) or the ramp model (return 1).
    :param data_points: 2*M data points, M data points for each model, where each data point is a (N by T) matrix
    :param m: mean jump time (in # of time-steps) for StepModel
    :param r: parameter r ("# of successes") of the Negative Binomial (NB) distribution of jump (stepping) time for StepModel
    :param sigma: diffusion strength of the drift-diffusion process for RampModel
    :param beta: drift rate of the drift-diffusion process for RampModel
    :param threshold: threshold for variance
    :return: M predictions, each being 0 (step model) or 1 (ramp model)
    """
    predictions = np.empty((data_points.shape[0], data_points.shape[1])) # 2 x M
    var_s = np.empty((data_points.shape[0], data_points.shape[1])) # 2 x M
    fano_factors = np.empty((data_points.shape[0], data_points.shape[1])) # 2 x M
    
    for ii in [0,1]:
        # ii = 0 -> STEP spike trains
        # ii = 1 -> RAMP spike trains
        for jj in range(data_points[ii].shape[0]): 
            spike_trains = data_points[ii, jj]; # (N by T) spike train matrix
            # Calculate the PSTH
            _,psth,_,_ = generate_psth(spike_trains, bin_size=20, bin_size_2=50)

            #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
            # Scale the PSTH 
    #         psth_step_scaled = psth_step * 2 * m / len(psth_step)
    #         psth_ramp_scaled = psth_ramp * 2 * m / len(psth_ramp)

            # Find the gradient of the PSTH
            grad_psth = np.gradient(psth)

            # Find the variance and the Fano factor of the gradient
            var = np.var(grad_psth)
            fano_factor = var / np.mean(grad_psth)
            
            # Print the variance and the Fano factor
            # print(f"variance = {var}, Fano factor = {fano_factor}")
            var_s[ii,jj] = var
            fano_factors[ii,jj] = fano_factor

            # Classify the spike trains based on the variance 
            if var > thresholds:
                predictions[ii, jj] = 0  # step model
            else:
                predictions[ii, jj] = 1  # ramp model

#             if var_ramp > threshold:
#                 predictions.append(0)  # step model
#             else:
#                 predictions.append(1)  # ramp model

    return predictions, var_s, fano_factors

def higher_order_classifier(data_points, thresholds):
    """
    Classify spike trains as being generated by the step model (return 0) or the ramp model (return 1).
    :param data_points: 2*M data points, M data points for each model, where each data point is a (N by T) matrix
    :param m: mean jump time (in # of time-steps) for StepModel
    :param r: parameter r ("# of successes") of the Negative Binomial (NB) distribution of jump (stepping) time for StepModel
    :param sigma: diffusion strength of the drift-diffusion process for RampModel
    :param beta: drift rate of the drift-diffusion process for RampModel
    :param threshold: threshold for variance
    :return: M predictions, each being 0 (step model) or 1 (ramp model)
    """
    predictions = np.empty((data_points.shape[0], data_points.shape[1])) # 2 x M
    var_s = np.empty((data_points.shape[0], data_points.shape[1])) # 2 x M
    fano_factors = np.empty((data_points.shape[0], data_points.shape[1])) # 2 x M
    
    for ii in [0,1]:
        # ii = 0 -> STEP spike trains
        # ii = 1 -> RAMP spike trains
        for jj in range(data_points[ii].shape[0]): 
            spike_trains = data_points[ii, jj]; # (N by T) spike train matrix
            # Calculate the PSTH
            _,psth,_,_ = generate_psth(spike_trains, bin_size=20, bin_size_2=50)
            
            
            ### find the vilid region
            
            
            
            ### Gradient and average gradient

            # Find the gradient of the PSTH
            grad_psth = np.gradient(psth)
            average_grad = (psth[-1] - psth[1]) / len(psth)
            # Find the variance and the Fano factor of the gradient
            var = np.var(grad_psth)
            fano_factor = var / np.mean(grad_psth)
            
            # Print the variance and the Fano factor
            # print(f"variance = {var}, Fano factor = {fano_factor}")
            var_s[ii,jj] = var
            fano_factors[ii,jj] = fano_factor

            # Classify the spike trains based on the variance 
            if var > thresholds:
                predictions[ii, jj] = 0  # step model
            else:
                predictions[ii, jj] = 1  # ramp model

#             if var_ramp > threshold:
#                 predictions.append(0)  # step model
#             else:
#                 predictions.append(1)  # ramp model

    return predictions, var_s, fano_factors


       

# def normalised_var_classifier(data_points, m, r, sigma, beta, threshold):
#     """
#     Classify spike trains as being generated by the step model (return 0) or the ramp model (return 1).
#     :param data_points: 2*M data points, M data points for each model, where each data point is a (N by T) matrix
#     :param m: mean jump time (in # of time-steps) for StepModel
#     :param r: parameter r ("# of successes") of the Negative Binomial (NB) distribution of jump (stepping) time for StepModel
#     :param sigma: diffusion strength of the drift-diffusion process for RampModel
#     :param beta: drift rate of the drift-diffusion process for RampModel
#     :param threshold: threshold for variance
#     :return: M predictions, each being 0 (step model) or 1 (ramp model)
#     """
#     predictions = np.empty((data_points.shape[0], data_points.shape[1])) # 2 x M
#     var_s = np.empty((data_points.shape[0], data_points.shape[1])) # 2 x M
#     fano_factors = np.empty((data_points.shape[0], data_points.shape[1])) # 2 x M
    
#     for ii in [0,1]:
#         # ii = 0 -> STEP spike trains
#         # ii = 1 -> RAMP spike trains
#         for jj in range(data_points[ii].shape[0]): 
#             spike_trains = data_points[ii, jj]; # (N by T) spike train matrix
#             # Calculate the PSTH
#             psth,_,_ = generate_psth(spike_trains, bin_size=20, bin_size_2=50)

#             #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
#             # Scale the PSTH 
#     #         psth_step_scaled = psth_step * 2 * m / len(psth_step)
#     #         psth_ramp_scaled = psth_ramp * 2 * m / len(psth_ramp)

#             # Find the gradient of the PSTH
#             grad_psth = np.gradient(psth)
            
#             # Normalize the gradient so that the area under it is equal to 1
#             grad_psth_normalized = grad_psth / np.sum(grad_psth)

#             # Find the variance and the Fano factor of the gradient
#             var = np.var(grad_psth_normalized)
#             fano_factor = var / np.mean(grad_psth_normalized)
            
            
#             # Print the variance and the Fano factor
#             # print(f"variance = {var}, Fano factor = {fano_factor}")
#             var_s[ii,jj] = var
#             fano_factors[ii,jj] = fano_factor
#             #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
#             # Classify the spike trains based on the variance 
# #             if var > threshold:
# #                 predictions[ii,].append(0)  # step model
# #             else:
# #                 predictions[ii].append(1)  # ramp model

# #             if var_ramp > threshold:
# #                 predictions.append(0)  # step model
# #             else:
# #                 predictions.append(1)  # ramp model

#     return predictions, var_s, fano_factors
