# Comparing methods to simulate cell differentiation and development model

In [None]:
import DiffusionBridge as db
import torch
import matplotlib.pyplot as plt
from scipy.optimize import fsolve
plt.style.use('ggplot')

In [None]:
# problem settings
d = 2
T = torch.tensor(2.0)

# parameters
alpha = torch.tensor(1.0)
beta = torch.tensor(1.0)
kappa = torch.tensor(1.0)
p = torch.tensor(4.0)
xi = torch.tensor(0.5)

# drift
def drift(x):
    out = torch.zeros(x.shape)    
    out[:,0] = alpha * x[:,0]**p / (xi**p + x[:,0]**p) + beta * xi**p / (xi**p + x[:,1]**p) - kappa * x[:,0]
    out[:,1] = alpha * x[:,1]**p / (xi**p + x[:,1]**p) + beta * xi**p / (xi**p + x[:,0]**p) - kappa * x[:,1]
    return out
f = lambda t,x: drift(x)
    
# diffusion coefficient
# sigma = torch.tensor(1.0)
sigma = torch.sqrt(torch.tensor(1 * 1e-1))
print('Diffusion coefficient: ' + str(sigma))

# number of time-discretization steps
M = 100

# initialize diffusion model
diffusion = db.diffusion.model(f, sigma, d, T, M)

# repetitions
N = 2**10
R = 100

In [None]:
# drift to find fixed points
def drift_(x):
    out = torch.zeros(d)    
    out[0] = alpha * x[0]**p / (xi**p + x[0]**p) + beta * xi**p / (xi**p + x[1]**p) - kappa * x[0]
    out[1] = alpha * x[1]**p / (xi**p + x[1]**p) + beta * xi**p / (xi**p + x[0]**p) - kappa * x[1]
    return out

# undifferentiated cell state
X0 = torch.ones(d)
print(drift_(X0))
print('Undifferentiated cell state: ' + str(X0))

# differentiate cell state
XT = torch.tensor(fsolve(func = drift_, x0 = torch.tensor([0.0, 2.0])), dtype = torch.float32)
print(drift_(XT))
print('Differentiated cell state: ' + str(XT))

In [None]:
# learn backward diffusion bridge process with score matching
epsilon = 1.0
minibatch = 100
num_iterations = 1000
learning_rate = 0.01
ema_momentum = 0.99
output = diffusion.learn_score_transition(X0, XT, epsilon, minibatch, num_iterations, learning_rate, ema_momentum)
score_transition_net = output['net']

# simulate modified backward diffusion bridge (MBDB) process with approximate score
MBDB = {'ess' : torch.zeros(R), 'logestimate' : torch.zeros(R), 'acceptrate' : torch.zeros(R)}
for r in range(R):
    with torch.no_grad():
        output = diffusion.simulate_bridge_backwards(score_transition_net, X0, XT, epsilon, N, modify = True)
        trajectories = output['trajectories']
        log_proposal = output['logdensity']
    log_target = diffusion.law_bridge(trajectories) 
    log_weights = log_target - log_proposal

    # importance sampling
    max_log_weights = torch.max(log_weights)
    weights = torch.exp(log_weights - max_log_weights)
    norm_weights = weights / torch.sum(weights)
    ess = 1.0 / torch.sum(norm_weights**2)
    log_transition_estimate = torch.log(torch.mean(weights)) + max_log_weights
    MBDB['ess'][r] = ess
    MBDB['logestimate'][r] = log_transition_estimate

    # independent Metropolis-Hastings
    initial = diffusion.simulate_bridge_backwards(score_transition_net, X0, XT, epsilon, 1, modify = True)
    current_trajectory = initial['trajectories']
    current_log_proposal = initial['logdensity'] 
    current_log_target = diffusion.law_bridge(current_trajectory)
    current_log_weight = current_log_target - current_log_proposal
    num_accept = 0
    for n in range(N):
        proposed_trajectory = trajectories[n, :, :]
        proposed_log_weight = log_weights[n]
        log_accept_prob = proposed_log_weight - current_log_weight

        if (torch.log(torch.rand(1)) < log_accept_prob):
            current_trajectory = proposed_trajectory.clone()
            current_log_weight = proposed_log_weight.clone()  
            num_accept += 1
    accept_rate = num_accept / N
    MBDB['acceptrate'][r] = accept_rate

    # print
    print('Repeat: ' + str(r) + 
          ' ESS%: ' + str(float(ess * 100 / N)) + 
          ' log-transition density: ' + str(float(log_transition_estimate)),
          ' Accept rate: ' + str(float(accept_rate)))

