# Make final plots

As of 8 pm on Wed 22 Apr 2020

In [None]:
# Packages

import numpy as np

from scipy import integrate, stats
from scipy.special import expit, binom

import pandas as pd
import xlrd

import copy
import warnings

from datetime import datetime
import random
import string
import os
import shutil
import sys
import cloudpickle
import dask
import distributed
from dask.distributed import Client
import itertools

import pymc3

import inspect
from collections import OrderedDict

import dask.dataframe as dd


import functools

# Plotly 
import cufflinks as cf
import plotly
from plotly.offline import iplot as plt
from plotly import graph_objs as plt_type
from plotly import graph_objs as go
plotly.offline.init_notebook_mode(connected=True)

cf.go_offline()


# Helper functions

Distributed dask workers seem very fiddly at importing non-standard modules, so it's better to just copy-paste the functions into this notebook sadly

In [None]:
# To get relatve age-related risks, we have to first re-group into our basic age groups, 
# then devide by total population (here's a non-well-defined subset of UK, so absolute values wont be used, only relative)
import numpy as np
import inspect
from collections import OrderedDict

def regroup_by_age(inp, fromAgeSplits, toAgeSplits, maxAge=100., maxAgeWeight = 5.):
    fromAgeSplits = np.concatenate([np.array([0]), fromAgeSplits, np.array([maxAge])]) # Add a zero at beginning for calculations
    toAgeSplits = np.concatenate([np.array([0]), toAgeSplits, np.array([maxAge])]) # Add inf at end for calculations
    def getOverlap(a, b):
        return max(0, min(a[1], b[1]) - max(a[0], b[0]))
    out = np.zeros((len(toAgeSplits)-1,)+inp.shape[1:])
    for from_ind in range(1, len(fromAgeSplits)):
        # Redistribute to the new bins by calculating how many years in from_ind-1:from_ind falls into each output bin
        cur_out_distribution = (
        [getOverlap(toAgeSplits[cur_to_ind-1:cur_to_ind+1],fromAgeSplits[from_ind-1:from_ind+1])  for cur_to_ind in range(1, len(toAgeSplits))]
        )
        
        if cur_out_distribution[-1] > 0:
            cur_out_distribution[-1] = maxAgeWeight # Define the relative number of ages if we have to distribute between second to last and last age groups

        cur_out_distribution = cur_out_distribution/np.sum(cur_out_distribution)
        
        for to_ind in range(len(out)):
            out[to_ind] += cur_out_distribution[to_ind] * inp[from_ind-1]
            
    return out


# PARAMETER DICTIONARIES AND TABLES
# -----------------------------------------------------------------------------------------


def build_paramDict(cur_func):
    """
    This function iterates through all inputs of a function, 
    and saves the default argument names and values into a dictionary.
    
    If any of the default arguments are functions themselves, then recursively (depth-first) adds an extra field to
    the dictionary, named <funcName + "_params">, that contains its inputs and arguments.
    
    The output of this function can then be passed as a "kwargs" object to the highest level function, 
    which will then pass the parameter values to the lower dictionary levels appropriately
    """
    
    paramDict = OrderedDict()
    
    allArgs = inspect.getfullargspec(cur_func)
    
    # Check if there are any default parameters, if no, just return empty dict
    if allArgs.defaults is None:
        return paramDict
    
    
    for argname, argval in zip(allArgs.args[-len(allArgs.defaults):], allArgs.defaults):
        # Save the default argument
        paramDict[argname] = argval
        # If the default argument is a function, inspect it for further 
        
        if callable(argval):
            # print(argname)
            paramDict[argname+"_params"] = build_paramDict(argval)

    return paramDict




# Do a mapping between dictionary and parameter table row (for convenient use)

# Flatten the dictionary into a table with a single row (but many headers):
def paramDict_toTable(paramDict):
    paramTable = pd.DataFrame()
    def paramDictRecurseIter(cur_table, cur_dict, preString):
        # Iterate through the dictionary to find all keys not ending in "_params", 
        # and add them to the table with name <preString + key>
        # 
        # If the key doesn end in "_params", then append the key to preString, in call this function on the value (that is a dict)
        for key, value in cur_dict.items():
            if key.endswith("_params"):
                paramDictRecurseIter(cur_table, value, preString+key+"_")
            else:
                paramTable[preString+key] = [value]
                
        # For the rare case where we want to keep an empty dictionary, the above for cycle doesn't keep it
        if len(cur_dict)==0:
            paramTable[preString] = [OrderedDict()]
                
        return cur_table
    
    return paramDictRecurseIter(paramTable, paramDict, preString="")

# Example dict -> table
# paramTable_default = paramDict_toTable(paramDict_default)
    

def paramTable_toDict(paramTable, defaultDict=None):
    # enable to pass a default dict (if paramTable is incomplete), in which we'll just add / overwrite the values
    paramDict = defaultDict if defaultDict is not None else OrderedDict() 
    def placeArgInDictRecurse(argName, argVal, cur_dict):
        # Find all "_params_" in the argName, and for each step more and more down in the dictionary
        strloc = argName.find("_params_")
        if strloc == -1:
            # We're at the correct level of dictionary
            cur_dict[argName] = argVal
            return cur_dict
        else:
            # step to the next level of dictionary
            nextKey = argName[:strloc+len("_params_")-1]
            nextArgName = argName[strloc+len("_params_"):]
            if not nextKey in cur_dict:
                cur_dict[nextKey] = OrderedDict()
            placeArgInDictRecurse(nextArgName, argVal, cur_dict[nextKey])
            
        return cur_dict
            
    for key in paramTable.columns:
        paramDict = placeArgInDictRecurse(key, paramTable.at[0,key], paramDict)
        
    return paramDict

