In [1]:
#Import required packages
import matplotlib.pyplot as plt
import numpy as np
import torch 
import torch.nn as nn
import torch.optim as optim

from torch.autograd import grad
from torchdiffeq import odeint_adjoint
from torchdiffeq import odeint
import pandas as pd

In [5]:
#Helper functions
def activation(x, k, theta, n):
    return (k*(x/theta)**n)/(1+(x/theta)**n)

def repression(x, k, theta, n):
    return k/(1+(x/theta)**n)

def nonlinearity(x, kc, km):
    return (kc*x)/(km+x)


In [6]:
#Biological System  
class UpstreamRepression(torch.nn.Module):
    def __init__(self):
        super(UpstreamRepression, self).__init__()
        #Initialize constants, taken from Verma et al paper.
        self.Vin = 1.
        self.e0 = 0.0467
        self.lam = 1.93E2 #1/s
        #Assume equal kinetics for all three enzymes
        self.kc = 12
        self.km = 10 #1/s
        self.W = torch.nn.Parameter(torch.tensor([[2,2],[1,1], [1E-7, 1E-7]]), requires_grad=True)

    def forward(self, t, y):
        dx0 = self.Vin - self.lam*y[0] - self.e0*nonlinearity(y[0], self.kc, self.km) - self.lam*y[1]
        dx1 = y[2]*nonlinearity(y[0], self.kc, self.km) - y[3]*nonlinearity(y[1], self.kc, self.km) - self.lam*y[1]
        de1 = repression(y[1], self.W[2][0], self.W[1][0], self.W[0][0]) - self.lam*y[2]
        de2 = self.W[2][1] - self.lam*y[3]
        j1 = (self.Vin -  y[3]*nonlinearity(y[1], self.kc, self.km))**2
        j2 = repression(y[1], self.W[2][0], self.W[1][0], self.W[0][0])
        return torch.stack([dx0, dx1, de1, de2, j1, j2])

In [7]:
#Custom loss function based on biological parameters
def loss_fun_bio(pred, alpha=1E7):
    """Computes scalarized loss including genetic constraint and product production"""
    j1 = pred[-1][-2]
    j2 = pred[-1][-1]
    loss = j1 + alpha*j2
    return loss

In [10]:
#TorchDiffEQ solution
func = UpstreamRepression()
learning_rate = 0.01
num_iters = 1000
optimizer = torch.optim.RMSprop(func.parameters(), lr=learning_rate)
adjoint = False
alpha = 1E7
solver = 'dopri8'
int_time = 5E4
np.random.seed(2021)

#Establish initial conditions
t = torch.linspace(0, int_time, 100) 
y0 = torch.tensor([2290., 0., 0., 0., 0., 0.]) 

pred = odeint(func, y0, t, method=solver)


KeyboardInterrupt: 

In [None]:
loss = loss_fun_bio(pred)