# Comparing methods to simulate cell differentiation and development model

In [1]:
import DiffusionBridge as db
import torch
import matplotlib.pyplot as plt
from scipy.optimize import fsolve
from DiffusionBridge.auxiliary import AuxiliaryDiffusion
plt.style.use('ggplot')

In [2]:
# specify problem settings
interval = 2
sigma_level = "smaller"
M = 100
num_iterations = 1000
path_dir = "."
folder = "results_cell"

In [3]:
# dimension
d = 2

# time interval
T = torch.tensor(float(interval))

# parameters
alpha = torch.tensor(1.0)
beta = torch.tensor(1.0)
kappa = torch.tensor(1.0)
p = torch.tensor(4.0)
xi = torch.tensor(0.5)

# drift
def model_drift(x):
    out = torch.zeros(x.shape)    
    out[:,0] = alpha * x[:,0]**p / (xi**p + x[:,0]**p) + beta * xi**p / (xi**p + x[:,1]**p) - kappa * x[:,0]
    out[:,1] = alpha * x[:,1]**p / (xi**p + x[:,1]**p) + beta * xi**p / (xi**p + x[:,0]**p) - kappa * x[:,1]
    return out
f = lambda t, x: model_drift(x)

# diffusion coefficient
if sigma_level == "smaller":
    sigma = torch.sqrt(torch.tensor(1 * 1e-1))
if sigma_level == "larger":
    sigma = torch.tensor(1.0)

# initialize diffusion model
diffusion = db.diffusion.model(f, sigma, d, T, M)

# sample size
N = 2**10

# repetitions
R = 100

In [None]:
# drift to find fixed points
def drift_(x):
    out = torch.zeros(d)    
    out[0] = alpha * x[0]**p / (xi**p + x[0]**p) + beta * xi**p / (xi**p + x[1]**p) - kappa * x[0]
    out[1] = alpha * x[1]**p / (xi**p + x[1]**p) + beta * xi**p / (xi**p + x[0]**p) - kappa * x[1]
    return out

# undifferentiated cell state
X0 = torch.ones(d)
print(f"Undifferentiated cell state {X0} has drift {drift_(X0)}")

# differentiate cell state
XT = torch.tensor(fsolve(func = drift_, x0 = torch.tensor([0.0, 2.0])), dtype = torch.float32)
print(f"Differentiated cell state {XT} has drift {drift_(XT)}")

In [None]:
# learn backward diffusion bridge process with score matching
epsilon = 1.0
minibatch = 100
learning_rate = 0.01
ema_momentum = 0.99
output = diffusion.learn_score_transition(
    X0, XT, epsilon, minibatch, num_iterations, learning_rate, ema_momentum
)
score_transition_net = output['net']

# simulate backward diffusion bridge (BDB) process with approximate score
BDB = {
    measure: torch.zeros(R) for measure in ["ess", "logestimate", "acceptrate"]
}
for r in range(R):
    with torch.no_grad():
        output = diffusion.simulate_bridge_backwards(
            score_transition_net, X0, XT, epsilon, N
        )
        trajectories = output["trajectories"]
        log_proposal = output["logdensity"]
    log_target = diffusion.law_bridge(trajectories)
    log_weights = log_target - log_proposal

    # importance sampling
    max_log_weights = torch.max(log_weights)
    weights = torch.exp(log_weights - max_log_weights)
    norm_weights = weights / torch.sum(weights)
    ess = 1.0 / torch.sum(norm_weights**2)
    log_transition_estimate = torch.log(torch.mean(weights)) + max_log_weights
    BDB["ess"][r] = ess
    BDB["logestimate"][r] = log_transition_estimate

    # independent Metropolis-Hastings
    initial = diffusion.simulate_bridge_backwards(
        score_transition_net, X0, XT, epsilon, 1
    )
    current_trajectory = initial["trajectories"]
    current_log_proposal = initial["logdensity"]
    current_log_target = diffusion.law_bridge(current_trajectory)
    current_log_weight = current_log_target - current_log_proposal
    num_accept = 0
    for n in range(N):
        proposed_trajectory = trajectories[n, :, :]
        proposed_log_weight = log_weights[n]
        log_accept_prob = proposed_log_weight - current_log_weight

        if torch.log(torch.rand(1)) < log_accept_prob:
            current_trajectory = proposed_trajectory.clone()
            current_log_weight = proposed_log_weight.clone()
            num_accept += 1
    accept_rate = num_accept / N
    BDB["acceptrate"][r] = accept_rate

    # print
    print(
        f"BDB repetition: {r}",
        f"ESS%: {float(ess * 100 / N):.2f}",
        f"log-transition: {float(log_transition_estimate):.2f}",
        f"Accept rate: {float(accept_rate):.4f}",
    )

In [None]:
# drift functions of existing methods
drifts = {}
modify = {}

