In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


# A differentiable version of an auto-repression bursting model

$G \to G'$ with rate $k_{on}$

$G' + P \to G + P$ with rate $k_{off}$

$0 \to M$ with rate $r_m$

$M \to M + P$ with rate $r_p$

$M \to 0$ with rate $\gamma_m$

$P \to 0$ with rate $\gamma_p$

Note: we are following Krishna's paper where OFF corresponds to $G=-1$. This is useful for smoothing with a sigmoid, such that $\sigma(-1) \to 0$.  

In [19]:
# Define the stoichiometry matrix for the reactions
stoic_matrix = torch.tensor([[2.0, 0.0, 0.0],    # Reaction 1: Promoter state goes from -1 to +1
                             [-2.0, 0.0, 0.0],   # Reaction 2: Promoter state goes from 1 to -1
                             [0.0, 1.0, 0.0],    # Reaction 3: mRNA production
                             [0.0, 0.0, 1.0],    # Reaction 4: protein production
                             [0.0, -1.0, 0.0],   # Reaction 5: Degradation of mRNA
                             [0.0, 0.0, -1.0]])  # Reaction 6: protein degradation

# Define a function to compute the state jump
def state_jump(reaction_index, stoic_matrix):
    """
    Calculate state jump vector based on the selected reaction index and stoichiometry matrix, where, 
    state vector -> state vector + state jump vector.

    Arguments:
        reaction_index: Selected reaction index
        stoic_matrix: Stoichiometry matrix

    Returns:
        State jump vector
    """
    return torch.sum(stoic_matrix * (torch.exp(-b_inv* (reaction_index - torch.arange(stoic_matrix.shape[0]))**2)).view(-1, 1), dim=0)

# Define a function to select the reaction based on reaction selection thresholds
def reaction_selection(breaks, random_num):
    """
    Select reaction based on the transition points and a random number. Transition points are 
    given by the ratio of cumulative sum of rates and the total rate.
    
    Arguments:
        breaks: Transition points between [0,1]
        random_num: Random number in [0,1]

    Returns:
        Index of the next reaction
    """
    return torch.sum(torch.sigmoid(a_inv * (random_num - breaks)))

# Define the Gillespie simulation function
def gillespie_simulation(r_m, r_p, k_on, k_off, num_simulations, sim_time, a_inv, b_inv, c):
    """
    Perform differentiable Gillespie simulation for a 2-state promoter model.
    
    Arguments:
        r_m: Rate of mRNA production.
        r_p: Rate of protein production
        k_on: Rate of promoter switching from -1 to +1.
        k_off: Rate of promoter switching from 1 to -1 (not including protein level).
        num_simulations: Number of simulations to run.
        sim_time: Simulation time.
        a_inv: Inverse parameter for reaction selection.
        b_inv : Inverse parameter for state jump calculation.
        c: Sigmoid slope parameter for propensities.
        
    Returns:
        mean_final_state: Mean of the mRNA levels at the end of the simulation.
        variance: Variance of the mRNA levels at the end of the simulation.
    """
    # Initialize random seed for reproducibility
    random_seed = torch.randint(1, 10000000, (1,))
    #print (random_seed)
    torch.manual_seed(random_seed)
    final_states = 0.0
    final_states_squared = 0.0

    # Main simulation loop
    for j in range(num_simulations):
        # Initialize 'levels':
        # The first component of 'levels' is the promoter state, initialized to 0
        # The second component of 'levels' is the mRNA count, initailized to 0.
        levels = torch.stack([torch.tensor(-1.0), torch.tensor(0.0), torch.tensor(0.0)])
        current_time = 0.0

        # Main simulation loop
        while current_time < sim_time:
            # Calculate reaction propensities
            propensities = torch.stack([k_on*torch.sigmoid(-c*levels[0]),            # Rate of promoter state switching from -1 to +1
                                        k_off*levels[2]*torch.sigmoid(c*levels[0]),  # Rate of promoter state switching from +1 to -1
                                        r_m*torch.sigmoid(-c*levels[0]),             # Rate of mRNA production
                                        r_p*levels[1],                               # Rate of mRNA production
                                        gamma_m*levels[1],                           # Rate of mRNA degradation
                                        gamma_p*levels[2]])                          # Rate of protein degradation             

            # Calculate total propensity
            total_propensity = propensities.sum()

            # Generate a random number to determine time to next reaction
            dt = -torch.log(torch.rand(1)) / total_propensity
            current_time += dt.item()

            # Check if the simulation exceeds sim_time. If it exceeds, quit the simulation.
            if current_time >= sim_time:
                break

            # Update state vector
            breaks = (propensities[:-1] / total_propensity).cumsum(dim=0)
            reaction_index = reaction_selection(breaks, torch.rand(1))
            levels = levels + state_jump(reaction_index, stoic_matrix)
            levels[1] = torch.relu(levels[1]) 

        # Accumulate final states after each sumulation
        final_states += levels[1]
        final_states_squared += levels[1] ** 2

    # Calculate mean and variance of mRNA levels (from the accumulated final states)
    mean_final_state = final_states / num_simulations
    variance = final_states_squared / num_simulations - mean_final_state ** 2

    # Return mean mRNA level and variance
    return mean_final_state, variance

