In [None]:
import re
from typing import Dict, List

import matplotlib.pyplot as plt
import numpy as np
import torch
from torchdiffeq import odeint


In [None]:
def ode_matlab_to_python(code: str, state_width: int) -> str:
    # Replace state syntax
    # code = re.sub(r"x\((\d+)\)", lambda match: f"x[...,{int(match.group(1))-1}]", code)
    code = re.sub(r"x\((\d+)\)", lambda match: f"x[{int(match.group(1))-1}]", code)

    # Replace derivative syntax
    # code = re.sub(r"dx\(\s(\d+)\)", lambda match: f"dx[{int(match.group(1))-1}]", code)
    code = re.sub(r"dx\(\s(\d+)\)", lambda match: f"dx{int(match.group(1))-1}", code)

    # Replace constant syntax
    # code = re.sub(r"k\((\d+)\)", lambda match: f"k[...,{int(match.group(1))-1}]", code)
    code = re.sub(r"k\((\d+)\)", lambda match: f"k[{int(match.group(1))-1}]", code)

    # Replace comments
    code = re.sub(r"%", r"#", code)

    # Remove line endings
    code = re.sub(r";", r"", code)

    # Add splitting
    code = f"x = torch.split(x, 1, dim=-1)\nk = torch.split(k, 1, dim=-1)\n\n" + code

    # Add list & stacking
    # code = f"dx = [None]*{state_width}\n\n" + code
    dx = [f"dx{i}" for i in range(state_width)]
    code = code + f"\n\nreturn torch.stack(({','.join(dx)}), dim=-1)"
    # code = code + "\n\nreturn torch.stack(dx, dim=-1)"

    return code


def constants_matlab_to_python(code: str, num_constants: int) -> str:
    # Replace constant syntax
    code = re.sub(r"k\((\d+)\)", lambda match: f"k[{int(match.group(1))-1}]", code)

    # Replace x*10^y -> xey
    code = re.sub(r"(\d+)\*10\^(\d+)", r"\1e\2", code)

    # Replace comments
    code = re.sub(r"%", r"#", code)

    # Remove line endings
    code = re.sub(r";", r"", code)

    # Add constant list & stacking
    code = f"k = [None]*{num_constants}\n\n" + code
    code = code + "\n\nk = torch.Tensor(k)"

    return code


ode = """dx( 1) = x(1)*(k(1) * x(4)/(k(11)+x(4))*(1 - x(1)/k(4)) - k(3)); % Fibrobasts
dx( 2) = x(2)*(k(2) * x(3)/(k(12)+x(3)) - k(3)) + k(13); % Mph
dx( 3) = k(7)*x(1) - k(9)*x(2) * x(3)/(k(12)+x(3)) - k(5)*x(3);% CSF
dx( 4) = k(8)*x(2) + k(6)*x(1) - k(10)*x(1)*x(4)/(k(11)+x(4)) - k(5)*x(4);% PDGF
"""

constants = """k(1)=0.9; %proliferation rates: lambda1=0.9/day,  
k(2)=0.8; %lambda2=0.8/day
k(3) = 0.3; %mu_1, mu_2, death rates: 0.3/day
k(4) = 1e6; %carrying capacity: 10^6 cells
k(5)= 2; %growth factor degradation: gamma=2/day
k(6)=240*1440;%growth factor secretion rates: beta3=240 molecules/cell/min  ---- beta_3 
k(7)=470*1440;% beta1=470 molecules/cell/min                                ---- beta_1
k(8)=70*1440;% beta2=70 molecules/cell/min                                 ---- beta_2
k(9)=940*1440; %alpha1=940 molecules/cell/min, endocytosis rate CSF1       ---- alpha_1
k(10)=510*1440; %alpha2=510 molecules/cell/min, endocytosis rate PDGF     ---- alpha_2
k(11)=6*10^8; % %binding affinities: k1=6x10^8 molecules (PDGF)     ---- k_1
k(12)=6*10^8; % k2=6x10^8 (CSF)                                   ---- k_2
k(13)=0;%120 inflammation pulse
k(14)=1e6;"""

print(ode_matlab_to_python(ode, 4))
print("\n----------\n")
print(constants_matlab_to_python(constants, 14))


In [None]:
@torch.jit.script
def fm_ode(x: torch.Tensor, k: List[float]):
    x = torch.split(x, 1, dim=-1)
    # k = torch.split(k, 1, dim=-1)

    dx0 = x[0] * (k[0] * x[3] / (k[10] + x[3]) * (1 - x[0] / k[3]) - k[2])  # Fibrobasts
    dx1 = x[1] * (k[1] * x[2] / (k[11] + x[2]) - k[2]) + k[12]  # Mph
    dx2 = k[6] * x[0] - k[8] * x[1] * x[2] / (k[11] + x[2]) - k[4] * x[2]  # CSF
    dx3 = (
        k[7] * x[1] + k[5] * x[0] - k[9] * x[0] * x[3] / (k[10] + x[3]) - k[4] * x[3]
    )  # PDGF

    return torch.stack((dx0, dx1, dx2, dx3), dim=-1)


