# Run policies on a base ensemble

In [None]:
# Packages

import numpy as np

from scipy import integrate, stats
from scipy.special import expit, binom

import pandas as pd
import xlrd

import copy
import warnings

from datetime import datetime
import random
import string
import os
import shutil
import sys
import cloudpickle
import dask
import distributed
from dask.distributed import Client
import itertools

import pymc3

import inspect
from collections import OrderedDict

import dask.dataframe as dd

from datetime import datetime
import random
import string
import os
import shutil
import sys
import cloudpickle
import dask
import distributed
from dask.distributed import Client
from dask.distributed import as_completed
import itertools



# # import helper funcs
# modulePath = "/mnt/efs/modules"
# shutil.copyfile("modelHelperFuncs.py", modulePath + "modelHelperFuncs.py")
# if not modulePath in sys.path:
#     sys.path.insert(0, modulePath)

# from modelHelperFuncs import regroup_by_age, build_paramDict, paramDict_toTable, paramTable_toDict

# Helper functions

Distributed dask workers seem very fiddly at importing non-standard modules, so it's better to just copy-paste the functions into this notebook sadly

In [None]:
# To get relatve age-related risks, we have to first re-group into our basic age groups, 
# then devide by total population (here's a non-well-defined subset of UK, so absolute values wont be used, only relative)
import numpy as np
import inspect
from collections import OrderedDict

def regroup_by_age(inp, fromAgeSplits, toAgeSplits, maxAge=100., maxAgeWeight = 5.):
    fromAgeSplits = np.concatenate([np.array([0]), fromAgeSplits, np.array([maxAge])]) # Add a zero at beginning for calculations
    toAgeSplits = np.concatenate([np.array([0]), toAgeSplits, np.array([maxAge])]) # Add inf at end for calculations
    def getOverlap(a, b):
        return max(0, min(a[1], b[1]) - max(a[0], b[0]))
    out = np.zeros((len(toAgeSplits)-1,)+inp.shape[1:])
    for from_ind in range(1, len(fromAgeSplits)):
        # Redistribute to the new bins by calculating how many years in from_ind-1:from_ind falls into each output bin
        cur_out_distribution = (
        [getOverlap(toAgeSplits[cur_to_ind-1:cur_to_ind+1],fromAgeSplits[from_ind-1:from_ind+1])  for cur_to_ind in range(1, len(toAgeSplits))]
        )
        
        if cur_out_distribution[-1] > 0:
            cur_out_distribution[-1] = maxAgeWeight # Define the relative number of ages if we have to distribute between second to last and last age groups

        cur_out_distribution = cur_out_distribution/np.sum(cur_out_distribution)
        
        for to_ind in range(len(out)):
            out[to_ind] += cur_out_distribution[to_ind] * inp[from_ind-1]
            
    return out


# PARAMETER DICTIONARIES AND TABLES
# -----------------------------------------------------------------------------------------


def build_paramDict(cur_func):
    """
    This function iterates through all inputs of a function, 
    and saves the default argument names and values into a dictionary.
    
    If any of the default arguments are functions themselves, then recursively (depth-first) adds an extra field to
    the dictionary, named <funcName + "_params">, that contains its inputs and arguments.
    
    The output of this function can then be passed as a "kwargs" object to the highest level function, 
    which will then pass the parameter values to the lower dictionary levels appropriately
    """
    
    paramDict = OrderedDict()
    
    allArgs = inspect.getfullargspec(cur_func)
    
    # Check if there are any default parameters, if no, just return empty dict
    if allArgs.defaults is None:
        return paramDict
    
    
    for argname, argval in zip(allArgs.args[-len(allArgs.defaults):], allArgs.defaults):
        # Save the default argument
        paramDict[argname] = argval
        # If the default argument is a function, inspect it for further 
        
        if callable(argval):
            # print(argname)
            paramDict[argname+"_params"] = build_paramDict(argval)

    return paramDict




# Do a mapping between dictionary and parameter table row (for convenient use)

