In [2]:
from typing import Dict, List, Union

import numpy as np
import sdeint
import torch

from sbi.utils.torchutils import atleast_2d_float32_tensor

In [4]:
input_dict = dict(starting_point=0.5, boundary_separation=2.0, drift=0.5, )

In [6]:
def run(input_dict): 
    rt = 0
    # set the diffusion to 1
    diffusion = 1
    # define time interval and number of intermediate steps
    tmax = 4.0
    n_steps = 4000
    tspan = np.linspace(0.0, tmax, n_steps)

    # compute the absolute starting point by multiplying the relative starting point
    # with the boundary separation
    starting_point = (
        input_dict["starting_point"] * input_dict["boundary_separation"]
    )

    # f describes the deterministic part of the SDE
    def f(x, t):
        return input_dict["drift"]

    # g describes the noise/diffusion
    def g(x, t):
        return diffusion

    # use a Runge-Kutte solver to compute the trajectory of the decision variable
    traj = sdeint.itoSRI2(f, g, starting_point, tspan).flatten()

    # check which boundary has been crossed (first)
    lower_bound = 0
    upper_bound = input_dict["boundary_separation"]
    pass_lower_bound = np.where(traj < lower_bound)[0]
    pass_upper_bound = np.where(traj > upper_bound)[0]

    if pass_lower_bound.size > 0 and pass_upper_bound.size > 0:
        if pass_lower_bound[0] < pass_upper_bound[0]:
            rt, choice = tspan[pass_lower_bound[0]], 0
        else:
            rt, choice = tspan[pass_upper_bound[0]], 1
    elif pass_lower_bound.size > 0:
        rt, choice = tspan[pass_lower_bound[0]], 0
    elif pass_upper_bound.size > 0:
        rt, choice = tspan[pass_upper_bound[0]], 1
    # if no boundary has been crossed, return nan
    else:
        rt, choice = torch.nan, torch.nan

    return torch.cat(
        (atleast_2d_float32_tensor(rt), atleast_2d_float32_tensor(choice)), dim=1
    )

In [8]:
%timeit run(input_dict)

95.9 ms ± 211 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [17]:
ms_per_sim = 100  # ms
s_per_sim = ms_per_sim / 1000
hour_per_sim = s_per_sim / 60 / 60

print(hour_per_sim * 10**5)
print(hour_per_sim * 10**11)

2.777777777777778
2777777.777777778
