# Comparing methods to simulate Ornstein-Uhlenbeck bridge

In [1]:
import DiffusionBridge as db
import torch
import matplotlib.pyplot as plt
from DiffusionBridge.utils import normal_logpdf
from DiffusionBridge.auxiliary import AuxiliaryDiffusion
plt.style.use('ggplot')

In [2]:
# specify problem settings
d = 1
interval = 1
M = 50
num_iterations = 500
pos_dim = 16
terminal_std = None
path_dir = "."
folder = "results_ou"

In [None]:
# time interval
T = torch.tensor(float(interval))

# diffusion model
alpha = torch.tensor(0.0)
beta = torch.tensor(2.0)
f = lambda t,x: alpha - beta * x 
sigma = torch.tensor(1.0)
diffusion = db.diffusion.model(f, sigma, d, T, M)

# initial and terminal constraints
X0 = 1.0 * torch.ones(d)
XT = 1.0 * torch.ones(d)

# transition density
ratio = alpha / beta
transition_mean = lambda t, x: ratio + (x - ratio) * torch.exp(-beta * t)
transition_var = lambda t: (1.0 - torch.exp(-2.0 * beta * t)) / (2.0 * beta)
score_transition = lambda t, x: (transition_mean(t, X0) - x) / transition_var(t)

# terminal constraint
if terminal_std:
    XT = (
        float(terminal_std) * torch.sqrt(transition_var(T)) + transition_mean(T, X0)
    ) * torch.ones(d)
print(f"terminal state: {float(XT):.4f}")

# transition density from X0 to XT
log_transition_density = normal_logpdf(
    XT.reshape(1, d), transition_mean(T, X0), transition_var(T)
)
print(f"log-transition: {float(log_transition_density):.4f}")

# marginal density
marginal_var = lambda t: 1.0 / (1.0 / transition_var(t) + torch.exp(- 2.0 * beta * (T-t)) / transition_var(T-t))
marginal_mean = lambda t: (transition_mean(t,X0) / transition_var(t) + XT * torch.exp(- beta * (T-t)) / transition_var(T-t)) * marginal_var(t) 
score_marginal = lambda t,x: (marginal_mean(t) - x) / marginal_var(t)
grad_logh = lambda t,x: (XT - transition_mean(T - t, x)) * torch.exp(- beta * (T - t)) / transition_var(T - t)

# sample size
N = 2**10

# repetitions
R = 100

In [None]:
# learn backward diffusion bridge process with score matching
epsilon = 1.0
minibatch = 100
learning_rate = 0.01
ema_momentum = 0.99
network_config = {"pos_dim": pos_dim}
output = diffusion.learn_score_transition(
    X0,
    XT,
    epsilon,
    minibatch,
    num_iterations,
    learning_rate,
    ema_momentum,
    network_config,
)
score_transition_net = output["net"]

# simulate backward diffusion bridge (BDB) process with approximate score
BDB = {
    measure: torch.zeros(R) for measure in ["ess", "logestimate", "acceptrate"]
}
for r in range(R):
    with torch.no_grad():
        output = diffusion.simulate_bridge_backwards(
            score_transition_net, X0, XT, epsilon, N
        )
        trajectories = output["trajectories"]
        log_proposal = output["logdensity"]
    log_target = diffusion.law_bridge(trajectories) 
    log_weights = log_target - log_proposal

    # importance sampling
    max_log_weights = torch.max(log_weights)
    weights = torch.exp(log_weights - max_log_weights)
    norm_weights = weights / torch.sum(weights)
    ess = 1.0 / torch.sum(norm_weights**2)
    log_transition_estimate = torch.log(torch.mean(weights)) + max_log_weights
    BDB["ess"][r] = ess
    BDB["logestimate"][r] = log_transition_estimate

    # independent Metropolis-Hastings
    initial = diffusion.simulate_bridge_backwards(
        score_transition_net, X0, XT, epsilon, 1
    )
    current_trajectory = initial["trajectories"]
    current_log_proposal = initial["logdensity"] 
    current_log_target = diffusion.law_bridge(current_trajectory)
    current_log_weight = current_log_target - current_log_proposal
    num_accept = 0
    for n in range(N):
        proposed_trajectory = trajectories[n, :, :]
        proposed_log_weight = log_weights[n]
        log_accept_prob = proposed_log_weight - current_log_weight

        if (torch.log(torch.rand(1)) < log_accept_prob):
            current_trajectory = proposed_trajectory.clone()
            current_log_weight = proposed_log_weight.clone()  
            num_accept += 1
    accept_rate = num_accept / N
    BDB["acceptrate"][r] = accept_rate

    # print
    print(
        f"BDB repetition: {r}",
        f"ESS%: {float(ess * 100 / N):.2f}",
        f"log-transition: {float(log_transition_estimate):.2f}",
        f"Accept rate: {float(accept_rate):.4f}"
    )