# Example table -> dict 
# paramDict_new = paramTable_toDict(paramTable_default)

## Load datasets

### NHS England deaths dataset

In [None]:
# NHS daily deaths report (about 24 hours behind)
# TODO manually update link and column numbers (maybe not consistent across days, cannot yet automate)
df_UK_NHS_daily_COVID_deaths = pd.read_excel(
    "https://www.england.nhs.uk/statistics/wp-content/uploads/sites/2/2020/04/COVID-19-total-announced-deaths-22-April-2020.xlsx",
    sheet_name = "COVID19 total deaths by age",
    index_col=0,
    usecols = "B,E:AX",
    skip_rows = range(17),
    nrows = 22
).iloc[14:].transpose().set_index("Age group").rename_axis(index = "Date", columns = "AgeGroup")

df_UK_NHS_daily_COVID_deaths.index = pd.to_datetime(df_UK_NHS_daily_COVID_deaths.index, format="%Y-%m-%d")

df_UK_NHS_daily_COVID_deaths = df_UK_NHS_daily_COVID_deaths.drop(df_UK_NHS_daily_COVID_deaths.columns[:2], axis=1)

df_UK_NHS_daily_COVID_deaths

# Ignore very recent unreliable data points
df_UK_NHS_daily_COVID_deaths = df_UK_NHS_daily_COVID_deaths.loc[
            df_UK_NHS_daily_COVID_deaths.index <= '2020-04-14']

df_UK_NHS_daily_COVID_deaths

### NHS England CHESS - COVID hospitalisations - dataset

In [None]:
# Load the aggregate data (ask @SebastianVollmer for aggregation details!)
df_CHESS = pd.read_csv("/mnt/efs/data/CHESS_Aggregate20200417.csv").drop(0)

# Clean the dates and make them index
# The "1899-12-30" is simply total, ignore that.
# The 2020-09-03, 2020-10-03, 2020-11-03, 2020-12-03 are parsed wrong and are march 09-12 supposedly.
# The data collection is only officially started across england on 09 March, the February dates seem empty, delete.
# Rest are ok

df_CHESS.index = pd.to_datetime(df_CHESS["DateOfAdmission"].values,format="%d-%m-%Y")

# Ignore too old and too recent data points
df_CHESS = df_CHESS.sort_index().drop("DateOfAdmission", axis=1).query('20200309 <= index <= 20200414')

df_CHESS


In [None]:
df_CHESS.columns

In [None]:
# Get the hospitalised people who tested positive for COVID, using cumsum (TODO: for now assuming they're all still in hospital)
df_CHESS_newCOVID = df_CHESS.loc[:,df_CHESS.columns.str.startswith("AllAdmittedPatientsWithNewLabConfirmedCOVID19")]

df_CHESS_newCOVID.sum(1).cumsum(0)

## Define dummy likelihood that helps plotting

In [None]:
# We expect 
def joinDataAndSimCurves(
    df_curves, # a pandas dataframe with dates as index, and each column is a curve
    simCurves, # curves x time np array
    simStartDate, # curves start dates
    simCurvesNames = None,
    fulljoin = False # if true, then one keeps all dates in the simulation, instead of just the ones in the date 
    ):
    
    out_df = copy.deepcopy(df_curves)
    
    simCurveIndex = pd.date_range(start=simStartDate, freq='D', periods=simCurves.shape[1])
    
    if simCurvesNames is None:
        simCurvesNames = ["simCurve_{}".format(i) for i in range(simCurves.shape[0])]
    
    join_type = "outer" if fulljoin else "left"
    
    for i, curve in enumerate(simCurves):
        out_df = out_df.join(
            pd.DataFrame(
                index = simCurveIndex,
                data = simCurves[i],
                columns=[simCurvesNames[i]]
            ),
            how = join_type
        )
    
    return out_df

In [None]:
# Dummy lik function to plot curves without data, get correct projection via functools.partial(..., projFunc = ...)
def likFunc_dummyForPlotting(
    sim, 
    simStartDate, 
    df = None,
    sumAges = True,
    outputDataframe = False, # If true, outputs the data-curves and simulated curves instead of likelihood, for plotting
    fulljoin = False, # if true, then one keeps all dates in the simulation, instead of just the ones in the date
    projFunc = lambda sim: np.diff(np.sum(sim[:,-1,:,:,:],axis=(1,2)),-1) # pass a lambda function to create a simCurve, this example does daily new deaths
    ):
    
    
    simCurves = projFunc(sim)
    if sumAges:
        simCurves = simCurves.sum(0, keepdims=True)
    
    if df is None:
        df = pd.DataFrame(index= pd.date_range(start=simStartDate, freq='D', periods=simCurves.shape[1])) # an empty pandas dataframe with dates as index
    else:
        if sumAges:
            df = pd.DataFrame(df.sum(1))
            
    # Join the two dataframes to align in time
    df_full = joinDataAndSimCurves(
        df_curves = df,
        simCurves = simCurves, # curves x time np array
        simStartDate = simStartDate, # curves start dates
        fulljoin = fulljoin
    )
    
    # If true, outputs the data-curves and simulated curves instead of likelihood, for plotting
    if outputDataframe:
        return df_full
    
    return None
    
    

## Helper funcs for organising data and simulations for plotting

In [None]:
def tryLoad(
    fn,
    shape = (9,8,4,3,100,),
    fill = 0.
    ):
    try:
        return np.load(fn)
    except:
        return fill*np.ones(shape)