# Flatten the dictionary into a table with a single row (but many headers):
def paramDict_toTable(paramDict):
    paramTable = pd.DataFrame()
    def paramDictRecurseIter(cur_table, cur_dict, preString):
        # Iterate through the dictionary to find all keys not ending in "_params", 
        # and add them to the table with name <preString + key>
        # 
        # If the key doesn end in "_params", then append the key to preString, in call this function on the value (that is a dict)
        for key, value in cur_dict.items():
            if key.endswith("_params"):
                paramDictRecurseIter(cur_table, value, preString+key+"_")
            else:
                paramTable[preString+key] = [value]
                
        # For the rare case where we want to keep an empty dictionary, the above for cycle doesn't keep it
        if len(cur_dict)==0:
            paramTable[preString] = [OrderedDict()]
                
        return cur_table
    
    return paramDictRecurseIter(paramTable, paramDict, preString="")

# Example dict -> table
# paramTable_default = paramDict_toTable(paramDict_default)
    

def paramTable_toDict(paramTable, defaultDict=None, to_flat=False):
    # enable to pass a default dict (if paramTable is incomplete), in which we'll just add / overwrite the values
    paramDict = defaultDict if defaultDict is not None else OrderedDict() 
    def placeArgInDictRecurse(argName, argVal, cur_dict):
        # Find all "_params_" in the argName, and for each step more and more down in the dictionary
        strloc = argName.find("_params_")
        if strloc == -1 or to_flat:
            # We're at the correct level of dictionary
            cur_dict[argName] = argVal
            return cur_dict
        else:
            # step to the next level of dictionary
            nextKey = argName[:strloc+len("_params_")-1]
            nextArgName = argName[strloc+len("_params_"):]
            if not nextKey in cur_dict:
                cur_dict[nextKey] = OrderedDict()
            placeArgInDictRecurse(nextArgName, argVal, cur_dict[nextKey])
            
        return cur_dict
            
    for key in paramTable.columns:
        paramDict = placeArgInDictRecurse(key, paramTable.at[0,key], paramDict)
            
    return paramDict

# Example table -> dict 
# paramDict_new = paramTable_toDict(paramTable_default)

## # Load a full ensemble and select the base ensemble with queries


In [None]:
# Load a full ensemble and select the base ensemble with queries
base_ensemble_params = OrderedDict()

base_ensemble_params["ensembleDir"] = "/mnt/efs/results/run_20200421T002117/"
base_ensemble_params["ensembleQuery"] = "likelihood_0 > -260 & likelihood_2 > -280"
base_ensemble_params["ensembleSortby"] = "likelihood_total"
base_ensemble_params["ensembleMaxnumber"] = 200

paramTable_ensemble_dd = dd.from_pandas(
    pd.read_hdf(base_ensemble_params["ensembleDir"] + 'paramTable_part0', key="paramTable"),chunksize=1000)

i = 1
while True:
    print("Loading {}".format(i))
    try:
        paramTable_ensemble_dd = paramTable_ensemble_dd.append(
            pd.read_hdf(base_ensemble_params["ensembleDir"] + 'paramTable_part{}'.format(i), key="paramTable"))
        i += 1
    except:
        break
        
        
paramTable_ensemble_dd_queriedSorted = (
    paramTable_ensemble_dd
    .query(base_ensemble_params["ensembleQuery"])
    .compute()
    .sort_values(base_ensemble_params["ensembleSortby"])
    .tail(base_ensemble_params["ensembleMaxnumber"])
)

In [None]:
paramTable_ensemble_dd_queriedSorted

## Load all the functions from the base ensemble (default_dict)

In [None]:
with open(base_ensemble_params["ensembleDir"] + "paramDict_default.cpkl", 'rb') as fh:
    paramDict_default = cloudpickle.load(fh)

## Set up new policy parameters

In [None]:
paramTable_default = paramDict_toTable(paramDict_default)

#### Case isolation policy parameters

In [None]:
newPolicyParams_caseIsolation = OrderedDict()

