In [2]:
import re
from typing import Dict

import numpy as np
import torch
from torchdiffeq import odeint


In [15]:
def ode_matlab_to_python(code: str, state_width: int) -> str:
    # Replace state syntax
    code = re.sub(r"x\((\d+)\)", lambda match: f"x[...,{int(match.group(1))-1}]", code)

    # Replace derivative syntax
    code = re.sub(r"dx\(\s(\d+)\)", lambda match: f"dx[{int(match.group(1))-1}]", code)

    # Replace constant syntax
    code = re.sub(r"k\((\d+)\)", lambda match: f"k[...,{int(match.group(1))-1}]", code)

    # Replace comments
    code = re.sub(r"%", r"#", code)

    # Remove line endings
    code = re.sub(r";", r"", code)

    # Add list & stacking
    # dx = [f"dx{i}" for i in range(state_width)]
    code = f"dx = [None]*{state_width}\n\n" + code
    # code = code + f"\n\nreturn torch.stack(({','.join(dx)}), dim=-1)"
    code = code + "\n\nreturn torch.stack(dx, dim=-1)"

    return code


def constants_matlab_to_python(code: str, num_constants: int) -> str:
    # Replace constant syntax
    code = re.sub(r"k\((\d+)\)", lambda match: f"k[{int(match.group(1))-1}]", code)

    # Replace x*10^y -> xey
    code = re.sub(r"(\d+)\*10\^(\d+)", r"\1e\2", code)

    # Replace comments
    code = re.sub(r"%", r"#", code)

    # Remove line endings
    code = re.sub(r";", r"", code)

    # Add constant list & stacking
    code = f"k = [None]*{num_constants}\n\n" + code
    code = code + "\n\nk = torch.Tensor(k)"

    return code


ode = """dx( 1) = x(1)*(k(1) * x(4)/(k(11)+x(4))*(1 - x(1)/k(4)) - k(3)); % Fibrobasts
dx( 2) = x(2)*(k(2) * x(3)/(k(12)+x(3)) - k(3)) + k(13); % Mph
dx( 3) = k(7)*x(1) - k(9)*x(2) * x(3)/(k(12)+x(3)) - k(5)*x(3);% CSF
dx( 4) = k(8)*x(2) + k(6)*x(1) - k(10)*x(1)*x(4)/(k(11)+x(4)) - k(5)*x(4);% PDGF"""

constants = """k(1)=0.9; %proliferation rates: lambda1=0.9/day,  
k(2)=0.8; %lambda2=0.8/day
k(3) = 0.3; %mu_1, mu_2, death rates: 0.3/day
k(4) = 1e6; %carrying capacity: 10^6 cells
k(5)= 2; %growth factor degradation: gamma=2/day
k(6)=240*1440;%growth factor secretion rates: beta3=240 molecules/cell/min  ---- beta_3 
k(7)=470*1440;% beta1=470 molecules/cell/min                                ---- beta_1
k(8)=70*1440;% beta2=70 molecules/cell/min                                 ---- beta_2
k(9)=940*1440; %alpha1=940 molecules/cell/min, endocytosis rate CSF1       ---- alpha_1
k(10)=510*1440; %alpha2=510 molecules/cell/min, endocytosis rate PDGF     ---- alpha_2
k(11)=6*10^8; % %binding affinities: k1=6x10^8 molecules (PDGF)     ---- k_1
k(12)=6*10^8; % k2=6x10^8 (CSF)                                   ---- k_2
k(13)=0;%120 inflammation pulse
k(14)=1e6;"""

print(ode_matlab_to_python(ode, 4))
print("\n----------\n")
print(constants_matlab_to_python(constants, 14))


dx = [None]*4

dx[0] = x[...,0]*(k[...,0] * x[...,3]/(k[...,10]+x[...,3])*(1 - x[...,0]/k[...,3]) - k[...,2]) # Fibrobasts
dx[1] = x[...,1]*(k[...,1] * x[...,2]/(k[...,11]+x[...,2]) - k[...,2]) + k[...,12] # Mph
dx[2] = k[...,6]*x[...,0] - k[...,8]*x[...,1] * x[...,2]/(k[...,11]+x[...,2]) - k[...,4]*x[...,2]# CSF
dx[3] = k[...,7]*x[...,1] + k[...,5]*x[...,0] - k[...,9]*x[...,0]*x[...,3]/(k[...,10]+x[...,3]) - k[...,4]*x[...,3]# PDGF

return torch.stack(dx, dim=-1)

----------

k = [None]*14

k[0]=0.9 #proliferation rates: lambda1=0.9/day,  
k[1]=0.8 #lambda2=0.8/day
k[2] = 0.3 #mu_1, mu_2, death rates: 0.3/day
k[3] = 1e6 #carrying capacity: 10^6 cells
k[4]= 2 #growth factor degradation: gamma=2/day
k[5]=240*1440#growth factor secretion rates: beta3=240 molecules/cell/min  ---- beta_3 
k[6]=470*1440# beta1=470 molecules/cell/min                                ---- beta_1
k[7]=70*1440# beta2=70 molecules/cell/min                                 ---- beta_2
k[8]=940*1440 #alpha1=940 molec

In [18]:
def fm_ode(x: torch.Tensor, k: torch.Tensor):
    dx = [None] * 4

    dx[0] = x[..., 0] * (
        k[..., 0] * x[..., 3] / (k[..., 10] + x[..., 3]) * (1 - x[..., 0] / k[..., 3])
        - k[..., 2]
    )  # Fibrobasts
    dx[1] = (
        x[..., 1] * (k[..., 1] * x[..., 2] / (k[..., 11] + x[..., 2]) - k[..., 2])
        + k[..., 12]
    )  # Mph
    dx[2] = (
        k[..., 6] * x[..., 0]
        - k[..., 8] * x[..., 1] * x[..., 2] / (k[..., 11] + x[..., 2])
        - k[..., 4] * x[..., 2]
    )  # CSF
    dx[3] = (
        k[..., 7] * x[..., 1]
        + k[..., 5] * x[..., 0]
        - k[..., 9] * x[..., 0] * x[..., 3] / (k[..., 10] + x[..., 3])
        - k[..., 4] * x[..., 3]
    )  # PDGF

    return torch.stack(dx, dim=-1)


k = [None] * 14

k[0] = 0.9  # proliferation rates: lambda1=0.9/day,
k[1] = 0.8  # lambda2=0.8/day
k[2] = 0.3  # mu_1, mu_2, death rates: 0.3/day
k[3] = 1e6  # carrying capacity: 10^6 cells
k[4] = 2  # growth factor degradation: gamma=2/day
k[5] = (
    240 * 1440
)  # growth factor secretion rates: beta3=240 molecules/cell/min  ---- beta_3
k[6] = (
    470 * 1440
)  # beta1=470 molecules/cell/min                                ---- beta_1
k[7] = (
    70 * 1440
)  # beta2=70 molecules/cell/min                                 ---- beta_2
k[8] = (
    940 * 1440
)  # alpha1=940 molecules/cell/min, endocytosis rate CSF1       ---- alpha_1
k[9] = (
    510 * 1440
)  # alpha2=510 molecules/cell/min, endocytosis rate PDGF     ---- alpha_2
k[10] = 6e8  # #binding affinities: k1=6x10^8 molecules (PDGF)     ---- k_1
k[11] = 6e8  # k2=6x10^8 (CSF)                                   ---- k_2
k[12] = 0  # 120 inflammation pulse
k[13] = 1e6

fm_k = torch.Tensor(k)
