In [1]:
#Set the devices which Jax should use. 
#This must be done before importing jax
import os
# --- Control Flag ---
USE_GPU = False
# --------------------

if USE_GPU:
    # Set JAX to use the GPU. The device number (0) is for the first GPU.
    os.environ['JAX_PLATFORMS'] = 'cuda,cpu'
    # Optional: Pin JAX to a specific GPU
    os.environ['CUDA_VISIBLE_DEVICES'] = '0' 
else:
    os.environ['JAX_PLATFORMS'] = 'cpu'
    import numpyro
    numpyro.set_host_device_count(4)
    

import jax
print(f"JAX is running on: {jax.default_backend()}")
print(f"JAX devices: {jax.devices()}")

import finitediffx as fdx
import pandas as pd
import numpyro
from niceode.pymc_utils import make_pymc_model
import pymc as pm
from niceode.utils import (CompartmentalModel, 
                           ODEInitVals,
                           PopulationCoeffcient,
                           neg2_log_likelihood_loss,
                           ObjectiveFunctionColumn,
                           FOCE_approx_ll_loss,
                           FOCEi_approx_ll_loss,
                           FO_approx_ll_loss
                           )
from niceode.diffeqs import OneCompartmentAbsorption
import numpy as np
import joblib as jb
from niceode.jax_utils import FO_approx_neg2ll_loss_jax


  from .autonotebook import tqdm as notebook_tqdm


JAX is running on: cpu
JAX devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]


# Data Loading

In [2]:
df = pd.read_csv(r"/workspaces/PK-Analysis/data/theo_nlmixr2.csv", ) 
df.loc[df['AMT'] == 0.0, 'AMT'] = pd.NA
df['AMT'] = df['AMT'].ffill()
df = df.loc[df['EVID'] == 0, :].copy()

# NCA Parameter Estimation

Perform non-compartmental analysis of concentration:time profiles to inform priors/intial estimates of PK parameters. 
NCA was not the focus of this project. The NCA estimates generated below are resonable and useful for setting inital estimates, but do not exactly match the results generated by mature NCA packages such as PKNCA (https://github.com/humanpred/pknca). 

In [3]:
from niceode.nca import NCA

nca_obj = NCA(
    subject_id_col='ID', 
    conc_col='DV',
    time_col='TIME', 
    dose_col='AMT',
    data = df
)

nca_result_df = nca_obj.estimate_all_nca_params(terminal_phase_adj_r2_thresh=0.85)
nca_result_df.describe()

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:01<00:00,  1.07s/it]


Unnamed: 0,ID,window_halflife_est,window_k_est,linup_logdown_auc,linup_logdown_aumc,cl/F,mrt,vss,boot_n
count,12.0,12.0,12.0,12.0,12.0,12.0,12.0,12.0,12.0
mean,6.5,11.387274,0.062483,148.710806,1940.940871,2.230496,12.315183,26.358032,0.0
std,3.605551,2.165094,0.009598,40.571084,1226.322035,0.46597,3.340649,3.520321,0.0
min,1.0,8.847931,0.039791,111.134079,1150.646323,1.263686,9.732603,21.146271,0.0
25%,3.75,10.007099,0.059189,126.788884,1358.684838,2.057916,10.19994,24.109577,0.0
50%,6.5,11.055922,0.062753,132.278256,1555.587703,2.290714,11.691217,25.438864,0.0
75%,9.25,11.711029,0.069251,155.734028,1691.303899,2.502026,12.345776,28.38802,0.0
max,12.0,17.41594,0.078323,253.221042,5539.478295,2.879405,21.876058,33.010618,0.0


# Model Fitting With NiceODE

In [4]:
me_mod_fo =  CompartmentalModel(
        model_name = "debug_theoph_abs_ka-clME-vd_JAXFOCE_jaxoptspwrapLbfgsb_fdxouteriftinner_nodep_omegadiag_dermal",
            ode_t0_cols=[ ODEInitVals('DV'), ODEInitVals('AMT'),],
            conc_at_time_col = 'DV',
            subject_id_col = 'ID', 
            time_col = 'TIME',
            population_coeff=[
                                PopulationCoeffcient('ka', 
                                                    optimization_init_val=1.6, 
                                                    subject_level_intercept=True,
                                                    #optimization_lower_bound = np.log(1e-6),
                                                    #optimization_upper_bound = np.log(15),
                                                    subject_level_intercept_sd_init_val = 0.6, 
                                                    #subject_level_intercept_sd_upper_bound = 20,
                                                    #subject_level_intercept_sd_lower_bound=1e-6
                                                    ),
                                PopulationCoeffcient('cl',
                                                    optimization_init_val = 3,
                                                    #optimization_lower_bound = np.log(1e-4),
                                                    #optimization_upper_bound=np.log(25),
                                                    subject_level_intercept=True, 
                                                    subject_level_intercept_sd_init_val = 0.3, 
                                                   # subject_level_intercept_sd_upper_bound = 5,
                                                   # subject_level_intercept_sd_lower_bound=1e-6
                                                    ),
                                PopulationCoeffcient('vd', optimization_init_val = 35,
                                                    #, optimization_lower_bound = np.log(.1)
                                                    #,optimization_upper_bound=np.log(80), 
                                                    subject_level_intercept=True, 
                                                    subject_level_intercept_sd_init_val = 0.1, 
                                                    #subject_level_intercept_sd_upper_bound = 5,
                                                    #subject_level_intercept_sd_lower_bound=1e-6
                                                    
                                                    #, optimization_upper_bound = np.log(.05)
                                                    ),
                            ],
            dep_vars= None, 

                                    pk_model_class=OneCompartmentAbsorption, 
                                    model_error_sigma=PopulationCoeffcient('sigma'
                                                                            ,log_transform_init_val=False
                                                                            , optimization_init_val=.5
                                                                            ,optimization_lower_bound=0.00001
                                                                            ,optimization_upper_bound=3
                                                                            ),
  
                                    batch_id='theoph_test1',

                                    significant_digits=3,
                                    #me_loss_function=FO_approx_ll_loss,
                                    jax_loss=FO_approx_neg2ll_loss_jax,
                                    use_full_omega=True, 
                                    use_surrogate_neg2ll=True, 
                                    fit_jax_objective=True,
                                    )