# Define the loss function 
def loss_function(mean_final_state, variance, target_mean, target_std):
    """
    Calculates the mean squared error of the simulation results against data
    """
    return (mean_final_state - target_mean) ** 2 + (variance ** 0.5 - target_std) ** 2



### Test function

In [5]:
r_m = 1.0
r_p = 1.0
k_on = 0.05
k_off = 0.10
gamma_m = 0.23
gamma_p = 0.23

# Hyperparameters
num_simulations=50
a_inv=200.0
b_inv=20.0
c=20.0
sim_time=5.0

In [6]:
mean_final_state, variance =  gillespie_simulation(r_m, r_p, k_on, k_off, num_simulations, sim_time, a_inv, b_inv, c)

tensor([1439364])


## Try out fitting simulated data

In [7]:
"""generate ground truth"""
r_m = 1.0
r_p = 1.0
k_on = 0.05
k_off = 0.10
gamma_m = 0.23
gamma_p = 0.23

# Hyperparameters
num_simulations=100
a_inv=200.0
b_inv=20.0
c=20.0
sim_time=5.0

In [8]:
# run sims
mean_final_state, variance =  gillespie_simulation(r_m, r_p, k_on, k_off, num_simulations, sim_time, a_inv, b_inv, c)

# store output
ground_truth = {'r_m': r_m, 'r_p': r_p, 'k_on': k_on, 'k_off': k_off, 'mean_mrna': mean_final_state, 'variance_mrna': variance}

tensor([6609668])


In [21]:
# Set seed for reproducibility 
torch.manual_seed(40)

# Define simulation hyperparameters
num_simulations = 100
num_iterations = 30

# Initialize parameters
r_m = torch.nn.Parameter(torch.tensor(2.0))
r_p =  torch.nn.Parameter(torch.tensor(0.5))
k_on =  torch.nn.Parameter(torch.tensor(0.01))
k_off =  torch.nn.Parameter(torch.tensor(0.2))

# Define the Adam optimizer
optimizer = torch.optim.Adam([r_m, r_p, k_on, k_off], lr=0.1)

# Set target mean and standard deviation
target_mean = torch.tensor(ground_truth['mean_mrna'])
target_std = torch.sqrt(torch.tensor(ground_truth['variance_mrna']))

# Loop through each iteration
for iteration in tqdm(range(num_iterations)):

    # Forward differentiable Gillespie simulation
    mean_final_state, variance =  gillespie_simulation(r_m, r_p, k_on, k_off, num_simulations, sim_time, a_inv, b_inv, c)

    # Compute the loss for the current iteration
    loss = loss_function(mean_final_state, variance, target_mean, target_std)

    # Zero the gradients to prepare for backward pass
    optimizer.zero_grad()

    # Compute the gradient of the loss with respect to parameters
    loss.backward()

    # Gradient clipping
    torch.nn.utils.clip_grad_norm_([r_m, r_p, k_on, k_off], max_norm=0.2)

    # Update the parameters using the optimizer
    optimizer.step()

    # Clamp the parameter values to certain bounds
    r_m.data = torch.clamp(r_m.data, min=0.01, max=100.0)
    r_p.data = torch.clamp(r_p.data, min=0.01, max=100.0)
    k_on.data = torch.clamp(k_on.data, min=0.001, max=100.0)
    k_off.data = torch.clamp(k_off.data, min=0.001, max=100.0)
    
    print(r_m.data)


  target_mean = torch.tensor(ground_truth['mean_mrna'])
  target_std = torch.sqrt(torch.tensor(ground_truth['variance_mrna']))
  3%|██████                                                                                                                                                                                | 1/30 [00:01<00:42,  1.46s/it]

