# Supplementary codes for:
## Potential severity and control of Omicron waves depending on pre-existing immunity and immune evasion

## Ferenc A. Bartha, Péter Boldog, Tamás Tekeli, Zsolt Vizi, Attila Dénes and Gergely Röst

---

In [None]:
use_colab = False
if use_colab:
    from google.colab import files


from typing import Union

from ipywidgets import fixed, interact
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib.patches import Rectangle
from scipy.integrate import odeint

## Parametrization

### Epidemiological Parameters

In [None]:
# Delta variant

# basic reproduction number of the Delta variant
#  (relevant for fully susceptible population with no interventions in place)
r0_delta_glob = 6.0

In [None]:
# Observations from South Africa (Laboratory Country)

# ratio of the immunized population
p_south_africa = 0.85

# ratio of the effective reproduction numbers - as observed:
#  R_t^{Omicron} / R_t^{Delta}
ratio_omicron_per_delta_south_africa = 4

In [None]:
# Assumptions on Omicron

# latent period (days): 2-chain        L1->L2
omicron_latent_period = 2.5

# infectious period (days): 4-chain    I1->I2->I3
omicron_infectious_period = 5.

# hospital evasion with pre-existing immunity (probability of evasion)
omicron_hospital_evasion = 0.85

In [None]:
# Deriving model parameters from the above assumptions

# alpha
alpha_glob = 1. / omicron_latent_period

# gamma
gamma_glob = 1. / omicron_infectious_period

### Technical Parameters

In [None]:
# Region for immune evasion (e) and local pre-existing immunity (p_loc)

# immune evasion (e)
e_vals = np.linspace(0, 1, 100)

# local pre-existing immunity (p_loc)
p_loc_vals = np.linspace(0, 1, 100)

In [None]:
# ODE solver

# integration timespan and resolution (t)
t_glob = np.linspace(0, 500, 5000)

# Model compartments
comps = ["s", "l1_s", "l2_s", "i1_s", "i2_s", "i3_s", "i4_s", "r_s",
         "p", "l1_p", "l2_p", "i1_p", "i2_p", "i3_p", "i4_p", "r_p"]

In [None]:
# Figures

# resolution
figures_dpi = 250

# auto download
figures_autodownload = True

# p, e ticks; npi axes
p_ticks_percentage = False
e_ticks_percentage = False
npi_percentage = False
population_percentage = False

# timeplot title enable / disable
figures_timeplot_title = False

# timeplot title top / bottom
figures_timeplot_title_bottom = False

## Methods

### Contour relation: pre-existing immunity vs immune evasion

In [None]:
def r0_omicron_from_contour_relation(
        e: Union[float, np.ndarray],
        p: float = p_south_africa,
        r0_delta: float = r0_delta_glob,
        ratio_omicron_per_delta: float = ratio_omicron_per_delta_south_africa
    ) -> float:
    """
    Approximates the basic reproduction number (R0) of the Omicron variant
    :param Union[float, np.ndarray] e: immune evasion of Omicron, i.e. ratio of individuals with
                  immunity against Delta who are susceptible to Omicron
    :param float p: pre-existing immunized fraction of the population
    :param float r0_delta: basic reproduction number of the Delta variant
    :param float ratio_omicron_per_delta: ratio of effective reproduction numbers
                                        for Omicron and Delta variants
    :return float: basic reproduction number of the Omicron variant
    """
    num = r0_delta * ratio_omicron_per_delta
    denom = 1 + (0 if p == 1 else e * p / (1 - p))

    return num / denom

### Level of non-pharmaceutical interventions (NPI) required to suppress an epidemic

In [None]:
def calculate_suppressing_npi(
        r0: float,
        p: float,
        goal: float = 1
    ) -> float:
    """
    Calculate the necessary contact rate reduction to achieve the <goal> rep. number
    :param float r0: basic reproduction number
    :param float p: pre-existing immunity
    :param float goal: desired reproduction number (<= 1)
    :return float: NPI
    """
    return 0 if (p == 1) else 1 - np.min((1.0, goal / (r0 * (1 - p))))

### Compartmental ODE modeling of the Omicron variant

In [None]:
def omicron_model(
        xs: np.ndarray,
        ts: np.ndarray,
        params: dict
    ) -> np.ndarray:
    """
    SL_2I_4R model with dual immunity
    :param np.ndarray xs: actual array of states
    :param np.ndarray ts: time values
    :param dict params: dictionary of parameters
    :return np.ndarray
    """
    # get parameters
    alpha = params["alpha"]
    beta  = params["beta"]
    gamma = params["gamma"]
    npi   = params["npi"]

    # get all states
    # _s: individuals susceptible to both Omicron and Delta
    # _p: individuals susceptible to Omicron but immune to Delta
    s, l1_s, l2_s, i1_s, i2_s, i3_s, i4_s, r_s, \
    p, l1_p, l2_p, i1_p, i2_p, i3_p, i4_p, r_p = xs

    # total count of infectious individuals
    i_sum = i1_s + i2_s + i3_s + i4_s + i1_p + i2_p + i3_p + i4_p

    # compartmental model
    ds    = - beta * (1 - npi) * s * i_sum
    dl1_s =   beta * (1 - npi) * s * i_sum - 2 * alpha * l1_s
    dl2_s = 2 * alpha * l1_s - 2 * alpha * l2_s
    di1_s = 2 * alpha * l2_s - 4 * gamma * i1_s
    di2_s = 4 * gamma * i1_s - 4 * gamma * i2_s
    di3_s = 4 * gamma * i2_s - 4 * gamma * i3_s
    di4_s = 4 * gamma * i3_s - 4 * gamma * i4_s
    dr_s  = 4 * gamma * i4_s

    dp    = - beta * (1 - npi) * p * i_sum
    dl1_p =   beta * (1 - npi) * p * i_sum - 2 * alpha * l1_p
    dl2_p = 2 * alpha * l1_p - 2 * alpha * l2_p
    di1_p = 2 * alpha * l2_p - 4 * gamma * i1_p
    di2_p = 4 * gamma * i1_p - 4 * gamma * i2_p
    di3_p = 4 * gamma * i2_p - 4 * gamma * i3_p
    di4_p = 4 * gamma * i3_p - 4 * gamma * i4_p
    dr_p  = 4 * gamma * i4_p

    return np.array([ds, dl1_s, dl2_s, di1_s, di2_s, di3_s, di4_s, dr_s,
                     dp, dl1_p, dl2_p, di1_p, di2_p, di3_p, di4_p, dr_p])