In [5]:
me_mod_fo = me_mod_fo.fit2(df, ci_level = None, debug_fit=False, )

Successfully compiled closed stiff ODE solver
Sucessfully complied non-stiff ODE solver
Sucessfully complied augmented dynamics stiff ODE solver




Sucessfully complied augmented dynamics stiff ODE solver
Sucessfully complied stiff ODE solver
Sucessfully complied KEYS stiff ODE solver
Sucessfully complied non-stiff PyMC ODE solver
Sucessfully complied hybrid-dimensional non-stiff PyMC ODE solver
Sucessfully complied non-dimensional non-stiff PyMC ODE solver
Sucessfully complied non-dimensional stiff PyMC ODE solver
Sucessfully complied stiff PyMC ODE solver
Compiling `FO_approx_neg2ll_loss_jax`
Compiling `approx_neg2ll_loss_jax`
Compiling `_solve_augdyn_ivp_jax_worker`
Compiling `FO_approx_neg2ll_loss_jax`
Compiling `approx_neg2ll_loss_jax`
Compiling `_solve_augdyn_ivp_jax_worker`
Compiling `_solve_ivp_jax_worker`
Compiling `_solve_ivp_jax_worker`
🏃 View run b-theoph_test1_m-debug_theoph_abs_ka-clME-vd_JAXFOCE_jaxoptspwrapLbfgsb_fdxouteriftinner_nodep_omegadiag_dermal_f-df3d73a8-7c30-485a-a44e-ddad84d18818 at: http://mlflow-server:5000/#/experiments/377/runs/bcff621b906d40aaa16be42db7161588
🧪 View experiment at: http://mlflow-serv

In [6]:
me_mod_fo.fit_result_summary_

Unnamed: 0,model_coeff,log_name,population_coeff,model_error,subject_level_intercept,coeff_dep_var,model_coeff_dep_var,subject_level_intercept_name,fit_result_summary_name,init_val,lower_bound,upper_bound,fitted_param_val,fitted_param_sd,fitted_param_rse,fitted_params_lower_ci95,fitted_params_upper_ci95,back_transformed_param_val,back_transformed_lower_ci95,back_transformed_upper_ci95
0,ka,ka_pop,True,False,False,False,,,ka_pop,0.470004,,,0.979452,,,,,2.662995,,
1,cl,cl_pop,True,False,False,False,,,cl_pop,1.098612,,,1.336739,,,,,3.806609,,
2,vd,vd_pop,True,False,False,False,,,vd_pop,3.555348,,,3.532256,,,,,34.201024,,
3,sigma,sigma_const,False,True,False,False,,,sigma_const,-0.693147,2e-09,22000.0,0.740784,,,,,,,
4,ka,lchol_omega_omega2_ka_omega2_ka,False,False,True,False,,lchol_omega_omega2_ka_omega2_ka,omega2_ka,-0.510826,2e-09,22000.0,0.023184,,,,,,,
5,cl_ka,lchol_omega_omega2_cl_omega2_ka,False,False,True,False,,lchol_omega_omega2_cl_omega2_ka,corr_b_i_cl_b_i_ka,0.0,2e-09,22000.0,0.024155,,,,,,,
6,cl,lchol_omega_omega2_cl_omega2_cl,False,False,True,False,,lchol_omega_omega2_cl_omega2_cl,omega2_cl,-1.203973,2e-09,22000.0,0.152193,,,,,,,
7,vd_ka,lchol_omega_omega2_vd_omega2_ka,False,False,True,False,,lchol_omega_omega2_vd_omega2_ka,corr_b_i_vd_b_i_ka,0.0,2e-09,22000.0,-0.025256,,,,,,,
8,vd_cl,lchol_omega_omega2_vd_omega2_cl,False,False,True,False,,lchol_omega_omega2_vd_omega2_cl,corr_b_i_vd_b_i_cl,0.0,2e-09,22000.0,0.995485,,,,,,,
9,vd,lchol_omega_omega2_vd_omega2_vd,False,False,True,False,,lchol_omega_omega2_vd_omega2_vd,omega2_vd,-2.302585,2e-09,22000.0,0.145562,,,,,,,


In [7]:
me_mod_fo.fit_result_summary_

Unnamed: 0,model_coeff,log_name,population_coeff,model_error,subject_level_intercept,coeff_dep_var,model_coeff_dep_var,subject_level_intercept_name,fit_result_summary_name,init_val,lower_bound,upper_bound,fitted_param_val,fitted_param_sd,fitted_param_rse,fitted_params_lower_ci95,fitted_params_upper_ci95,back_transformed_param_val,back_transformed_lower_ci95,back_transformed_upper_ci95
0,ka,ka_pop,True,False,False,False,,,ka_pop,0.470004,,,0.979452,,,,,2.662995,,
1,cl,cl_pop,True,False,False,False,,,cl_pop,1.098612,,,1.336739,,,,,3.806609,,
2,vd,vd_pop,True,False,False,False,,,vd_pop,3.555348,,,3.532256,,,,,34.201024,,
3,sigma,sigma_const,False,True,False,False,,,sigma_const,-0.693147,2e-09,22000.0,0.740784,,,,,,,
4,ka,lchol_omega_omega2_ka_omega2_ka,False,False,True,False,,lchol_omega_omega2_ka_omega2_ka,omega2_ka,-0.510826,2e-09,22000.0,0.023184,,,,,,,
5,cl_ka,lchol_omega_omega2_cl_omega2_ka,False,False,True,False,,lchol_omega_omega2_cl_omega2_ka,corr_b_i_cl_b_i_ka,0.0,2e-09,22000.0,0.024155,,,,,,,
6,cl,lchol_omega_omega2_cl_omega2_cl,False,False,True,False,,lchol_omega_omega2_cl_omega2_cl,omega2_cl,-1.203973,2e-09,22000.0,0.152193,,,,,,,
7,vd_ka,lchol_omega_omega2_vd_omega2_ka,False,False,True,False,,lchol_omega_omega2_vd_omega2_ka,corr_b_i_vd_b_i_ka,0.0,2e-09,22000.0,-0.025256,,,,,,,
8,vd_cl,lchol_omega_omega2_vd_omega2_cl,False,False,True,False,,lchol_omega_omega2_vd_omega2_cl,corr_b_i_vd_b_i_cl,0.0,2e-09,22000.0,0.995485,,,,,,,
9,vd,lchol_omega_omega2_vd_omega2_vd,False,False,True,False,,lchol_omega_omega2_vd_omega2_vd,omega2_vd,-2.302585,2e-09,22000.0,0.145562,,,,,,,


