In [None]:
from scipy import stats
import pandas as pd
import plotly.express as px
import numpy as np
from scipy.stats import gamma
import statsmodels.api as sm
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from numpy.random import RandomState

In [None]:
# set random seed and random number generator
prng = RandomState(123)

In [None]:
# sample from a gamma distribution with a mean of approx ~5, SD ~1.3
sample = prng.gamma(15, scale=1/3, size=100)

In [None]:
# visualise sample
fig = px.histogram(x=sample)
fig.show()

In [None]:
# kernel over sample to smooth
kde = sm.nonparametric.KDEUnivariate(sample)
k = kde.fit()

In [None]:
# Playing around with visualising smoothed densities
fig = px.line(x=kde.support, y=kde.density)
fig.add_scatter(x=kde.support+3, y=kde.density+0.1,mode='lines')
fig.add_scatter(x=kde.support+10, y=kde.density+0.25,mode='lines')
fig.show()

In [None]:
def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]

In [None]:
first_dist_x = kde.support+0.3
first_dist_y = (kde.density+0.0667)*15

second_dist_x = kde.support+5.2
second_dist_y = (kde.density+1)*15

In [None]:
sec_days_lookup = {2.5: 1,
               3: 1,
               4: 2,
               5:3,
               6:2,
               7:1
              }

list_x1 = []
list_y1 = []
for k,v in sec_days_lookup.items():
    ys = [i for i in range(1,v+1,1)]
    for y in ys:
        list_x1.append(k)
        if y == 1:
            list_y1.append(first_dist_y[np.where(first_dist_x==find_nearest(first_dist_x, k))][0] +1)
        else:
            list_y1.append(first_dist_y[np.where(first_dist_x==find_nearest(first_dist_x, k))][0] +(y+y-1))

In [None]:
tet_days_lookup = {8: 2,
                    9:3,
                   10:1,
                   12:1
                }

list_x2 = []
list_y2 = []
for k,v in tet_days_lookup.items():
    ys = [i for i in range(1,v+1,1)]
    for y in ys:
        list_x2.append(k)
        if y == 1:
            list_y2.append(second_dist_y[np.where(second_dist_x==find_nearest(second_dist_x, k))][0] +1)
        else:
            list_y2.append(second_dist_y[np.where(second_dist_x==find_nearest(second_dist_x, k))][0] +(y+y-1))

In [None]:
# Code to generate plot - lots of manual wrangling and playing around with numbers

fig = make_subplots()
first_case_col = "blue"
secondary_case_col = "darkblue"
second_case_col = "red"
tertiary_case_col = "darkred"


# Add firstgeneration distribution
fig.add_trace(
    go.Scatter(x=first_dist_x, y=first_dist_y, fill="toself"),
    secondary_y=False,
)

# add next generation interval
fig.add_trace(
    go.Scatter(x=second_dist_x, y=second_dist_y, fill="toself"),
    secondary_y=False,
)


# add case dot
fig.add_trace(
    go.Scatter(x=[1], y=[1], line=dict(color=first_case_col)),
    secondary_y=False,
)

# add secondary case dots

fig.add_trace(
    go.Scatter(x=list_x1, y=list_y1, mode='markers', line=dict(color=secondary_case_col)),
    secondary_y=False,
)

fig.add_trace(
    go.Scatter(x=[6], y=[first_dist_y[np.where(first_dist_x==find_nearest(first_dist_x, 6))][0]+3], mode='markers',line=dict(color=second_case_col)),
    secondary_y=False,
)


# add dotted line for teritary cases
fig.add_trace(
    go.Scatter(x=[6,6], y=[8.8,15], mode='lines', line=dict(color=tertiary_case_col,dash='dot')),
    secondary_y=False,
)

# add tertairy case dots
fig.add_trace(
    go.Scatter(x=list_x2, y=list_y2, mode='markers', line=dict(color=tertiary_case_col)),
    secondary_y=False,
)
# add rectangles
fig.add_vrect(x0=min(first_dist_x), x1=max(first_dist_x),
              #annotation_text="secondary cases", annotation_position="top",
              line_width=0, fillcolor=secondary_case_col, opacity=0.05)
fig.add_vrect(x0=min(second_dist_x), x1=max(second_dist_x),
              #annotation_text="tertiary cases", annotation_position="top",
              line_width=0, fillcolor=tertiary_case_col, opacity=0.05)

# add annotation for gen interval and first case
fig.add_annotation(x=1, y=1.5,
            text="Primary case",
            showarrow=False,
            yshift=10)

fig.add_annotation(x=5, y=1.5,
            text="Generation interval",
            showarrow=False,
            yshift=10)

fig.add_annotation(x=10, y=15,
            text="Generation interval",
            showarrow=False,
            yshift=10)

# update plot and marker
fig.update_traces(marker=dict(size=10))
fig.update_layout(plot_bgcolor='white', showlegend=False)
fig.update_yaxes(title='cases', showticklabels=False)
fig.update_xaxes(title='time (days)')

fig.show()