newPolicyParams_caseIsolation["tStartQuarantineCaseIsolation"] = [
    #pd.to_datetime("2020-04-20", format="%Y-%m-%d"),
    pd.to_datetime("2020-05-01", format="%Y-%m-%d")
]
newPolicyParams_caseIsolation['tStopSocialDistancing'] = [
    #pd.to_datetime("2020-04-30", format="%Y-%m-%d"),
    pd.to_datetime("2020-05-30", format="%Y-%m-%d"),
    pd.to_datetime("2020-06-30", format="%Y-%m-%d"),
    pd.to_datetime("2020-09-30", format="%Y-%m-%d")
]

newPolicyParams_caseIsolation['trFunc_testing_params_policyFunc_params_basic_policyFunc_params_antibody_testing_policy'] = [
    "virus_positive_only_hospworker_first"
]

newPolicyParams_caseIsolation['trFunc_testing_params_trFunc_testCapacity_params_testCapacity_antigenratio_country'] = [1.]


newPolicyParams_caseIsolation['trFunc_testing_params_trFunc_testCapacity_params_testCapacity_antibody_country_total'] = [
    1e5, 5e5, 1e6, 2e6, 3e6, 5e6, 10e6]

newPolicyParams_caseIsolation['trFunc_testing_params_trFunc_testCapacity_params_testCapacity_antibody_country_inflexday'] = [
     #pd.to_datetime("2020-04-30", format="%Y-%m-%d"),
     pd.to_datetime("2020-05-30", format="%Y-%m-%d")
]

newPolicyParams_caseIsolation['trFunc_testing_params_trFunc_testCapacity_params_testCapacity_antibody_country_inflexslope'] = [
    10.#, 20
]

newPolicyParams_caseIsolation['trFunc_quarantine_params_nDaysInHomeIsolation'] = [
    14.#, 24.
]



newPolicyParams_caseIsolation["trFunc_testing_params_inpFunc_testSpecifications_params_antigen_FNR_I1_to_R2"]= np.array(
    [
        # Good test
        [0.8,   0.3,   0.1, 0.15, 0.3, 0.7],
        # Medium test
        [0.9,   0.5,   0.25,  0.35, 0.5, 0.9],
        #Bad test
        [0.99,   0.8,   0.4,  0.5, 0.6, 0.99]
    ]
)

newPolicyParams_caseIsolation["trFunc_testing_params_inpFunc_testSpecifications_params_antigen_FPR"] = [0.05, 0.1]



# Make the joint dataframe
df_newPolicyParams_caseIsolation = pd.DataFrame(
    data=list(itertools.product(*newPolicyParams_caseIsolation.values())), 
    columns=newPolicyParams_caseIsolation.keys()
    )


# Filter rows to retain only zipped combinations (TODO - there's probably a way to do it during the itertools.product, but this is ok for now)
zippedColumnSets = [
#    ["trFunc_testing_params_inpFunc_testSpecifications_params_antigen_FNR_I1_to_R2", "trFunc_testing_params_inpFunc_testSpecifications_params_antigen_FPR"]
]

all_validInds = df_newPolicyParams_caseIsolation.index
for zcs in zippedColumnSets:
    zippedVals = list(zip(*[newPolicyParams_caseIsolation[key] for key in zcs]))
    
    curZCS_validInds = []
    for zv in zippedVals:
        validInds = df_newPolicyParams_caseIsolation.loc[all_validInds].index
        for ind, key in enumerate(zcs):
            tmp = df_newPolicyParams_caseIsolation.loc[validInds]
            validInds = tmp.index[
                            tmp[key].apply(lambda x: np.all(x==zv[ind]))
                        ]
            
        curZCS_validInds = curZCS_validInds + list(validInds)
        
    #curZCS_validInds = [b for a in curZCS_validInds for b in a]
    
    all_validInds = list(set(all_validInds) & set(curZCS_validInds))
            
df_newPolicyParams_caseIsolation = df_newPolicyParams_caseIsolation.loc[all_validInds] 
df_newPolicyParams_caseIsolation = df_newPolicyParams_caseIsolation.reset_index(drop=True)
    