In [None]:
def joinSimulations(
    loadDir,
    paramTable,
    simRowIndices, # from paramTables filled with bestStartTime column
    likFunc,
    simSuffix,
    df, # data
    paramTable_startTimeColname = "bestStartTime",
    sumAges = True
    ):
    
    # Get the individual data frames that contain the appropriate projection(s) of the simulation to match the data
    all_df_full = [
        likFunc(
            sim = tryLoad(loadDir + paramTable.at[ind,"out_fname"][:-4] + simSuffix + ".npy"), 
            simStartDate = paramTable.at[ind, paramTable_startTimeColname],
            df = df,
            sumAges=sumAges,
            outputDataframe=True,
            fulljoin=True
        )
        
        for ind in simRowIndices
    ]
    
    
    df_data_cols = all_df_full[0].loc[:,(all_df_full[0].columns.str.startswith("simCurve_")==True)==False]
    
    origSimColNames = all_df_full[0].loc[:,(all_df_full[0].columns.str.startswith("simCurve_")==True)].columns
   
    
    df_out = df_data_cols.join(
        # Get only the simCurve columns from the dataframes and rename them to include the simulation row index
        other = [
            cur_df.loc[:,cur_df.columns.str.startswith("simCurve_")==True].rename(
                columns = {s: "row_" + str(rowInd) + "_" + s for s in origSimColNames}
            )
            
            for cur_df, rowInd in zip(all_df_full, simRowIndices)
        ],
        
        how = "outer"
    )
    
    return df_out
    
    
    

In [None]:
with open('paramTypes.cpkl', 'rb') as fh:
    paramTypes = cloudpickle.load(fh)

In [None]:
# Find a given row in a paramTable
def allEqual(row1, row2, policyColumns = paramTypes["policy"]):
    for c in row1.index:
        if c not in policyColumns:
            
            continue
        
        try:
            if row1.at[c]!=row2.at[c]:
                #print(c)
                return False
        except:
            try:
                if not np.array_equal(row1.at[c],row2.at[c]):
                    #print(c)
                    return False

            except:
                print(c, row1[c], row2[c])
                #print("haah")
    return True

def findRow(row_in, df, policyColumns = paramTypes["policy"]):
    for ind, row in df.iterrows():
        #print(ind)
        if allEqual(row_in, row, policyColumns= policyColumns):
            return ind
    return -1

In [None]:
def defaultLayout(scale=1.):
    return dict(
        font=dict(family="Times New Roman", size=int(14*scale)),
        titlefont=dict(family="Times New Roman", size=int(18*scale))
    )

In [None]:
def plot(Y, X=None, now = True, **kwargs):
    """
    plots a matrix Y (n x num_lines)
    with optionally x spacing
    """
    plots = list()
    
    if len(Y.shape)==1:
        Y = Y[:,np.newaxis]
    
    if X is None:
        X = np.arange(Y.shape[0])
    
    for num_line in range(Y.shape[1]):
        plots.append(
            plt_type.Scatter(
                x = X[:,num_line] if len(X.shape)==2 else X,
                y = Y[:,num_line],
                **kwargs
            )
        )
        
    if now:
        plt(plots)
    else:
        return plots

## Load simulations and evaluate

In [None]:
loadDir = "/mnt/efs/results/run_20200422T001618/" # <- NEWGP+SMALL policy ensemeble


In [None]:
# Get HDF keys
with pd.HDFStore(loadDir + 'paramTable_part0', 'r') as tmp:
    print(tmp.keys())

In [None]:
paramTable_stopSocial = pd.read_hdf(loadDir + 'paramTable_part0', key="Stop Social Distancing")
paramTable_keepSocial = pd.read_hdf(loadDir + 'paramTable_part0', key="Keep Social Distancing")
paramTable_caseIsolation = pd.read_hdf(loadDir + 'paramTable_part0', key="Case Isolation Varying Test Numbers")

# Set ensemble size and compute number of policies per table
nBaseEnsemble = 91
nPolicies_stopSocial = int(len(paramTable_stopSocial)/nBaseEnsemble)
nPolicies_keepSocial = int(len(paramTable_keepSocial)/nBaseEnsemble)
nPolicies_caseIsolation = int(len(paramTable_caseIsolation)/nBaseEnsemble)


In [None]:
nPolicies_caseIsolation

# FINAL PLOTS FOR GITHUB

## Figure 1 - Daily new infections for 4 "policies"

* a - keep social
* b - stop social 30 June
* c - stop social 30 June + 1 mill "good tests"
* d - stop social 30 June + 5 mill "terrible tests"

* e - comparing economic costs

In [None]:
# Figure 1a
indices = np.arange(0,len(paramTable_keepSocial),nPolicies_keepSocial)

# Daily new infections predicted by model
tmp = joinSimulations(
    loadDir = loadDir,
    paramTable=paramTable_keepSocial,
    simRowIndices = paramTable_keepSocial.index[indices],
    paramTable_startTimeColname = "realStartDate",
    likFunc = functools.partial(
        likFunc_dummyForPlotting, 
        projFunc=(lambda sim: np.sum(sim[:,1,:,:,:],axis=(1,2)))),
    simSuffix="_newOnly",
    df = None,
    sumAges=True
)


a = plot(tmp.values, tmp.index, now=False)

# Style the curves
for ind in range(0, len(a)):
    a[ind]["opacity"]=0.05#1./np.sqrt(len(a)/4.)
    a[ind]["legendgroup"]="Simulations"
    a[ind]["showlegend"]=False
    a[ind]["line"]=dict(color="blue")    

    
layoutArgs = defaultLayout(scale=1.4)    

