In [204]:
# %load_ext nb_black

<IPython.core.display.Javascript object>

In [51]:
import plotly
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots
import datetime

today_date = datetime.datetime.today().strftime("%m/%d/%Y")

pio.templates.default = "plotly_white"
import pandas as pd
import numpy as np

from scipy.integrate import odeint

<IPython.core.display.Javascript object>

In [65]:
print(px.colors.qualitative.Vivid * 2)


['rgb(229, 134, 6)', 'rgb(93, 105, 177)', 'rgb(82, 188, 163)', 'rgb(153, 201, 69)', 'rgb(204, 97, 176)', 'rgb(36, 121, 108)', 'rgb(218, 165, 27)', 'rgb(47, 138, 196)', 'rgb(118, 78, 159)', 'rgb(237, 100, 90)', 'rgb(165, 170, 153)', 'rgb(229, 134, 6)', 'rgb(93, 105, 177)', 'rgb(82, 188, 163)', 'rgb(153, 201, 69)', 'rgb(204, 97, 176)', 'rgb(36, 121, 108)', 'rgb(218, 165, 27)', 'rgb(47, 138, 196)', 'rgb(118, 78, 159)', 'rgb(237, 100, 90)', 'rgb(165, 170, 153)']


<IPython.core.display.Javascript object>

In [114]:
colors = [
    "rgb(255, 111, 97)",
    "rgb(93, 105, 177)",
    "rgb(82, 188, 163)",
    "rgb(165, 170, 153)",
    "rgb(204, 97, 176)",
    "rgb(36, 121, 108)",
    "rgb(218, 165, 27)",
    "rgb(47, 138, 196)",
    "rgb(118, 78, 159)",
    "rgb(229, 134, 6)",
    "rgb(153, 201, 69)",
    #
    "rgb(255, 111, 97)",
    "rgb(93, 105, 177)",
    "rgb(82, 188, 163)",
    "rgb(165, 170, 153)",
    "rgb(204, 97, 176)",
    "rgb(36, 121, 108)",
    "rgb(218, 165, 27)",
    "rgb(47, 138, 196)",
    "rgb(118, 78, 159)",
    "rgb(229, 134, 6)",
    "rgb(153, 201, 69)",
]

<IPython.core.display.Javascript object>

### exponential growth figure

In [115]:
def get_series(t, i):
    """
    takes in a base population of infected <i>
    and the duration of spreading and returns an array of cumulative infected <c> (recovery period assumed as 14 days)
    and of newly infected <x> per day.
    """
    p = 0  # newly infected per day
    x = []  # array of cumulative infected per day
    c = []  # array of newly infected per day
    for z in range(t + 1):
        if z <= 14:
            p += i ** z
            x.append(p)
            c.append(i ** z)
        else:
            p += i ** z
            x.append(p)
            c.append(i ** z - c[z - 14])

    return (x, c)


def get_frame(t, vals):
    """
    takes in the duration of spreading <t> and an array of initially infected <vals> and calls get_series for every element of 
    vals. A pandas Dataframe with the concatenated results is returned.
    """
    df = pd.DataFrame()
    for i in vals:
        data = (
            pd.DataFrame(get_series(t, i))
            .transpose()
            .rename(columns={0: "cum_{}".format(i), 1: "ind_{}".format(i)})
        )
        df = pd.concat([df, data], axis=1, join="outer", ignore_index=False)
    return df.reset_index()

<IPython.core.display.Javascript object>