len(df_newPolicyParams_caseIsolation)

In [None]:
# Merge the new policy params with the selected ensemble params
for i in range(len(paramTable_ensemble_dd_queriedSorted)):
    tmp = copy.deepcopy(df_newPolicyParams_caseIsolation)
    for colname in (paramTable_ensemble_dd_queriedSorted
                    .reset_index()
                    .loc[i:i]
                    #.rename(columns={"Index":"EnsembleIndex", "out_fname":"Ensemble_out_fname"})
                    .columns
                   ):
        tmp[colname] = [paramTable_ensemble_dd_queriedSorted.reset_index().at[i, colname]]*len(df_newPolicyParams_caseIsolation)
        tmp = tmp.rename(columns={"index":"EnsembleIndex", "out_fname":"Ensemble_out_fname"})
        
    if i==0:
        paramTable_merged = tmp
    else:
        paramTable_merged = paramTable_merged.append(tmp)
        
paramTable_merged = paramTable_merged.reset_index(drop=True)



In [None]:
paramTable_merged

### Also make an alternative paramTable to evaluate the "keep social distancing" policy

In [None]:
# Add a single row on top that runs with base settings (ie policy completely off)

paramTable_baseParams_keepSocialDistancing = copy.deepcopy(paramTable_ensemble_dd_queriedSorted)

paramTable_baseParams_keepSocialDistancing = (
    paramTable_baseParams_keepSocialDistancing
    .reset_index()
    .rename(columns={"index":"EnsembleIndex", "out_fname":"Ensemble_out_fname"})
)

In [None]:
paramTable_baseParams_keepSocialDistancing

### Also make an alternative paramTable to evaluate the "stop social distancing" policy

In [None]:
# Add a single row on top that runs with base settings (ie policy completely off)
# Make a single table with base params but stopping social distancing at some points

newPolicyParams_stopSocialDistancing = OrderedDict()

newPolicyParams_stopSocialDistancing['tStopSocialDistancing'] = [
    pd.to_datetime("2020-04-30", format="%Y-%m-%d"),
    pd.to_datetime("2020-05-30", format="%Y-%m-%d"),
    pd.to_datetime("2020-06-30", format="%Y-%m-%d"),
    pd.to_datetime("2020-09-30", format="%Y-%m-%d")
]

df_newPolicyParams_stopSocialDistancing = pd.DataFrame(
    data=list(itertools.product(*newPolicyParams_stopSocialDistancing.values())), 
    columns=newPolicyParams_stopSocialDistancing.keys()
    )


# Merge the new policy params with the selected ensemble params
for i in range(len(paramTable_ensemble_dd_queriedSorted)):
    tmp = copy.deepcopy(df_newPolicyParams_stopSocialDistancing)
    for colname in (paramTable_ensemble_dd_queriedSorted
                    .reset_index()
                    .loc[i:i]
                    #.rename(columns={"Index":"EnsembleIndex", "out_fname":"Ensemble_out_fname"})
                    .columns
                   ):
        tmp[colname] = [paramTable_ensemble_dd_queriedSorted.reset_index().at[i, colname]]*len(df_newPolicyParams_stopSocialDistancing)
        tmp = tmp.rename(columns={"index":"EnsembleIndex", "out_fname":"Ensemble_out_fname"})
        
    if i==0:
        paramTable_baseParams_stopSocialDistancing = tmp
    else:
        paramTable_baseParams_stopSocialDistancing = paramTable_baseParams_stopSocialDistancing.append(tmp)
        
paramTable_baseParams_stopSocialDistancing = paramTable_baseParams_stopSocialDistancing.reset_index(drop=True)



In [None]:
paramTable_baseParams_stopSocialDistancing


# Set up and save policy control dicts for dashboard

In [None]:
with open('paramTypes.cpkl', 'rb') as fh:
    paramTypes = cloudpickle.load(fh)

