In [1]:
from __future__ import annotations
from typing import Callable, List, Tuple

import aesara
import aesara.tensor as at
import arviz as az
import numpy as np
import pymc as pm
import pandas as pd

from aesara.tensor.random.op import RandomVariable
from pymc.distributions.continuous import PositiveContinuous

## Disambiguation and naming conventions

There is some confusion regarding the input of each function. WFPT itself is a positive distribution, but input data might be negative due to negative decisions. In order for this code to be clear about the values of the input, it is a good idea to define the names of variables and specify what values they can take.

* `data`: RTs that can take (-inf, inf), without non-decision time removed.
* `rt`: RTs that can take (0, inf), with non-decision time removed.
* `tt`: RTs that with non-decision time removed (rt - t) and normalized

* `w`: normalized version of v.

## Lessons learned:

* Partitioning data is the fastest way to compute the PDF, but it's not differentiable
* Using at.where() is slightly faster than using boolean computation.

In [2]:
def decision_func() -> Callable:
    """Partitions the data according to the result of the decision function.

    Args:
        w: Normalized RTs with negative values for negative decisions
        decision: The output of the decision function. A boolean 1D array with True
            indicating the fast expansion should be used and "False" otherwise.

    Returns: A tuple with data that should be used with the fast and slow expansions.

    """
    internal_tt = None
    internal_err = None
    internal_result = None

    def inner_func(tt: np.ndarray, err: float = 1e-7) -> np.ndarray:
        """For each element in `x`, return `True` if the large-time expansion is
        more efficient than the small-time expansion.

        Args:
            tt: An 1D numpy array of normalized RTs. (0, inf).
            err: Error bound

        Returns: a 1D boolean array of which implementation should be used.

        """
        nonlocal internal_tt
        nonlocal internal_err
        nonlocal internal_result

        if np.all(tt == internal_tt) and err == internal_err:
            return internal_result

        internal_tt = tt
        internal_err = err

        # determine number of terms needed for small-t expansion
        ks = 2 + at.sqrt(-2 * tt * at.log(2 * np.sqrt(2 * np.pi * tt) * err))
        ks = at.max(at.stack([ks, at.sqrt(tt) + 1]), axis=0)
        ks = at.switch(2 * at.sqrt(2 * np.pi * tt) * err < 1, ks, 2)

        # determine number of terms needed for large-t expansion
        kl = at.sqrt(-2 * at.log(np.pi * tt * err) / (np.pi**2 * tt))
        kl = at.max(at.stack([kl, 1.0 / (np.pi * at.sqrt(tt))]), axis=0)
        kl = at.switch(np.pi * tt * err < 1, kl, 1.0 / (np.pi * at.sqrt(tt)))

        lambda_tt = ks < kl

        internal_result = lambda_tt

        return lambda_tt

    return inner_func


decision = decision_func()

In [3]:
def ftt01w_fast(tt_fast: np.ndarray, w: float, k_terms: int):
    """Density function for lower-bound first-passage times with drift rate set to 0 and
    upper bound set to 1, calculated using the fast-RT expansion.

    Args:
        tt_fast: RTs. (0, inf).
        w: Normalized decision starting point. (0, 1).
        k_terms: number of terms to use to approximate the PDF

    Returns:
        The approximated function f(0, 1, w)
    """

    k = at.arange(-at.floor((k_terms - 1) / 2), at.ceil((k_terms - 1) / 2) + 1)
    y = w + 2 * k.reshape((-1, 1))
    r = -at.power(y, 2) / 2 / tt_fast
    c = at.max(r, axis=0)
    p = at.exp(c + at.log(at.sum(y * at.exp(r - c), axis=0)))
    p = p / at.sqrt(2 * np.pi * at.power(tt_fast, 3))

    return p


def ftt01w_slow(tt_slow: np.ndarray, w: float, k_terms: int) -> np.ndarray:
    """Density function for lower-bound first-passage times with drift rate set to 0 and
    upper bound set to 1, calculated using the slow-RT expansion.

    Args:
        tt_slow: RTs. (0, inf).
        w: Normalized decision starting point. (0, 1).
        k_terms: number of terms to use to approximate the PDF

    Returns:
        The approximated function f(0, 1, w)
    """

    k = at.arange(1, k_terms + 1).reshape((-1, 1))
    y = k * at.sin(k * np.pi * w)
    r = -at.power(k, 2) * at.power(np.pi, 2) * tt_slow / 2
    p = at.sum(y * at.exp(r), axis=0) * np.pi

    return p