In [203]:
def expfigure():
    t = 10
    vals_cont = list(np.linspace(1.5, 3, 5))

    df = get_frame(t, vals_cont)
    fig = make_subplots(
        rows=1,
        cols=2,
        subplot_titles=("Linear scale", "Log scale"),
        x_title="Days",
        y_title="Number of infected",
    )
    for i, val in enumerate(vals_cont):
        fig.add_trace(
            go.Scatter(
                x=df["index"],
                y=df["cum_{}".format(val)],
                mode="lines",
                name="Growth factor of {}".format(val),
                line_shape="spline",
                legendgroup="lin",
                line=dict(color=colors[i]),
            ),
            row=1,
            col=1,
        )
    for i, val in enumerate(vals_cont):
        fig.add_trace(
            go.Scatter(
                x=df["index"],
                y=df["cum_{}".format(val)],
                mode="lines",
                name="Growth factor of {}".format(val),
                line_shape="spline",
                legendgroup="log",
                showlegend=False,
                line=dict(color=colors[i]),
            ),
            row=1,
            col=2,
        )
    fig.update_layout(
        title="Exponential growth of infected population (cumulative sum)",
    )
    fig.update_xaxes(showgrid=False, zeroline=True)
    fig.update_yaxes(showgrid=False, zeroline=True)
    fig.update_yaxes(type="log", row=1, col=2)
    return fig


fig = expfigure()
fig.show()
fig.write_image("EXPmodel.svg", width=700, height=400)

<IPython.core.display.Javascript object>

### Intensive care units

https://de.statista.com/statistik/daten/studie/1111057/umfrage/intensivbetten-je-einwohner-in-ausgewaehlten-laendern/#professional

In [128]:
def icufigure():
    countries = [
        "Germany (2017)",
        "Austria (2018)",
        "US (2018)",
        "France (2018)",
        "Spain (2017)",
        "Italy (2020)",
        "Denmark (2014)",
        "Ireland (2016)",
    ]
    icu_number = [33.9, 28.9, 25.8, 16.3, 9.7, 8.9, 7.8, 5]
    fig = go.Figure(
        data=[
            go.Bar(
                y=countries,
                x=icu_number,
                text=icu_number,
                textposition="outside",
                orientation="h",
                marker=dict(color=colors[0],),
            )
        ]
    )
    fig.update_layout(
        title="Number of intensive care units per 100,000 inhabitants by country",
        yaxis=dict(showgrid=False),
        xaxis=dict(showgrid=False),
    )
    return fig


fig = icufigure()
fig.show()
fig.write_image("ICUdata.svg", width=700, height=400)

<IPython.core.display.Javascript object>

### ECDC Data graphs

In [118]:
def ECDCdata():
    df_ECDC = pd.read_csv(
        "https://opendata.ecdc.europa.eu/covid19/casedistribution/csv"
    )

    df_ECDC["dateRep"] = pd.to_datetime(df_ECDC["dateRep"], infer_datetime_format=True)
    df_ECDC["cases_cum"] = (
        df_ECDC.sort_values("dateRep")
        .groupby("countriesAndTerritories")["cases"]
        .cumsum()
    )
    df_ECDC["deaths_cum"] = (
        df_ECDC.sort_values("dateRep")
        .groupby("countriesAndTerritories")["deaths"]
        .cumsum()
    )

    df_ECDC_pv = df_ECDC.pivot_table(
        df_ECDC, columns="countriesAndTerritories", aggfunc="max"
    ).transpose()

    countries_group2 = [
        "China",
        "South_Korea",
        "Sweden",
        "Italy",
        "Germany",
        "United_States_of_America",
    ]
    df_ECDC_pv_group_2 = df_ECDC_pv.reset_index()[
        (df_ECDC_pv.reset_index()["countriesAndTerritories"] == "China")
        | (df_ECDC_pv.reset_index()["countriesAndTerritories"] == "South_Korea")
        | (df_ECDC_pv.reset_index()["countriesAndTerritories"] == "Sweden")
        | (df_ECDC_pv.reset_index()["countriesAndTerritories"] == "Italy")
        | (df_ECDC_pv.reset_index()["countriesAndTerritories"] == "Germany")
        #     | (
        #         df_ECDC_pv.reset_index()["countriesAndTerritories"]
        #         == "United_States_of_America"
        #     )
    ][["countriesAndTerritories", "cases_cum", "deaths_cum"]].values.tolist()
    return df_ECDC_pv, df_ECDC_pv_group_2