In [None]:
policy_control_parameters = {
        # Timings
        'tStopSocialDistancing': ["stop_social_distancing", "Social Distancing End Date", "datetime64", "dropdown", None, True],
        'tStartImmunityPassports': ["start_immunity_passports", "Immunity Passports Start Date", "datetime64", "dropdown", None, False],
        'tStopImmunityPassports': ["stop_immunity_passports", "Immunity Passports End Date", "datetime64", "dropdown", None, False],
        'tStartQuarantineCaseIsolation': ["start_case_isolation", "Case Isolation Start Date", "datetime64", "dropdown", None, False],
        'tStopQuarantineCaseIsolation': ["stop_case_isolation", "Case Isolation End Date", "datetime64", "dropdown", None, True],
        # Quarantine
        'trFunc_quarantine_params_nDaysInHomeIsolation': [
            "caseiso_ndayshome", "Length of strict home quarantine (days)", "int", "slider", None, False
        ],
        # 'trFunc_newInfections_params_ageSocialMixingIsolation': ["STATIC"],
        'trFunc_quarantine_params_timeToIsolation': [
            "caseiso_timetoiso", "Time between test and quarantine start (days)", "float", "slider", None, False
        ],
        # 'trFunc_quarantine_params_symptomHospitalisedRate_ageAdjusted' : ["STATIC"],
        # 'trFunc_quarantine_params_symptomaticHealthStates' : ["STATIC"],
        # Testing
        'trFunc_testing_params_trFunc_testCapacity_params_testCapacity_pcr_phe_total': [
            "testcapacity_pcr_phe_total", "PHE lab PCR tests - maximum capacity per day", "float", "slider", None, False
        ],
        'trFunc_testing_params_trFunc_testCapacity_params_testCapacity_pcr_phe_inflexday': [
            "testcapacity_pcr_phe_inflexday", "PHE lab PCR tests - date of reaching half maximum capacity", "datetime64", "dropdown", None, False
        ],
        'trFunc_testing_params_trFunc_testCapacity_params_testCapacity_pcr_phe_inflexslope': [
            "testcapacity_pcr_phe_inflexslope", "PHE lab PCR tests - days to reach maximum capacity", "float", "slider", None, False
        ],
        'trFunc_testing_params_trFunc_testCapacity_params_testCapacity_pcr_country_total': [
            "testcapacity_pcr_country_total", "UK non-PHE PCR tests - maximum capacity per day", "float", "slider", None, False
        ],
        'trFunc_testing_params_trFunc_testCapacity_params_testCapacity_pcr_country_inflexday': [
            "testcapacity_pcr_country_inflexday", "UK non-PHE PCR tests - date of reaching half maximum capacity", "datetime64", "dropdown", None, False
        ],
        'trFunc_testing_params_trFunc_testCapacity_params_testCapacity_pcr_country_inflexslope': [
            "testcapacity_pcr_country_inflexslope", "UK non-PHE PCR tests - days to reach maximum capacity", "float", "slider", None, False
        ],
        'trFunc_testing_params_trFunc_testCapacity_params_testCapacity_antibody_country_firstday': [
            "testcapacity_lfa_country_firstday", "UK home tests (LFA) - first date deployed", "datetime64", "dropdown", None, False
        ],
        'trFunc_testing_params_trFunc_testCapacity_params_testCapacity_antibody_country_total': [
            "testcapacity_lfa_country_total", "UK home tests (LFA) - maximum capacity per day", "float", "slider", None, False
        ],
        'trFunc_testing_params_trFunc_testCapacity_params_testCapacity_antibody_country_inflexday': [
            "testcapacity_lfa_country_inflexday", "UK home tests (LFA) - date of reaching half maximum capacity", "datetime64", "dropdown", None, False
        ],
        'trFunc_testing_params_trFunc_testCapacity_params_testCapacity_antibody_country_inflexslope': [
            "testcapacity_lfa_country_inflexslope", "UK home tests (LFA) - days to reach maximum capacity", "float", "slider", None, False
        ],
        'trFunc_testing_params_trFunc_testCapacity_params_testCapacity_antigenratio_country': [
            "testcapacity_lfa_country_antigenratio", "UK home tests (LFA) - ratio of virus and immunity tests", "float", "slider", None, True
        ],
        'trFunc_testing_params_policyFunc_params_basic_policyFunc_params_antibody_testing_policy': [
            "antibody_testing_policy", "Antibody test distribution policy", "string", "dropdown", None, False
        ],
        # 'trFunc_testing_params_policyFunc_params_f_symptoms_nonCOVID_params_' : ["STATIC"],
        # 'trFunc_testing_params_policyFunc_params_distributeRemainingToRandom' : ["STATIC"],
        # Test specs
        'trFunc_testing_params_inpFunc_testSpecifications_params_antigen_FNR_I1_to_R2': [
            "testspecs_antigen_FNR", "LFA antigen test specifications (FNR)", "spec_list", "dropdown", 
            [[np.array([0.8, 0.3, 0.1, 0.15, 0.3, 0.7]), 
              np.array([0.9, 0.5, 0.25, 0.35, 0.5, 0.9]), 
              np.array([0.99, 0.8, 0.4, 0.5, 0.6, 0.99])], ["Good", "Medium", "Bad"]], True
        ],
        'trFunc_testing_params_inpFunc_testSpecifications_params_antigen_FPR' : [
            "testspecs_antigen_FPR","LFA antigen test specifications (FPR)", "float", "slider", None, True
        ],
        'trFunc_testing_params_inpFunc_testSpecifications_params_antibody_FNR_I1_to_R2': [
            "testspecs_antibody_FNR", "LFA antibody test specifications (FNR)", "spec_list", "dropdown", [[np.array([0.99, 0.85, 0.8, 0.65, 0.3, 0.05])], ["Baseline"]], True
        ],
        'trFunc_testing_params_inpFunc_testSpecifications_params_antibody_FPR_S_to_I4': [
            "testspecs_antibody_FPR","LFA antibody test specifications (FPR)", "float", "slider", None, True
        ],
    
        'trFunc_testing_params_policyFunc_params_retesting_antigen_viruspos_ratio': [
            "retesting_antigen_viruspos_ratio", "Ratio of leftover home virus tests used for re-testing virus pos", "float", "slider", None, False            
        ],
    
        'trFunc_testing_params_policyFunc_params_retesting_antigen_immunepos_ratio': [
                "retesting_antigen_viruspos_ratio", "Ratio of leftover home virus tests used for re-testing immune pos", "float", "slider", None, False            
            ],
    
        'trFunc_testing_params_policyFunc_params_retesting_antibody_immunepos_ratio': [
            "retesting_antibody_immunepos_ratio", "Ratio of leftover home immunity tests used for re-testing", "float", "slider", None, False            
        ]
    
    }