In [None]:
# forward diffusion (FD) method of Pedersen (1995)
fs_drift = f
FD = {'ess' : torch.zeros(R), 'logestimate' : torch.zeros(R), 'acceptrate' : torch.zeros(R)}

for r in range(R):
    output = diffusion.simulate_proposal_bridge(fs_drift, X0, XT, N)
    trajectories = output['trajectories']
    log_proposal = output['logdensity']
    log_target = diffusion.law_bridge(trajectories) 
    log_weights = log_target - log_proposal

    # importance sampling
    max_log_weights = torch.max(log_weights)
    weights = torch.exp(log_weights - max_log_weights)
    norm_weights = weights / torch.sum(weights)
    ess = 1.0 / torch.sum(norm_weights**2)    
    log_transition_estimate = torch.log(torch.mean(weights)) + max_log_weights
    FD['ess'][r] = ess
    FD['logestimate'][r] = log_transition_estimate

    # independent Metropolis-Hastings
    initial = diffusion.simulate_proposal_bridge(fs_drift, X0, XT, 1)
    current_trajectory = initial['trajectories']
    current_log_proposal = initial['logdensity'] 
    current_log_target = diffusion.law_bridge(current_trajectory)
    current_log_weight = current_log_target - current_log_proposal
    num_accept = 0
    for n in range(N):
        proposed_trajectory = trajectories[n, :, :]
        proposed_log_weight = log_weights[n]
        log_accept_prob = proposed_log_weight - current_log_weight

        if (torch.log(torch.rand(1)) < log_accept_prob):
            current_trajectory = proposed_trajectory.clone()
            current_log_weight = proposed_log_weight.clone()  
            num_accept += 1
    accept_rate = num_accept / N
    FD['acceptrate'][r] = accept_rate
    
    # print
    print('Repeat: ' + str(r) + 
          ' ESS%: ' + str(float(ess * 100 / N)) + 
          ' log-transition density: ' + str(float(log_transition_estimate)),
          ' Accept rate: ' + str(float(accept_rate)))


In [None]:
# modified diffusion bridge (MDB) method of Durham and Gallant (2002)
mdb_drift = lambda t,x: (XT - x) / (T - t)
MDB = {'ess' : torch.zeros(R), 'logestimate' : torch.zeros(R), 'acceptrate' : torch.zeros(R)}

for r in range(R):
    output = diffusion.simulate_proposal_bridge(mdb_drift, X0, XT, N, modify = True)
    trajectories = output['trajectories']
    log_proposal = output['logdensity']
    log_target = diffusion.law_bridge(trajectories) 
    log_weights = log_target - log_proposal
    
    # importance sampling
    max_log_weights = torch.max(log_weights)
    weights = torch.exp(log_weights - max_log_weights)
    norm_weights = weights / torch.sum(weights)
    ess = 1.0 / torch.sum(norm_weights**2)
    log_transition_estimate = torch.log(torch.mean(weights)) + max_log_weights
    MDB['ess'][r] = ess
    MDB['logestimate'][r] = log_transition_estimate

    # independent Metropolis-Hastings
    initial = diffusion.simulate_proposal_bridge(mdb_drift, X0, XT, 1, modify = True)
    current_trajectory = initial['trajectories']
    current_log_proposal = initial['logdensity'] 
    current_log_target = diffusion.law_bridge(current_trajectory)
    current_log_weight = current_log_target - current_log_proposal
    num_accept = 0
    for n in range(N):
        proposed_trajectory = trajectories[n, :, :]
        proposed_log_weight = log_weights[n]
        log_accept_prob = proposed_log_weight - current_log_weight

        if (torch.log(torch.rand(1)) < log_accept_prob):
            current_trajectory = proposed_trajectory.clone()
            current_log_weight = proposed_log_weight.clone()  
            num_accept += 1
    accept_rate = num_accept / N
    MDB['acceptrate'][r] = accept_rate

    # print
    print('Repeat: ' + str(r) + 
          ' ESS%: ' + str(float(ess * 100 / N)) + 
          ' log-transition density: ' + str(float(log_transition_estimate)),
          ' Accept rate: ' + str(float(accept_rate)))