In [None]:
def calculate_beta(
        r0: float,
        params: dict
    ) -> float:
    """
    Calculate beta from R0 and other parameters
    :param float r0: basic reproduction number
    :param dict params: dictionary of parameters
    :return float: calculated beta
    """

    return r0 * params["gamma"]

In [None]:
def solve_omicron_model(
        r0_omicron: float,
        e: Union[float, np.ndarray],
        p_loc: float,
        npi_loc: float,
        initial_l1: float,
        t: np.ndarray = t_glob
    ) -> np.ndarray:
    """
    Calculate peak and final sizes
    :param float r0_omicron: basic reproduction number of the Omicron variant
    :param Union[float, np.ndarray] e: immune evasion of Omicron
    :param float p_loc: pre-existing immunity in the model country
    :param float npi_loc: npi in effect in the model country
    :param float initial_l1: initially infected (L1_s + L1_p, symmetric)
    :param np.ndarray t: timespan and resolution of the numerical solution
    :return np.ndarray: numerical solution to the omicron model
    """

    # initial values
    s_0    = 1 - p_loc
    l1_s_0 = initial_l1 / 2.
    l2_s_0 = 0.0
    i1_s_0 = 0.0
    i2_s_0 = 0.0
    i3_s_0 = 0.0
    i4_s_0 = 0.0
    r_s_0  = 0.0

    p_0    = e * p_loc
    l1_p_0 = initial_l1 / 2.
    l2_p_0 = 0.0
    i1_p_0 = 0.0
    i2_p_0 = 0.0
    i3_p_0 = 0.0
    i4_p_0 = 0.0
    r_p_0  = 0.0

    iv = [s_0, l1_s_0, l2_s_0, i1_s_0, i2_s_0, i3_s_0, i4_s_0, r_s_0,
          p_0, l1_p_0, l2_p_0, i1_p_0, i2_p_0, i3_p_0, i4_p_0, r_p_0]

    # set readily known parameters
    params = {
      "alpha": alpha_glob,
      "gamma": gamma_glob,
      "npi": npi_loc
    }

    # calculate beta
    beta = calculate_beta(
      r0=r0_omicron,
      params=params
    )

    params["beta"] = beta

    # compute the numerical solution
    sol = odeint(
      func=omicron_model,
      y0=iv,
      t=t,
      args=(params, )
    )

    return sol

In [None]:
def calculate_peak_and_final_size(
        sol: np.ndarray,
        severity: float = 1,
        relative_severity: float = (1 - omicron_hospital_evasion)
    ) -> list:
    """
    Calculate peak and final sizes
    :param np.ndarray sol: solution of the numerical simulation
    :param float severity: common weight of _s and _p compartments
    :param float relative_severity: additional weight of _p compartments
    :return list: peak and final size
    """

    # unwrap the ODE solution
    sol_d = {comps[i]: sol[:, i] for i in range(len(comps))}

    # plug-in weights
    r = severity * (sol_d["r_s"] + relative_severity * sol_d["r_p"])

    i = severity * (
          sol_d["i1_s"] + sol_d["i2_s"] + sol_d["i3_s"] + sol_d["i4_s"] +
          relative_severity * (sol_d["i1_p"] + sol_d["i2_p"] + sol_d["i3_p"] + sol_d["i4_p"])
    )

    # peak size
    peak_size = np.max(i)

    # final size
    final_size = r[-1]

    return [peak_size, final_size]

## Results

### Contours: R0 of Omicron vs immune evasion

#### Code

In [None]:
def plot_r0_omicron_vs_immune_evasion(
        es: np.ndarray,
        ps: Union[np.ndarray, list],
        save_this_figure: bool = False
    ) -> None:
    """
    Plot R0 of Omicron depending on its immune evasion
    :param np.ndarray es: immune evasion values for the horizontal axis (resultion)
    :param Union[np.ndarray, list] ps: pre-existing immunity values (number of curves)
    :param bool save_this_figure: if True then the figure is saved
    :return None
    """

    # ensure proper fontsize
    plt.rcParams.update({'font.size': 10})

    # setup the coloring scheme
    colors = plt.cm.bone_r(np.linspace(0, 1, len(ps) + 3))[2:-1]

    # setup the figure
    fig, ax = plt.subplots(
      dpi=figures_dpi if save_this_figure else 180,
      figsize=(5, 3)
    )

    # plot a contour for each p \in ps
    for idx, p in enumerate(ps):
        r0_omicron_vals = r0_omicron_from_contour_relation(
            e=es,
            p=p,
            r0_delta=r0_delta_glob,
            ratio_omicron_per_delta=ratio_omicron_per_delta_south_africa
        )

        ax.plot(es, r0_omicron_vals,
                label=str(round(p, 2)) if not p_ticks_percentage else (str(int(p * 100)) + '%'),
                color=colors[idx])

    lgd = ax.legend(loc='right', bbox_to_anchor=(1.6, 0.5),
                    title='Pre-existing immunity\nin South Africa\n(fraction of population)')

    ax.set_xlim(0, 1)
    ax.set_ylim(0, r0_delta_glob * ratio_omicron_per_delta_south_africa)

    ax.set_yticks(range(0, int(r0_delta_glob * ratio_omicron_per_delta_south_africa) + 1, 4))

    ax.set_xlabel('immune evasion')
    ax.set_ylabel('$R_0$ of Omicron')

    # label axes with %
    positions = [0, 0.25, 0.5, 0.75, 1]
    labels = ["0%", "25%", "50%", "75%", "100%"]

    if e_ticks_percentage:
        ax.xaxis.set_major_locator(ticker.FixedLocator(positions))
        ax.xaxis.set_major_formatter(ticker.FixedFormatter(labels))

    if not save_this_figure:
        ax.set_title('Immune Evasion vs $R_0$ of Omicron')
    else:
        my_file_name = "contourRelation.pdf"
        plt.savefig(my_file_name, dpi=figures_dpi,
                    bbox_extra_artists=(lgd,), bbox_inches='tight')
        if figures_autodownload and use_colab:
            files.download(my_file_name)