[print(i) for i in (set(policy_control_parameters.keys()) - set(paramTypes["policy"]))]
print("-------------")
[print(i) for i in (set(paramTypes["policy"]) - set(policy_control_parameters.keys()))]

In [None]:
# Create the config dicts for each table
# Policy params
controlDict_caseIsolation = OrderedDict()
for key in paramTable_merged.columns:
    if key in policy_control_parameters:
        if len(newPolicyParams_caseIsolation[key])>1: # keep only the ones that actually vary!
            controlDict_caseIsolation[key] = policy_control_parameters[key]
df_controlDict_caseIsolation = paramDict_toTable(
    controlDict_caseIsolation
)

df_controlDict_caseIsolation

In [None]:
# Create the config dicts for each table
# Policy params
controlDict_keepSocialDistancing = OrderedDict()
df_controlDict_keepSocialDistancing = paramDict_toTable(
    controlDict_keepSocialDistancing
)

df_controlDict_keepSocialDistancing

In [None]:
# Create the config dicts for each table
# Policy params
controlDict_stopSocialDistancing = OrderedDict()
for key in paramTable_baseParams_stopSocialDistancing.columns:
    if key in policy_control_parameters:
        if len(newPolicyParams_stopSocialDistancing[key])>1: # keep only the ones that actually vary!
            controlDict_stopSocialDistancing[key] = policy_control_parameters[key]
