# Comparing methods to simulate radial Ornstein-Uhlenbeck bridge 
This is a specific case of an interest rates model proposed by Aït-Sahalia and Lo (Journal of Finance, 1998)

In [1]:
import DiffusionBridge as db
import torch
import matplotlib.pyplot as plt
from scipy.special import iv, ivp
from torch.distributions.gamma import Gamma
from DiffusionBridge.auxiliary import AuxiliaryDiffusion
plt.style.use("ggplot")

In [2]:
# specify problem settings
interval = 1
M = 50
path_dir = "."
folder = "results_radial"

In [3]:
# problem settings
d = 1
T = torch.tensor(float(interval))
theta = torch.tensor(4.0)
f = lambda t,x: theta / x - x 
sigma = torch.tensor(1.0)
diffusion = db.diffusion.model(f, sigma, d, T, M)

# transition density
log_transition_density = lambda t,x,x0: theta * torch.log(x / x0) + 0.5 * torch.log(x * x0) - x**2 \
                                    + (theta + 0.5) * t - torch.log(torch.sinh(t)) \
                                    - (x**2 + x0**2) / (torch.exp(2.0 * t) - 1.0) \
                                    + torch.log(iv(theta - 0.5, x * x0 / torch.sinh(t)))

score_transition = lambda t,x,x0: theta / x + 1.0 / (2.0 * x) - 2.0 * x - 2.0 * x / (torch.exp(2.0 * t) - 1.0) \
                                 + (1.0 / iv(theta - 0.5, x * x0 / torch.sinh(t))) * ivp(theta - 0.5, x * x0 / torch.sinh(t)) * x0 / torch.sinh(t)

# h-transform
grad_logh = lambda t,x,xT: -theta / x + 1.0 / (2.0 * x) - 2.0 * x / (torch.exp(2.0 * (T-t)) - 1.0) \
                        + (1.0 / iv(theta - 0.5, xT * x / torch.sinh(T-t))) * ivp(theta - 0.5, xT * x / torch.sinh(T-t)) * xT / torch.sinh(T-t)

# score marginal 
score_marginal = lambda t,x,x0,xT: score_transition(t, x, x0) + grad_logh(t, x, xT)

# sample size
N = 2**10

# repetitions
R = 100

In [4]:
# initial and terminal conditions
initial_condition = [1.5, 1.5, 3.0]
terminal_condition = [1.0, 2.5, 4.0]

In [None]:
# learn backward diffusion bridge process with score matching
distribution_X0 = Gamma(torch.tensor(5.0), torch.tensor(2.0))
simulate_X0 = lambda n: distribution_X0.sample((n, )).reshape(n, d)
XT = []
epsilon = 1.0
minibatch = 100
num_initial_per_batch = 10
num_iterations = 1000
learning_rate = 0.01
ema_momentum = 0.99

output = diffusion.learn_full_score_transition(
    simulate_X0, XT, epsilon, minibatch, num_initial_per_batch, num_iterations, learning_rate, ema_momentum
)
score_transition_net = output["net"]
loss_values_transition = output["loss"]

In [None]:
# simulate backward diffusion bridge (BDB) process with approximate score
BDB = {
    measure: torch.zeros(3, R) for measure in ["ess", "logestimate", "acceptrate"]
}

for c in range(3):
    X0 = initial_condition[c] * torch.ones(d)
    XT = terminal_condition[c] * torch.ones(d)
    for r in range(R):
        with torch.no_grad():
            output = diffusion.simulate_bridge_backwards(
                score_transition_net, X0, XT, epsilon, N, full_score=True,
            )
            trajectories = output["trajectories"]
            log_proposal = output["logdensity"]
        log_target = diffusion.law_bridge(trajectories) 
        log_weights = log_target - log_proposal

        # importance sampling
        max_log_weights = torch.max(log_weights)
        weights = torch.exp(log_weights - max_log_weights)
        norm_weights = weights / torch.sum(weights)
        ess = 1.0 / torch.sum(norm_weights**2)
        log_transition_estimate = torch.log(torch.mean(weights)) + max_log_weights
        BDB["ess"][c, r] = ess
        BDB["logestimate"][c, r] = log_transition_estimate

        # independent Metropolis-Hastings
        initial = diffusion.simulate_bridge_backwards(
            score_transition_net, X0, XT, epsilon, 1, full_score=True,
        )
        current_trajectory = initial["trajectories"]
        current_log_proposal = initial["logdensity"] 
        current_log_target = diffusion.law_bridge(current_trajectory)
        current_log_weight = current_log_target - current_log_proposal
        num_accept = 0
        for n in range(N):
            proposed_trajectory = trajectories[n, :, :]
            proposed_log_weight = log_weights[n]
            log_accept_prob = proposed_log_weight - current_log_weight

            if (torch.log(torch.rand(1)) < log_accept_prob):
                current_trajectory = proposed_trajectory.clone()
                current_log_weight = proposed_log_weight.clone()  
                num_accept += 1
        accept_rate = num_accept / N
        BDB["acceptrate"][c, r] = accept_rate

        # print
        print(
            f"Initial: {initial_condition[c]}",
            f"Terminal: {terminal_condition[c]}",
            f"BDB repetition: {r}",
            f"ESS%: {float(ess * 100 / N):.2f}",
            f"log-transition: {float(log_transition_estimate):.2f}",
            f"Accept rate: {float(accept_rate):.4f}"
        )

In [None]:
# drift functions of existing methods
drifts = {}
modify = {}