def ftt01w(
    rt: np.ndarray,
    a: float,
    w: float,
    err: float = 1e-7,
    k_terms: int = 10,
) -> np.ndarray:
    """Compute the likelihood of the drift diffusion model f(t|v,a,z) using the method
    and implementation of Navarro & Fuss, 2009.

    Args:
        data: RTs. (-inf, inf) except 0. Negative values correspond to the lower bound.
        v: Mean drift rate. (-inf, inf).
        a: Value of decision upper bound. (0, inf).
        z: Normalized decision starting point. (0, 1).
        err: Error bound.
    """
    lambda_tt = decision(rt, err)
    tt = rt / a**2

    p_fast = ftt01w_fast(tt, w, k_terms)
    p_slow = ftt01w_slow(tt, w, k_terms)

    return at.switch(lambda_tt, p_fast, p_slow)
    # return lambda_tt * p_fast + (1 - lambda_tt) * p_slow


def log_pdf_sv(
    data: np.ndarray,
    v: float,
    sv: float,
    a: float,
    z: float,
    t: float,
    err: float = 1e-7,
    k_terms: int = 10,
) -> np.ndarray:
    """Compute the log-likelihood of the drift diffusion model f(t|v,a,z) using the method
    and implementation of Navarro & Fuss, 2009.

    Args:
        x: RTs. (-inf, inf) except 0. Negative values correspond to the lower bound.
        v: Mean drift rate. (-inf, inf).
        a: Value of decision upper bound. (0, inf).
        z: Normalized decision starting point. (0, 1).
        err: Error bound.
    """
    # First, convert data to positive
    flip = data > 0
    v = flip * -v + (1 - flip) * v  # transform v if x is upper-bound response
    z = flip * (1 - z) + (1 - flip) * z  # transform z if x is upper-bound response
    rt = np.abs(data)  # absolute rts
    rt = rt - t  # remove nondecision time

    p = ftt01w(rt, a, z, err, k_terms)

    p = at.sum(
        at.log(p)
        + ((a * z * sv) ** 2 - 2 * a * v * z - (v**2) * rt) / (2 * (sv**2) * rt + 2)
        - at.log(sv**2 * rt + 1) / 2
        - 2 * at.log(a)
    )

    return p

In [4]:
cavanaugh_data = pd.read_csv("cavanagh_theta_nn.txt")

In [5]:
%%time
v = 1
sv = 0
a = 0.8
z = 0.5
t = 0.0

log_pdf_sv(
    cavanaugh_data.rt.values, v=v, sv=sv, a=a, z=z, t=0.0, err=1e-7, k_terms=15
).eval()

CPU times: user 258 ms, sys: 40.8 ms, total: 299 ms
Wall time: 298 ms


array(-37978.16611222)

In [7]:
%%timeit
log_pdf_sv(cavanaugh_data.rt.values, v=v, sv=sv, a=a, z=z, t=0.0, err=1e-7).eval()

118 ms ± 1.98 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
class WFPTRandomVariable(RandomVariable):
    """WFPT random variable"""

    name: str = "WFPT_RV"
    ndim_supp: int = 0
    ndims_params: List[int] = [0] * 10
    dtype: str = "floatX"
    _print_name: Tuple[str, str] = ("WFPT", "WFPT")

    @classmethod
    def rng_fn(
        cls, rng: np.random.RandomState, v, sv, a, z, sz, t, st, q, l, r, size
    ) -> np.ndarray:
        return NotImplementedError("Not Implemented")

In [9]:
class WFPT(PositiveContinuous):
    """Wiener first-passage time (WFPT) log-likelihood."""

    rv_op = WFPTRandomVariable()

    @classmethod
    def dist(cls, v, sv, a, z, t, **kwargs):
        v = at.as_tensor_variable(pm.floatX(v))
        sv = at.as_tensor_variable(pm.floatX(sv))
        a = at.as_tensor_variable(pm.floatX(a))
        z = at.as_tensor_variable(pm.floatX(z))
        t = at.as_tensor_variable(pm.floatX(t))
        return super().dist([v, sv, a, z, t], **kwargs)

    def logp(data, v, sv, a, z, t, err=1e-7, k_terms=10, **kwargs):

        return log_pdf_sv(data, v, sv, a, z, t, err, k_terms)

In [11]:
with pm.Model():

    sv = 0
    a = 0.8
    z = 0.5
    t = 0.0

    v = pm.Normal(name="v")
    WFPT(name="x", v=v, sv=sv, a=a, z=z, t=t, observed=cavanaugh_data.rt.values)
    results = pm.sample(1000, return_inferencedata=True)
    print(results)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [v]


Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 9 seconds.


Inference data with groups:
	> posterior
	> log_likelihood
	> sample_stats
	> observed_data


<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=309a84a8-4378-4e4a-ad40-54daa1f48cb7' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>

In [None]:
results.log_likelihood.x

In [None]:
az.plot_posterior(results)