df_controlDict_stopSocialDistancing = paramDict_toTable(
    controlDict_stopSocialDistancing
)

df_controlDict_stopSocialDistancing

## Evaluate the full policy grid on all ensemble members

In [None]:
#Load the saved initialisation
stateTensor_init = paramDict_default["INIT_stateTensor_init"]

In [None]:
def solveSystem(stateTensor_init, total_days = 200, samplesPerDay=np.inf, **kwargs):
    # Run the simulation
    
    if kwargs["debugReturnNewPerDay"]: # Keep the second copy as well
        cur_stateTensor = np.reshape(
            np.stack([copy.deepcopy(stateTensor_init), copy.deepcopy(stateTensor_init)], axis=0),-1)
    else:
        cur_stateTensor = np.reshape(copy.deepcopy(stateTensor_init),-1)
    
    if np.isinf(samplesPerDay):
        # Run precise integrator takes forever
        out = integrate.solve_ivp(
            fun = lambda t,y: kwargs["dydt_Complete"](t,y, **kwargs),
            t_span=(0.,total_days),
            y0 = cur_stateTensor,
            method='RK23',
            t_eval=range(total_days),
            #max_step = 1.,
            #first_step = 1e-1,
            rtol = 1e-3, #default 1e-3
            atol = 1e-3, # default 1e-6
        )
        
        out = out.y
        
    else:
        # Run simple Euler method with given step size (1/samplesPerDay)
        deltaT = 1./samplesPerDay
        out = np.zeros((np.prod(stateTensor_init.shape),total_days))
                       
        for tt in range(total_days*samplesPerDay):
            if tt % samplesPerDay==0:
                out[:, int(tt/samplesPerDay)] = cur_stateTensor
                       
            cur_stateTensor += deltaT * kwargs["dydt_Complete"]((tt*1.)/(1.*samplesPerDay),cur_stateTensor, **kwargs)
            
    
    # Reshape to reasonable format
    if kwargs["debugReturnNewPerDay"]:
        out = np.reshape(out, (2,) + stateTensor_init.shape+(-1,))
    else:
        out = np.reshape(out, stateTensor_init.shape+(-1,))
    
    
    return out

### Set up dask parallel run

In [None]:
client = Client("127.0.0.1:8786")

In [None]:
# Set up where to save and save default parameters

timeOfRunning = datetime.now().strftime("%Y%m%dT%H%M%S")

saveDir = "/mnt/efs/results/run_" + timeOfRunning + "/"
os.makedirs(saveDir, exist_ok=True)
os.chmod(saveDir, 0o777) # enable workers to write the files

# Save the default parameter dictionary that we'll merge with new inputs
with open(saveDir+'paramDict_default.cpkl', 'wb') as fh:
    cloudpickle.dump(paramDict_default, fh)


# save also the ensemble parameters dictionary
with open(saveDir+'ensembleSelectionSettings.cpkl', 'wb') as fh:
    cloudpickle.dump(base_ensemble_params, fh)
    


In [None]:
# Create file names for the policy grid
out_fnames = []
for ind in range(len(paramTable_merged)):
    out_fnames.append("outTensor_" + timeOfRunning + "_" + ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(20))+".npy")
    
paramTable_merged["out_fname"] = out_fnames

In [None]:
# Create file names for the "keep social distancing" run
out_fnames = []
for ind in range(len(paramTable_baseParams_keepSocialDistancing)):
    out_fnames.append("outTensor_" + timeOfRunning + "_" + ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(20))+".npy")
    
paramTable_baseParams_keepSocialDistancing["out_fname"] = out_fnames

In [None]:
# Create file names for the "keep social distancing" run
out_fnames = []
for ind in range(len(paramTable_baseParams_stopSocialDistancing)):
    out_fnames.append("outTensor_" + timeOfRunning + "_" + ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(20))+".npy")
    
paramTable_baseParams_stopSocialDistancing["out_fname"] = out_fnames