layout = go.Layout(
    **layoutArgs,
    
    # Add basic info
    title="Policy A - Keeping lockdown indefinitely",
    xaxis_title="Date",
    yaxis_title="New infections per day",
    yaxis=dict(range=[0, 2e6], autorange=False),
    
    # Add arrow(s)
    annotations=[
        dict(
            text='Start of lockdown',
            x=pd.to_datetime("2020-03-23"),
            y=0.7e6,
            xref='x1',
            yref='y1',
            font=dict(color='black'),
            showarrow=True,
            arrowhead=2,
            arrowcolor='black',
            ax=80,
            ay=-140,
            xshift=5,
        )
    ]
)

    
fig = go.Figure(
        data=a, 
        layout=layout)
    
plt(fig)

# Save as interactive html
plotly.offline.plot(fig, filename="images/figure1a.html")

# Save as png
plotly.offline.plot(fig, 
                    image='png', 
                    image_filename='figure1a', 
                    auto_open=False, 
                    image_width=1100, 
                    image_height=500, 
                    validate=False)


In [None]:
# Figure 1b
searchRow = copy.deepcopy(paramTable_stopSocial.loc[0])
searchRow["tStopSocialDistancing"] =  pd.to_datetime("2020-06-30", format="%Y-%m-%d")

indices = np.arange(findRow(searchRow, paramTable_stopSocial),len(paramTable_stopSocial),nPolicies_stopSocial)

# Daily new infections predicted by model
tmp = joinSimulations(
    loadDir = loadDir,
    paramTable=paramTable_stopSocial,
    simRowIndices = paramTable_stopSocial.index[indices],
    paramTable_startTimeColname = "realStartDate",
    likFunc = functools.partial(
        likFunc_dummyForPlotting, 
        projFunc=(lambda sim: np.sum(sim[:,1,:,:,:],axis=(1,2)))),
    simSuffix="_newOnly",
    df = None,
    sumAges=True
)


a = plot(tmp.values, tmp.index, now=False)

# Style the curves
for ind in range(0, len(a)):
    a[ind]["opacity"]=0.05#1./np.sqrt(len(a)/4.)
    a[ind]["legendgroup"]="Simulations"
    a[ind]["showlegend"]=False
    a[ind]["line"]=dict(color="red")    


layoutArgs = defaultLayout(scale=1.4)    

layout = go.Layout(
    **layoutArgs,
    
    # Add basic info
    title="Policy B - Stop lockdown on 30 June",
    xaxis_title="Date",
    yaxis_title="New infections per day",
    yaxis=dict(range=[0, 2e6], autorange=False),
    
    # Add arrow(s)
    annotations=[
        dict(
            text='Start of lockdown',
            x=pd.to_datetime("2020-03-23"),
            y=0.7e6,
            xref='x1',
            yref='y1',
            font=dict(color='black'),
            showarrow=True,
            arrowhead=2,
            arrowcolor='black',
            ax=80,
            ay=-140,
            xshift=5,
        ),
        
        dict(
            text='Stop lockdown',
            x=pd.to_datetime("2020-06-30"),
            y=0.2e6,
            xref='x1',
            yref='y1',
            font=dict(color='black'),
            showarrow=True,
            arrowhead=2,
            arrowcolor='black',
            ax=-80,
            ay=-120,
            xshift=-5,
        ),
        
        dict(
            text='Unmitigated second wave',
            x=pd.to_datetime("2020-08-23"),
            y=1.6e6,
            xref='x1',
            yref='y1',
            font=dict(color='black'),
            showarrow=True,
            arrowhead=2,
            arrowcolor='black',
            ax=140,
            ay=40,
            xshift=5,
        ),
        
        dict(
            text='Herd immunity reached',
            x=pd.to_datetime("2020-09-30"),
            y=0.2e6,
            xref='x1',
            yref='y1',
            font=dict(color='black'),
            showarrow=True,
            arrowhead=2,
            arrowcolor='black',
            ax=160,
            ay=-40,
            xshift=5,
        )
    ]
)

    
fig = go.Figure(
        data=a, 
        layout=layout)
    
plt(fig)

    
# Save as interactive html
plotly.offline.plot(fig, filename="images/figure1b.html")

# Save as png
plotly.offline.plot(fig, 
                    image='png', 
                    image_filename='figure1b', 
                    auto_open=False, 
                    image_width=1100, 
                    image_height=500, 
                    validate=False)


In [None]:
# Figure 1c

searchRow = copy.deepcopy(paramTable_caseIsolation.loc[0])
searchRow["tStopSocialDistancing"] =  pd.to_datetime("2020-06-30", format="%Y-%m-%d")
#searchRow["trFunc_testing_params_inpFunc_testSpecifications_params_antigen_FNR_I1_to_R2"] = [0.99, 0.8, 0.4, 0.5, 0.6, 0.99]
searchRow["trFunc_testing_params_inpFunc_testSpecifications_params_antigen_FPR"] = 0.05
searchRow["trFunc_testing_params_trFunc_testCapacity_params_testCapacity_antibody_country_total"] = 1e6

indices = np.arange(findRow(searchRow, paramTable_caseIsolation.loc[:500]),len(paramTable_caseIsolation),nPolicies_caseIsolation)

# Daily new infections predicted by model
tmp = joinSimulations(
    loadDir = loadDir,
    paramTable=paramTable_caseIsolation,
    simRowIndices = paramTable_caseIsolation.index[indices],
    paramTable_startTimeColname = "realStartDate",
    likFunc = functools.partial(
        likFunc_dummyForPlotting,
        projFunc=(lambda sim: np.sum(sim[:,1,:,:,:],axis=(1,2)))),    
    simSuffix="_newOnly",
    df = None,
    sumAges=True
)

a = plot(tmp.values, tmp.index, now=False)