In [8]:
me_mod_fo.fit_result_summary_pk_stats_['no_me']

Unnamed: 0,ka,ke,t_half_elim,t_half_abs,cl,vd
0,2.662995,0.1113010346251162,6.227679579930244,0.2602885495463712,3.806609,34.201024
1,2.662995,0.1113010346251162,6.227679579930244,0.2602885495463712,3.806609,34.201024
2,2.662995,0.1113010346251162,6.227679579930244,0.2602885495463712,3.806609,34.201024
3,2.662995,0.1113010346251162,6.227679579930244,0.2602885495463712,3.806609,34.201024
4,2.662995,0.1113010346251162,6.227679579930244,0.2602885495463712,3.806609,34.201024
5,2.662995,0.1113010346251162,6.227679579930244,0.2602885495463712,3.806609,34.201024
6,2.662995,0.1113010346251162,6.227679579930244,0.2602885495463712,3.806609,34.201024
7,2.662995,0.1113010346251162,6.227679579930244,0.2602885495463712,3.806609,34.201024
8,2.662995,0.1113010346251162,6.227679579930244,0.2602885495463712,3.806609,34.201024
9,2.662995,0.1113010346251162,6.227679579930244,0.2602885495463712,3.806609,34.201024


In [9]:
me_mod_fo.fit_result_.pred_df

Unnamed: 0,ID,TIME,DV,pop_pred_DV__PRED,pop_pred_resid__RES,weighted_pop_pred_resid__WRES,indiv_pred_DV__IPRED,indiv_pred_resid__IRES,weighted_indiv_pred_resid__IWRES
1,1,0.00,0.74,0.021637,0.718363,1.309065,0.024358,0.715642,1.304105
2,1,0.25,2.84,4.499723,-1.659723,-3.024494,5.036034,-2.196034,-4.001806
3,1,0.57,6.57,7.044429,-0.474429,-0.864546,7.909258,-1.339258,-2.440513
4,1,1.12,10.50,8.144386,2.355614,4.292608,9.180074,1.319926,2.405285
5,1,2.02,9.66,7.770563,1.889437,3.443098,8.793707,0.866293,1.578636
...,...,...,...,...,...,...,...,...,...
139,12,5.07,8.57,5.564961,3.005039,5.476047,5.956351,2.613649,4.762822
140,12,7.07,6.59,4.454390,2.135610,3.891696,4.790172,1.799828,3.279805
141,12,9.03,6.11,3.581352,2.528648,4.607926,3.869127,2.240873,4.083517
142,12,12.05,4.57,2.558867,2.011133,3.664864,2.784164,1.785836,3.254308


In [10]:
all_pred_df = []

In [11]:
pred_df = me_mod_fo.fit_result_.pred_df.copy()
pred_cols = {'pop_pred_DV__PRED':'PRED_full_omega', 'indiv_pred_DV__IPRED':'IPRED_full_omega'}
pred_df = pred_df.rename(columns=pred_cols)
value_vars = ['DV'] + list(pred_cols.values())
pred_df = pred_df.melt(id_vars = ['ID', 'TIME'], value_vars = value_vars, var_name = 'DV_and_PREDs', value_name = 'Conc')
all_pred_df.append(pred_df.copy())

In [12]:
pred_df

Unnamed: 0,ID,TIME,DV_and_PREDs,Conc
0,1,0.00,DV,0.740000
1,1,0.25,DV,2.840000
2,1,0.57,DV,6.570000
3,1,1.12,DV,10.500000
4,1,2.02,DV,9.660000
...,...,...,...,...
391,12,5.07,IPRED_full_omega,5.956351
392,12,7.07,IPRED_full_omega,4.790172
393,12,9.03,IPRED_full_omega,3.869127
394,12,12.05,IPRED_full_omega,2.784164


In [13]:
import plotly.express as px


px.line(data_frame=pred_df, x = 'TIME', y = 'Conc', color = 'DV_and_PREDs', animation_frame='ID')

# Use neg2ll as objective

In [15]:
me_mod_fo_nosurrogate =  CompartmentalModel(
        model_name = "debug_theoph_abs_ka-clME-vd_JAXFOCE_jaxoptspwrapLbfgsb_fdxouteriftinner_nodep_omegadiag_dermal",
            ode_t0_cols=[ ODEInitVals('DV'), ODEInitVals('AMT'),],
            conc_at_time_col = 'DV',
            subject_id_col = 'ID', 
            time_col = 'TIME',
            population_coeff=[
                                PopulationCoeffcient('ka', 
                                                    optimization_init_val=1.6, 
                                                    subject_level_intercept=True,
                                                    #optimization_lower_bound = np.log(1e-6),
                                                    #optimization_upper_bound = np.log(15),
                                                    subject_level_intercept_sd_init_val = 0.6, 
                                                    #subject_level_intercept_sd_upper_bound = 20,
                                                    #subject_level_intercept_sd_lower_bound=1e-6
                                                    ),
                                PopulationCoeffcient('cl',
                                                    optimization_init_val = 3,
                                                    #optimization_lower_bound = np.log(1e-4),
                                                    #optimization_upper_bound=np.log(25),
                                                    subject_level_intercept=True, 
                                                    subject_level_intercept_sd_init_val = 0.3, 
                                                   # subject_level_intercept_sd_upper_bound = 5,
                                                   # subject_level_intercept_sd_lower_bound=1e-6
                                                    ),
                                PopulationCoeffcient('vd', optimization_init_val = 35,
                                                    #, optimization_lower_bound = np.log(.1)
                                                    #,optimization_upper_bound=np.log(80), 
                                                    subject_level_intercept=True, 
                                                    subject_level_intercept_sd_init_val = 0.1, 
                                                    #subject_level_intercept_sd_upper_bound = 5,
                                                    #subject_level_intercept_sd_lower_bound=1e-6
                                                    
                                                    #, optimization_upper_bound = np.log(.05)
                                                    ),
                            ],
            dep_vars= None, 

                                    pk_model_class=OneCompartmentAbsorption, 
                                    model_error_sigma=PopulationCoeffcient('sigma'
                                                                            ,log_transform_init_val=False
                                                                            , optimization_init_val=.5
                                                                            ,optimization_lower_bound=0.00001
                                                                            ,optimization_upper_bound=3
                                                                            ),
  
                                    batch_id='theoph_test1',

                                    significant_digits=3,
                                    #me_loss_function=FO_approx_ll_loss,
                                    jax_loss=FO_approx_neg2ll_loss_jax,
                                    use_full_omega=True, 
                                    use_surrogate_neg2ll=False, 
                                    fit_jax_objective=True,
                                    )