tensor(1.9000)


  7%|████████████▏                                                                                                                                                                         | 2/30 [00:03<00:45,  1.63s/it]

tensor(1.8026)


 10%|██████████████████▏                                                                                                                                                                   | 3/30 [00:05<00:47,  1.78s/it]

tensor(1.7043)


 13%|████████████████████████▎                                                                                                                                                             | 4/30 [00:07<00:48,  1.85s/it]

tensor(1.6166)


 17%|██████████████████████████████▎                                                                                                                                                       | 5/30 [00:09<00:48,  1.93s/it]

tensor(1.5390)


 20%|████████████████████████████████████▍                                                                                                                                                 | 6/30 [00:11<00:45,  1.91s/it]

tensor(1.4786)


 23%|██████████████████████████████████████████▍                                                                                                                                           | 7/30 [00:12<00:43,  1.88s/it]

tensor(1.4096)


 27%|████████████████████████████████████████████████▌                                                                                                                                     | 8/30 [00:14<00:39,  1.78s/it]

tensor(1.3361)


 30%|██████████████████████████████████████████████████████▌                                                                                                                               | 9/30 [00:16<00:36,  1.73s/it]

tensor(1.2626)


 33%|████████████████████████████████████████████████████████████▎                                                                                                                        | 10/30 [00:18<00:40,  2.00s/it]

tensor(1.1837)


 37%|██████████████████████████████████████████████████████████████████▎                                                                                                                  | 11/30 [00:20<00:36,  1.92s/it]

tensor(1.1147)


 40%|████████████████████████████████████████████████████████████████████████▍                                                                                                            | 12/30 [00:22<00:33,  1.84s/it]

tensor(1.0464)


 43%|██████████████████████████████████████████████████████████████████████████████▍                                                                                                      | 13/30 [00:23<00:29,  1.75s/it]

tensor(0.9703)


 47%|████████████████████████████████████████████████████████████████████████████████████▍                                                                                                | 14/30 [00:25<00:26,  1.67s/it]

tensor(0.9156)


 50%|██████████████████████████████████████████████████████████████████████████████████████████▌                                                                                          | 15/30 [00:26<00:24,  1.67s/it]

tensor(0.8807)


 53%|████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                    | 16/30 [00:28<00:22,  1.61s/it]

tensor(0.8577)


 57%|██████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                              | 17/30 [00:29<00:20,  1.56s/it]

tensor(0.8203)


 60%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                        | 18/30 [00:31<00:18,  1.58s/it]

tensor(0.8065)


 63%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                  | 19/30 [00:32<00:17,  1.60s/it]

tensor(0.7772)


 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                            | 20/30 [00:34<00:16,  1.63s/it]

tensor(0.7655)


 70%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                      | 21/30 [00:36<00:15,  1.67s/it]

tensor(0.7635)


 73%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                | 22/30 [00:37<00:12,  1.61s/it]

tensor(0.7834)


 77%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                          | 23/30 [00:39<00:11,  1.59s/it]

tensor(0.8176)


 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                    | 24/30 [00:41<00:09,  1.66s/it]

tensor(0.8561)


 83%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                              | 25/30 [00:43<00:08,  1.79s/it]

tensor(0.8759)


 87%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                        | 26/30 [00:45<00:07,  1.87s/it]

tensor(0.9062)


 90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                  | 27/30 [00:47<00:05,  1.86s/it]

tensor(0.9351)


 93%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉            | 28/30 [00:49<00:03,  1.94s/it]

tensor(0.9747)


 97%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉      | 29/30 [00:51<00:02,  2.08s/it]

tensor(1.0120)


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:54<00:00,  1.80s/it]

tensor(1.0378)





In [23]:
r_m.data

tensor(1.0378)

In [24]:
r_p.data

tensor(2.0903)

In [25]:
k_on.data

tensor(0.0010)

In [26]:
k_off.data

tensor(0.3731)