# T4 - Optimization of Infection Duration

In this tutorial, we'll introduce the concept of parameter optimization against an optimization function -- in this case, maximizing the mean infection duration in naive infectious challenges by changing the antigenic switching rate parameter.

We'll start by defining a function to perform multi-individual challenges (similar to the last tutorial)

In [1]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

from emodlib.malaria import IntrahostComponent


def multiple_challenges(n_people, duration):
    
    asexuals = np.zeros((n_people, duration))
    gametocytes = np.zeros((n_people, duration))
    pp = [IntrahostComponent.create() for _ in range(n_people)]
    _ = [p.challenge() for p in pp]

    for t in range(duration):
        for i, p in enumerate(pp):
            p.update(dt=1)
            asexuals[i, t] = p.parasite_density
            gametocytes[i, t] = p.gametocyte_density
            
    da = xr.DataArray(dims=('individual', 'time', 'channel'),
                      coords=(range(n_people), range(duration), ['parasite_density', 'gametocyte_density']))
                      
    da.loc[dict(channel='parasite_density')] = asexuals
    da.loc[dict(channel='gametocyte_density')] = gametocytes
                      
    return da

We'll also define a helper function to determine the duration of challenge infections based on the time index of the last non-zero parasite density

In [2]:
def get_last_nonzero_by_row(A):
    """ https://stackoverflow.com/a/39959511 """
    return np.arange(A.shape[0]), A.shape[1] - 1 - (A[:, ::-1]!=0).argmax(1)

Then we'll define our objective function:
- log-uniform sampling of the antigen switching rate within a defined range
- setting the model parameters
- running the multi-individual challenge time-series
- returning the mean infection-duration value

In [3]:
def objective(trial):
    
    n_people = 50
    duration = 500

    antigen_switch_rate = trial.suggest_float("Antigen_Switch_Rate", 5e-10, 5e-8, log=True)
    IntrahostComponent.set_params(dict(infection_params=dict(Antigen_Switch_Rate=antigen_switch_rate)))
    
    da = multiple_challenges(duration=duration, n_people=n_people)
    infection_durations = get_last_nonzero_by_row(da.sel(channel='parasite_density').values)[1]
    
    return infection_durations.mean()

Finally, we'll create an optuna study and run a number of optimization trials...

In [4]:
import optuna

study = optuna.create_study(study_name='maximize_duration', direction='maximize')
study.optimize(objective, n_trials=25)

[32m[I 2023-05-26 13:05:42,713][0m A new study created in memory with name: maximize_duration[0m
[32m[I 2023-05-26 13:05:42,760][0m Trial 0 finished with value: 353.46 and parameters: {'Antigen_Switch_Rate': 7.42087214725448e-09}. Best is trial 0 with value: 353.46.[0m
[32m[I 2023-05-26 13:05:42,799][0m Trial 1 finished with value: 313.36 and parameters: {'Antigen_Switch_Rate': 1.6115573043218315e-09}. Best is trial 0 with value: 353.46.[0m
[32m[I 2023-05-26 13:05:42,844][0m Trial 2 finished with value: 286.5 and parameters: {'Antigen_Switch_Rate': 3.347336402474378e-08}. Best is trial 0 with value: 353.46.[0m
[32m[I 2023-05-26 13:05:42,879][0m Trial 3 finished with value: 229.88 and parameters: {'Antigen_Switch_Rate': 9.379407533084728e-10}. Best is trial 0 with value: 353.46.[0m
[32m[I 2023-05-26 13:05:42,924][0m Trial 4 finished with value: 322.72 and parameters: {'Antigen_Switch_Rate': 1.2638851044622985e-08}. Best is trial 0 with value: 353.46.[0m
[32m[I 2023-05

In [5]:
study.best_params  # parameter for longest avg duration

{'Antigen_Switch_Rate': 2.818178737648485e-09}

In [6]:
study.best_value  # longest avg duration

405.62

Now let's look at a few default visualizations of the optuna study:
- the convergence towards maximizing the objective value over successive trials
- the value of the objective as a function of our 1-d parameter range explored

In [7]:
from optuna.visualization import *

plot_optimization_history(study)

In [8]:
plot_slice(study)