# Style the curves
for ind in range(0, len(a)):
    a[ind]["opacity"]=0.05#1./np.sqrt(len(a)/4.)
    a[ind]["legendgroup"]="Simulations"
    a[ind]["showlegend"]=False
    a[ind]["line"]=dict(color="purple")    



layoutArgs = defaultLayout(scale=1.4)    

layout = go.Layout(
    **layoutArgs,
    
    # Add basic info
    title="Policy C - Stop lockdown on 30 June, flatten 2nd wave via 1 million '90% sensitive' tests each day",
    xaxis_title="Date",
    yaxis_title="New infections per day",
    yaxis=dict(range=[0, 2e6], autorange=False),
    
    # Add arrow(s)
    annotations=[
        dict(
            text='Start of lockdown',
            x=pd.to_datetime("2020-03-23"),
            y=0.7e6,
            xref='x1',
            yref='y1',
            font=dict(color='black'),
            showarrow=True,
            arrowhead=2,
            arrowcolor='black',
            ax=80,
            ay=-140,
            xshift=5,
        ),
        
        dict(
            text='Stop lockdown',
            x=pd.to_datetime("2020-06-30"),
            y=0.2e6,
            xref='x1',
            yref='y1',
            font=dict(color='black'),
            showarrow=True,
            arrowhead=2,
            arrowcolor='black',
            ax=-80,
            ay=-120,
            xshift=-5,
        ),
        
        dict(
            text='Flattened second wave',
            x=pd.to_datetime("2020-09-08"),
            y=0.65e6,
            xref='x1',
            yref='y1',
            font=dict(color='black'),
            showarrow=True,
            arrowhead=2,
            arrowcolor='black',
            ax=40,
            ay=-100,
            xshift=5,
        ),
        
        dict(
            text='Herd immunity still reached!',
            x=pd.to_datetime("2020-11-11"),
            y=0.22e6,
            xref='x1',
            yref='y1',
            font=dict(color='black'),
            showarrow=True,
            arrowhead=2,
            arrowcolor='black',
            ax=20,
            ay=-80,
            xshift=5,
        )
    ]
)

    
fig = go.Figure(
        data=a, 
        layout=layout)
    
plt(fig)

# Save as interactive html
plotly.offline.plot(fig, filename="images/figure1c.html")

# Save as png
plotly.offline.plot(fig, 
                    image='png', 
                    image_filename='figure1c', 
                    auto_open=False, 
                    image_width=1100, 
                    image_height=500, 
                    validate=False)

In [None]:
# Figure 1d - 

searchRow = copy.deepcopy(paramTable_caseIsolation.loc[0])
searchRow["tStopSocialDistancing"] =  pd.to_datetime("2020-06-30", format="%Y-%m-%d")
searchRow["trFunc_testing_params_inpFunc_testSpecifications_params_antigen_FNR_I1_to_R2"] = [0.99, 0.8, 0.4, 0.5, 0.6, 0.99]
searchRow["trFunc_testing_params_inpFunc_testSpecifications_params_antigen_FPR"] = 0.05
searchRow["trFunc_testing_params_trFunc_testCapacity_params_testCapacity_antibody_country_total"] = 5e6

indices = np.arange(findRow(searchRow, paramTable_caseIsolation.loc[:500]),len(paramTable_caseIsolation),nPolicies_caseIsolation)

# Daily new infections predicted by model
tmp = joinSimulations(
    loadDir = loadDir,
    paramTable=paramTable_caseIsolation,
    simRowIndices = paramTable_caseIsolation.index[indices],
    paramTable_startTimeColname = "realStartDate",
    likFunc = functools.partial(
        likFunc_dummyForPlotting,
        projFunc=(lambda sim: np.sum(sim[:,1,:,:,:],axis=(1,2)))),    
    simSuffix="_newOnly",
    df = None,
    sumAges=True
)

a = plot(tmp.values, tmp.index, now=False)

# Style the curves
for ind in range(0, len(a)):
    a[ind]["opacity"]=0.05#1./np.sqrt(len(a)/4.)
    a[ind]["legendgroup"]="Simulations"
    a[ind]["showlegend"]=False
    a[ind]["line"]=dict(color="green")    


layoutArgs = defaultLayout(scale=1.4)    

layout = go.Layout(
    **layoutArgs,
    
    # Add basic info
    title="Policy D - Stop lockdown on 30 June, suppress 2nd wave via 5 million '60% sensitive' tests each day",
    xaxis_title="Date",
    yaxis_title="New infections per day",
    yaxis=dict(range=[0, 2e6], autorange=False),
    
    # Add arrow(s)
    annotations=[
        dict(
            text='Start of lockdown',
            x=pd.to_datetime("2020-03-23"),
            y=0.7e6,
            xref='x1',
            yref='y1',
            font=dict(color='black'),
            showarrow=True,
            arrowhead=2,
            arrowcolor='black',
            ax=80,
            ay=-140,
            xshift=5,
        ),
        
        dict(
            text='Stop lockdown',
            x=pd.to_datetime("2020-06-30"),
            y=0.2e6,
            xref='x1',
            yref='y1',
            font=dict(color='black'),
            showarrow=True,
            arrowhead=2,
            arrowcolor='black',
            ax=-80,
            ay=-120,
            xshift=-5,
        ),
        
        dict(
            text='Suppressed second wave',
            x=pd.to_datetime("2020-08-30"),
            y=0.15e6,
            xref='x1',
            yref='y1',
            font=dict(color='black'),
            showarrow=True,
            arrowhead=2,
            arrowcolor='black',
            ax=20,
            ay=-160,
            xshift=5,
        ),
        
        dict(
            text='No herd immunity needed,<br>only testing based case isolation!',
            x=pd.to_datetime("2020-11-11"),
            y=0.04e6,
            xref='x1',
            yref='y1',
            font=dict(color='black'),
            showarrow=True,
            arrowhead=2,
            arrowcolor='black',
            ax=40,
            ay=-80,
            xshift=5,
        )
    ]
)

    
fig = go.Figure(
        data=a, 
        layout=layout)
    