In [16]:
me_mod_fo_nosurrogate = me_mod_fo_nosurrogate.fit2(df, ci_level = None, debug_fit=False, )

Successfully compiled closed stiff ODE solver
Sucessfully complied non-stiff ODE solver
Sucessfully complied augmented dynamics stiff ODE solver




Sucessfully complied augmented dynamics stiff ODE solver
Sucessfully complied stiff ODE solver
Sucessfully complied KEYS stiff ODE solver
Sucessfully complied non-stiff PyMC ODE solver
Sucessfully complied hybrid-dimensional non-stiff PyMC ODE solver
Sucessfully complied non-dimensional non-stiff PyMC ODE solver
Sucessfully complied non-dimensional stiff PyMC ODE solver
Sucessfully complied stiff PyMC ODE solver
Compiling `FO_approx_neg2ll_loss_jax`
Compiling `approx_neg2ll_loss_jax`
Compiling `_solve_augdyn_ivp_jax_worker`
Compiling `FO_approx_neg2ll_loss_jax`
Compiling `approx_neg2ll_loss_jax`
Compiling `_solve_augdyn_ivp_jax_worker`
Compiling `_solve_ivp_jax_worker`
Compiling `_solve_ivp_jax_worker`
🏃 View run b-theoph_test1_m-debug_theoph_abs_ka-clME-vd_JAXFOCE_jaxoptspwrapLbfgsb_fdxouteriftinner_nodep_omegadiag_dermal_f-36882e23-5e43-4be3-a603-9d3af9ac3114 at: http://mlflow-server:5000/#/experiments/377/runs/5fec3cd30f034eb3adaca147aec76fa3
🧪 View experiment at: http://mlflow-serv

In [17]:
me_mod_fo_omegadiag =  CompartmentalModel(
        model_name = "debug_theoph_abs_ka-clME-vd_JAXFO_jaxoptspwrapLbfgsb_AD_nodep_omegadiag_dermal",
            ode_t0_cols=[ ODEInitVals('DV'), ODEInitVals('AMT'),],
            conc_at_time_col = 'DV',
            subject_id_col = 'ID', 
            time_col = 'TIME',
            population_coeff=[
                                PopulationCoeffcient('ka', 
                                                    optimization_init_val=1.6, 
                                                    subject_level_intercept=True,
                                                    #optimization_lower_bound = np.log(1e-6),
                                                    #optimization_upper_bound = np.log(15),
                                                    subject_level_intercept_sd_init_val = 0.6, 
                                                    #subject_level_intercept_sd_upper_bound = 20,
                                                    #subject_level_intercept_sd_lower_bound=1e-6
                                                    ),
                                PopulationCoeffcient('cl',
                                                    optimization_init_val = 3,
                                                    #optimization_lower_bound = np.log(1e-4),
                                                    #optimization_upper_bound=np.log(25),
                                                    subject_level_intercept=True, 
                                                    subject_level_intercept_sd_init_val = 0.3, 
                                                   # subject_level_intercept_sd_upper_bound = 5,
                                                   # subject_level_intercept_sd_lower_bound=1e-6
                                                    ),
                                PopulationCoeffcient('vd', optimization_init_val = 35,
                                                    #, optimization_lower_bound = np.log(.1)
                                                    #,optimization_upper_bound=np.log(80), 
                                                    subject_level_intercept=True, 
                                                    subject_level_intercept_sd_init_val = 0.1, 
                                                    #subject_level_intercept_sd_upper_bound = 5,
                                                    #subject_level_intercept_sd_lower_bound=1e-6
                                                    
                                                    #, optimization_upper_bound = np.log(.05)
                                                    ),
                            ],
            dep_vars= None, 

                                    pk_model_class=OneCompartmentAbsorption, 
                                    model_error_sigma=PopulationCoeffcient('sigma'
                                                                            ,log_transform_init_val=False
                                                                            , optimization_init_val=.5
                                                                            ,optimization_lower_bound=0.00001
                                                                            ,optimization_upper_bound=3
                                                                            ),
  
                                    batch_id='theoph_test1',

                                    significant_digits=3,
                                    #me_loss_function=FO_approx_ll_loss,
                                    jax_loss=FO_approx_neg2ll_loss_jax,
                                    use_full_omega=False, 
                                    use_surrogate_neg2ll=True, 
                                    fit_jax_objective=True,
                                    )

In [18]:
me_mod_fo_omegadiag = me_mod_fo_omegadiag.fit2(df, ci_level = None, debug_fit=False, )

Successfully compiled closed stiff ODE solver
Sucessfully complied non-stiff ODE solver
Sucessfully complied augmented dynamics stiff ODE solver