<IPython.core.display.Javascript object>

Graph based on https://ourworldindata.org/covid-mortality-risk

In [127]:
def ECDCfigure():
    df_ECDC_pv, df_ECDC_pv_group_2C = ECDCdata()
    fig = px.scatter(
        df_ECDC_pv[df_ECDC_pv["continentExp"] != "Other"].reset_index(),
        x="cases_cum",
        y="deaths_cum",
        color="continentExp",
        log_x=True,
        log_y=True,
        hover_data=["countriesAndTerritories"],
        title="Cumulative cases and deaths in different countries until {}".format(
            today_date
        ),
        symbol="continentExp",
        color_discrete_sequence=colors,
    )
    v = [0.1, 0.025, 0.00625, 0.0015625]
    diags = []
    for i in v:
        diags.append(1 / i)
    for diag in diags:
        fig.add_trace(
            go.Scatter(
                x=[1 * diag, df_ECDC_pv["cases_cum"].max()],
                y=[1, df_ECDC_pv["cases_cum"].max() / diag],
                mode="lines+text",
                name="CFR of {}%".format(round((1 / diag * 100), 3)),
                line_shape="spline",
                line=dict(width=1, dash="dot", color="grey"),
                #             text=["", "{}% CFR".format(round((1 / diag * 100), 3))],
                textposition="middle right",
                showlegend=False,
            )
        )
        fig.add_annotation(
            text="{}% CFR".format(round((1 / diag * 100), 3)),
            x=np.log10(df_ECDC_pv["cases_cum"].max()),
            y=np.log10(df_ECDC_pv["cases_cum"].max() / diag),
            xref="x",
            yref="y",
            ax=42,
            ay=0,
            width=80,
            align="left",
            font=dict(size=8,),
        )
    fig.update_layout(legend_title_text="")
    fig.update_traces(textfont_size=7,)
    fig.update_xaxes(
        nticks=6, title_text="Log of cumulative confirmed cases",
    )
    fig.update_yaxes(
        nticks=6, title_text="Log of cumulative confirmed deaths",
    )
    for line in range(len(df_ECDC_pv_group_2)):
        fig.add_annotation(
            text=df_ECDC_pv_group_2[line][0].replace("_", " "),
            x=np.log10(df_ECDC_pv_group_2[line][1]),
            y=np.log10(df_ECDC_pv_group_2[line][2]),
            xref="x",
            yref="y",
            ax=-40,
            ay=-40,
            font=dict(size=8,),
        )
    return fig


fig = ECDCfigure()
fig.show()
fig.write_image("ECDCdata.svg", width=700, height=400)

<IPython.core.display.Javascript object>

###  Testing fallacies

https://de.statista.com/statistik/daten/studie/1107749/umfrage/labortest-fuer-das-coronavirus-covid-19-in-deutschland/

In [202]:
def RKIfigure():
    headers = [
        "Week",
        "Number of tests",
        "Number of positive tests",
        "positive tests as a percentage",
        "number of reporting laboratiories",
    ]
    data = [
        #     ["Bis einschließlich KW", "week 10", 124.716, 3.892, 3.1, 90],
        ["Week 11", 127_457, 7_582, 5.9, 114],
        ["Week 12", 348_619, 23_820, 6.8, 152],
        ["Week 13", 361_515, 31_414, 8.7, 151],
        ["Week 14", 408_348, 36_885, 9, 154],
        ["Week 15", 379_233, 30_728, 8.1, 163],
        ["Week 16", 330_027, 21_993, 6.7, 167],
        ["Week 17", 361_999, 18_052, 5, 177],
        ["Week 18", 325_259, 12_585, 3.9, 174],
        ["Week 19", 402_044, 10_746, 2.7, 181],
        ["Week 20", 425_842, 7_060, 1.7, 176],
    ]
    df = pd.DataFrame(data, columns=headers)
    df["Test growth"] = df["Number of tests"].pct_change() * 100

    fig = make_subplots(specs=[[{"secondary_y": True}]])

    fig.add_trace(
        go.Bar(
            x=df["Week"],
            y=df["Number of tests"],
            #             text=df["Number of tests"],
            #             textposition="inside",
            showlegend=False,
            marker=dict(color=colors[0],),
        ),
        secondary_y=False,
    )
    fig.add_trace(
        go.Scatter(
            x=df["Week"],
            y=df["Test growth"],
            mode="markers+lines",
            #             text=df["Test growth"].round(2),
            showlegend=False,
            line=dict(color=colors[1]),
        ),
        secondary_y=True,
    )
    fig.update_layout(title="Number of tests performed in Germany per week",)
    fig.update_yaxes(
        title_text="Number of weekly tests", showgrid=False, secondary_y=False
    )
    fig.update_yaxes(
        title_text="Change in percent", showgrid=False, zeroline=False, secondary_y=True
    )

    return fig