In [None]:
def heatmap_r0_omicron_vs_immune_evasion(
        es: np.ndarray,
        ps: Union[np.ndarray, list],
        r0s: list,
        add_r0_delta = False,
        add_frame = None,
        save_this_figure: bool = False
    ) -> None:
    """
    Heatmap for R0 of Omicron depending wrt. pre-existing immunity and immune evasion
    :param np.ndarray es: immune evasion values for the vertical axis (resultion)
    :param Union[np.ndarray, list] ps: pre-existing immunity values for the horizontal axis (resultion)
    :param list r0s: R0-contours to be highlighted
    :param bool add_r0_delta: if True, then a background emphasis is added on the R_0 Delta contour
    :param dict add_frame: None or dictionary describing a highlighted frame
    :param bool save_this_figure: if True then the figure is saved
    :return None
    """

    # compute data
    reproduction_numbers = []

    for e in es:
        reproduction_numbers.append([
            r0_omicron_from_contour_relation(e=e, p=p_sa)
            for p_sa in ps
        ])

    # setup the coloring scheme
    my_levels = np.arange(0, np.ceil(r0_delta_glob * ratio_omicron_per_delta_south_africa) + 1, 1)
    colors = plt.cm.bone_r(np.linspace(0, 1, len(my_levels) + 32))[2:-30]

    # ensure proper fontsize
    plt.rcParams.update({'font.size': 10})

    fig, ax = plt.subplots(
      dpi=figures_dpi if save_this_figure else 200,
      figsize=(4, 4)
    )

    ax.contourf(
        ps, es, reproduction_numbers,
        levels=my_levels,
        colors=colors, alpha=1)

    if add_r0_delta:
        ax.contour(
            ps, es,
            reproduction_numbers,
            [r0_delta_glob],
            colors='#6a0033', linewidths = 7,
            alpha = 0.2, linestyles = 'solid')

    contours = ax.contour(
        ps, es, reproduction_numbers, r0s,
        colors='#2a0033', linewidths=1, alpha = 0.8)
    ax.clabel(contours, inline=True, fmt=str, fontsize=7)

    ax.set_ylabel("immune evasion")
    ax.set_xlabel("pre-existing immunity in South Africa")

      # add highlighting frame
    if add_frame is not None:
        frame_p = add_frame["frame_p"]
        frame_e = add_frame["frame_e"]

        highlighted_area = Rectangle(
          (frame_p[0], frame_e[0]),
          frame_p[1] - frame_p[0], frame_e[1] - frame_e[0],
          fc = 'none',
          ec = '#5064a0',
          lw = 5,
          alpha = 0.5)

        ax.add_patch(highlighted_area)

    # label axes with %
    if p_ticks_percentage:
        p_positions = [0.4, 0.6, 0.8, 0.99]
        p_labels = ["40%", "60%", "80%", "100%"]

        ax.xaxis.set_major_locator(ticker.FixedLocator(p_positions))
        ax.xaxis.set_major_formatter(ticker.FixedFormatter(p_labels))

    if e_ticks_percentage:
        e_positions = [0, 0.2, 0.4, 0.6, 0.8, 0.99]
        e_labels = ["0%", "20%", "40%", "60%", "80%", "100%"]

        ax.yaxis.set_major_locator(ticker.FixedLocator(e_positions))
        ax.yaxis.set_major_formatter(ticker.FixedFormatter(e_labels))

    ax.margins(0)
    plt.tight_layout()

    if not save_this_figure:
        ax.set_title('$R_0$ of Omicron')
    else:
        my_file_name = "contourRelationHeatmap.pdf"
        plt.savefig(my_file_name, dpi=figures_dpi)

        if figures_autodownload and use_colab:
            files.download(my_file_name)

#### Figure

In [None]:
interact(
    lambda production: plot_r0_omicron_vs_immune_evasion(
        es=e_vals,
        ps=[0.75, 0.8, 0.85, 0.9, 0.95],
        save_this_figure=production
    ),
    production=False
)

In [None]:
frame_to_add = {
    "frame_p": [0.7, 0.95],
    "frame_e": [0.35, 0.97]
}

interact(
        lambda add_r0_delta, add_frame, production: heatmap_r0_omicron_vs_immune_evasion(
            es = np.linspace(0, 1, 100),
            ps = np.linspace(0.4, 0.99, 100),
            r0s = [3, 6, 9, 12, 15, 18, 20, 22],
            add_r0_delta = add_r0_delta,
            add_frame = (frame_to_add if add_frame else None),
            save_this_figure = production
        ),
        add_r0_delta = True,
        add_frame = True,
        production = False
)

### Level of non-pharmaceutical interventions (NPI) required to suppress Delta

#### Code