In [None]:
# learn forward diffusion bridge process with score matching
epsilon = 1.0
minibatch = 100
learning_rate = 0.01
ema_momentum = 0.99
network_config = {"pos_dim": pos_dim}
output = diffusion.learn_score_marginal(
    score_transition_net,
    X0,
    XT,
    epsilon,
    minibatch,
    num_iterations,
    learning_rate,
    ema_momentum,
    network_config,
)
score_marginal_net = output["net"]

# simulate forward diffusion bridge (FDB) process using approximate score
FDB = {
    measure: torch.zeros(R) for measure in ["ess", "logestimate", "acceptrate"]
}
for r in range(R):
    with torch.no_grad():
        output = diffusion.simulate_bridge_forwards(
            score_transition_net, score_marginal_net, X0, XT, epsilon, N
        )
        trajectories = output["trajectories"]
        log_proposal = output["logdensity"]
    log_target = diffusion.law_bridge(trajectories) 
    log_weights = log_target - log_proposal

    # importance sampling
    max_log_weights = torch.max(log_weights)
    weights = torch.exp(log_weights - max_log_weights)
    norm_weights = weights / torch.sum(weights)
    ess = 1.0 / torch.sum(norm_weights**2)
    log_transition_estimate = torch.log(torch.mean(weights)) + max_log_weights
    FDB["ess"][r] = ess
    FDB["logestimate"][r] = log_transition_estimate

    # independent Metropolis-Hastings
    initial = diffusion.simulate_bridge_forwards(
        score_transition_net, score_marginal_net, X0, XT, epsilon, 1
    )
    current_trajectory = initial["trajectories"]
    current_log_proposal = initial["logdensity"] 
    current_log_target = diffusion.law_bridge(current_trajectory)
    current_log_weight = current_log_target - current_log_proposal
    num_accept = 0
    for n in range(N):
        proposed_trajectory = trajectories[n, :, :]
        proposed_log_weight = log_weights[n]
        log_accept_prob = proposed_log_weight - current_log_weight

        if (torch.log(torch.rand(1)) < log_accept_prob):
            current_trajectory = proposed_trajectory.clone()
            current_log_weight = proposed_log_weight.clone()  
            num_accept += 1
    accept_rate = num_accept / N
    FDB["acceptrate"][r] = accept_rate

    # print
    print(
        f"FDB repetition: {r}",
        f"ESS%: {float(ess * 100 / N):.2f}",
        f"log-transition: {float(log_transition_estimate):.2f}",
        f"Accept rate: {float(accept_rate):.4f}",
    )

In [None]:
# drift functions of existing methods
drifts = {}
modify = {}

# forward diffusion method of Pedersen (1995)
drifts["FD"] = f

# modified diffusion bridge (MDB) method of Durham and Gallant (2002)
drifts["MDB"] = lambda t, x: (XT - x) / (T - t)
modify["MDB"] = "variance"

# diffusion bridge proposal of Clark (1990) and Delyon and Hu (2006)
auxiliary_type = "bm"
initial_params = {
    "alpha": torch.zeros(d)
}
bm_auxiliary = AuxiliaryDiffusion(
    diffusion, auxiliary_type, initial_params, requires_grad=False
)
drifts["CDH"] = lambda t, x: f(t, x) + diffusion.Sigma * bm_auxiliary.grad_logh(XT, t, x)
modify["CDH"] = "time"