fig = RKIfigure()
fig.show()
fig.write_image("RKIdata.svg", width=700, height=400)

<IPython.core.display.Javascript object>

#### SIRD includign critical cases

In [141]:
def deriv(y, t, N, beta, gamma, rho, alpha, epsilon, zeta, eta, teta):
    S, I, C, R, D = y
    dSdt = -beta(t) * S * (I + 0.5 * C) / N
    dIdt = (
        beta(t) * S * (I + 0.5 * C) / N
        - ((1 - epsilon) * (((1 - alpha) * gamma * I) + (alpha * rho * I)))
        - (epsilon * I)
    )
    dCdt = (epsilon * I) - (
        ((1 - teta) * zeta * min(beds(t, N), C))
        + (teta * eta * min(beds(t, N), C))
        + max(0, C - beds(t, N))
    )
    dRdt = ((1 - epsilon) * (1 - alpha) * gamma * I) + (
        (1 - teta) * zeta * min(beds(t, N), C)
    )
    dDdt = (
        ((1 - epsilon) * alpha * rho * I)
        + (teta * eta * min(beds(t, N), C))
        + max(0, C - beds(t, N))
    )

    return (dSdt, dIdt, dCdt, dRdt, dDdt)


def contact_rate(t):
    init_contact_rate, end_contact_rate, inflection_point, speed_of_change = (
        5,
        5,
        10,
        1,
    )
    return (init_contact_rate - end_contact_rate) / (
        1 + np.exp(-speed_of_change * (-t + inflection_point))
    ) + end_contact_rate


def beta(t):
    prob_infection = 0.25
    return contact_rate(t) * prob_infection


def beds(t, N):
    base_number = (30 / 1000) * N  # correct for Germany
    s = 0.000  # scaling constant meaning 1 bed per 1000 persons added every 5 days
    return base_number + s * t * base_number


