In [33]:
from functools import partial

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np

from emodlib.malaria import IntrahostComponent

# Create the buttons and output area
infect_button = widgets.Button(description='Infect')
treat_button = widgets.Button(description='Treat')
step1_button = widgets.Button(description='t+1')
step7_button = widgets.Button(description='t+7')
step30_button = widgets.Button(description='t+30')
output = widgets.Output()

display_days = 100

# Create individual
ic = IntrahostComponent.create()

# Density timeseries
days = [0]
parasite_densities = [0]
gametocyte_densities = [0]
infectiousness = [0]
fever_temperature = [0]
n_infections = [0]

@output.capture()
def run(steps):

    for _ in range(steps):
        for _ in range(4):
            ic.update(dt=1.0/4)

        days.append(days[-1] + 1)
        parasite_densities.append(ic.parasite_density)
        gametocyte_densities.append(ic.gametocyte_density)
        infectiousness.append(ic.infectiousness)
        fever_temperature.append(ic.fever_temperature)
        n_infections.append(ic.n_infections)

    fig, axs = plt.subplots(2, 1, figsize=(8, 6), sharex=True)

    ax = axs[0]
    ax.plot(days[-display_days:], parasite_densities[-display_days:], c='navy')
    ax.plot(days[-display_days:], gametocyte_densities[-display_days:], c='darkgreen')
    ax.set(yscale='log', ylim=(1e-3, 1e6), ylabel='density (1/uL)')

    # draw infectiousness measurement points
    ax.scatter(
        x=np.array(days[-display_days:]),
        y=[4e5] * min(len(days), display_days),
        s=1 + 100 * np.array(infectiousness[-display_days:]),
        c=100 * np.array(infectiousness[-display_days:]),
        cmap="Greens",
        vmin=0,
        vmax=100,
        lw=0.5,
        edgecolors="darkgreen",
    )

    # draw fever measurement points
    ax.scatter(
        x=np.array(days[-display_days:]),
        y=[8e5] * min(len(days), display_days),
        s=20 * (np.array(fever_temperature[-display_days:]) - 37) + 1,
        c=np.array(fever_temperature[-display_days:]) - 37,
        cmap="Reds",
        vmin=0,
        vmax=4,
        lw=0.5,
        edgecolors="firebrick",
    )

    axs[1].plot(days[-display_days:], n_infections[-display_days:], c='darkgray')
    axs[1].set(ylabel='n_infections')

    fig.set_tight_layout(True)

    return plt.show()

def infect(b):
    ic.challenge()

def treat(b):
    ic.treat()

def step(b, steps):
    output.clear_output(wait=True)
    run(steps)

# Create and show the app
infect_button.on_click(infect)
treat_button.on_click(treat)
step1_button.on_click(partial(step, steps=1))
step7_button.on_click(partial(step, steps=7))
step30_button.on_click(partial(step, steps=30))

app = widgets.VBox([widgets.HBox([infect_button, treat_button]),
                    widgets.HBox([step1_button, step7_button, step30_button]),
                    output])
display(app)

VBox(children=(HBox(children=(Button(description='Infect', style=ButtonStyle()), Button(description='Treat', s…