Sucessfully complied augmented dynamics stiff ODE solver
Sucessfully complied stiff ODE solver
Sucessfully complied KEYS stiff ODE solver
Sucessfully complied non-stiff PyMC ODE solver
Sucessfully complied hybrid-dimensional non-stiff PyMC ODE solver
Sucessfully complied non-dimensional non-stiff PyMC ODE solver
Sucessfully complied non-dimensional stiff PyMC ODE solver
Sucessfully complied stiff PyMC ODE solver
Compiling `FO_approx_neg2ll_loss_jax`
Compiling `approx_neg2ll_loss_jax`
Compiling `_solve_augdyn_ivp_jax_worker`
Compiling `FO_approx_neg2ll_loss_jax`
Compiling `approx_neg2ll_loss_jax`
Compiling `_solve_augdyn_ivp_jax_worker`
Compiling `_solve_ivp_jax_worker`
Compiling `_solve_ivp_jax_worker`
🏃 View run b-theoph_test1_m-debug_theoph_abs_ka-clME-vd_JAXFO_jaxoptspwrapLbfgsb_AD_nodep_omegadiag_dermal_f-12f7fc76-74a7-48ce-a106-7dea1de6cdae at: http://mlflow-server:5000/#/experiments/377/runs/0f64cb964b9a4b2380b6b3bf0b2a30b5
🧪 View experiment at: http://mlflow-server:5000/#/experi

In [49]:
tmp = me_mod_fo_omegadiag.fit_result_.pred_df
tmp['inner_pred_y'] = me_mod_fo_omegadiag.fit_result_.pred_y

In [51]:
tmp

Unnamed: 0,ID,TIME,DV,pop_pred_DV__PRED,pop_pred_resid__RES,weighted_pop_pred_resid__WRES,indiv_pred_DV__IPRED,indiv_pred_resid__IRES,weighted_indiv_pred_resid__IWRES,inner_pred_y
1,1,0.00,0.74,0.022107,0.717893,1.258000,0.021793,0.718207,1.258550,0.022107
2,1,0.25,2.84,4.266604,-1.426604,-2.499910,4.192989,-1.352989,-2.370911,4.266666
3,1,0.57,6.57,6.889475,-0.319475,-0.559832,6.782762,-0.212762,-0.372834,6.889540
4,1,1.12,10.50,8.239820,2.260180,3.960627,8.132088,2.367912,4.149412,8.239796
5,1,2.02,9.66,8.058147,1.601853,2.807007,7.976241,1.683759,2.950536,8.058145
...,...,...,...,...,...,...,...,...,...,...
139,12,5.07,8.57,5.951950,2.618050,4.587741,5.948123,2.621877,4.594446,5.951950
140,12,7.07,6.59,4.848884,1.741116,3.051045,4.863072,1.726928,3.026183,4.848886
141,12,9.03,6.11,3.966427,2.143573,3.756290,3.991942,2.118058,3.711579,3.966438
142,12,12.05,4.57,2.910349,1.659651,2.908290,2.944873,1.625127,2.847792,2.910511


In [16]:
me_mod_fo_omegadiag.fit_result_.b_i

Array([[-4.77074881e-03, -1.09204238e-02,  1.42905333e-02],
       [-4.66954756e-03, -1.15181610e-03,  1.50727359e-03],
       [-6.44966189e-04, -2.28021040e-03,  2.98389727e-03],
       [-2.29015512e-02, -3.56920133e-03,  4.67067868e-03],
       [-8.66547024e-03, -7.27991617e-03,  9.52654280e-03],
       [-3.10496807e-02,  3.75736811e-03, -4.91691488e-03],
       [-3.59968403e-02, -2.82319529e-04,  3.69445062e-04],
       [-1.98300627e-02,  8.80091326e-05, -1.15169289e-04],
       [ 2.98649445e-02, -2.54131167e-03,  3.32557599e-03],
       [-2.77756814e-02, -1.08148855e-02,  1.41524253e-02],
       [ 5.56954352e-03,  1.99880714e-03, -2.61565125e-03],
       [-2.34760368e-02, -7.59338487e-03,  9.93674987e-03]],      dtype=float64)

In [17]:
pred_df = me_mod_fo_omegadiag.fit_result_.pred_df.copy()
pred_cols = {'pop_pred_DV__PRED':'PRED_diag_omega', 'indiv_pred_DV__IPRED':'IPRED_diag_omega'}
pred_df = pred_df.rename(columns=pred_cols)
value_vars = list(pred_cols.values())
pred_df = pred_df.melt(id_vars = ['ID', 'TIME'], value_vars = value_vars, var_name = 'DV_and_PREDs', value_name = 'Conc')
all_pred_df.append(pred_df.copy())

In [18]:
px.line(data_frame=pred_df, x = 'TIME', y = 'Conc', color = 'DV_and_PREDs', animation_frame='ID')

In [19]:
import rpy2.robjects as ro


Environment variable "LD_LIBRARY_PATH" redefined by R and overriding existing variable. Current: "/usr/local/nvidia/lib:/usr/local/nvidia/lib64", R: "/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/default-java/lib/server:/usr/local/nvidia/lib:/usr/local/nvidia/lib64"


Environment variable "PWD" redefined by R and overriding existing variable. Current: "/vscode/vscode-server/bin/linux-x64/488a1f239235055e34e673291fb8d8c810886f81", R: "/workspaces/PK-Analysis/examples"

Sourcing user .Rprofile to configure library paths...
-> User library set to: /home/vscode/R/library

Environment variable "LD_LIBRARY_PATH" redefined by R and overriding existing variable. Current: "/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/default-java/lib/server:/usr/local/nvidia/lib:/usr/local/nvidia/lib64", R: "/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/default-java/lib/server:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/default-java/lib/server:/usr/local/nvidia/lib:/usr/local

In [20]:
r_str = """
library(nlmixr2)

one.cmt <- function() {
  ini({
    ## You may label each parameter with a comment
    tka <- 0.45 # Log Ka
    tcl <- log(c(0, 2.7, 100)) # Log Cl
    ## This works with interactive models
    ## You may also label the preceding line with label("label text")
    tv <- 3.45; label("log V")
    ## the label("Label name") works with all models
    eta.ka ~ 0.6
    eta.cl ~ 0.3
    eta.v ~ 0.1
    add.sd <- 0.7
  })
  model({
    ka <- exp(tka + eta.ka)
    cl <- exp(tcl + eta.cl)
    v <- exp(tv + eta.v)
    linCmt() ~ add(add.sd)
  })
}

f <- nlmixr(one.cmt)

fit <- nlmixr(one.cmt, theo_sd, est="fo",
                   control=foceiControl(interaction=FALSE, print=0))
"""