In [None]:
def plot_omicron_suppressing_npi(
        ps: Union[np.ndarray, list],
        es: Union[np.ndarray, list],
        p_sa: float = p_south_africa,
        r0_delta: float = r0_delta_glob,
        save_this_figure: bool = True
    ) -> None:
    """
    Plot of Omicron suppressing suppressing NPIs compared to the NPI suppressing Delta
    :param Union[np.ndarray, list] ps: pre-existing immunity values for the horizontal axis (resultion)
    :param Union[np.ndarray, list] es: immune evasion values (number of curves)
    :param float p_sa: pre-existing immunity in South Africa
    :param float r0_delta: R0 of the Delta variant
    :param bool save_this_figure: if True then the figure is saved
    :return None
    """

    # compute the npi suppressing Delta for all model locations (ps)
    npi_suppressing_delta = np.array([
        calculate_suppressing_npi(r0=r0_delta, p=p)
        for p in ps
    ])

    # ensure proper fontsize
    plt.rcParams.update({'font.size': 10})

    # setup the coloring scheme
    colors = plt.cm.bone_r(np.linspace(0, 1, len(es) + 5))[2:-3]

    # setup the figure
    plt.figure(
      dpi=figures_dpi if save_this_figure else 150,
      figsize=(5, 3)
    )

    # plot a curve for each e \in es
    for idx, e in enumerate(es):
        # Get R0 of the Omicron variant
        r0_omicron = r0_omicron_from_contour_relation(
          p=p_sa,
          e=e
        )
        # compute the npi suppressing Delta for all model locations (ps)
        npi_suppressing_omicron = np.array([
            calculate_suppressing_npi(
                r0=r0_omicron,
                p=p * (1 - e)
            )
            for p in ps
        ])
        plt.plot(ps, npi_suppressing_omicron,
                 label=str(round(e, 1)) if not e_ticks_percentage else (str(int(e * 100)) + '%'),
                 color=colors[idx])

    # plot a curve for Delta suppression
    plt.plot(ps, npi_suppressing_delta, 'r--',
             linewidth=3,
             label="suppression of $\Delta$")

    lgd = plt.legend(loc='right', bbox_to_anchor=(1.55, 0.5),
                     title='Immune evasion\nof the Omicron variant')

    plt.xlim(ps[0], ps[-1])
    plt.ylim(0, 1)

    plt.xlabel('pre-existing immunity')
    plt.ylabel('reduction of transmission by NPIs')


    ax = plt.gca()

    # label axes with %
    if p_ticks_percentage:
        p_positions = [0.4, 0.6, 0.8, 0.99]
        p_labels = ["40%", "60%", "80%", "100%"]

        ax.xaxis.set_major_locator(ticker.FixedLocator(p_positions))
        ax.xaxis.set_major_formatter(ticker.FixedFormatter(p_labels))

    if npi_percentage:
        npi_positions = [0, 0.25, 0.5, 0.75, 1]
        npi_labels = ["0%", "25%", "50%", "75%", "100%"]
        ax.yaxis.set_major_locator(ticker.FixedLocator(npi_positions))
        ax.yaxis.set_major_formatter(ticker.FixedFormatter(npi_labels))

    if not save_this_figure:
        plt.title('NPI requirement for controlling Omicron')
    else:
        my_file_name = "npiRequirementPlot.pdf"
        plt.savefig(my_file_name, dpi=figures_dpi,
                    bbox_extra_artists=(lgd,), bbox_inches='tight')

        if figures_autodownload and use_colab:
            files.download(my_file_name)

#### Figures

In [None]:
interact(
  lambda p_sa=p_south_africa, production=False: plot_omicron_suppressing_npi(
      ps=np.linspace(0.4, 1, 1000),
      es=np.arange(0.2, 0.8, 0.1),
      p_sa=p_sa,
      save_this_figure=production
  ),
  p_sa=(0, 1, 0.01),
  production=False
)

### Timeplots of the Omicron model

#### Code

In [None]:
def plot_omicron_model_on_axes(
        ax,
        p_loc: float,
        e: Union[list, np.ndarray],
        t: np.ndarray,
        use_npi_loc = False,
        npi_loc = 0,
        y_range: int = 100,
        add_title = True,
        title_prefix: str = '',
        title_r0: bool = False
    ) -> None:
    """
    Plot omicron model on input axes
    :param ax: axes of the figure
    :param float p_loc: pre-existing immunity of the model country
    :param Union[list, np.ndarray] e: immune evasion ratio of Omicron
    :param bool use_npi_loc: if False then Delta suppressing npi is assumed
    :param float npi_loc: npi in effect in the model country
    :param np.ndarray t: time range and resolution
    :param int y_range: sets the y-range of the main plot (I-plot)
    :param add_title: enable / disable title
    :param str title_prefix: prepends title
    :param bool title_r0: adds R_0, R_t of Omicron to title
    :return: None
    """

    # local npi
    if not use_npi_loc:
        npi_loc = calculate_suppressing_npi(r0=r0_delta_glob, p=p_loc)

    # r0 omicron
    r0_omicron = r0_omicron_from_contour_relation(e=e)

    # Get model solution
    sol = solve_omicron_model(
        r0_omicron=r0_omicron,
        e=e,
        p_loc=p_loc,
        npi_loc=npi_loc,
        initial_l1=0.00001,
        t=t
    )
    sol_d = {comps[i]: sol[:, i] for i in range(len(comps))}

    # get the timeseries for compartments
    s   = sol_d["s"]
    l_s = sol_d["l1_s"] + sol_d["l2_s"]
    i_s = sol_d["i1_s"] + sol_d["i2_s"] + sol_d["i3_s"] + sol_d["i4_s"]
    r_s = sol_d["r_s"]

    p   = sol_d["p"]
    l_p = sol_d["l1_p"] + sol_d["l2_p"]
    i_p = sol_d["i1_p"] + sol_d["i2_p"] + sol_d["i3_p"] + sol_d["i4_p"]
    r_p = sol_d["r_p"]

    # main plot
    color_map = ["#ff6666", "#ffaaaa"]
    ax.stackplot(t, i_s, i_p, colors = color_map)

    ax.set_xlabel("time (days)")
    ax.set_ylabel("infected (%)")

    if population_percentage:
        positions = [0, 0.2, 0.4, 0.6, 0.8, 0.99]
        labels = ["0%", "20%", "40%", "60%", "80%", "100%"]

        ax.yaxis.set_major_locator(ticker.FixedLocator(positions))
        ax.yaxis.set_major_formatter(ticker.FixedFormatter(labels))

    title = title_prefix + "p=" + str(p_loc) + ", e=" + str(e) + ", npi=" + "{:.2f}".format(npi_loc)

    if title_r0:
        title = title + ", $R_0$=" + "{:.2f}".format(r0_omicron) + \
                ", $R_{t^*}$=" + "{:.2f}".format(r0_omicron * (1 - npi_loc) * (1 - p_loc + e * p_loc))
    if add_title:
        ax.set_title(title) if not figures_timeplot_title_bottom else ax.set_title(title, y = -.2)

    ax.set_xlim([0, t[-1]])
    ax.set_ylim([0, y_range])

    # create the inset
    left, bottom, width, height = [0.55, 0.55, 0.40, 0.40]
    ax2 = ax.inset_axes([left, bottom, width, height])

    color_map_inset = color_map + ["#ffffff", "#dfdfdf", "#d0d0d0"]

    ax2.stackplot(
        t,
        r_s,
        r_p,
        (s + l_s + i_s + l_p + i_p),
        p,
        np.full(r_s.shape, (1 - e) * p_loc),
        colors=color_map_inset)

    ax2.set_ylabel("affected" + ("(%)" if population_percentage else ""))

    if population_percentage:
        ax2.yaxis.set_major_locator(ticker.FixedLocator(positions))
        ax2.yaxis.set_major_formatter(ticker.FixedFormatter(labels))

    ax2.set_xlim([0, t[-1]])
    ax2.set_ylim([0, 1])