def SICRD_model(params):

    N = 1
    init_infected = 0.000001
    S0, I0, C0, R0, D0 = (
        N - init_infected,
        init_infected,
        0,
        0,
        0,
    )  # initial conditions: one infected, rest susceptible
    t = np.linspace(0, 100, 1000)  # Grid of time points (in days)

    y0 = S0, I0, C0, R0, D0  # Initial conditions vector
    S, I, C, R, D = odeint(
        deriv, y0, t, args=(N, beta, gamma, rho, alpha, epsilon, zeta, eta, teta)
    ).T
    SIRD_odeint = pd.DataFrame(
        {
            "Susceptible": S,
            "Infected": I,
            "Critically ill": C,
            "Recovered": R,
            "Dead": D,
        }
    )
    SIRD_odeint["total"] = SIRD_odeint.sum(axis=1)
    SIRD_odeint["days"] = SIRD_odeint.reset_index()["index"].div(10)
    triage_df = pd.DataFrame({"beds": map(beds, t, [1 for x in range(len(t))])})
    triage_df["days"] = triage_df.reset_index()["index"].div(10)

    fig = make_subplots(
        rows=2,
        cols=1,
        x_title="Days",
        y_title="Percentage of total population",
        subplot_titles=("compartmental Model", "triage conditions"),
    )
    #     Subfig 1
    for i, col in enumerate(SIRD_odeint.columns):
        if col != "days" and col != "total":
            fig.add_trace(
                go.Scatter(
                    x=SIRD_odeint["days"],
                    y=SIRD_odeint[col],
                    mode="lines",
                    name=col,
                    line=dict(color=colors[i]),
                    legendgroup="base model",
                ),
                row=1,
                col=1,
            )
    # Subfig 2
    fig.add_trace(
        go.Scatter(
            x=triage_df["days"],
            y=triage_df["beds"],
            mode="lines",
            name="Beds",
            line=(dict(color=colors[0])),
            legendgroup="beds",
        ),
        row=2,
        col=1,
    )
    fig.add_trace(
        go.Scatter(
            x=SIRD_odeint["days"],
            y=SIRD_odeint["Critically ill"],
            mode="lines",
            name="Critically ill",
            line=(dict(color=colors[1])),
            legendgroup="beds",
        ),
        row=2,
        col=1,
    )
    fig.add_trace(
        go.Scatter(
            x=SIRD_odeint["days"],
            y=SIRD_odeint["Dead"],
            mode="lines",
            name="Dead",
            line=(dict(color=colors[2])),
            legendgroup="beds",
        ),
        row=2,
        col=1,
    )

    #     axes update
    fig.update_layout(
        title="Triage in the expanded SIR model",
        #         xaxis_title="Days",
    )
    fig.update_xaxes(showgrid=False)
    fig.update_yaxes(showgrid=False)

    return fig


gamma = 1.0 / 4.0  # time normal case till recovery
rho = 1 / 9.0  # time normal case till death
alpha = 0.02  # death rate normal case
epsilon = 0.05  # critical rate
zeta = 1.0 / 5.0  # time critical case till recovery
eta = 1.0 / 7.0  # time critical case till death
teta = 0.1  # death rate critical case


params = [gamma, rho, alpha, epsilon, zeta, eta, teta]
fig = SICRD_model(params)
fig.show()
fig.write_image("SCRDmodel.svg", width=700, height=400)
# SIRD_odeint

<IPython.core.display.Javascript object>

In [123]:
t = np.linspace(0, 100, 1000)  # Grid of time points (in days)
contact_rates = pd.DataFrame({"cr": map(contact_rate, t)})
contact_rates["days"] = contact_rates.reset_index()["index"].div(10)
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=contact_rates["days"],
        y=contact_rates["cr"],
        mode="lines",
        name="contact rate",
        line=(dict(color=colors[0])),
    )
)
fig.show()

<IPython.core.display.Javascript object>

### Testing in the SIR model

In [137]:
def deriv(y, t, N, beta, gamma):
    S, I, R = y
    dSdt = -beta(t) * S * (I) / N
    dIdt = beta(t) * S * (I) / N - (gamma * I)
    dRdt = gamma * I
    return (dSdt, dIdt, dRdt)


def contact_rate(t):
    init_contact_rate, end_contact_rate, inflection_point, speed_of_change = (
        5,
        5,
        10,
        1,
    )
    return (init_contact_rate - end_contact_rate) / (
        1 + np.exp(-speed_of_change * (-t + inflection_point))
    ) + end_contact_rate


def beta(t):
    prob_infection = 0.25
    return contact_rate(t) * prob_infection


def test_rate(t, N):
    base_number = (100 / 1000) * N  # correct for Germany
    s = 0.04  # scaling constant meaning 1 bed per 1000 persons added every 5 days
    return base_number + s * t * base_number