ro.r(r_str)

R callback write-console: ── Attaching packages ───────────────────────────────────────── nlmixr2 4.0.0 ──
  
R callback write-console: ✔ lotri        1.0.1     ✔ nlmixr2extra 3.0.2
✔ nlmixr2data  2.0.9     ✔ nlmixr2plot  3.0.2
✔ nlmixr2est   4.0.2     ✔ rxode2       4.0.3
  
R callback write-console: ── Optional Packages Loaded/Ignored ─────────────────────────── nlmixr2 4.0.0 ──
  
R callback write-console: ✖ babelmixr2     ✖ nlmixr2rpt
✖ ggPMX     ✖ nonmem2rx
✖ monolix2rx     ✖ shinyMixR
✖ nlmixr2lib     ✖ xpose.nlmixr2
  


R callback write-console: ── Conflicts ───────────────────────────────────────────── nlmixr2conflicts() ──
✖ rxode2::boxCox()     masks nlmixr2est::boxCox()
✖ rxode2::yeoJohnson() masks nlmixr2est::yeoJohnson()
  


ℹ parameter labels from comments are typically ignored in non-interactive mode
ℹ Need to run with the source intact to parse comments
ℹ parameter labels from comments are typically ignored in non-interactive mode
ℹ Need to run with the source intact to parse comments
→ loading into symengine environment...
→ pruning branches (`if`/`else`) of full model...
✔ done
→ calculate jacobian
→ calculate ∂(f)/∂(η)                                                            
→ finding duplicate expressions in inner model...                                
→ optimizing duplicate expressions in inner model...                             
→ finding duplicate expressions in EBE model...                                  [====|====|====|====|====|====|====|====|====|====] 100%; 0:00:00 [====|====|====|====|====|====|====|====|====|====] 100%; 0:00:00 
→ optimizing duplicate expressions in EBE model...                               
→ compiling inner model...                                              

R callback write-console: using C compiler: ‘gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0’

  


✔ done
→ finding duplicate expressions in FD model...
→ compiling EBE model...                                                         


R callback write-console: using C compiler: ‘gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0’

  


✔ done
→ compiling events FD model...


R callback write-console: using C compiler: ‘gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0’

  


✔ done
done
→ compress origData in nlmixr2 object, save 5952
→ compress parHistData in nlmixr2 object, save 4928
calculating covariance matrix
[>------------------------------------------------]  03%; 0:00:00 ====|====|====|====|====|====|====|====|====|====] 100%; 0:00:00 → Calculating residuals/tables
✔ done
→ compress origData in nlmixr2 object, save 5952


In [21]:
ro.r("fit$omega")

0,1,2,3,4,5,6
0.84091,0.0,0.0,...,0.0,0.0,0.019252


In [22]:

from rpy2.robjects import pandas2ri
from rpy2.robjects.conversion import localconverter

with localconverter(ro.default_converter + pandas2ri.converter):
    r_pred_df = ro.conversion.rpy2py(ro.r['fit'])


In [23]:
with localconverter(ro.default_converter + pandas2ri.converter):
    r_omega = ro.r("fit$omega")

In [24]:
r_omega

array([[0.84091031, 0.        , 0.        ],
       [0.        , 0.08978456, 0.        ],
       [0.        , 0.        , 0.01925209]])

In [22]:
r_pred_df

Unnamed: 0,ID,TIME,DV,PRED,RES,WRES,IPRED,IRES,IWRES,CPRED,...,central,rx__sens_central_BY_p1,rx__sens_central_BY_v1,rx__sens_central_BY_ka,rx__sens_depot_BY_ka,ka,cl,v,tad,dosenum
1,1,0.00,0.74,0.000000,0.740000,1.203485,0.000000,0.740000,1.203485,0.000000,...,0.000000,0.0,0.0,0.0,0.0,1.785768,1.651905,29.397752,0.00,1.0
2,1,0.25,2.84,4.901736,-2.061736,-2.248670,3.890218,-1.050218,-1.708001,3.408287,...,114.363654,0.0,0.0,0.0,0.0,1.785768,1.651905,29.397752,0.25,1.0
3,1,0.57,6.57,7.546761,-0.976761,-0.804374,6.823158,-0.253158,-0.411719,5.977887,...,200.585517,0.0,0.0,0.0,0.0,1.785768,1.651905,29.397752,0.57,1.0
4,1,1.12,10.50,8.636333,1.863667,1.383717,9.032173,1.467827,2.387173,7.913243,...,265.525575,0.0,0.0,0.0,0.0,1.785768,1.651905,29.397752,1.12,1.0
5,1,2.02,9.66,8.343234,1.316766,1.004549,9.727770,-0.067770,-0.110217,8.522668,...,285.974586,0.0,0.0,0.0,0.0,1.785768,1.651905,29.397752,2.02,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
128,12,5.07,8.57,6.512145,2.057855,1.882858,8.430759,0.139241,0.226452,6.334619,...,218.778039,0.0,0.0,0.0,0.0,0.941572,2.420869,25.949982,5.07,1.0
129,12,7.07,6.59,5.515837,1.074163,1.094142,7.074306,-0.484306,-0.787641,5.315421,...,183.578104,0.0,0.0,0.0,0.0,0.941572,2.420869,25.949982,7.07,1.0
130,12,9.03,6.11,4.687494,1.422506,1.589321,5.904048,0.205952,0.334946,4.436125,...,153.209943,0.0,0.0,0.0,0.0,0.941572,2.420869,25.949982,9.03,1.0
131,12,12.05,4.57,3.647961,0.922039,1.157736,4.456386,0.113614,0.184774,3.348394,...,115.643128,0.0,0.0,0.0,0.0,0.941572,2.420869,25.949982,12.05,1.0