plt(fig)

# Save as interactive html
plotly.offline.plot(fig, filename="images/figure1d.html")

# Save as png
plotly.offline.plot(fig, 
                    image='png', 
                    image_filename='figure1d', 
                    auto_open=False, 
                    image_width=1100, 
                    image_height=500, 
                    validate=False)


## Comparing people in home quarantine

In [None]:
# GOOD TESTS
# ---------------

searchRow = copy.deepcopy(paramTable_caseIsolation.loc[0])
searchRow["tStopSocialDistancing"] =  pd.to_datetime("2020-06-30", format="%Y-%m-%d")
#searchRow["trFunc_testing_params_inpFunc_testSpecifications_params_antigen_FNR_I1_to_R2"] = [0.99, 0.8, 0.4, 0.5, 0.6, 0.99]
searchRow["trFunc_testing_params_inpFunc_testSpecifications_params_antigen_FPR"] = 0.05
searchRow["trFunc_testing_params_trFunc_testCapacity_params_testCapacity_antibody_country_total"] = 1e6

indices = np.arange(findRow(searchRow, paramTable_caseIsolation.loc[:500]),len(paramTable_caseIsolation),nPolicies_caseIsolation)

# Daily new infections predicted by model
tmp = joinSimulations(
    loadDir = loadDir,
    paramTable=paramTable_caseIsolation,
    simRowIndices = paramTable_caseIsolation.index[indices],
    paramTable_startTimeColname = "realStartDate",
    likFunc = functools.partial(
        likFunc_dummyForPlotting,
        projFunc=(lambda sim: np.sum(sim[:,:,1,:,:],axis=(1,2)))),    
    simSuffix="",
    df = None,
    sumAges=True
)

a = plot(tmp.values, tmp.index, now=False)

# Style the non-data curves
for ind in range(0, len(a)):
    a[ind]["opacity"]=0.05#1./np.sqrt(len(a)/4.)
    a[ind]["legendgroup"]="1 million '90% sensitive' tests a day"
    a[ind]["showlegend"]=False
    a[ind]["line"]=dict(color="purple")  
    

a_goodtests = copy.deepcopy(a)
    

# BAD TESTS
# -------------
searchRow = copy.deepcopy(paramTable_caseIsolation.loc[0])
searchRow["tStopSocialDistancing"] =  pd.to_datetime("2020-06-30", format="%Y-%m-%d")
searchRow["trFunc_testing_params_inpFunc_testSpecifications_params_antigen_FNR_I1_to_R2"] = [0.99, 0.8, 0.4, 0.5, 0.6, 0.99]
searchRow["trFunc_testing_params_inpFunc_testSpecifications_params_antigen_FPR"] = 0.05
searchRow["trFunc_testing_params_trFunc_testCapacity_params_testCapacity_antibody_country_total"] = 5e6

indices = np.arange(findRow(searchRow, paramTable_caseIsolation.loc[:500]),len(paramTable_caseIsolation),nPolicies_caseIsolation)

# Plot new deaths per day
tmp = joinSimulations(
    loadDir = loadDir,
    paramTable=paramTable_caseIsolation,
    simRowIndices = paramTable_caseIsolation.index[indices],
    paramTable_startTimeColname = "realStartDate",
    likFunc = functools.partial(
        likFunc_dummyForPlotting,
        projFunc=(lambda sim: np.sum(sim[:,:,1,:,:],axis=(1,2)))),    
    simSuffix="",
    df = None,
    sumAges=True
)

a = plot(tmp.values, tmp.index, now=False)

# Style the non-data curves
for ind in range(0, len(a)):
    a[ind]["opacity"]=0.05#1./np.sqrt(len(a)/4.)
    a[ind]["legendgroup"]="5 million '60% sensitive' tests a day"
    a[ind]["showlegend"]=False
    a[ind]["line"]=dict(color="green") 
    

a_badtests = copy.deepcopy(a)


# Add one more "curve" showing total England population
a_englandpop = [go.Scatter(
    name="England population",
    x=np.array([pd.to_datetime("2020-01-30"), pd.to_datetime("2021-02-20")]),
    y=np.array([55.98e6, 55.98e6]),
    mode="lines",
    line=dict(
        color="black",
        dash='dash'
    ),
    showlegend=False
)]


layout = go.Layout(
    **layoutArgs,
    
    # Add basic info
    title="The economic cost for policies C and D",
    xaxis_title="Date",
    yaxis_title="People currently in home quarantine",
    #yaxis=dict(range=[0, 60e6], autorange=False),
    
    annotations=[
        dict(
            text='England population',
            x=pd.to_datetime("2020-04-30"),
            y=52e6,
            xref='x1',
            yref='y1',
            font=dict(color='black'),
            showarrow=False        
        ),
        
        dict(
            text="1 million '90% sensitive' tests a day",
            x=pd.to_datetime("2020-09-20"),
            y=4.1e6,
            xref='x1',
            yref='y1',
            font=dict(color='black'),
            showarrow=True,
            arrowhead=2,
            arrowcolor='purple',
            arrowwidth=3,
            ax=60,
            ay=-110,
            xshift=5,
        ),
        
        dict(
            text="5 million '60% sensitive' tests a day",
            x=pd.to_datetime("2020-06-30"),
            y=6.66e6,
            xref='x1',
            yref='y1',
            font=dict(color='black'),
            showarrow=True,
            arrowhead=2,
            arrowcolor='green',
            arrowwidth=3,
            ax=-120,
            ay=-40,
            xshift=-5,
        ),
    ]
)


