# Basic working examples of parameter estimation

In [1]:
using DifferentialEquations
using Thyrosim
using Plots
using DiffEqCallbacks
using Optim
using DiffEqParamEstim

┌ Info: Recompiling stale cache file /Users/biona001/.julia/compiled/v1.2/Thyrosim/Oo7R6.ji for Thyrosim [7ef34fca-2b35-11e9-1aeb-a527bedb189e]
└ @ Base loading.jl:1240


## Import data and initial conditions

In [2]:
train, test, toy = schneider_data();

In [3]:
# useful parameters
train[!, Symbol("Days.to.euthyroid")]
train[!, Symbol("Wt.kg")]
train[!, Symbol("Ht.m")]
train[!, Symbol("TSH.preop")]
train[!, Symbol("Dose.changes")]
train[!, Symbol("LT4.euthyroid.dose")]
train[!, Symbol("LT4.initial.dose")]'
train[!, Symbol("Sex")]

400-element CSV.Column{Int64,Int64}:
 0
 1
 1
 1
 0
 1
 1
 1
 0
 1
 1
 1
 1
 ⋮
 1
 0
 1
 1
 0
 0
 1
 1
 1
 1
 1
 1

## Solve 1 schneider patient

In [4]:
# each row is a patient's tspan
total_days = train[!, Symbol("Days.to.euthyroid")]
tspans = [(0.0, 24.0total_days[i]) for i in 1:length(total_days)]

400-element Array{Tuple{Float64,Float64},1}:
 (0.0, 1320.0) 
 (0.0, 2376.0) 
 (0.0, 5208.0) 
 (0.0, 4296.0) 
 (0.0, 3864.0) 
 (0.0, 3744.0) 
 (0.0, 4560.0) 
 (0.0, 1032.0) 
 (0.0, 2424.0) 
 (0.0, 3720.0) 
 (0.0, 2136.0) 
 (0.0, 1656.0) 
 (0.0, 5376.0) 
 ⋮             
 (0.0, 6744.0) 
 (0.0, 2208.0) 
 (0.0, 3792.0) 
 (0.0, 1896.0) 
 (0.0, 336.0)  
 (0.0, 1416.0) 
 (0.0, 2208.0) 
 (0.0, 1080.0) 
 (0.0, 10224.0)
 (0.0, 5904.0) 
 (0.0, 2424.0) 
 (0.0, 4008.0) 

In [67]:
# initialize ODE problem
ic, p = initialize([0.0; 0.88; 0.0; 0.88]) # schneider patients are completely thyroidectomized

# designate p55 as oral T4 dose (400 mcg), p56 as oral T3 dose
p[55] = 400.0 / 777.0
p[56] = 0.0

# call back function for adding dose every 24 hour
function add_dose!(integrator)
    integrator.u[10] += integrator.p[55]
    integrator.u[12] += integrator.p[56]
end
cbk = PeriodicCallback(add_dose!, 24.0);

# define ODE problem
train_patient_1 = ODEProblem(thyrosim,ic,tspans[1],p,callback=cbk)
    
# solve ODE problem
sol = solve(train_patient_1, save_idxs=7)

retcode: Success
Interpolation: Automatic order switching interpolation
t: 4559-element Array{Float64,1}:
    0.0                
    0.00589049502729523
    0.04331680083793424
    0.10775925933305852
    0.18828674241224558
    0.29923986313348233
    0.4437005253206401 
    0.644852613888753  
    0.9211243196572607 
    1.2520854980785918 
    1.5654962222870112 
    1.833032589387102  
    2.0650280564109496 
    ⋮                  
 1316.1781678297414    
 1316.5961926455232    
 1317.0001173677451    
 1317.3909561429982    
 1317.7697570125836    
 1318.137574658376     
 1318.4954539673524    
 1318.8444217148797    
 1319.1854843936421    
 1319.5196308744833    
 1319.847839162069     
 1320.0                
u: 4559-element Array{Float64,1}:
 1.7882958476437     
 1.7873661315008171  
 1.7814425498755786  
 1.7711794243750691  
 1.7582510574481578  
 1.7402748479390808  
 1.7166452356747597  
 1.6834692924951937  
 1.6377629070233788  
 1.5834986432550218  
 1.5333428402691