In [None]:
def plot_omicron_model(
        p_loc: Union[int, float] = 0.5,
        e: Union[float, np.ndarray] = 0.5,
        use_npi_loc = False,
        npi_loc = 0,
        t_end: Union[int, float] = 200,
        y_range: Union[int, float] = 20,
        title_prefix: str = '',
        title_r0: bool = True,
        save_this_figure: bool = False
    ) -> None:
    """
    Plot omicron model
    :param title_r0: adds R_0, R_t of Omicron to title
    :param Union[int, float] p_loc: pre-existing immunity of the model country
    :param Union[float, np.ndarray] e: immune evasion ratio of Omicron
    :param bool use_npi_loc: if False then Delta suppressing npi is assumed
    :param float npi_loc: npi in effect in the model country
    :param Union[int, float] t_end: final simulation time
    :param Union[int, float] y_range: sets the y-range of the main plot (I-plot)
    :param str title_prefix: prepends title
    :param bool save_this_figure: if True then the figure is saved
    :return: None
    """
    fig = plt.figure(
        dpi=figures_dpi if save_this_figure else 150,
        figsize=(4, 4))

    ax = plt.gca()

    plt.rcParams.update({'font.size': 9})

    plot_omicron_model_on_axes(
        ax=ax,
        p_loc=p_loc,
        e=e,
        use_npi_loc = use_npi_loc,
        npi_loc = npi_loc,
        t=np.linspace(0, t_end, 200),
        y_range=y_range,
        title_prefix=title_prefix,
        title_r0=title_r0
    )

    fig.tight_layout()

    if save_this_figure:
        my_file_name = "singleTimeplot.pdf"
        plt.savefig(my_file_name, dpi=figures_dpi)

        if figures_autodownload and use_colab:
            files.download(my_file_name)

In [None]:
def multiplot_omicron_model(
        ps: Union[list, np.ndarray],
        es: Union[list, np.ndarray, float],
        title_prefixes: list,
        t_end: Union[float, int],
        y_range: Union[float, int],
        npis: Union[list, np.ndarray] = None,
        title_r0: bool = False,
        save_this_figure: bool = False
    ) -> None:
    """
    4 timeplots of Omicron spread (4 scenarios)
    :param Union[list, np.ndarray] ps: pre-existing immunity levels of model countries (4-list)
    :param Union[list, np.ndarray, float] es: immune evasion ratios of Omicron (4-list)
    :param list npis: npis overriding Delta suppressing npi or None (2 / 4-list)
    :param list title_prefixes: prefixes to titles (4-list)
    :param Union[float, int] t_end: final simulation time (common)
    :param Union[float, int] y_range: sets the y-range of the main I-plots (common)
    :param bool title_r0: adds R_0, R_t of Omicron to titles (common)
    :param bool save_this_figure: if True then the figure is saved
    :return None
    """
    four_plot = (len(ps) == 4)

    fig = plt.figure(
      dpi=figures_dpi if save_this_figure else 110,
      figsize = (7, 7) if four_plot else (8, 4)
    )

    plt.rcParams.update({'font.size': 7})

    if not isinstance(title_prefixes, list):
        title_prefixes = ['', '', '', '']

    t = np.linspace(0, t_end, 1000)

    if npis is None:
        use_npi_locs = [False, False, False, False]
        npi_locs = [0, 0, 0, 0]
    else:
        use_npi_locs = [True, True, True, True]
        npi_locs = npis
    if four_plot:
        ax = fig.add_subplot(221)
        plot_omicron_model_on_axes(
            ax=ax, p_loc=ps[0], e=es[0],
            use_npi_loc=use_npi_locs[0],
            npi_loc=npi_locs[0],
            t=t, add_title = figures_timeplot_title or not save_this_figure,
            title_prefix=title_prefixes[0],
            y_range=y_range,
            title_r0=title_r0)
        ax = fig.add_subplot(222)
        plot_omicron_model_on_axes(
            ax=ax, p_loc=ps[1], e=es[1],
            use_npi_loc=use_npi_locs[1],
            npi_loc=npi_locs[1],
            t=t, add_title = figures_timeplot_title or not save_this_figure,
            title_prefix=title_prefixes[1],
            y_range=y_range, title_r0=title_r0)
        ax = fig.add_subplot(223)
        plot_omicron_model_on_axes(
            ax=ax, p_loc=ps[2], e=es[2],
            use_npi_loc=use_npi_locs[2],
            npi_loc=npi_locs[2],
            t=t, add_title = figures_timeplot_title or not save_this_figure,
            title_prefix=title_prefixes[2],
            y_range=y_range, title_r0=title_r0)
        ax = fig.add_subplot(224)
        plot_omicron_model_on_axes(
            ax=ax, p_loc=ps[3], e=es[3],
            use_npi_loc=use_npi_locs[3],
            npi_loc=npi_locs[3],
            t=t, add_title = figures_timeplot_title or not save_this_figure,
            title_prefix=title_prefixes[3],
            y_range=y_range, title_r0=title_r0)
    else:
        ax = fig.add_subplot(121)
        plot_omicron_model_on_axes(ax = ax, p_loc = ps[0], e = es[0],
                                   use_npi_loc = use_npi_locs[0],
                                   npi_loc = npi_locs[0], t = t,
                                   add_title = figures_timeplot_title or not save_this_figure,
                                   title_prefix = title_prefixes[0],
                                   y_range = y_range, title_r0 = title_r0)
        ax = fig.add_subplot(122)
        plot_omicron_model_on_axes(ax = ax, p_loc = ps[1], e = es[1],
                                   use_npi_loc = use_npi_locs[1],
                                   npi_loc = npi_locs[1], t = t,
                                   add_title = figures_timeplot_title or not save_this_figure,
                                   title_prefix = title_prefixes[1],
                                   y_range = y_range, title_r0 = title_r0)
    fig.tight_layout()

    if save_this_figure:
        my_file_name = "fourTimeplots.pdf"
        plt.savefig(my_file_name, dpi=figures_dpi)
        if figures_autodownload and use_colab:
            files.download(my_file_name)