In [23]:
pred_df = r_pred_df.copy()
pred_cols = {'PRED':'PRED_R_diag_omega', 'IPRED':'IPRED_R_diag_omega'}
pred_df = pred_df.rename(columns=pred_cols)
value_vars = list(pred_cols.values())
pred_df = pred_df.melt(id_vars = ['ID', 'TIME'], value_vars = value_vars, var_name = 'DV_and_PREDs', value_name = 'Conc')
all_pred_df.append(pred_df.copy())

In [24]:
px.line(data_frame=pred_df, x = 'TIME', y = 'Conc', color = 'DV_and_PREDs', animation_frame='ID')

In [25]:
all_pred_df = pd.concat(all_pred_df) if isinstance(all_pred_df, list) else all_pred_df
all_pred_df['ID'] = all_pred_df['ID'].astype(int)
px.line(data_frame=all_pred_df, x = 'TIME', y = 'Conc', color = 'DV_and_PREDs', animation_frame='ID')

In [26]:
me_mod_fo_nojax =  CompartmentalModel(
        model_name = "debug_theoph_abs_ka-clME-vd_JAXFOCE_jaxoptspwrapLbfgsb_fdxouteriftinner_nodep_omegadiag_dermal",
            ode_t0_cols=[ ODEInitVals('DV'), ODEInitVals('AMT'),],
            conc_at_time_col = 'DV',
            subject_id_col = 'ID', 
            time_col = 'TIME',
            population_coeff=[
                                PopulationCoeffcient('ka', 
                                                    optimization_init_val=1.6, 
                                                    subject_level_intercept=True,
                                                    #optimization_lower_bound = np.log(1e-6),
                                                    #optimization_upper_bound = np.log(15),
                                                    subject_level_intercept_sd_init_val = 0.6, 
                                                    #subject_level_intercept_sd_upper_bound = 20,
                                                    #subject_level_intercept_sd_lower_bound=1e-6
                                                    ),
                                PopulationCoeffcient('cl',
                                                    optimization_init_val = 3,
                                                    #optimization_lower_bound = np.log(1e-4),
                                                    #optimization_upper_bound=np.log(25),
                                                    subject_level_intercept=True, 
                                                    subject_level_intercept_sd_init_val = 0.3, 
                                                   # subject_level_intercept_sd_upper_bound = 5,
                                                   # subject_level_intercept_sd_lower_bound=1e-6
                                                    ),
                                PopulationCoeffcient('vd', optimization_init_val = 35,
                                                    #, optimization_lower_bound = np.log(.1)
                                                    #,optimization_upper_bound=np.log(80), 
                                                    subject_level_intercept=True, 
                                                    subject_level_intercept_sd_init_val = 0.1, 
                                                    #subject_level_intercept_sd_upper_bound = 5,
                                                    #subject_level_intercept_sd_lower_bound=1e-6
                                                    
                                                    #, optimization_upper_bound = np.log(.05)
                                                    ),
                            ],
            dep_vars= None, 

                                    pk_model_class=OneCompartmentAbsorption, 
                                    model_error_sigma=PopulationCoeffcient('sigma'
                                                                            ,log_transform_init_val=False
                                                                            , optimization_init_val=.5
                                                                            ,optimization_lower_bound=0.00001
                                                                            ,optimization_upper_bound=3
                                                                            ),
  
                                    batch_id='theoph_test1',

                                    significant_digits=3,
                                    me_loss_function=FO_approx_ll_loss,
                                    jax_loss=FO_approx_neg2ll_loss_jax,
                                    use_full_omega=False, 
                                    use_surrogate_neg2ll=True, 
                                    fit_jax_objective=False,
                                    )

In [27]:
me_mod_fo_nojax = me_mod_fo_nojax.fit2(df, ci_level = None, debug_fit=False, )

Successfully compiled closed stiff ODE solver
Sucessfully complied non-stiff ODE solver
Sucessfully complied augmented dynamics stiff ODE solver




Sucessfully complied augmented dynamics stiff ODE solver
Sucessfully complied stiff ODE solver
Sucessfully complied KEYS stiff ODE solver
Sucessfully complied non-stiff PyMC ODE solver
Sucessfully complied hybrid-dimensional non-stiff PyMC ODE solver
Sucessfully complied non-dimensional non-stiff PyMC ODE solver
Sucessfully complied non-dimensional stiff PyMC ODE solver
Sucessfully complied stiff PyMC ODE solver
Sucessfully complied augmented dynamics stiff ODE solver
Compiling `_solve_ivp_jax_worker`
Sucessfully complied augmented dynamics stiff ODE solver
Sucessfully complied augmented dynamics stiff ODE solver
Sucessfully complied augmented dynamics stiff ODE solver
Sucessfully complied augmented dynamics stiff ODE solver
Sucessfully complied augmented dynamics stiff ODE solver
Sucessfully complied augmented dynamics stiff ODE solver
Sucessfully complied augmented dynamics stiff ODE solver
Sucessfully complied augmented dynamics stiff ODE solver
Sucessfully complied augmented dynami

In [40]:
me_mod_fo_nojax.fit_result_summary_

Unnamed: 0,model_coeff,log_name,population_coeff,model_error,subject_level_intercept,coeff_dep_var,model_coeff_dep_var,subject_level_intercept_name,fit_result_summary_name,init_val,lower_bound,upper_bound,best_fit_param_val
0,ka,ka_pop,True,False,False,False,,,ka_pop,0.470004,,,0.806248
1,cl,cl_pop,True,False,False,False,,,cl_pop,1.098612,,,0.923883
2,vd,vd_pop,True,False,False,False,,,vd_pop,3.555348,,,3.579973
3,sigma,sigma_const,False,True,False,False,,,sigma_const,-0.693147,2e-09,22000.0,-0.269861
4,ka,omega2_ka,False,False,True,False,,omega2_ka,omega2_ka,-0.510826,,,-1.018334
5,cl,omega2_cl,False,False,True,False,,omega2_cl,omega2_cl,-1.203973,,,-1.267918
6,vd,omega2_vd,False,False,True,False,,omega2_vd,omega2_vd,-2.302585,,,-2.303583


