Skip to content

Commit

Permalink
Added Rt estimation & plotting function
Browse files Browse the repository at this point in the history
  • Loading branch information
trouleau committed Apr 24, 2020
1 parent 0156342 commit 88328a4
Show file tree
Hide file tree
Showing 2 changed files with 338 additions and 184 deletions.
317 changes: 133 additions & 184 deletions sim/lib/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@
import networkx as nx
import scipy
import scipy.optimize
from scipy.interpolate import interp1d
import scipy as sp
import random as rd
import os, math
from datetime import datetime
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.dates import date2num, num2date

from lib.measures import (MeasureList, BetaMultiplierMeasure,
SocialDistancingForAllMeasure, BetaMultiplierMeasureByType,
SocialDistancingForPositiveMeasure, SocialDistancingByAgeMeasure, SocialDistancingForSmartTracing, ComplianceForAllMeasure)
from lib.rt import compute_daily_rts, R_T_RANGE

import numpy as np
import seaborn as sns
Expand Down Expand Up @@ -55,21 +58,22 @@ def days_to_datetime(arr, start_date):
return pd.to_datetime(ts, unit='s')


def lockdown_widget(lockdown_at, start_date, lockdown_label_y, ymax, lockdown_label, ax, ls='--', xshift=0.0):
def lockdown_widget(lockdown_at, start_date, lockdown_label_y, ymax,
lockdown_label, ax, ls='--', xshift=0.0, zorder=None):
# Convert x-axis into posix timestamps and use pandas to plot as dates
lckdn_x = days_to_datetime(lockdown_at, start_date=start_date)
ax.plot([lckdn_x, lckdn_x], [0, ymax], linewidth=2.5, linestyle=ls,
color='black', label='_nolegend_')
color='black', label='_nolegend_', zorder=zorder)
lockdown_label_y = lockdown_label_y or ymax*0.4
ax.text(x=lckdn_x - pd.Timedelta(2.1 + xshift, unit='d'),
y=lockdown_label_y, s=lockdown_label, rotation=90)


def target_widget(show_target,start_date, ax):
def target_widget(show_target,start_date, ax, zorder=None):
txx = np.linspace(0, show_target.shape[0] - 1, num=show_target.shape[0])
txx = days_to_datetime(txx, start_date=start_date)
ax.plot(txx, show_target, linewidth=4, linestyle='', marker='X', ms=6,
color='red', label='COVID-19 case data')
color='red', label='COVID-19 case data', zorder=zorder)


class Plotter(object):
Expand Down Expand Up @@ -105,7 +109,7 @@ def __init__(self):
'#dc2ade',
'#21ff53',
'#323edd',
'#ff9021',
'#ff9021',
'#4d089a',
'#cc0066',
'#ff6666',
Expand All @@ -114,6 +118,13 @@ def __init__(self):
'#ff2222'
]

self.color_different_scenarios_alt = [
'#a1dab4',
'#41b6c4',
'#2c7fb8',
'#253494',
]



# sequential
Expand Down Expand Up @@ -195,7 +206,6 @@ def __comp_state_over_time(self, sim, state, acc):
stds.append(np.std(restarts))
return np.array(ts), np.array(means), np.array(stds)