k = [None] * 14

k[0] = 0.9  # proliferation rates: lambda1=0.9/day,
k[1] = 0.8  # lambda2=0.8/day
k[2] = 0.3  # mu_1, mu_2, death rates: 0.3/day
k[3] = 1e6  # carrying capacity: 10^6 cells
k[4] = 2  # growth factor degradation: gamma=2/day
k[5] = (
    240 * 1440
)  # growth factor secretion rates: beta3=240 molecules/cell/min  ---- beta_3
k[6] = (
    470 * 1440
)  # beta1=470 molecules/cell/min                                ---- beta_1
k[7] = (
    70 * 1440
)  # beta2=70 molecules/cell/min                                 ---- beta_2
k[8] = (
    940 * 1440
)  # alpha1=940 molecules/cell/min, endocytosis rate CSF1       ---- alpha_1
k[9] = (
    510 * 1440
)  # alpha2=510 molecules/cell/min, endocytosis rate PDGF     ---- alpha_2
k[10] = 6e8  # #binding affinities: k1=6x10^8 molecules (PDGF)     ---- k_1
k[11] = 6e8  # k2=6x10^8 (CSF)                                   ---- k_2
k[12] = 0  # 120 inflammation pulse
k[13] = 1e6

fm_k = torch.Tensor(k)


# def fm_k_helper(t: torch.Tensor) -> torch.Tensor:
def fm_k_helper(t: float) -> List[float]:
    k = [None] * 14

    k[0] = 0.9  # proliferation rates: lambda1=0.9/day,
    k[1] = 0.8  # lambda2=0.8/day
    k[2] = 0.3  # mu_1, mu_2, death rates: 0.3/day
    k[3] = 1e6  # carrying capacity: 10^6 cells
    k[4] = 2  # growth factor degradation: gamma=2/day
    k[5] = (
        240 * 1440
    )  # growth factor secretion rates: beta3=240 molecules/cell/min  ---- beta_3
    k[6] = (
        470 * 1440
    )  # beta1=470 molecules/cell/min                                ---- beta_1
    k[7] = (
        70 * 1440
    )  # beta2=70 molecules/cell/min                                 ---- beta_2
    k[8] = (
        940 * 1440
    )  # alpha1=940 molecules/cell/min, endocytosis rate CSF1       ---- alpha_1
    k[9] = (
        510 * 1440
    )  # alpha2=510 molecules/cell/min, endocytosis rate PDGF     ---- alpha_2
    k[10] = 6e8  # #binding affinities: k1=6x10^8 molecules (PDGF)     ---- k_1
    k[11] = 6e8  # k2=6x10^8 (CSF)                                   ---- k_2
    k[12] = 140 * 1440 if t < 4.0 else 0  # 120 inflammation pulse
    k[13] = 1e6

    # return torch.Tensor(k)
    return k


In [None]:
with torch.no_grad():
    fm_k0 = fm_k_helper(0)
    fm_k1 = fm_k_helper(100)

    t = torch.linspace(0.0, 100.0, 300)
    y = odeint(
        lambda t, y: fm_ode(y, fm_k0 if t < 4.0 else fm_k1),
        torch.Tensor([1.0, 1.0, 0.0, 0.0]),
        t,
        # method="scipy_solver",
        method="dopri5",
        atol=1e-5,
        rtol=1e-5,
        options=dict(
            # solver="LSODA",
            # step_size=0.01,
            jump_t=torch.Tensor([4.0]),
        ),
    )


In [None]:
with torch.no_grad():
    y = odeint(
        lambda t, y: fm_ode(y, fm_k_helper(0)),
        torch.Tensor([1.0, 1.0, 0.0, 0.0]),
        torch.linspace(0.0, 4.0, 4),
        # method="scipy_solver",
        method="dopri5",
        options=dict(
            # solver="BDF",
            # step_size=0.1,
            # jump_t=torch.Tensor([4.0]),
        ),
    )

    y = odeint(
        lambda t, y: fm_ode(y, fm_k_helper(10)),
        y[-1],
        torch.linspace(4.0, 10.0, 300 - 4),
        # method="scipy_solver",
        method="dopri5",
        options=dict(
            # solver="BDF",
            # step_size=0.1,
            # jump_t=torch.Tensor([4.0]),
        ),
    )


In [None]:
for i in range(4):
    plt.figure()
    plt.yscale("log")
    plt.plot(t.cpu().numpy(), y[..., i].cpu().numpy())
    plt.show()
