In [17]:
import torch
import torch.nn as nn
from scipy.integrate import odeint
import numpy as np

In [18]:
DEVICE = torch.device('cpu')

In [19]:
def func(params, y0):
      """This function solves the system to find the true mechanistic component
      for graphing."""
      def stress(t):
          if t < 30:
              return 0
          return torch.exp(-params['kdStress']*(t-30))

      def ode_rhs(y, t):
          y = torch.from_numpy(y).view(-1)
          dy = torch.zeros(4)

          # Apologies for the mess here, but typing out ODEs in Python is a bit of a
          #  chore
          wCRH = params['R0CRH'] + params['RCRH_CRH']*y[...,0] \
              + params['RSS_CRH']*stress(t) + params['RGR_CRH']*y[...,3]
          FCRH = (params['MaxCRH'])/(1 + torch.exp(-params['sigma']*wCRH))
          dy[0] = FCRH - params['tsCRH']*y[...,0]

          wACTH = params['R0ACTH'] + params['RCRH_ACTH']*y[...,0] \
              + params['RGR_ACTH']*y[...,3]
          FACTH = (params['MaxACTH'])/(1 + torch.exp(-params['sigma']*wACTH)) + params['BasalACTH']
          dy[1] = FACTH - params['tsACTH']*y[...,1]

          wCORT = params['R0CORT'] + params['RACTH_CORT']*y[...,1]
          FCORT = (params['MaxCORT'])/(1 + torch.exp(-params['sigma']*wCORT)) + params['BasalCORT']
          dy[2] = FCORT - params['tsCORT']*y[...,2]

          wGR = params['R0GR'] + params['RCORT_GR']*y[...,2] + params['RGR_GR']*y[...,3]
          FGR = params['ksGR']/(1 + torch.exp(-params['sigma']*wGR))
          dy[3] = FGR - params['kdGR']*y[...,3]
          return dy

      t_eval = torch.linspace(0,140,40)
      gflow = odeint(ode_rhs, y0, t_eval)
      gflow = torch.from_numpy(gflow)
      gflow = torch.cat((t_eval.view(-1,1), gflow), dim=1)
      return gflow



In [20]:
def param_init_tsst(model: nn.Module):
    """Initialize the parameters for the mechanistic loss, and set them to
    require gradient"""
    R0CRH = torch.nn.Parameter(torch.tensor(-0.52239, device=DEVICE), requires_grad=False)
    RCRH_CRH = torch.nn.Parameter(torch.tensor(0.97555, device=DEVICE), requires_grad=False)
    RGR_CRH = torch.nn.Parameter(torch.tensor(-2.0241, device=DEVICE), requires_grad=False)
    RSS_CRH = torch.nn.Parameter(torch.tensor(9.8594, device=DEVICE), requires_grad=False)
    sigma = torch.nn.Parameter(torch.tensor(4.974, device=DEVICE), requires_grad=False)
    tsCRH = torch.nn.Parameter(torch.tensor(0.10008, device=DEVICE), requires_grad=False)
    R0ACTH = torch.nn.Parameter(torch.tensor(-0.29065, device=DEVICE), requires_grad=False)
    RCRH_ACTH = torch.nn.Parameter(torch.tensor(6.006, device=DEVICE), requires_grad=False)
    RGR_ACTH = torch.nn.Parameter(torch.tensor(-10.004, device=DEVICE), requires_grad=False)
    tsACTH = torch.nn.Parameter(torch.tensor(0.046655, device=DEVICE), requires_grad=False)
    R0CORT = torch.nn.Parameter(torch.tensor(-0.95265, device=DEVICE), requires_grad=False)
    RACTH_CORT = torch.nn.Parameter(torch.tensor(0.022487, device=DEVICE), requires_grad=False)
    tsCORT = torch.nn.Parameter(torch.tensor(0.048451, device=DEVICE), requires_grad=False)
    R0GR = torch.nn.Parameter(torch.tensor(-0.49428, device=DEVICE), requires_grad=False)
    RCORT_GR = torch.nn.Parameter(torch.tensor(0.02745, device=DEVICE), requires_grad=False)
    RGR_GR = torch.nn.Parameter(torch.tensor(0.10572, device=DEVICE), requires_grad=False)
    kdStress = torch.nn.Parameter(torch.tensor(0.19604, device=DEVICE), requires_grad=False)
    stressStr = torch.nn.Parameter(torch.tensor(1., device=DEVICE), requires_grad=False)
    MaxCRH = torch.nn.Parameter(torch.tensor(1.0011, device=DEVICE), requires_grad=False)
    MaxACTH = torch.nn.Parameter(torch.tensor(140.2386, device=DEVICE), requires_grad=False)
    MaxCORT = torch.nn.Parameter(torch.tensor(30.3072, device=DEVICE), requires_grad=False)
    BasalACTH = torch.nn.Parameter(torch.tensor(0.84733, device=DEVICE), requires_grad=False)
    BasalCORT = torch.nn.Parameter(torch.tensor(0.29757, device=DEVICE), requires_grad=False)
    ksGR = torch.nn.Parameter(torch.tensor(0.40732, device=DEVICE), requires_grad=False)
    kdGR = torch.nn.Parameter(torch.tensor(0.39307, device=DEVICE), requires_grad=False)

    params = {
        'R0CRH': R0CRH,
        'RCRH_CRH': RCRH_CRH,
        'RGR_CRH': RGR_CRH,
        'RSS_CRH': RSS_CRH,
        'sigma': sigma,
        'tsCRH': tsCRH,
        'R0ACTH': R0ACTH,
        'RCRH_ACTH': RCRH_ACTH,
        'RGR_ACTH': RGR_ACTH,
        'tsACTH': tsACTH,
        'R0CORT': R0CORT,
        'RACTH_CORT': RACTH_CORT,
        'tsCORT': tsCORT,
        'R0GR': R0GR,
        'RCORT_GR': RCORT_GR,
        'RGR_GR': RGR_GR,
        'kdStress': kdStress,
        'stressStr': stressStr,
        'MaxCRH': MaxCRH,
        'MaxACTH': MaxACTH,
        'MaxCORT': MaxCORT,
        'BasalACTH': BasalACTH,
        'BasalCORT': BasalCORT,
        'ksGR': ksGR,
        'kdGR': kdGR,
    }
    for key, val in params.items():
        model.register_parameter(key, val)
    return params

In [22]:
params = param_init_tsst(nn.Module())

In [None]:
for i in range(15):
    