def SIRmodel(
    params, test_rate, base_model=True, random_testing=True, cluster_testing=True
):

    N = 1
    init_infected = 0.000001
    S0, I0, R0 = (
        N - init_infected,
        init_infected,
        0,
    )  # initial conditions: one infected, rest susceptible
    t = np.linspace(0, 100, 1000)  # Grid of time points (in days)

    y0 = S0, I0, R0  # Initial conditions vector
    S, I, R = odeint(deriv, y0, t, args=(N, beta, gamma)).T
    SIRD_odeint = pd.DataFrame({"Susceptible": S, "Infected": I, "Recovered": R,})
    SIRD_odeint["total"] = SIRD_odeint.sum(axis=1)
    SIRD_odeint["days"] = SIRD_odeint.reset_index()["index"].div(10)

    #       create figure
    fig = make_subplots(
        rows=1,
        cols=1,
        x_title="Days",
        y_title="Percentage of total population",
        subplot_titles=("Basic Model", "Random testing", "Cluster testing"),
    )
    #     Subfig 1
    for i, col in enumerate(SIRD_odeint.columns):
        if col in ["Susceptible", "Infected", "Recovered"]:
            fig.add_trace(
                go.Scatter(
                    x=SIRD_odeint["days"],
                    y=SIRD_odeint[col],
                    mode="lines",
                    name=col,
                    line=dict(color=colors[i]),
                    legendgroup="Basic SIR model",
                ),
                row=1,
                col=1,
            )

    #     axes update
    fig.update_layout(title="Testing in the SIR model",)
    fig.update_xaxes(showgrid=False)
    fig.update_yaxes(showgrid=False)

    return fig


gamma = 1.0 / 4.0  # time normal case till recovery

params = [gamma]
fig = SIRmodel(params, test_rate)
fig.show()
fig.write_image("SIRmodel.svg", width=700, height=400)

<IPython.core.display.Javascript object>

In [145]:
0.25 * 5 * 4

5.0

<IPython.core.display.Javascript object>

In [135]:
def deriv(y, t, N, beta, gamma):
    S, I, R = y
    dSdt = -beta(t) * S * (I) / N
    dIdt = beta(t) * S * (I) / N - (gamma * I)
    dRdt = gamma * I
    return (dSdt, dIdt, dRdt)


def contact_rate(t):
    init_contact_rate, end_contact_rate, inflection_point, speed_of_change = (
        5,
        5,
        10,
        1,
    )
    return (init_contact_rate - end_contact_rate) / (
        1 + np.exp(-speed_of_change * (-t + inflection_point))
    ) + end_contact_rate


def beta(t):
    prob_infection = 0.25
    return contact_rate(t) * prob_infection


def test_rate(t, N):
    base_number = (100 / 1000) * N  # correct for Germany
    s = 0.04  # scaling constant meaning 1 bed per 1000 persons added every 5 days
    return base_number + s * t * base_number