# Define error function for all Schneider patients

We simulate patients according to the number of days the patient achieved euthyrodism using the given dose, then calculate if their TSH levels are within [0.45, 4.5] mIU/ml.

#### Parameter definition:
- `p[55]:` Daily T4 oral dose
- `p[56]:` Daily T3 oral dose

#### Error definition:
+ When given an euthyroid T4 dose, if any TSH values $\notin [0.5, 4.5]$ in the last 24h of simulation, then error + 1 (i.e. **patients receiving correct dose should have normal TSH**)
+ When given the initial T4 dose, if the initial T4 dose is not equal to euthyroid T4 dose, and all TSH values $\in [0.5, 4.5]$, then error + 1 (i.e. **patients not receiving correct dose should NOT have normal TSH**)

In [9]:
#TODO: define atomic `tot_loss` and make loop multithreaded
function compute_schneider_error(train_data)
    dial = [0.0; 0.88; 0.0; 0.88]
    scale_Vp = true
    tot_loss = 0.0
    
    # define function for adding dose
    function add_dose!(integrator)
        integrator.u[10] += integrator.p[55]
        integrator.u[12] += integrator.p[56]
    end
    cbk = PeriodicCallback(add_dose!, 24.0);
    
    # preallocate vectors
    ic, p = initialize()

    #loop over all patients
    for i in 1:size(train_data, 1)
        height = train_data[i, Symbol("Ht.m")]
        weight = train_data[i, Symbol("Wt.kg")]
        sex    = Bool(train_data[i, Symbol("Sex")])
        Thyrosim.initialize!(ic, p, dial, scale_Vp, height, weight, sex) #initializes ic and p
        ic[7]  = train_data[i, Symbol("TSH.preop")] #set initial TSH value
        tspan  = (0.0, 24.0train_data[i, Symbol("Days.to.euthyroid")]) #(0, total hours)
        
        # calculate error for euthyroid dose
        euthyroid_dose = train_data[i, Symbol("LT4.euthyroid.dose")] / 777.0
        p[55] = euthyroid_dose
        prob  = ODEProblem(thyrosim,ic,tspan,p,callback=cbk)
        sol   = solve(prob, save_idxs=7)
        tot_loss += compute_euthyroid_dose_error(sol)
        
        # when initial dose != euthyroid dose, calculate error
        initial_dose = train_data[i, Symbol("LT4.initial.dose")] / 777.0
        if initial_dose != euthyroid_dose
            p[55] = initial_dose
            prob  = ODEProblem(thyrosim,ic,tspan,p,callback=cbk)
            sol   = solve(prob, save_idxs=7)
            tot_loss += compute_initial_dose_error(sol)
        end
    end
    
    return tot_loss
end

compute_schneider_error (generic function with 1 method)

### Helper functions for calculating error

These functions assume the solution object contains only TSH values, which can be achieved by using `save_idxs=7` when one call `solve`. 

In [5]:
function compute_euthyroid_dose_error(sol)
    tot_loss = 0.0
    if any((s.retcode != :Success for s in sol))
        tot_loss = Inf
    else
        total_hours  = sol.t[end]
        TSH_last_day = sol.u[sol.t .>= total_hours - 24]
        if !all(0.5 .≤ TSH_last_day .≤ 4.5)
            tot_loss += 1.0
        end
    end
    return tot_loss
end

function compute_initial_dose_error(sol)
    tot_loss = 0.0
    if any((s.retcode != :Success for s in sol))
        tot_loss = Inf
    else
        total_hours  = sol.t[end]
        TSH_last_day = sol.u[sol.t .>= total_hours - 24]
        if all(0.5 .≤ TSH_last_day .≤ 4.5)
            tot_loss += 1.0
        end
    end
    return tot_loss
end                     

compute_initial_dose_error (generic function with 1 method)

## Error and timing on toy and train data 

In [26]:
@time compute_schneider_error(toy) #single thread

  3.116344 seconds (4.66 M allocations: 257.820 MiB, 1.55% gc time)


17.0

In [27]:
@time compute_schneider_error(train) #single thread

 73.081660 seconds (95.49 M allocations: 5.502 GiB, 1.29% gc time)


