In [1]:
#Set the devices which Jax should use. 
#This must be done before importing jax
import os
# --- Control Flag ---
USE_GPU = False
# --------------------

if USE_GPU:
    # Set JAX to use the GPU. The device number (0) is for the first GPU.
    os.environ['JAX_PLATFORMS'] = 'cuda,cpu'
    # Optional: Pin JAX to a specific GPU
    os.environ['CUDA_VISIBLE_DEVICES'] = '0' 
else:
    os.environ['JAX_PLATFORMS'] = 'cpu'
    import numpyro
    numpyro.set_host_device_count(4)
    

import jax
print(f"JAX is running on: {jax.default_backend()}")
print(f"JAX devices: {jax.devices()}")

import finitediffx as fdx
import pandas as pd
import numpyro
from niceode.pymc_utils import make_pymc_model
import pymc as pm
from niceode.utils import (CompartmentalModel, 
                           ODEInitVals,
                           PopulationCoeffcient,
                           neg2_log_likelihood_loss,
                           ObjectiveFunctionColumn,
                           FOCE_approx_ll_loss,
                           FOCEi_approx_ll_loss,
                           FO_approx_ll_loss
                           )
from niceode.diffeqs import OneCompartmentAbsorption
import numpy as np
import joblib as jb
from niceode.jax_utils import FO_approx_neg2ll_loss_jax, FOCE_approx_neg2ll_loss_jax_fdxOUTER


  from .autonotebook import tqdm as notebook_tqdm


JAX is running on: cpu
JAX devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]


In [2]:
df = pd.read_csv(r"/workspaces/PK-Analysis/data/theo_nlmixr2.csv", ) 
df.loc[df['AMT'] == 0.0, 'AMT'] = pd.NA
df['AMT'] = df['AMT'].ffill()
df = df.loc[df['EVID'] == 0, :].copy()

In [5]:
me_mod_foce_nojax =  CompartmentalModel(
        model_name = "debug_theoph_abs_ka-clME-vd_JAXFOCE_jaxoptspwrapLbfgsb_fdxouteriftinner_nodep_omegadiag_dermal",
            ode_t0_cols=[ ODEInitVals('DV'), ODEInitVals('AMT'),],
            conc_at_time_col = 'DV',
            subject_id_col = 'ID', 
            time_col = 'TIME',
            population_coeff=[
                                PopulationCoeffcient('ka', 
                                                    optimization_init_val=1.6, 
                                                    subject_level_intercept=True,
                                                    #optimization_lower_bound = np.log(1e-6),
                                                    #optimization_upper_bound = np.log(15),
                                                    subject_level_intercept_sd_init_val = 0.6, 
                                                    #subject_level_intercept_sd_upper_bound = 20,
                                                    #subject_level_intercept_sd_lower_bound=1e-6
                                                    ),
                                PopulationCoeffcient('cl',
                                                    optimization_init_val = 3,
                                                    #optimization_lower_bound = np.log(1e-4),
                                                    #optimization_upper_bound=np.log(25),
                                                    subject_level_intercept=True, 
                                                    subject_level_intercept_sd_init_val = 0.3, 
                                                   # subject_level_intercept_sd_upper_bound = 5,
                                                   # subject_level_intercept_sd_lower_bound=1e-6
                                                    ),
                                PopulationCoeffcient('vd', optimization_init_val = 35,
                                                    #, optimization_lower_bound = np.log(.1)
                                                    #,optimization_upper_bound=np.log(80), 
                                                    subject_level_intercept=True, 
                                                    subject_level_intercept_sd_init_val = 0.1, 
                                                    #subject_level_intercept_sd_upper_bound = 5,
                                                    #subject_level_intercept_sd_lower_bound=1e-6
                                                    
                                                    #, optimization_upper_bound = np.log(.05)
                                                    ),
                            ],
            dep_vars= None, 

                                    pk_model_class=OneCompartmentAbsorption, 
                                    model_error_sigma=PopulationCoeffcient('sigma'
                                                                            ,log_transform_init_val=False
                                                                            , optimization_init_val=.5
                                                                            ,optimization_lower_bound=0.00001
                                                                            ,optimization_upper_bound=3
                                                                            ),
  
                                    batch_id='theoph_test1',

                                    significant_digits=3,
                                    #me_loss_function=FO_approx_ll_loss,
                                    jax_loss=FOCE_approx_neg2ll_loss_jax_fdxOUTER,
                                    use_full_omega=True, 
                                    use_surrogate_neg2ll=True, 
                                    fit_jax_objective=True,
                                    )

In [6]:
me_mod_foce_nojax = me_mod_foce_nojax.fit2(df, ci_level = None, debug_fit=False, )

Successfully compiled closed stiff ODE solver
Sucessfully complied non-stiff ODE solver
Sucessfully complied augmented dynamics stiff ODE solver




Sucessfully complied augmented dynamics stiff ODE solver
Sucessfully complied stiff ODE solver
Sucessfully complied KEYS stiff ODE solver
Sucessfully complied non-stiff PyMC ODE solver
Sucessfully complied hybrid-dimensional non-stiff PyMC ODE solver
Sucessfully complied non-dimensional non-stiff PyMC ODE solver
Sucessfully complied non-dimensional stiff PyMC ODE solver
Sucessfully complied stiff PyMC ODE solver
Compiling `FOCE_approx_neg2ll_loss_jax`
Compiling `approx_neg2ll_loss_jax`
Compiling `estimate_b_i_vmapped`
Compiling `FOCE_inner_loss_fn` WITHOUT LAX
Compiling `_solve_ivp_jax_worker`
Compiling `FOCE_inner_loss_fn` WITHOUT LAX
Compiling `_solve_ivp_jax_worker`
Compiling `FOCE_inner_loss_fn` WITHOUT LAX
Compiling `_solve_ivp_jax_worker`
Compiling `_solve_augdyn_ivp_jax_worker`
Compiling `_solve_augdyn_ivp_jax_worker`
Compiling `FOCE_approx_neg2ll_loss_jax`
Compiling `approx_neg2ll_loss_jax`
Compiling `estimate_b_i_vmapped`
Compiling `FOCE_inner_loss_fn` WITHOUT LAX
Compiling `_

In [8]:
pred_df = me_mod_foce_nojax.fit_result_.pred_df.copy()
pred_cols = {'pop_pred_DV__PRED':'PRED_full_omega', 'indiv_pred_DV__IPRED':'IPRED_full_omega'}
pred_df = pred_df.rename(columns=pred_cols)
value_vars = ['DV'] + list(pred_cols.values())
pred_df = pred_df.melt(id_vars = ['ID', 'TIME'], value_vars = value_vars, var_name = 'DV_and_PREDs', value_name = 'Conc')
#me_mod_foce_nojax.append(pred_df.copy())

In [9]:
import plotly.express as px


px.line(data_frame=pred_df, x = 'TIME', y = 'Conc', color = 'DV_and_PREDs', animation_frame='ID')