def SIRmodel(
    params, test_rate, base_model=True, random_testing=True, cluster_testing=True
):

    N = 1
    init_infected = 0.000001
    S0, I0, R0 = (
        N - init_infected,
        init_infected,
        0,
    )  # initial conditions: one infected, rest susceptible
    t = np.linspace(0, 100, 1000)  # Grid of time points (in days)

    y0 = S0, I0, R0  # Initial conditions vector
    S, I, R = odeint(deriv, y0, t, args=(N, beta, gamma)).T
    SIRD_odeint = pd.DataFrame({"Susceptible": S, "Infected": I, "Recovered": R,})
    SIRD_odeint["total"] = SIRD_odeint.sum(axis=1)
    SIRD_odeint["days"] = SIRD_odeint.reset_index()["index"].div(10)

    #     Add test statistics for random testing
    SIRD_odeint["Tested overall - random"] = (
        test_rate(SIRD_odeint["days"], N) * SIRD_odeint["total"]
    )
    SIRD_odeint["Tested negative - random"] = SIRD_odeint["Susceptible"] * test_rate(
        SIRD_odeint["days"], N
    ) + SIRD_odeint["Recovered"] * test_rate(SIRD_odeint["days"], N)
    SIRD_odeint["Tested positive - random"] = SIRD_odeint["Infected"] * test_rate(
        SIRD_odeint["days"], N
    )
    SIRD_odeint["Implied prevalence - random"] = SIRD_odeint[
        "Tested positive - random"
    ] / (
        SIRD_odeint["Tested positive - random"]
        + SIRD_odeint["Tested negative - random"]
    )
    # Add test statistics for cluster testing
    SIRD_odeint["Tested positive - cluster"] = SIRD_odeint.apply(
        lambda row: min(test_rate(row["days"], N), row["Infected"]) * 0.6, axis=1
    )
    SIRD_odeint["Tested negative - cluster"] = SIRD_odeint.apply(
        lambda row: test_rate(row["days"], N) * row["total"]
        - min(test_rate(row["days"], N), row["Infected"]) * 0.6,
        axis=1,
    )
    SIRD_odeint["Tested overall - cluster"] = (
        test_rate(SIRD_odeint["days"], N) * SIRD_odeint["total"]
    )
    SIRD_odeint["Implied prevalence - cluster"] = SIRD_odeint[
        "Tested positive - cluster"
    ] / (
        SIRD_odeint["Tested positive - cluster"]
        + SIRD_odeint["Tested negative - cluster"]
    )

    #       create figure
    fig = make_subplots(
        rows=3,
        cols=1,
        x_title="Days",
        y_title="Percentage of total population",
        subplot_titles=("Basic Model", "Random testing", "Cluster testing"),
    )
    #     Subfig 1
    for i, col in enumerate(SIRD_odeint.columns):
        if col in ["Susceptible", "Infected", "Recovered"]:
            fig.add_trace(
                go.Scatter(
                    x=SIRD_odeint["days"],
                    y=SIRD_odeint[col],
                    mode="lines",
                    name=col,
                    line=dict(color=colors[i]),
                    legendgroup="Basic SIR model",
                ),
                row=1,
                col=1,
            )
        if col in [
            "Tested overall - random",
            "Tested positive - random",
            "Tested negative - random",
            "Implied prevalence - random",
            "Infected",
        ]:
            fig.add_trace(
                go.Scatter(
                    x=SIRD_odeint["days"],
                    y=SIRD_odeint[col],
                    mode="lines",
                    name=col,
                    line=dict(
                        color=colors[i],
                        #                         dash="dash"
                    ),
                    legendgroup="Random testing",
                ),
                row=2,
                col=1,
            )
        fig.add_trace(
            go.Scatter(
                x=SIRD_odeint["days"],
                y=SIRD_odeint["total"],
                mode="none",
                name="total",
                showlegend=False,
            ),
            row=2,
            col=1,
        )
        if col in [
            "Tested overall - cluster",
            "Tested positive - cluster",
            "Tested negative - cluster",
            "Implied prevalence - cluster",
            "Infected",
        ]:
            fig.add_trace(
                go.Scatter(
                    x=SIRD_odeint["days"],
                    y=SIRD_odeint[col],
                    mode="lines",
                    name=col,
                    line=dict(
                        color=colors[i],
                        #                         dash="dash"
                    ),
                    legendgroup="Cluster testing",
                    #                 line_shape="spline",
                ),
                row=3,
                col=1,
            )
        fig.add_trace(
            go.Scatter(
                x=SIRD_odeint["days"],
                y=SIRD_odeint["total"],
                mode="none",
                name="total",
                showlegend=False,
            ),
            row=3,
            col=1,
        )
    #     axes update
    fig.update_layout(title="Testing in the SIR model",)
    fig.update_xaxes(showgrid=False)
    fig.update_yaxes(showgrid=False)

    return fig


gamma = 1.0 / 4.0  # time normal case till recovery

params = [gamma]
fig = SIRmodel(params, test_rate)
fig.show()
fig.write_image("SIRtestmodel.svg", width=700, height=500)

<IPython.core.display.Javascript object>