fig = go.Figure(
    data=a_goodtests+a_badtests+a_englandpop,
    layout=layout
)
    

plt(fig)


# Save as interactive html
plotly.offline.plot(fig, filename="images/figure1e.html")

# Save as png
plotly.offline.plot(fig, 
                    image='png', 
                    image_filename='figure1e', 
                    auto_open=False, 
                    image_width=1100, 
                    image_height=500, 
                    validate=False)


# Figure 2 - Deaths in-hospital

In [None]:
# Daily in hospital deaths fit to the data
# Figure 1a
indices = np.arange(0,len(paramTable_keepSocial),nPolicies_keepSocial)

# Daily new infections predicted by model
df_deaths_a = joinSimulations(
    loadDir = loadDir,
    paramTable=paramTable_keepSocial,
    simRowIndices = paramTable_keepSocial.index[indices],
    paramTable_startTimeColname = "realStartDate",
    likFunc = functools.partial(
        likFunc_dummyForPlotting, 
        projFunc=(lambda sim: np.sum(sim[:,-1,:,:,:],axis=(1,2)))),
    simSuffix="_newOnly",
    df = df_UK_NHS_daily_COVID_deaths,
    sumAges=True
)

tmp = copy.deepcopy(df_deaths_a)

a = plot(tmp.values, tmp.index, now=False)

# Style the curves
for ind in range(1, len(a)):
    a[ind]["opacity"]=0.05#1./np.sqrt(len(a)/4.)
    a[ind]["legendgroup"]="Simulations"
    a[ind]["showlegend"]=False
    a[ind]["line"]=dict(color="blue")    

    
a[1]["name"] = "Simulations"
a[1]["showlegend"]=True
a[1]["opacity"]=0.15 # To make sure it shows clearly in legend
    
a[0]["showlegend"]=True
a[0]["line"]=dict(color="black", width=2)
a[0]["mode"]="lines+markers"
a[0]["legendgroup"]="Data"
a[0]["opacity"]=1.
a[0]["name"] = "Data" 
    
layoutArgs = defaultLayout(scale=1.4)    

layout = go.Layout(
    **layoutArgs,
    
    # Add basic info
    title="Policy A - Keeping lockdown indefinitely",
    xaxis_title="Date",
    yaxis_title="In hospital deaths per day",
    
)

    
fig = go.Figure(
        data=a[1:]+[a[0]], 
        layout=layout)
    
plt(fig)


In [None]:
# Figure 2b
searchRow = copy.deepcopy(paramTable_stopSocial.loc[0])
searchRow["tStopSocialDistancing"] =  pd.to_datetime("2020-06-30", format="%Y-%m-%d")

indices = np.arange(findRow(searchRow, paramTable_stopSocial),len(paramTable_stopSocial),nPolicies_stopSocial)

# Daily new infections predicted by model
df_deaths_b = joinSimulations(
    loadDir = loadDir,
    paramTable=paramTable_stopSocial,
    simRowIndices = paramTable_stopSocial.index[indices],
    paramTable_startTimeColname = "realStartDate",
    likFunc = functools.partial(
        likFunc_dummyForPlotting, 
        projFunc=(lambda sim: np.sum(sim[:,-1,:,:,:],axis=(1,2)))),
    simSuffix="_newOnly",
    df = df_UK_NHS_daily_COVID_deaths,
    sumAges=True
)

tmp = copy.deepcopy(df_deaths_b)

a = plot(tmp.values, tmp.index, now=False)

# Style the curves
for ind in range(1, len(a)):
    a[ind]["opacity"]=0.05#1./np.sqrt(len(a)/4.)
    a[ind]["legendgroup"]="Simulations"
    a[ind]["showlegend"]=False
    a[ind]["line"]=dict(color="red")    


a[1]["name"] = "Simulations"
a[1]["showlegend"]=True
a[1]["opacity"]=0.15 # To make sure it shows clearly in legend
    
a[0]["showlegend"]=True
a[0]["line"]=dict(color="black", width=2)
a[0]["mode"]="lines+markers"
a[0]["legendgroup"]="Data"
a[0]["opacity"]=1.
a[0]["name"] = "Data" 
    
layoutArgs = defaultLayout(scale=1.4)    

layout = go.Layout(
    **layoutArgs,
    
    # Add basic info
    title="Policy B - Stop lockdown on 30 June",
    xaxis_title="Date",
    yaxis_title="Daily in-hospital deaths",
)

    
fig = go.Figure(
        data=a, 
        layout=layout)
    
plt(fig)


In [None]:
# Figure 2c

searchRow = copy.deepcopy(paramTable_caseIsolation.loc[0])
searchRow["tStopSocialDistancing"] =  pd.to_datetime("2020-06-30", format="%Y-%m-%d")
#searchRow["trFunc_testing_params_inpFunc_testSpecifications_params_antigen_FNR_I1_to_R2"] = [0.99, 0.8, 0.4, 0.5, 0.6, 0.99]
searchRow["trFunc_testing_params_inpFunc_testSpecifications_params_antigen_FPR"] = 0.05
searchRow["trFunc_testing_params_trFunc_testCapacity_params_testCapacity_antibody_country_total"] = 1e6

indices = np.arange(findRow(searchRow, paramTable_caseIsolation.loc[:500]),len(paramTable_caseIsolation),nPolicies_caseIsolation)