374.0

# Optimize error function

In [20]:
# main objective function
function schneider_objective(p, data)
    fitting_index = [30, 31]
    dial = [0.0; 0.88; 0.0; 0.88]
    train, test, toy = schneider_data()
    
    if data == :train
        return compute_schneider_error(p, fitting_index, train, dial)
    elseif data == :toy
        return compute_schneider_error(p, fitting_index, toy, dial)
    else
        error("data must be :train or :toy but was $data")
    end
    
    return nothing
end

schneider_objective (generic function with 2 methods)

In [18]:
# function that computes schneider's error with current_iter storing the estimated parameters in current iteration
# TODO: define atomic `tot_loss` and make loop multithreaded
function compute_schneider_error(current_iter, fitting_index, train_data, dial)
    scale_Vp = true
    tot_loss = 0.0
    
    # preallocate vectors
    ic, p0 = initialize()
    
    # function for adding dose
    function add_dose!(integrator)
        integrator.u[10] += integrator.p[55]
    end
    cbk = PeriodicCallback(add_dose!, 24.0);

    # Helper function that sets the parameter of `p` that are not in `fitting_index` to corresponding values in `p0`. 
    function reset_p!(p, p0, fitting_index)
        for i in 1:length(p)
            if i in fitting_index 
                continue
            else
                p[i] = p0[i]
            end
        end
    end
    
    #loop over all patients
    for i in 1:size(train_data, 1)
        height = train_data[i, Symbol("Ht.m")]
        weight = train_data[i, Symbol("Wt.kg")]
        sex    = Bool(train_data[i, Symbol("Sex")])
        
        # parameters not in `fitting_index` gets reset back 
        Thyrosim.initialize!(ic, p0, dial, scale_Vp, height, weight, sex) #initializes ic and p0
        ic[7] = train_data[i, Symbol("TSH.preop")] #set initial TSH value
        tspan = (0.0, 24.0train_data[i, Symbol("Days.to.euthyroid")]) #(0, total hours)
        reset_p!(current_iter, p0, fitting_index)
        
        # calculate error for euthyroid dose
        euthyroid_dose = train_data[i, Symbol("LT4.euthyroid.dose")] / 777.0
        p[55] = euthyroid_dose
        prob  = ODEProblem(thyrosim,ic,tspan,current_iter,callback=cbk)
        sol   = solve(prob, save_idxs=7)
        tot_loss += compute_euthyroid_dose_error(sol)
        
        # when initial dose != euthyroid dose, calculate error
        initial_dose = train_data[i, Symbol("LT4.initial.dose")] / 777.0
        if initial_dose != euthyroid_dose
            p[55] = initial_dose
            prob  = ODEProblem(thyrosim,ic,tspan,current_iter,callback=cbk)
            sol   = solve(prob, save_idxs=7)
            tot_loss += compute_initial_dose_error(sol)
        end
    end
    
    return tot_loss
end

compute_schneider_error (generic function with 2 methods)

### Try optimizing schneider_objective

In [21]:
_, p = initialize([0.0; 0.88; 0.0; 0.88]) # schneider patients are completely thyroidectomized
result = optimize(p -> schneider_objective(p, :toy), p, BFGS())

 * Status: success

 * Candidate solution
    Minimizer: [0.00e+00, 8.00e+00, 8.68e-01,  ...]
    Minimum:   1.500000e+01

 * Found with
    Algorithm:     BFGS
    Initial Point: [0.00e+00, 8.00e+00, 8.68e-01,  ...]

 * Convergence measures
    |x - x'|               = 0.00e+00 ≤ 0.0e+00
    |x - x'|/|x'|          = 0.00e+00 ≤ 0.0e+00
    |f(x) - f(x')|         = NaN ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = NaN ≰ 0.0e+00
    |g(x)|                 = 0.00e+00 ≤ 1.0e-08

 * Work counters
    Seconds run:   465  (vs limit Inf)
    Iterations:    0
    f(x) calls:    1
    ∇f(x) calls:   2


In [29]:
result.minimum #original = 17

15.0

In [30]:
result.minimizer[30:31]

2-element Array{Float64,1}:
 101.0 
  47.64