# forward diffusion method of Pedersen (1995)
drifts["FD"] = f

# modified diffusion bridge (MDB) method of Durham and Gallant (2002)
drifts["MDB"] = lambda t, x: (XT - x) / (T - t)
modify["MDB"] = "variance"

# diffusion bridge proposal of Clark (1990) and Delyon and Hu (2006)
auxiliary_type = "bm"
initial_params = {"alpha": torch.zeros(d)}
bm_auxiliary = AuxiliaryDiffusion(
    diffusion, auxiliary_type, initial_params, requires_grad=False
)
drifts["CDH"] = lambda t, x: f(t, x) + diffusion.Sigma * bm_auxiliary.grad_logh(
    XT, t, x
)
modify["CDH"] = "time"

# learn guided proposal of Schauer, Van Der Meulen and Van Zanten
auxiliary_type = "ou-full"
initial_params = {
    "theta": XT,
    "eigvals": torch.ones(d),
    "eigvecs": torch.eye(d),
}
ou_auxiliary = AuxiliaryDiffusion(diffusion, auxiliary_type, initial_params)

minibatch = 100
learning_rate = 0.01
guided = diffusion.learn_guided_proposal(
    ou_auxiliary, X0, XT, minibatch, num_iterations, learning_rate
)
drifts["GDB"] = lambda t, x: f(t, x) + diffusion.Sigma * ou_auxiliary.grad_logh(
    XT, t, x
)
modify["GDB"] = "time"

In [None]:
# simulate existing methods
results = {"BDB": BDB}

for method, drift in drifts.items():
    # measures to store
    result = {
        measure: torch.zeros(R) for measure in ["ess", "logestimate", "acceptrate"]
    }

    # repetition
    for r in range(R):
        with torch.no_grad():
            output = diffusion.simulate_proposal_bridge(
                drift, X0, XT, N, modify.get(method)
            )
        trajectories = output["trajectories"]
        if method == "CDH":
            log_weights = bm_auxiliary.log_radon_nikodym(trajectories)
        elif method == "GDB":
            log_weights = ou_auxiliary.log_radon_nikodym(trajectories)
        else:
            log_proposal = output["logdensity"]
            log_target = diffusion.law_bridge(trajectories)
            log_weights = log_target - log_proposal

        # importance sampling
        max_log_weights = torch.max(log_weights)
        weights = torch.exp(log_weights - max_log_weights)
        norm_weights = weights / torch.sum(weights)
        ess = 1.0 / torch.sum(norm_weights**2)
        log_transition_estimate = torch.log(torch.mean(weights)) + max_log_weights
        result["ess"][r] = ess
        result["logestimate"][r] = log_transition_estimate

        # independent Metropolis-Hastings
        initial = diffusion.simulate_proposal_bridge(
            drift, X0, XT, 1, modify.get(method)
        )
        current_trajectory = initial["trajectories"]
        if method == "CDH":
            current_log_weight = bm_auxiliary.log_radon_nikodym(current_trajectory)
        elif method == "GDB":
            current_log_weight = ou_auxiliary.log_radon_nikodym(current_trajectory)
        else:
            current_log_proposal = initial["logdensity"]
            current_log_target = diffusion.law_bridge(current_trajectory)
            current_log_weight = current_log_target - current_log_proposal
        num_accept = 0
        for n in range(N):
            proposed_trajectory = trajectories[n, :, :]
            proposed_log_weight = log_weights[n]
            log_accept_prob = proposed_log_weight - current_log_weight
            if torch.log(torch.rand(1)) < log_accept_prob:
                current_trajectory = proposed_trajectory.clone()
                current_log_weight = proposed_log_weight.clone()
                num_accept += 1
        accept_rate = num_accept / N
        result["acceptrate"][r] = accept_rate

        # print
        print(
            f"{method} repetition: {r}",
            f"ESS%: {float(ess * 100 / N):.2f}",
            f"log-transition: {float(log_transition_estimate):.2f}",
            f"Accept rate: {float(accept_rate):.4f}",
        )

    # store result
    results[method] = result

In [None]:
# compare ESS
for method, result in results.items():
    print(
        f"{method}",
        f"ESS%: {float(torch.mean(result['ess']) * 100 / N):.2f}",
    )
print("-" * 30)

# compare RMSE of log-transition density
for method, result in results.items():
    print(
        f"{method}",
        f"ELBO: {torch.mean(result['logestimate']):.3f}",
    )
print("-" * 30)

# compare independent Meteropolis-Hastings acceptance rate
for method, result in results.items():
    print(
        f"{method}",
        f"Accept rate%: {float(torch.mean(result['acceptrate']) * 100):.2f}",
    )
print("-" * 30)

In [None]:
# save results
file_name = f"{path_dir}/{folder}/cell_sigma_{sigma_level}_T{interval}"
torch.save(results, file_name + ".pt")