In [None]:
# diffusion bridge proposal of Clark (1990) and Delyon and Hu (2006) (CDH)
cdh_drift = lambda t,x: f(t,x) + (XT - x) / (T - t)
CDH = {'ess' : torch.zeros(R), 'logestimate' : torch.zeros(R), 'acceptrate' : torch.zeros(R)}

for r in range(R):
    output = diffusion.simulate_proposal_bridge(cdh_drift, X0, XT, N, modify = False)
    trajectories = output['trajectories']
    log_proposal = output['logdensity']
    log_target = diffusion.law_bridge(trajectories) 
    log_weights = log_target - log_proposal

    # importance sampling
    max_log_weights = torch.max(log_weights)
    weights = torch.exp(log_weights - max_log_weights)
    norm_weights = weights / torch.sum(weights)
    ess = 1.0 / torch.sum(norm_weights**2)
    log_transition_estimate = torch.log(torch.mean(weights)) + max_log_weights
    CDH['ess'][r] = ess
    CDH['logestimate'][r] = log_transition_estimate

    # independent Metropolis-Hastings
    initial = diffusion.simulate_proposal_bridge(cdh_drift, X0, XT, 1, modify = False)
    current_trajectory = initial['trajectories']
    current_log_proposal = initial['logdensity'] 
    current_log_target = diffusion.law_bridge(current_trajectory)
    current_log_weight = current_log_target - current_log_proposal
    num_accept = 0
    for n in range(N):
        proposed_trajectory = trajectories[n, :, :]
        proposed_log_weight = log_weights[n]
        log_accept_prob = proposed_log_weight - current_log_weight

        if (torch.log(torch.rand(1)) < log_accept_prob):
            current_trajectory = proposed_trajectory.clone()
            current_log_weight = proposed_log_weight.clone()  
            num_accept += 1
    accept_rate = num_accept / N
    CDH['acceptrate'][r] = accept_rate

    # print
    print('Repeat: ' + str(r) + 
          ' ESS%: ' + str(float(ess * 100 / N)) + 
          ' log-transition density: ' + str(float(log_transition_estimate)),
          ' Accept rate: ' + str(float(accept_rate)))

In [None]:
# compare ESS
print('FD ESS%: ' + str(float(torch.mean(FD['ess']) * 100 / N)))
print('MDB ESS%: ' + str(float(torch.mean(MDB['ess']) * 100 / N)))
print('CDH ESS%: ' + str(float(torch.mean(CDH['ess']) * 100 / N)))
print('MBDB ESS%: ' + str(float(torch.mean(MBDB['ess']) * 100 / N)))

In [None]:
# compare mean of log-transition density estimates
print('FD ELBO: ' + str(float(torch.mean(FD['logestimate']))))
print('MDB ELBO: ' + str(float(torch.mean(MDB['logestimate']))))
print('CDH ELBO: ' + str(float(torch.mean(CDH['logestimate']))))
print('MBDB ELBO: ' + str(float(torch.mean(MBDB['logestimate']))))

In [None]:
# compare indepedent Meteropolis-Hastings acceptance rate
print('FD acceptance%: ' + str(float(torch.mean(FD['acceptrate'] * 100))))
print('MDB acceptance%: ' + str(float(torch.mean(MDB['acceptrate'] * 100))))
print('CDH acceptance%: ' + str(float(torch.mean(CDH['acceptrate'] * 100))))
print('MBDB acceptance%: ' + str(float(torch.mean(MBDB['acceptrate'] * 100))))

In [None]:
# store results
results = {'FD': FD, 
           'MDB': MDB, 
           'CDH': CDH, 
           'MBDB': MBDB}

torch.save(results, 'cell_sigmasq_T2.pt')