# forward diffusion method of Pedersen (1995)
drifts["FD"] = f

# modified diffusion bridge (MDB) method of Durham and Gallant (2002)
drifts["MDB"] = lambda t, x: (XT - x) / (T - t)
modify["MDB"] = "variance"

# diffusion bridge proposal of Clark (1990) and Delyon and Hu (2006)
auxiliary_type = "bm"
initial_params = {"alpha": torch.zeros(d)}
bm_auxiliary = AuxiliaryDiffusion(
    diffusion, auxiliary_type, initial_params, requires_grad=False
)
drifts["CDH"] = lambda t, x: f(t, x) + diffusion.Sigma * bm_auxiliary.grad_logh(
    XT, t, x
)
modify["CDH"] = "time"

# learn guided proposal of Schauer, Van Der Meulen and Van Zanten
auxiliary_type = "ou"
initial_params = {
    "alpha": 4.0 * torch.ones(d),
    "beta": 2.0 * torch.ones(d),
}
for c in range(3):
    X0 = initial_condition[c] * torch.ones(d)
    XT = terminal_condition[c] * torch.ones(d)

    # guided proposal with linearization
    ou_auxiliary = AuxiliaryDiffusion(diffusion, auxiliary_type, initial_params, requires_grad=False)
    drifts["GDB"] = lambda t, x: f(t, x) + diffusion.Sigma * ou_auxiliary.grad_logh(XT, t, x)

modify["GDB"] = "time"

In [None]:
# simulate existing methods
results = {"BDB": BDB}

for method, drift in drifts.items():
    # measures to store
    result = {
        measure: torch.zeros(3, R) for measure in ["ess", "logestimate", "acceptrate"]
    }
    for c in range(3):
        X0 = initial_condition[c] * torch.ones(d)
        XT = terminal_condition[c] * torch.ones(d)

        # repetition
        for r in range(R):
            with torch.no_grad():
                output = diffusion.simulate_proposal_bridge(
                    drift, X0, XT, N, modify.get(method)
                )
            trajectories = output["trajectories"]
            if method == "CDH":
                log_weights = bm_auxiliary.log_radon_nikodym(trajectories)
            elif method == "GDB":
                log_weights = ou_auxiliary.log_radon_nikodym(trajectories)
            else:
                log_proposal = output["logdensity"]
                log_target = diffusion.law_bridge(trajectories)
                log_weights = log_target - log_proposal

            # importance sampling
            max_log_weights = torch.max(log_weights)
            weights = torch.exp(log_weights - max_log_weights)
            norm_weights = weights / torch.sum(weights)
            ess = 1.0 / torch.sum(norm_weights**2)
            log_transition_estimate = torch.log(torch.mean(weights)) + max_log_weights
            result["ess"][c, r] = ess
            result["logestimate"][c, r] = log_transition_estimate

            # independent Metropolis-Hastings
            initial = diffusion.simulate_proposal_bridge(
                drift, X0, XT, 1, modify.get(method)
            )
            current_trajectory = initial["trajectories"]
            if method == "CDH":
                current_log_weight = bm_auxiliary.log_radon_nikodym(current_trajectory)
            elif method == "GDB":
                current_log_weight = ou_auxiliary.log_radon_nikodym(current_trajectory)
            else:
                current_log_proposal = initial["logdensity"]
                current_log_target = diffusion.law_bridge(current_trajectory)
                current_log_weight = current_log_target - current_log_proposal
            num_accept = 0
            for n in range(N):
                proposed_trajectory = trajectories[n, :, :]
                proposed_log_weight = log_weights[n]
                log_accept_prob = proposed_log_weight - current_log_weight
                if torch.log(torch.rand(1)) < log_accept_prob:
                    current_trajectory = proposed_trajectory.clone()
                    current_log_weight = proposed_log_weight.clone()
                    num_accept += 1
            accept_rate = num_accept / N
            result["acceptrate"][c, r] = accept_rate

            # print
            print(
                f"Initial: {initial_condition[c]}",
                f"Terminal: {terminal_condition[c]}",
                f"{method} repetition: {r}",
                f"ESS%: {float(ess * 100 / N):.2f}",
                f"log-transition: {float(log_transition_estimate):.2f}",
                f"Accept rate: {float(accept_rate):.4f}",
            )

    # store result
    results[method] = result

In [None]:
for c in range(3):
    X0 = initial_condition[c] * torch.ones(d)
    XT = terminal_condition[c] * torch.ones(d)

    print("-" * 30)
    print(
        f"Initial: {initial_condition[c]}",
        f"Terminal: {terminal_condition[c]}",
    )
    print("-" * 30)
    for method, result in results.items():
        # compare ESS
        print(
            f"{method}",
            f"ESS%: {float(torch.mean(result['ess'][c, :]) * 100 / N):.2f}",
        )
    print("-" * 30)

    # compare RMSE of log-transition density
    for method, result in results.items():
        RMSE = float(
            torch.sqrt(torch.mean((result['logestimate'][c, :] - log_transition_density(T, XT, X0)) ** 2))
        )
        print(
            f"{method}",
            f"{RMSE:.4f}",
        )
    print("-" * 30)

    # compare indepedent Meteropolis-Hastings acceptance rate
    for method, result in results.items():
        print(
            f"{method}",
            f"Accept rate%: {float(torch.mean(result['acceptrate'][c, :]) * 100):.2f}",
        )
    print("-" * 30)

In [None]:
# save results
file_name = f"{path_dir}/{folder}/radial_T{interval}.pt"
torch.save(results, file_name)