#### Figures

In [None]:
interact(
    plot_omicron_model,
    p_loc=(0, 1, 0.01),
    e=(0, 1, 0.01),
    t_end=(0, 500, 1),
    y_range=(0, 1, 0.01),
    title_r0=fixed(True),
    title_prefix=fixed(''),
    save_this_figure=fixed(False),
    use_npi_loc=fixed(False),
    npi_loc=fixed(0),
)

In [None]:
interact(
    lambda production=False: multiplot_omicron_model(
        ps=[0.1, 0.75, 0.9, 0.96],
        es=[0.03, 0.08, 0.47, 0.68],
        npis=None,
        title_prefixes=['a) ', 'b) ', 'c) ', 'd) '],
        t_end=75,
        y_range=0.60,
        title_r0=True,
        save_this_figure=production
    ),
    production=False
)


In [None]:
interact(
    lambda production=False: multiplot_omicron_model(
        ps=[0.6, 0.9, 0.6, 0.9],
        es=[0.8, 0.8, 0.5, 0.5],
        npis = None,
        title_prefixes=['a) ', 'b) ', 'c) ', 'd) '],
        t_end=200,
        y_range=0.4,
        title_r0=False,
        save_this_figure=production
    ),
    production=False
)

In [None]:
interact(
    lambda production = False : multiplot_omicron_model(
        ps = [0.9, 0.9],
        es = [0.8, 0.5],
        npis = [0.40, 0.40],
        title_prefixes = ['b) ', 'd) '],
        t_end = 200,
        y_range = 0.40,
        title_r0 = False,
        save_this_figure = production
    ),
    production = False
)

### Analysis of peak and final size

#### Code

##### Data generators

In [None]:
def calculate_for_fixed_e_all_peak_and_final_sizes(
        e: Union[list, np.ndarray, float],
        ps: Union[list, np.ndarray],
        p_sa: float = p_south_africa,
        severity: float = 1,
        relative_severity: float = (1 - omicron_hospital_evasion)
    ) -> list:
    """
    Interactive plot for relationship between peak and final size
    :param Union[list, np.ndarray, float] ps: pre-existing immunity values for the horiztonal axis
    :param Union[list, np.ndarray] e: immune evasion of Omicron
    :param float p_sa: pre-existing immunity in South Africa
    :param float severity: common weight of _s and _p compartments
    :param float relative_severity: additional weight of _p compartments
    :return list: list of peak sizes and list of final sizes
    """
    # r0 omicron
    r0_omicron = r0_omicron_from_contour_relation(e=e, p=p_sa)

    peak_sizes = []
    final_sizes = []

    for p_loc in ps:
        # local npi
        npi_loc = calculate_suppressing_npi(
            r0 = r0_delta_glob,
            p = p_loc)
        # Get model solution
        sol = solve_omicron_model(
            r0_omicron=r0_omicron,
            e=e,
            p_loc=p_loc,
            npi_loc=npi_loc,
            initial_l1=0.00001,
            t=t_glob)
        peak_size, final_size = calculate_peak_and_final_size(
          sol=sol,
          severity=severity,
          relative_severity=relative_severity)

        peak_sizes.append(peak_size)
        final_sizes.append(final_size)

    return [peak_sizes, final_sizes]

In [None]:
def generate_heatmap_data(
        severity: float = 1,
        relative_severity: float = (1-omicron_hospital_evasion),
    ) -> tuple:
    """
    Generates data for heatmaps
    :param float severity: common weight of _s and _p compartments
    :param float relative_severity: additional weight of _p compartments
    :return tuple: tuple containing final sizes, peak sizes and reproduction numbers
    """
    peak_sizes = []
    final_sizes = []
    reproduction_numbers = []

    for e in e_vals:
        peaks, finals = calculate_for_fixed_e_all_peak_and_final_sizes(
            e=e,
            ps=p_loc_vals,
            severity=severity,
            relative_severity=relative_severity)

        peak_sizes.append(peaks)
        final_sizes.append(finals)

        # r0 omicron
        r0_omicron = r0_omicron_from_contour_relation(e=e, p=p_south_africa)

        # R_{t^*} in model countries
        reproduction_numbers.append(
            [r0_omicron *
             (1 - calculate_suppressing_npi(r0=r0_delta_glob, p=p_loc)) * (1 - p_loc + e * p_loc)
             for p_loc in p_loc_vals
             ])

    return np.array(peak_sizes), np.array(final_sizes), np.array(reproduction_numbers)

##### Figure generators