# learn guided proposal of Schauer, Van Der Meulen and Van Zanten
auxiliary_type = "ou"
initial_params = {
    "alpha": alpha * torch.ones(d),
    "beta": beta * torch.ones(d),
}
ou_auxiliary = AuxiliaryDiffusion(
    diffusion, auxiliary_type, initial_params, requires_grad=False
)

drifts["GDB"] = lambda t, x: f(t, x) + diffusion.Sigma * ou_auxiliary.grad_logh(XT, t, x)
modify["GDB"] = "time"

In [None]:
# simulate existing methods
results = {"BDB": BDB, "FDB": FDB}

for method, drift in drifts.items():
    # measures to store
    result = {
        measure: torch.zeros(R) for measure in ["ess", "logestimate", "acceptrate"]
    }

    # repetition
    for r in range(R):
        with torch.no_grad():
            output = diffusion.simulate_proposal_bridge(drift, X0, XT, N, modify.get(method))
        trajectories = output["trajectories"]
        if method == "CDH":
            log_weights = bm_auxiliary.log_radon_nikodym(trajectories)
        elif method == "GDB":
            log_weights = ou_auxiliary.log_radon_nikodym(trajectories)
        else:
            log_proposal = output["logdensity"]
            log_target = diffusion.law_bridge(trajectories)
            log_weights = log_target - log_proposal

        # importance sampling
        max_log_weights = torch.max(log_weights)
        weights = torch.exp(log_weights - max_log_weights)
        norm_weights = weights / torch.sum(weights)
        ess = 1.0 / torch.sum(norm_weights**2)
        log_transition_estimate = torch.log(torch.mean(weights)) + max_log_weights
        result["ess"][r] = ess
        result["logestimate"][r] = log_transition_estimate

        # independent Metropolis-Hastings
        initial = diffusion.simulate_proposal_bridge(
            drift, X0, XT, 1, modify.get(method)
        )
        current_trajectory = initial["trajectories"]
        if method == "CDH":
            current_log_weight = bm_auxiliary.log_radon_nikodym(current_trajectory)
        elif method == "GDB":
            current_log_weight = ou_auxiliary.log_radon_nikodym(current_trajectory)
        else:
            current_log_proposal = initial["logdensity"]
            current_log_target = diffusion.law_bridge(current_trajectory)
            current_log_weight = current_log_target - current_log_proposal
        num_accept = 0
        for n in range(N):
            proposed_trajectory = trajectories[n, :, :]
            proposed_log_weight = log_weights[n]
            log_accept_prob = proposed_log_weight - current_log_weight
            if torch.log(torch.rand(1)) < log_accept_prob:
                current_trajectory = proposed_trajectory.clone()
                current_log_weight = proposed_log_weight.clone()
                num_accept += 1
        accept_rate = num_accept / N
        result["acceptrate"][r] = accept_rate

        # print
        print(
            f"{method} repetition: {r}",
            f"ESS%: {float(ess * 100 / N):.2f}",
            f"log-transition: {float(log_transition_estimate):.2f}",
            f"Accept rate: {float(accept_rate):.4f}",
        )

    # store result
    results[method] = result

In [None]:
# compare ESS
for method, result in results.items():
    print(
        f"{method}", f"ESS%: {float(torch.mean(result['ess']) * 100 / N):.2f}",
    )
print("-" * 30)

# compare RMSE of log-transition density
for method, result in results.items():
    RMSE = float(
        torch.sqrt(torch.mean((result['logestimate'] - log_transition_density) ** 2))
    )
    print(
        f"{method}", f"{RMSE:.4f}",
    )
print("-" * 30)

# compare independent Meteropolis-Hastings acceptance rate
for method, result in results.items():
    print(
        f"{method}", f"Accept rate%: {float(torch.mean(result['acceptrate']) * 100):.2f}",
    )
print("-" * 30)

In [None]:
# save results
file_name = f"{path_dir}/{folder}/ou_dim{d}_T{interval}"
if terminal_std:
    file_name += f"_std{terminal_std}"
torch.save(results, file_name + ".pt")