In [None]:
# Save also the paramTable, with the appropriate keys to determine "policy type"
paramTable_merged.to_hdf(saveDir + "paramTable_part{}".format(0), key="Case Isolation Varying Test Numbers")
df_controlDict_caseIsolation.to_hdf(saveDir + "paramTable_part{}".format(0), key="DashboardConfig-Case Isolation Varying Test Numbers")

paramTable_baseParams_keepSocialDistancing.to_hdf(saveDir + "paramTable_part{}".format(0), key="Keep Social Distancing")
df_controlDict_keepSocialDistancing.to_hdf(saveDir + "paramTable_part{}".format(0), key="DashboardConfig-Keep Social Distancing")

paramTable_baseParams_stopSocialDistancing.to_hdf(saveDir + "paramTable_part{}".format(0), key="Stop Social Distancing")
df_controlDict_stopSocialDistancing.to_hdf(saveDir + "paramTable_part{}".format(0), key="DashboardConfig-Stop Social Distancing")

In [None]:
# Run parallel for each parameter setting and save to out_fname
def runAll(newParams_row, stateTensor_init=stateTensor_init, defaultDict=paramDict_default, timeOfRunning=timeOfRunning):
    # Run model 
    # Make sure the newOnly stuff is saved as well
    curDict = copy.deepcopy(defaultDict)
    curDict["debugReturnNewPerDay"] = True
    
    out = solveSystem(stateTensor_init, 
                total_days = 365, 
                **paramTable_toDict(
                            # sub-select allowed columns
                           newParams_row[list(set(newParams_row.columns) & set(paramDict_toTable(defaultDict).columns))].reset_index(drop=True),
                           defaultDict=copy.deepcopy(curDict)
                    )
               )
    # The out is now both the states and the cumsum
    out_newOnly = np.diff(np.concatenate([np.expand_dims(copy.deepcopy(out[0][:,:,:,:,0]),axis=4), copy.deepcopy(out[1])], axis=-1), axis=-1)
    out = out[0]    
    
    return out, out_newOnly, newParams_row

### First run the short ones

In [None]:
futures = []

In [None]:
# Submit all futures that were not yet completed (check if files exist)
for index in range(len(paramTable_baseParams_keepSocialDistancing)):
    
    tmp_params_row = copy.deepcopy(paramTable_baseParams_keepSocialDistancing.loc[index:index])
    fut = client.submit(runAll, tmp_params_row)
    futures.append(fut)

In [None]:
# Submit all futures that were not yet completed (check if files exist)
for index in range(len(paramTable_baseParams_stopSocialDistancing)):
    
    tmp_params_row = copy.deepcopy(paramTable_baseParams_stopSocialDistancing.loc[index:index])
    fut = client.submit(runAll, tmp_params_row)
    futures.append(fut)

### Then the variations on the policy

In [None]:
# Submit all futures that were not yet completed (check if files exist)
for index in range(len(paramTable_merged)):
    
    tmp_params_row = copy.deepcopy(paramTable_merged.loc[index:index])
    fut = client.submit(runAll, tmp_params_row)
    futures.append(fut)

### Monitor and save results as they arrive

In [None]:
# Monitor and save
seq = as_completed(futures)

curIndex = 0

for future in seq:
    if future.status == "finished":
        out, out_newOnly, newParams_row = future.result()
        
#         if not "out_fname" in newParams_row.columns:
#             orig_row_ind = findRow(newParams_row.iloc[0], df = paramTable_merged, policyColumns=paramTypes["policy"]+paramTypes["ensemble"])
#             newParams_row["out_fname"] = paramTable_merged.loc[orig_row_ind].out_fname

        # Save all the files
        np.save(file = saveDir + newParams_row["out_fname"].values[0],
                        arr= out
                    )

        np.save(file = saveDir + newParams_row["out_fname"].values[0][:-4]+"_newOnly.npy",
                        arr= out_newOnly
                    )
            
        curIndex += 1
    
        client.cancel(future) # remove from memory after saved

In [None]:
client.cancel(futures)

In [None]:
client.close()