# Daily new infections predicted by model
df_deaths_c = joinSimulations(
    loadDir = loadDir,
    paramTable=paramTable_caseIsolation,
    simRowIndices = paramTable_caseIsolation.index[indices],
    paramTable_startTimeColname = "realStartDate",
    likFunc = functools.partial(
        likFunc_dummyForPlotting,
        projFunc=(lambda sim: np.sum(sim[:,-1,:,:,:],axis=(1,2)))),    
    simSuffix="_newOnly",
    df = df_UK_NHS_daily_COVID_deaths,
    sumAges=True
)

tmp = copy.deepcopy(df_deaths_c)

a = plot(tmp.values, tmp.index, now=False)

# Style the curves
for ind in range(1, len(a)):
    a[ind]["opacity"]=0.05#1./np.sqrt(len(a)/4.)
    a[ind]["legendgroup"]="Simulations"
    a[ind]["showlegend"]=False
    a[ind]["line"]=dict(color="purple")    


a[1]["name"] = "Simulations"
a[1]["showlegend"]=True
a[1]["opacity"]=0.15 # To make sure it shows clearly in legend
    
a[0]["showlegend"]=True
a[0]["line"]=dict(color="black", width=2)
a[0]["mode"]="lines+markers"
a[0]["legendgroup"]="Data"
a[0]["opacity"]=1.
a[0]["name"] = "Data" 

layoutArgs = defaultLayout(scale=1.4)    

layout = go.Layout(
    **layoutArgs,
    
    # Add basic info
    title="Policy C - Stop lockdown on 30 June, flatten 2nd wave via 1 million '90% sensitive' tests",
    xaxis_title="Date",
    yaxis_title="In hospital deaths per day",
)

    
fig = go.Figure(
        data=a, 
        layout=layout)
    
plt(fig)

In [None]:
# Figure 2d - 

searchRow = copy.deepcopy(paramTable_caseIsolation.loc[0])
searchRow["tStopSocialDistancing"] =  pd.to_datetime("2020-06-30", format="%Y-%m-%d")
searchRow["trFunc_testing_params_inpFunc_testSpecifications_params_antigen_FNR_I1_to_R2"] = [0.99, 0.8, 0.4, 0.5, 0.6, 0.99]
searchRow["trFunc_testing_params_inpFunc_testSpecifications_params_antigen_FPR"] = 0.05
searchRow["trFunc_testing_params_trFunc_testCapacity_params_testCapacity_antibody_country_total"] = 5e6

indices = np.arange(findRow(searchRow, paramTable_caseIsolation.loc[:500]),len(paramTable_caseIsolation),nPolicies_caseIsolation)

# Daily in-hospital deaths by model and data
df_deaths_d = joinSimulations(
    loadDir = loadDir,
    paramTable=paramTable_caseIsolation,
    simRowIndices = paramTable_caseIsolation.index[indices],
    paramTable_startTimeColname = "realStartDate",
    likFunc = functools.partial(
        likFunc_dummyForPlotting,
        projFunc=(lambda sim: np.sum(sim[:,-1,:,:,:],axis=(1,2)))),    
    simSuffix="_newOnly",
    df = df_UK_NHS_daily_COVID_deaths,
    sumAges=True
)

tmp = copy.deepcopy(df_deaths_d)

a = plot(tmp.values, tmp.index, now=False)

# Style the curves
for ind in range(1, len(a)):
    a[ind]["opacity"]=0.05#1./np.sqrt(len(a)/4.)
    a[ind]["legendgroup"]="Simulations"
    a[ind]["showlegend"]=False
    a[ind]["line"]=dict(color="green")    


a[1]["name"] = "Simulations"
a[1]["showlegend"]=True
a[1]["opacity"]=0.15 # To make sure it shows clearly in legend
    
a[0]["showlegend"]=True
a[0]["line"]=dict(color="black", width=2)
a[0]["mode"]="lines+markers"
a[0]["legendgroup"]="Data"
a[0]["opacity"]=1.
a[0]["name"] = "Data"  


layoutArgs = defaultLayout(scale=1.4)    

layout = go.Layout(
    **layoutArgs,
    
    # Add basic info
    title="Policy D - Stop lockdown on 30 June, suppress 2nd wave via 5 million '60% sensitive' tests",
    xaxis_title="Date",
    yaxis_title="In hospital deaths per day",
    
)

    
fig = go.Figure(
        data=a, 
        layout=layout)
    
plt(fig)


In [None]:
# Histogram of deaths per policy per scenario:

plt(go.Figure(
    data=[
        go.Histogram(
            x=(df_deaths_a.iloc[:,1:]).sum(0),
            marker=dict(
                  opacity=0.3,
                  color='blue'
              ),
            histnorm='probability',
            name="Policy A"
        ),
        go.Histogram(
            x=(df_deaths_b.iloc[:,1:]).sum(0),
            marker=dict(
                  opacity=0.3,
                  color='red'
              ),
            histnorm='probability',
            name="Policy B"
        ),
        go.Histogram(
            x=(df_deaths_c.iloc[:,1:]).sum(0),
            marker=dict(
                  opacity=0.3,
                  color='purple'
              ),
            histnorm='probability',
            name="Policy C"
        ),
        go.Histogram(
            x=(df_deaths_d.iloc[:,1:]).sum(0),
            marker=dict(
                  opacity=0.3,
                  color='green'
              ),
            histnorm='probability',
            name="Policy D"
        ),
    ],

    layout = go.Layout(
        **defaultLayout(),
        
        title = "Number of total in-hospital deaths across the same scenarios using different policies",
        xaxis_title = "Number of total in-hospital COVID deaths by end of 2020",
        yaxis_title = "Fraction of scenarios"
    )
)
)