In [None]:
def plot_peak_and_final_size(
        e: float,
        p_sa: float = p_south_africa,
        severity: float = 1,
        relative_severity: float = (1 - omicron_hospital_evasion),
        y_limit_peak: list = 1.,
        y_limit_final: list = 1.,
        save_this_figure: bool = False
    ) -> None:
    """
    Plot of peak and final size wrt. pre-existing immunity in model country (p_loc)
    :param float e: immune evasion of Omicron
    :param float p_sa: pre-existing immunity in South Africa
    :param float severity: common weight of _s and _p compartments
    :param float relative_severity: additional weight of _p compartments
    :param list y_limit_peak: ymax for the peak size
    :param list y_limit_final: ymax for the final size
    :param bool save_this_figure: if True then the figure is saved
    :return None
    """
    peak_sizes, final_sizes = calculate_for_fixed_e_all_peak_and_final_sizes(
        e=e,
        ps=p_loc_vals,
        p_sa=p_sa,
        severity=severity,
        relative_severity=relative_severity
    )

    fig = plt.figure(
        dpi=figures_dpi if save_this_figure else 110,
        figsize=(5, 3))

    plt.rcParams.update({'font.size': 7})

    # label axes with %
    positions = [0, 0.25, 0.5, 0.75, 1]
    labels = ["0%", "25%", "50%", "75%", "100%"]

    ax = fig.add_subplot(121)

    # peak sizes
    ax.plot(p_loc_vals, peak_sizes)
    ax.set_xlabel("pre-existing immunity")
    if p_ticks_percentage:
        ax.xaxis.set_major_locator(ticker.FixedLocator(positions))
        ax.xaxis.set_major_formatter(ticker.FixedFormatter(labels))
    ax.set_title("peak size")
    ax.set_ylim(0.0, y_limit_peak)

    if population_percentage:
        ax.yaxis.set_major_locator(ticker.FixedLocator(positions))
        ax.yaxis.set_major_formatter(ticker.FixedFormatter(labels))

    # final sizes
    ax = fig.add_subplot(122)
    ax.plot(p_loc_vals, final_sizes)
    ax.set_xlabel("pre-existing immunity")
    ax.set_title("final size")
    ax.set_ylim(0.0, y_limit_final)
    if population_percentage:
        ax.yaxis.set_major_locator(ticker.FixedLocator(positions))
        ax.yaxis.set_major_formatter(ticker.FixedFormatter(labels))

    if p_ticks_percentage:
        ax.xaxis.set_major_locator(ticker.FixedLocator(positions))
        ax.xaxis.set_major_formatter(ticker.FixedFormatter(labels))

    # finalize
    fig.tight_layout()

    if save_this_figure:
        my_file_name = "peakAndFinalSize.pdf"
        plt.savefig(my_file_name, dpi=figures_dpi)
        if figures_autodownload and use_colab:
            files.download(my_file_name)