In [39]:
me_mod_fo_omegadiag.fit_result_summary_

Unnamed: 0,model_coeff,log_name,population_coeff,model_error,subject_level_intercept,coeff_dep_var,model_coeff_dep_var,subject_level_intercept_name,fit_result_summary_name,init_val,lower_bound,upper_bound,fitted_param_val,fitted_param_sd,fitted_param_rse,fitted_params_lower_ci95,fitted_params_upper_ci95,back_transformed_param_val,back_transformed_lower_ci95,back_transformed_upper_ci95
0,ka,ka_pop,True,False,False,False,,,ka_pop,0.470004,,,0.872916,,,,,2.393881,,
1,cl,cl_pop,True,False,False,False,,,cl_pop,1.098612,,,1.232761,,,,,3.43069,,
2,vd,vd_pop,True,False,False,False,,,vd_pop,3.555348,,,3.510758,,,,,33.473627,,
3,sigma,sigma_const,False,True,False,False,,,sigma_const,-0.693147,2e-09,22000.0,0.755422,,,,,,,
4,ka,omega2_ka,False,False,True,False,,omega2_ka,omega2_ka,-0.510826,,,0.021985,,,,,,,
5,cl,omega2_cl,False,False,True,False,,omega2_cl,omega2_cl,-1.203973,,,0.00891,,,,,,,
6,vd,omega2_vd,False,False,True,False,,omega2_vd,omega2_vd,-2.302585,,,0.010192,,,,,,,


In [29]:
no_jax_pred = me_mod_fo_nojax.predict2(df, parallel=False,  return_loss = True )


Sucessfully complied non-stiff ODE solver
Sucessfully complied augmented dynamics stiff ODE solver
Sucessfully complied stiff ODE solver
Sucessfully complied stiff PyMC ODE solver
Sucessfully complied augmented dynamics stiff ODE solver
Compiling `_solve_ivp_jax_worker`
Sucessfully complied augmented dynamics stiff ODE solver
Sucessfully complied augmented dynamics stiff ODE solver
Sucessfully complied augmented dynamics stiff ODE solver
Sucessfully complied augmented dynamics stiff ODE solver
Sucessfully complied augmented dynamics stiff ODE solver
Sucessfully complied augmented dynamics stiff ODE solver
Sucessfully complied augmented dynamics stiff ODE solver


In [52]:
me_mod_fo_nojax.b_i_approx

Unnamed: 0_level_0,ka,cl,vd
Unnamed: 0_level_1,omega2_ka,omega2_cl,omega2_vd
0,-0.187099,-0.493764,-0.155573
1,-0.045576,0.195533,-0.05836
2,0.106288,0.081579,-0.030293
3,-0.595514,0.003016,-0.061193
4,-0.375635,-0.132252,-0.205003
5,-0.548327,0.355091,0.137196
6,-0.795108,0.100412,0.089886
7,-0.375264,0.184754,0.036088
8,0.956451,0.053466,-0.026429
9,-0.751074,-0.545763,-0.068302


In [30]:
no_jax_pred[0]

array([2.41009231e-02, 3.86898135e+00, 6.71725320e+00, 8.81150378e+00,
       9.45053551e+00, 8.85686901e+00, 8.31452242e+00, 7.54933566e+00,
       6.82311289e+00, 5.85088572e+00, 3.16832792e+00, 0.00000000e+00,
       4.07629559e+00, 6.14719843e+00, 7.82234497e+00, 8.10009395e+00,
       7.15519634e+00, 6.24002366e+00, 5.20216362e+00, 4.35250595e+00,
       3.31731536e+00, 1.08960600e+00, 0.00000000e+00, 4.44003731e+00,
       6.81870036e+00, 7.99905862e+00, 8.02346746e+00, 7.12947887e+00,
       6.35813102e+00, 5.43824048e+00, 4.67338352e+00, 3.64904471e+00,
       1.41979887e+00, 0.00000000e+00, 3.27979915e+00, 4.83718096e+00,
       6.62186542e+00, 7.87649421e+00, 7.63116688e+00, 6.90934683e+00,
       5.96430325e+00, 5.13611405e+00, 4.11517934e+00, 1.59355537e+00,
       0.00000000e+00, 3.99804600e+00, 5.89628716e+00, 8.20378091e+00,
       9.36985845e+00, 8.78665715e+00, 7.87577841e+00, 6.77583139e+00,
       5.79103241e+00, 4.65200609e+00, 1.83054001e+00, 0.00000000e+00,
      

In [32]:
pred_df = pred_df.loc[pred_df['DV_and_PREDs'] == 'PRED_R_diag_omega', :]
pred_df['DV_and_PREDs'] = 'IPRED_NoJax_diag_omega'
pred_df['Conc'] = no_jax_pred[0]
pred_df



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



Unnamed: 0,ID,TIME,DV_and_PREDs,Conc
0,1,0.00,IPRED_NoJax_diag_omega,0.024101
1,1,0.25,IPRED_NoJax_diag_omega,3.868981
2,1,0.57,IPRED_NoJax_diag_omega,6.717253
3,1,1.12,IPRED_NoJax_diag_omega,8.811504
4,1,2.02,IPRED_NoJax_diag_omega,9.450536
...,...,...,...,...
127,12,5.07,IPRED_NoJax_diag_omega,7.981420
128,12,7.07,IPRED_NoJax_diag_omega,6.972997
129,12,9.03,IPRED_NoJax_diag_omega,6.069823
130,12,12.05,IPRED_NoJax_diag_omega,4.893846


In [33]:
#the jax prediction process does not seem to be working properly. The IPREd w/ jax should be much closer given that the fitted params are similar between
#jax, pymc and scipy
all_pred_df = pd.concat([all_pred_df, pred_df])
all_pred_df['ID'] = all_pred_df['ID'].astype(int)
px.line(data_frame=all_pred_df, x = 'TIME', y = 'Conc', color = 'DV_and_PREDs', animation_frame='ID')