In [15]:
import math
import time
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm
from mpmath import mp, jtheta

mp.dps = 5; mp.pretty = True
np_jtheta = np.frompyfunc(jtheta, nin=3, nout=1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [16]:
a, b = -2, 2

def conditional_probability_to_survive(t, x, y, k_arr=range(-5, 5)):
    ans = 0
    for k in k_arr:
        ans += (
                torch.exp(((y - x) ** 2 - (y - x + 2 * k * (b - a)) ** 2) / (2 * t))
                - torch.exp(((y - x) ** 2 - (y + x - 2 * a + 2 * k * (b - a)) ** 2) / (2 * t))
        )
    return ans

def conditional_probability_to_survive_2(t, x, y):
    x, y = x.cpu().detach().numpy(), y.cpu().detach().numpy()
    ans_in_np = (
        np.sqrt(2*math.pi*t)
        * np.exp((y - x)**2 / (2*t))
        * (
            np_jtheta(3, math.pi * (x-y) / (2*(a-b)), np.exp(-math.pi**2*t / (2*(b-a)**2)))
            - np_jtheta(3, math.pi * (y+x-2*a)/(2*(b-a)), np.exp(-math.pi**2*t / (2*(b-a)**2)))
        )
        /(2*(b-a))
    )
    return torch.tensor(np.array(ans_in_np, dtype=float), dtype=torch.get_default_dtype(), device=device)

In [20]:
sample_size = 100000
t = 0.3
lo, hi = a, b

x = lo + (hi - lo) * torch.rand(sample_size, device=device)
y = lo + (hi - lo) * torch.rand(sample_size, device=device)
# approximation
start = time.time()
approx = conditional_probability_to_survive(t, x, y)
approx_time = time.time() - start
# exact
start = time.time()
exact = conditional_probability_to_survive_2(t, x, y)
exact_time = time.time() - start
# print results
diff = (approx - exact).abs().max().item()
idx = (approx - exact).abs().argmax().item()

print(f"Approx takes {approx_time} seconds; Exact takes {exact_time} seconds.")
print(f"Max diff is {diff} at index {idx}.")
print(f"Approx[{idx}] = {approx[idx]}; Exact[{idx}] = {exact[idx]}.")
print(f"x[{idx}] = {x[idx]}; y[{idx}] = {y[idx]}.")

Approx takes 0.0013048648834228516 seconds; Exact takes 7.748569011688232 seconds.
Max diff is 0.42206886410713196 at index 55802.
Approx[55802] = 0.08923652768135071; Exact[55802] = -0.33283233642578125.
x[55802] = 1.9861419200897217; y[55802] = -1.986928939819336.