In [None]:
def plot_heatmap(
        data: np.ndarray,
        typ: str = "final",
        add_frame: dict = None,
        add_npi_plot: bool = True,
        save_this_figure: bool = False
    ) -> None:
    """
    Generate heatmap of given type from the data
    :param np.ndarray data: data given as [[data(p, e) for p_loc_vals] for e_vals]
    :param str typ: final, peak, reproduction_number
    :param dict add_frame: None or dictionary describing a highlighted frame
    :param bool add_npi_plot: adding a plot of Delta suppressing npis
    :param bool save_this_figure: if True then the figure is saved
    :return None
    """

    this_figure_dpi = figures_dpi if save_this_figure else 100

    if add_npi_plot:
        fig, (ax1, ax) = plt.subplots(
            2, sharex=True, dpi=this_figure_dpi,
            figsize=(5, 7.5),
            gridspec_kw={'height_ratios': [1, 3]})

        # NPI plot
        ax1.plot(
            p_loc_vals,
            [calculate_suppressing_npi(r0=r0_delta_glob, p=p_loc)
             for p_loc in p_loc_vals])
        ax1.set_ylabel("NPI controlling Delta")
        if npi_percentage:
            positions = [0, 0.25, 0.5, 0.75, 1]
            labels = ["0%", "25%", "50%", "75%", "100%"]

            ax1.yaxis.set_major_locator(ticker.FixedLocator(positions))
            ax1.yaxis.set_major_formatter(ticker.FixedFormatter(labels))
        ax1.set_ylim(0, 1.)
        ax1.margins(0)
    else:
        fig, ax = plt.subplots(dpi=this_figure_dpi,
                               figsize=(7, 7))

    plt.rcParams.update({'font.size': 16})

    # final size
    if typ == "final":
        frame_color = "#4d0000"
        marker_color = "#4d0000"

        levels = [0.0001, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
        colormap = 'Reds'

        curves = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
        curve_color = '#4a4a4a'

        title = 'final size'

    # peak size
    elif typ == "peak":
        frame_color = "#804000"
        marker_color = "#804000"

        levels = [0.0001, 0.001, 0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4]
        colormap = 'Oranges'

        curves = [0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4]
        curve_color = '#5e5e5e'

        title = 'peak size'

    # reproduction number
    else:
        frame_color = "black"
        marker_color = "black"

        levels = [1, 1.2, 1.5, 2, 2.5, 3, 3.5, 4, 5, 6]
        colormap = 'Purples'

        curves = [1, 1.2, 1.5, 2, 2.5, 3, 3.5, 4, 5, 6]
        curve_color = '#5e5e5e'

        title = 'control reproduction number'

    if population_percentage and typ != "reproduction_number":
        data = np.array(data) * 100
        levels = np.array(levels) * 100
        curves = np.array(curves) * 100

    # generate the main heatmap
    ax.contourf(
        p_loc_vals, e_vals, data,
        levels=levels,
        cmap=colormap, alpha=1)

    contours = ax.contour(
        p_loc_vals, e_vals, data,
        curves,
        colors=curve_color, linewidths=1)

    ax.clabel(contours, inline=True, fmt=str, fontsize=10)

    ax.set_ylabel("immune evasion")
    ax.set_xlabel("pre-existing immunity")

    if not save_this_figure:
        ax.set_title(title, fontsize=20)

    ax.margins(0)

    # label axes with %
    positions = [0, 0.25, 0.5, 0.75, 1]
    labels = ["0%", "25%", "50%", "75%", "100%"]

    if p_ticks_percentage:
        ax.xaxis.set_major_locator(ticker.FixedLocator(positions))
        ax.xaxis.set_major_formatter(ticker.FixedFormatter(labels))

    if e_ticks_percentage:
        ax.yaxis.set_major_locator(ticker.FixedLocator(positions))
        ax.yaxis.set_major_formatter(ticker.FixedFormatter(labels))

    # add highlighting frame
    if add_frame is not None:
        frame_p = add_frame["frame_p"]
        frame_e = add_frame["frame_e"]

        markers = add_frame["markers"]

        highlighted_area = Rectangle(
          (frame_p[0], frame_e[0]),
          frame_p[1] - frame_p[0], frame_e[1] - frame_e[0],
          fc='none',
          ec=frame_color,
          lw=5,
          alpha=0.5)

        ax.add_patch(highlighted_area)

        for marker in markers:
            ax.text(marker["p"] + 0.01, marker["e"] + 0.01,
                    s=marker["name"], fontsize=12, color=marker_color)
            ax.plot(marker["p"], marker["e"], "o",
                    color=marker_color, linewidth=3)

    # finalize
    fig.tight_layout()

    if save_this_figure:
        my_file_name = "heatmap-" + typ + ".pdf"
        plt.savefig(my_file_name, dpi=figures_dpi)
        if figures_autodownload and use_colab:
            files.download(my_file_name)

#### Figures

##### Plot of peak and final sizes for fixed immune evasion

In [None]:
interact(
    plot_peak_and_final_size,
    e=(0.2, 1, 0.01),
    p_sa=(0, 1, 0.01),
    severity=(0, 1, 0.01),
    relative_severity=(0, 1, 0.01),
    y_limit_peak=(0, 1, 0.01),
    y_limit_final=(0, 1, 0.01),
    save_this_figure=fixed(False)
)

##### Heatmaps for peak size, final size, and control reproduction number of Omicron

###### Data generation [slow ~ 2 x 1m 30s]

In [None]:
# generate data considering the population not immune to Omicron
peak_sizes_to_plot, final_sizes_to_plot, _ = generate_heatmap_data(
    severity=1,
    relative_severity=1)

In [None]:
# generate data considering the population not immune to Delta
peak_sizes_s_only, final_sizes_s_only, reproduction_numbers_to_plot = generate_heatmap_data(
    severity=1,
    relative_severity=0)

###### Heatmaps

In [None]:
frame_to_add = {
    "frame_p": [0.5, 0.97],
    "frame_e": [0.35, 0.97],
    "markers": [
                 {"p": 0.6, "e": 0.8, "name": "a"},
                 {"p": 0.9, "e": 0.8, "name": "b"},
                 {"p": 0.6, "e": 0.5, "name": "c"},
                 {"p": 0.9, "e": 0.5, "name": "d"}
               ]
}

print('CONTROL REPRODUCTION NUMBER')

interact(
    lambda add_npi_plot, add_frame, production: plot_heatmap(
        data=reproduction_numbers_to_plot,
        typ="reproduction_number",
        add_frame=(frame_to_add if add_frame else None),
        add_npi_plot=add_npi_plot,
        save_this_figure=production
    ),
    add_npi_plot=True,
    add_frame=True,
    production=False
)

In [None]:
frame_to_add = {
    "frame_p": [0.5, 0.97],
    "frame_e": [0.35, 0.97],
    "markers": [
                 {"p": 0.6, "e": 0.8, "name": "a"},
                 {"p": 0.9, "e": 0.8, "name": "b"},
                 {"p": 0.6, "e": 0.5, "name": "c"},
                 {"p": 0.9, "e": 0.5, "name": "d"}
               ]
}

print('PEAK SIZE FOR SEVERITY = 1, RELATIVE SEVERITY = 1')

interact(
    lambda add_npi_plot, add_frame, production: plot_heatmap(
        data=peak_sizes_to_plot,
        typ="peak",
        add_frame=(frame_to_add if add_frame else None),
        add_npi_plot=add_npi_plot,
        save_this_figure=production
    ),
    add_npi_plot=True,
    add_frame=True,
    production=False
)

In [None]:
frame_to_add = {
    "frame_p": [0.5, 0.97],
    "frame_e": [0.35, 0.97],
    "markers": [
                 {"p": 0.6, "e": 0.8, "name": "a"},
                 {"p": 0.9, "e": 0.8, "name": "b"},
                 {"p": 0.6, "e": 0.5, "name": "c"},
                 {"p": 0.9, "e": 0.5, "name": "d"}
               ]
}

print('FINAL SIZE FOR SEVERITY = 1, RELATIVE SEVERITY = 1')

interact(
    lambda add_npi_plot, add_frame, production: plot_heatmap(
        data=final_sizes_to_plot,
        typ="final",
        add_frame=(frame_to_add if add_frame else None),
        add_npi_plot=add_npi_plot,
        save_this_figure=production
    ),
    add_npi_plot=True,
    add_frame=True,
    production=False
)

In [None]:
frame_to_add = {
    "frame_p": [0.5, 0.97],
    "frame_e": [0.35, 0.97],
    "markers": [
                 {"p": 0.6, "e": 0.8, "name": "a"},
                 {"p": 0.9, "e": 0.8, "name": "b"},
                 {"p": 0.6, "e": 0.5, "name": "c"},
                 {"p": 0.9, "e": 0.5, "name": "d"}
               ]
}

print('PEAK SIZE FOR SEVERITY = 1, RELATIVE SEVERITY = 0')

interact(
    lambda add_npi_plot, add_frame, production : plot_heatmap(
        data=peak_sizes_s_only,
        typ="peak",
        add_frame=(frame_to_add if add_frame else None),
        add_npi_plot=add_npi_plot,
        save_this_figure=production
    ),
    add_npi_plot=True,
    add_frame=True,
    production=False
)

In [None]:
frame_to_add = {
    "frame_p": [0.5, 0.97],
    "frame_e": [0.3, 0.97],
    "markers": [
                 {"p": 0.6, "e": 0.8, "name": "a"},
                 {"p": 0.9, "e": 0.8, "name": "b"},
                 {"p": 0.6, "e": 0.5, "name": "c"},
                 {"p": 0.9, "e": 0.5, "name": "d"}
               ]
}

print('FINAL SIZE FOR SEVERITY = 1, RELATIVE SEVERITY = 0')

interact(
    lambda add_npi_plot, add_frame, production: plot_heatmap(
        data=final_sizes_s_only,
        typ="final",
        add_frame=(frame_to_add if add_frame else None),
        add_npi_plot=add_npi_plot,
        save_this_figure=production
    ),
    add_npi_plot=True,
    add_frame=True,
    production=False
)