def __comp_contained_over_time(self, sim, measure, acc):
'''
Computes `state` variable over time [0, self.max_time] with given accuracy `acc
Expand All @@ -209,25 +219,6 @@ def __comp_contained_over_time(self, sim, measure, acc):
stds.append(np.std(restarts))
return np.array(ts), np.array(means), np.array(stds)

def __comp_checkins_in_a_day(self, sim, r, t):
site_checkins = np.zeros(sim.n_sites, dtype='bool')

for site in range(sim.n_sites):
for indiv in range(sim.n_people):
if ( (not sim.measure_list[r].is_contained_prob(SocialDistancingForAllMeasure, t=t, j=indiv)) and
(not sim.measure_list[r].is_contained_prob(SocialDistancingForSmartTracing, t=t, j=indiv)) and
(not sim.measure_list[r].is_contained_prob(SocialDistancingByAgeMeasure, t=t, age=sim.people_age[r, indiv])) and
(not sim.measure_list[r].is_contained_prob(SocialDistancingForPositiveMeasure,
t=t, j=indiv,
state_posi_started_at=sim.state_started_at['posi'][r, :],
state_posi_ended_at=sim.state_ended_at['posi'][r, :],
state_resi_started_at=sim.state_started_at['resi'][r, :],
state_dead_started_at=sim.state_started_at['dead'][r, :])) and
(not sim.state['dead'][r, indiv]) and
len(sim.mob[r].list_intervals_in_window_individual_at_site(indiv, site, t, t+24.0)) > 0):
site_checkins[site] += 1
return site_checkins

def plot_cumulative_infected(self, sim, title='Example', filename='daily_inf_0',
figsize=(6, 5), errorevery=20, acc=1000, ymax=None,
lockdown_label='Lockdown', lockdown_at=None,
Expand Down Expand Up @@ -578,164 +569,6 @@ def plot_daily_at_home(self, sim, title='Example', filename='daily_at_home_0', f
plt.close()
return

def __compute_Rt_over_time(self, sim, estimation_window):
'''
Computes Rt over time by counting infected "children" of nodes
that got infectious in windows of size `estimation_window`
over the whole time of simulation,
Returns Rt as well as proportion accounted to by different
types of infections (iasy, ipre, isym)
'''

ts = [0.0]
Rt_mu, Rt_std = [0.0], [0.0]
prop_iasy_mu, prop_ipre_mu, prop_isym_mu = [0.33], [0.33], [0.34]
prop_iasy_std, prop_ipre_std, prop_isym_std = [0.0], [0.0], [0.0]

acc = math.ceil(sim.max_time / estimation_window)
for aa in range(acc - 1):

# discrete time window
t0 = aa * estimation_window
t1 = (aa + 1) * estimation_window
ts.append((t0 + t1) / 2)

tmp_Rt, tmp_prop_iasy, tmp_prop_ipre, tmp_prop_isym = [], [], [], []
for r in range(sim.random_repeats):

# people that got infectious in this window
became_iasy = (sim.state_started_at['iasy'][r] >= t0) & (
sim.state_started_at['iasy'][r] < t1)
became_ipre = (sim.state_started_at['ipre'][r] >= t0) & (
sim.state_started_at['ipre'][r] < t1)
became_isym = (sim.state_started_at['isym'][r] >= t0) & (
sim.state_started_at['isym'][r] < t1)

idx_became_iasy = np.where(became_iasy)[0]
idx_became_ipre = np.where(became_ipre)[0]
idx_became_isym = np.where(became_isym)[0]

# count children of people that got asymptomatic now
iasy_count = idx_became_iasy.shape[0]
iasy_children = sim.children_count_iasy[r, idx_became_iasy].sum()

# count children of people that got presymptomatic now
ipre_count = idx_became_ipre.shape[0]
ipre_children = sim.children_count_ipre[r, idx_became_ipre].sum()

# count children of people that got symptomatic now
isym_count = idx_became_isym.shape[0]
isym_children = sim.children_count_isym[r, idx_became_isym].sum()

total = (iasy_children + ipre_children + isym_children)
if total > 0:
tmp_Rt.append((iasy_children + ipre_children + isym_children) / (iasy_count + ipre_count + isym_count))
tmp_prop_iasy.append(iasy_children / total)
tmp_prop_ipre.append(ipre_children / total)
tmp_prop_isym.append(isym_children / total)
else:
tmp_Rt.append(0.0)
tmp_prop_iasy.append(0.33)
tmp_prop_ipre.append(0.33)
tmp_prop_isym.append(0.34)

Rt_mu.append(np.mean(tmp_Rt))
prop_iasy_mu.append(np.mean(tmp_prop_iasy))
prop_ipre_mu.append(np.mean(tmp_prop_ipre))
prop_isym_mu.append(np.mean(tmp_prop_isym))

Rt_std.append(np.std(tmp_Rt))
prop_iasy_std.append(np.std(tmp_prop_iasy))
prop_ipre_std.append(np.std(tmp_prop_ipre))
prop_isym_std.append(np.std(tmp_prop_isym))

Rt_mu, Rt_std = np.array(Rt_mu), np.array(Rt_std)
prop_iasy_mu, prop_ipre_mu, prop_isym_mu = \
np.array(prop_iasy_mu), np.array(prop_ipre_mu), np.array(prop_isym_mu)
prop_iasy_std, prop_ipre_std, prop_isym_std = \
np.array(prop_iasy_std), np.array(prop_ipre_std), np.array(prop_isym_std)
return np.array(ts) / 24.0, (Rt_mu, Rt_std), (prop_iasy_mu, prop_ipre_mu, prop_isym_mu), (prop_iasy_std, prop_ipre_std, prop_isym_std)

def plot_Rt_types(self, sim, title='Example', filename='reproductive_rate_inf_0', errorevery=20, estimation_window=7 * 24.0,
figsize=(10, 10), lockdown_at=None, lockdown_end=None):

'''
Plots Rt split up by infection types (iasy, ipre, isym) over time,
averaged over random restarts, using error bars for std-dev
'''
ts, (Rt_mu, Rt_std), \
(prop_iasy_mu, prop_ipre_mu, prop_isym_mu), \
(prop_iasy_std, prop_ipre_std, prop_isym_std) = \
self.__compute_Rt_over_time(sim, estimation_window)

fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)

line_xaxis = np.zeros(ts.shape)
line_iasy = prop_iasy_mu * Rt_mu
line_ipre = prop_iasy_mu * Rt_mu + prop_ipre_mu * Rt_mu
line_isym = prop_iasy_mu * Rt_mu + prop_ipre_mu * Rt_mu + prop_isym_mu * Rt_mu

error_iasy = prop_iasy_std * Rt_mu
error_ipre = prop_iasy_std * Rt_mu + prop_ipre_std * Rt_mu
error_isym = prop_iasy_std * Rt_mu + prop_ipre_std * Rt_mu + prop_isym_std * Rt_mu

# lines
ax.errorbar(ts, line_iasy, c='black', linestyle='-', yerr=error_iasy, elinewidth=0.8, errorevery=errorevery)
ax.errorbar(ts, line_ipre, c='black', linestyle='-', yerr=error_ipre, elinewidth=0.8, errorevery=errorevery)
ax.errorbar(ts, line_isym, c='black', linestyle='-', yerr=error_isym, elinewidth=0.8, errorevery=errorevery)

# filling
ax.fill_between(ts, line_xaxis, line_iasy, alpha=self.filling_alpha,
edgecolor='black', facecolor=self.color_iasy, linewidth=0,
label=r'$R_t$ due to asymptomatic $I^a(t)$', zorder=0)
ax.fill_between(ts, line_iasy, line_ipre, alpha=self.filling_alpha,
edgecolor='black', facecolor=self.color_ipre, linewidth=0,
label=r'$R_t$ due to pre-symptomatic $I^p(t)$', zorder=0)
ax.fill_between(ts, line_ipre, line_isym, alpha=self.filling_alpha,
edgecolor='black', facecolor=self.color_isym, linewidth=0,
label=r'$R_t$ due to symptomatic $I^s(t)$', zorder=0)

# axis
maxx = np.max(ts)
ax.set_xlim((0, maxx))
ymax = 1.5 * np.max(Rt_mu)
ax.set_ylim((0, ymax))

ax.set_xlabel(r'$t$ [days]')
ax.set_ylabel(r'$R_t$')

if lockdown_at is not None:
ax.plot(lockdown_at * np.ones(10), np.linspace(0, ymax, num=10),
linewidth=1, linestyle='--', color='black')

if lockdown_end is not None:
ax.plot(lockdown_end * np.ones(10), np.linspace(0, ymax, num=10),
linewidth=1, linestyle='dotted', color='black', label='End of restrictive measures')
ax.set_xlim((0, max(maxx, lockdown_end + 2)))


# Hide the right and top spines
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)

# Only show ticks on the left and bottom spines
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')

# legend
fig.legend(loc='center right', borderaxespad=0.1)
# Adjust the scaling factor to fit your legend text completely outside the plot
plt.subplots_adjust(right=0.70)
ax.set_title(title, pad=20)
plt.draw()
plt.savefig('plots/' + filename + '.png', format='png', facecolor=None,
dpi=DPI, bbox_inches='tight')
if NO_PLOT:
plt.close()
return

def compare_total_infections(self, sims, titles, figtitle='Title',
filename='compare_inf_0', figsize=(10, 10), errorevery=20, acc=1000, ymax=None,
lockdown_label='Lockdown', lockdown_at=None, lockdown_label_y=None,
Expand Down Expand Up @@ -833,7 +666,7 @@ def compare_total_infections(self, sims, titles, figtitle='Title',
# Get the bounding box of the original legend
bb = leg.get_bbox_to_anchor().inverse_transformed(ax.transAxes)

# Change to location of the legend.
# Change to location of the legend.
bb.y0 += legendYoffset
bb.y1 += legendYoffset
leg.set_bbox_to_anchor(bb, transform = ax.transAxes)
Expand Down Expand Up @@ -1144,3 +977,119 @@ def plot_positives_vs_target(self, sim, targets, test_lag, title='Example',
if NO_PLOT:
plt.close()
return

def plot_daily_rts(self, sims, filename, start_date, titles, sigma=None,
r_t_range=R_T_RANGE, window=3, figsize=(6, 5),
subplot_adjust=None, lockdown_label='Lockdown',
lockdown_at=None, lockdown_label_y=None, ymax=None,
colors=['grey'], fill_between=True, draw_dots=True,
errorevery=1, show_legend=False, xtick_interval=1):

# If a single summary is provided
if not isinstance(sims, list):
sims = [sims]
sigma = [sigma]

results = list()
for i, sim in enumerate(sims):
res = compute_daily_rts(sim, start_date, sigma[i], r_t_range, window)
results.append(res)

# Colors
ABOVE = [1,0,0]
MIDDLE = [1,1,1]
BELOW = [0,0,0]
cmap = ListedColormap(np.r_[
np.linspace(BELOW,MIDDLE,25),
np.linspace(MIDDLE,ABOVE,25)
])
color_mapped = lambda y: np.clip(y, .5, 1.5)-.5

ymax_computed = 0.0 # Keep track of max y to set limit

fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)

# Trick to set the xticks like `compare_total_infections`
ts = np.linspace(0.0, sim.max_time, num=500, endpoint=True) / 24.0
ts = days_to_datetime(ts)
plt.plot(ts, np.zeros(len(ts)), lw=0.0)

sidx = 0

for i, result in enumerate(results):
result = result.iloc[sidx:,:]
index = result['ML'].index
values = result['ML'].values

# Plot dots and line
ax.plot(index, values, c=colors[i], zorder=1, alpha=1.0)

if draw_dots:
ax.scatter(index, values, s=40, lw=0.0,
c=cmap(color_mapped(values)),
edgecolors='k', zorder=2)

# Aesthetically, extrapolate credible interval by 1 day either side
lowfn = interp1d(date2num(index), result['Low_90'].values,
bounds_error=False, fill_value='extrapolate')
highfn = interp1d(date2num(index), result['High_90'].values,
bounds_error=False, fill_value='extrapolate')
extended = pd.date_range(start=index[0], end=index[-1])
error_low = lowfn(date2num(extended))
error_high = highfn(date2num(extended))

if fill_between:
ax.fill_between(extended, error_low, error_high,
color=colors[i], alpha=0.1, linewidth=0.0)
else:
# Ignore first value which is just prior, not informed by data
ax.errorbar(x=index[1:], y=values[1:], label=titles[i],
yerr=np.vstack((result['Low_90'], result['High_90']))[:,1:],
color=colors[i], linewidth=1.0,
elinewidth=0.8, capsize=3.0,
errorevery=errorevery)

ymax_computed = max(ymax_computed, np.max(error_high))

# Plot horizontal line at R_t = 1
ax.axhline(1.0, c='k', lw=1, alpha=.25);

# limits
ymax = ymax or 1.2 * ymax_computed
ax.set_ylim((0, ymax_computed))

if show_legend:
ax.legend(loc='upper left', borderaxespad=0.5)

# extra
if lockdown_at is not None:
lockdown_widget(lockdown_at, start_date,
lockdown_label_y, ymax,
lockdown_label, ax, zorder=-200)

# Hide the right and top spines
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)

# Only show ticks on the left and bottom spines
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')

# Set label
ax.set_ylabel(r'$R_t$')

#set ticks every week
ax.xaxis.set_major_locator(mdates.WeekdayLocator(interval=xtick_interval))
#set major ticks format
ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %d'))
fig.autofmt_xdate(bottom=0.2, rotation=0, ha='center')

subplot_adjust = subplot_adjust or {'bottom':0.14, 'top': 0.98, 'left': 0.12, 'right': 0.96}
plt.subplots_adjust(**subplot_adjust)

plt.savefig('plots/' + filename + '.png', format='png', facecolor=None,
dpi=DPI)#, bbox_inches='tight')

if NO_PLOT:
plt.close()
Loading

0 comments on commit 88328a4

Please sign in to comment.