In [4]:
import argparse
from collections import defaultdict
import random
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
import datetime

import os
import sys
stderr = sys.stderr
sys.stderr = open(os.devnull, 'w')
import keras
sys.stderr = stderr

from ortools.linear_solver import pywraplp
from covid_xprize.standard_predictor.xprize_predictor import XPrizePredictor

NPI_COLS_NAMES = ['C1_School closing',
            'C2_Workplace closing',
            'C3_Cancel public events',
            'C4_Restrictions on gatherings',
            'C5_Close public transport',
            'C6_Stay at home requirements',
            'C7_Restrictions on internal movement',
            'C8_International travel controls',
            'H1_Public information campaigns',
            'H2_Testing policy',
            'H3_Contact tracing',
            'H6_Facial Coverings']

NPI_values = [[0,1,2,3],
             [0,1,2,3],
             [0,1,2],
             [0,1,2,3,4],
             [0,1,2],
             [0,1,2,3],
             [0,1,2],
             [0,1,2,3,4],
             [0,1,2],
             [0,1,2,3],
             [0,1,2],
             [0,1,2,3,4]]

NPI_dict = {}
for i, col in enumerate(NPI_COLS_NAMES):
    NPI_dict[col] = NPI_values[i]

col_names = ['PrescriptionIndex', 'CountryName', 'RegionName', 'Date'] + NPI_COLS_NAMES


In [6]:
def prescribe(start_date: str,
              end_date: str,
              path_to_prior_ips_file: str,
              path_to_cost_file: str,
              output_file_path) -> None:

    start_date = np.datetime64(start_date)
    end_date = np.datetime64(end_date)

    # get weights
    case_weights_dict = {}

    with open("weights/weights_reformat.pickle", "rb") as file:
        case_weights_dict['case_weights_1'] = pickle.load(file)

    case_weights_names = ['case_weights_1']

    # get stringency
    stringency_weight_df = pd.read_csv(path_to_cost_file)
    stringency_weight_df["GeoID"] = np.where(stringency_weight_df["RegionName"].isnull(),
                                        stringency_weight_df["CountryName"],
                                        stringency_weight_df["CountryName"] + ' / ' + stringency_weight_df["RegionName"])

    GeoIDs = stringency_weight_df["GeoID"].values
    countries = stringency_weight_df["CountryName"].values
    regions = stringency_weight_df["RegionName"].values

    # process stringency
    stringency_weight = defaultdict(lambda : defaultdict(lambda : defaultdict(np.float64)))
    for index, row in stringency_weight_df.iterrows():
        row_sum = 0
        for col in NPI_COLS_NAMES:
            row_sum +=  row[col]*sum(NPI_dict[col])
    #     row_sum = row[NPI_COLS_NAMES].sum()
        if row_sum == 0:
            row_sum = 1
        for col in NPI_COLS_NAMES:
            for j in range(len(NPI_dict[col])):
    #             tmp.append((stringency_weight_tmp[i] * (NPI_values[i][j]))/stringency_weight_sum)
                stringency_weight[row['GeoID']][col][j] = (row[col] * NPI_dict[col][j])/row_sum
    
    stringency_weight = default_to_regular(stringency_weight)

    # set global variables
    data_url = path_to_prior_ips_file
    OxCGRT_latest = pd.read_csv(data_url,
                                    parse_dates=['Date'],
                                    encoding="ISO-8859-1",
                                    dtype={"RegionName": str,
                                        "RegionCode": str},
                                    error_bad_lines=False)
    OxCGRT_latest["GeoID"] = np.where(OxCGRT_latest["RegionName"].isnull(),
                                        OxCGRT_latest["CountryName"],
                                        OxCGRT_latest["CountryName"] + ' / ' + OxCGRT_latest["RegionName"])
    OxCGRT_latest['NewCases'] = OxCGRT_latest.groupby('GeoID').ConfirmedCases.diff().fillna(0)
    initial_day_cases = OxCGRT_latest[OxCGRT_latest['Date'] == '2020-07-31'].set_index('GeoID')[['NewCases']]
    initial_day_cases['NewCases'] = initial_day_cases['NewCases'].replace(0, 1)

    predictor = XPrizePredictor()

    ip_file_path = 'prescriptions/'
    preds_file_path = 'predictions/'


    day_count = (end_date - start_date)/np.timedelta64(1, 'D')

    previous_day_cases = initial_day_cases
    for w in case_weights_names:
        
        prescriptions_total_df = pd.DataFrame(columns = col_names)
        predictions_total = pd.DataFrame()
        
        for i in tqdm(range(int(day_count) + 1)):
            cur_start_date = start_date + np.timedelta64(i,'D')
            cur_end_date = start_date + np.timedelta64(i,'D')
            cur_ip_file_path = ip_file_path + 'prescriptions_initial_' + w + '_'+ str(cur_start_date) + '.csv'
            print('prescribing for day ' + str(cur_start_date))

            prescriptions_total = []
            for geo, country, RegionName in tqdm(zip(GeoIDs, countries, regions)):
                cur_case_weight = case_weights_dict[w][geo]
                cur_stringency_weight = stringency_weight[geo]

                ip_solution = run_opt(previous_day_cases, cur_case_weight, cur_stringency_weight, geo, initial_day_cases)
                prescriptions_total.append([0] + [country] + [RegionName] + [str(cur_start_date)] + ip_solution)
                
                
            prescriptions_df = pd.DataFrame()
            prescriptions_df = prescriptions_df.append(pd.DataFrame(prescriptions_total))
            prescriptions_df.columns = col_names
            
            prescriptions_total_df = prescriptions_total_df.append(prescriptions_df)

            prescriptions_df.to_csv(cur_ip_file_path)

            # predict for next day for all geo
            print('predicting for day ' + str(cur_start_date))
            previous_day_cases = predictor.predict(cur_start_date, cur_start_date, cur_ip_file_path)
            previous_day_cases["GeoID"] = np.where(previous_day_cases["RegionName"].isnull(),
                                                previous_day_cases["CountryName"],
                                                previous_day_cases["CountryName"] + ' / ' + previous_day_cases["RegionName"])
            previous_day_cases = previous_day_cases.set_index('GeoID')[['PredictedDailyNewCases']]
            previous_day_cases.rename(columns={'PredictedDailyNewCases':'NewCases'}, inplace=True)
            
            predictions_total[cur_start_date] = previous_day_cases['NewCases']
            
        
        final_ip_path = output_file_path
        prescriptions_total_df.to_csv(final_ip_path)
        final_pred_path = preds_file_path + 'final_predictions_' + w + '_'+ str(cur_start_date) + '.csv'
        predictions_total.to_csv(final_pred_path)


    # raise NotImplementedError


def default_to_regular(d):
    if isinstance(d, defaultdict):
        d = {k: default_to_regular(v) for k, v in d.items()}
    return d

def run_opt(previous_day_cases, case_weight, stringency_weight, geo_id, initial_day_cases):
    solver = pywraplp.Solver.CreateSolver('SCIP')
    
    # Create variables
    x = {}
    for i,col in enumerate(NPI_COLS_NAMES):
        for j in range(len(NPI_dict[col])):
            x[i, j] = solver.IntVar(0, 1, (geo_id + '_' + col + '_' + str(NPI_values[i][j])))
    
    # Create contraint

    for i,col in enumerate(NPI_COLS_NAMES):
        solver.Add(solver.Sum([x[i, j] for j in range(len(NPI_dict[col]))]) == 1)
    
    # Create objective function

    objective_terms = []
    objective_terms.append((previous_day_cases.loc[geo_id]['NewCases']/initial_day_cases.loc[geo_id]['NewCases'])) # scale to initial number of cases
    
    for i,col in enumerate(NPI_COLS_NAMES):
        for j in range(len(NPI_dict[col])):
            objective_terms.append(((case_weight[col][j] * x[i,j]) * previous_day_cases.loc[geo_id]['NewCases'])/initial_day_cases.loc[geo_id]['NewCases'])
    for i,col in enumerate(NPI_COLS_NAMES):
        for j in range(len(NPI_dict[col])):
            objective_terms.append(stringency_weight[col][j] * x[i, j])
    solver.Minimize(solver.Sum(objective_terms))
    
    status = solver.Solve()
#     print(status)
    
    solution = []
    for i,col in enumerate(NPI_COLS_NAMES): 
        for j in range(len(NPI_dict[col])):
            if x[i,j].solution_value() > 0.5:
                solution.append(j)
    return solution


In [134]:
weights = pd.read_csv('weights.csv')

In [140]:
weights_7 = pd.read_csv('weights_7.csv')

In [142]:
weights_7['Country_Region'] = weights_7['Country_Region'].apply(lambda x: x.replace('_nan', ''))

In [143]:
weights_7['Country_Region'] = weights_7['Country_Region'].apply(lambda x: x.replace('_', ' / '))

In [145]:
weights_7[weights_7['Country_Region'] == 'Timor-Leste']

Unnamed: 0.1,Unnamed: 0,Country_Region,IP,IP Val,impact
7106,7106,Timor-Leste,C1_School closing,1,-3.490943
7107,7107,Timor-Leste,C1_School closing,2,-6.618114
7108,7108,Timor-Leste,C1_School closing,3,-9.436813
7109,7109,Timor-Leste,C2_Workplace closing,1,-4.069077
7110,7110,Timor-Leste,C2_Workplace closing,2,-7.77616
7111,7111,Timor-Leste,C2_Workplace closing,3,-11.174943
7112,7112,Timor-Leste,C3_Cancel public events,1,-0.893682
7113,7113,Timor-Leste,C3_Cancel public events,2,-1.709507
7114,7114,Timor-Leste,C4_Restrictions on gatherings,1,-1.407571
7115,7115,Timor-Leste,C4_Restrictions on gatherings,2,-2.682706


In [146]:
case_weights = defaultdict(lambda : defaultdict(lambda : defaultdict(int)))
for geo in GeoIDs:
    for col in NPI_COLS_NAMES:
        case_weights[geo][col][0] = 0

In [147]:
for index, row in weights_7.iterrows():
    case_weights[row['Country_Region']][row['IP']][row['IP Val']] = row['impact']/100

In [148]:
case_weights = default_to_regular(case_weights)

In [149]:
case_weights['Timor-Leste']

{'C1_School closing': {0: 0,
  1: -0.034909433121432916,
  2: -0.06618114375472878,
  3: -0.09436812768852203},
 'C2_Workplace closing': {0: 0,
  1: -0.0406907706070623,
  2: -0.07776160212934506,
  3: -0.11174943355125118},
 'C3_Cancel public events': {0: 0,
  1: -0.008936821644175108,
  2: -0.01709506688238163},
 'C4_Restrictions on gatherings': {0: 0,
  1: -0.01407571118257945,
  2: -0.02682706028848098,
  3: -0.03841480161277496,
  4: -0.04896806622391457},
 'C5_Close public transport': {0: 0,
  1: -0.0218403225395296,
  2: -0.04098249644803259},
 'C6_Stay at home requirements': {0: 0,
  1: -0.011628459647918692,
  2: -0.02188331055652517,
  3: -0.03095566395717869},
 'C7_Restrictions on internal movement': {0: 0,
  1: -0.012456110678536844,
  2: -0.02390428669582771},
 'C8_International travel controls': {0: 0,
  1: -0.017022611457860787,
  2: -0.03200276678776179,
  3: -0.04524231199457974,
  4: -0.05698167831560192},
 'H1_Public information campaigns': {0: 0,
  1: -0.02146848058

In [150]:
with open("weights/weights_7_reformat.pickle", "wb") as file:
    pickle.dump(case_weights, file)

In [128]:
for geo in GeoIDs:
#     print(geo)
    w_sum = 0
    for col in NPI_COLS_NAMES:
        for j in range(len(NPI_dict[col])):
            w_sum -= case_weights[geo][col][j]
#             print(w_sum)
    if w_sum >= 1:
        print(geo, ': ', w_sum)
#         for col in NPI_COLS_NAMES:
#             for j in range(len(NPI_dict[col])):
#                 case_weights[geo][col][j] /= w_sum

Afghanistan
Afghanistan :  1.0000000000000004
Albania
Albania :  1.0000000000000002
Algeria
Andorra
Andorra :  1.0000000000000002
Angola
Argentina
Argentina :  1.0
Aruba
Australia
Australia :  1.0000000000000002
Austria
Azerbaijan
Bahamas
Bahamas :  1.0
Bahrain
Bahrain :  1.0000000000000002
Bangladesh
Barbados
Belarus
Belarus :  1.0
Belgium
Belgium :  1.0
Belize
Belize :  1.0
Benin
Bermuda
Bhutan
Bolivia
Bolivia :  1.0
Bosnia and Herzegovina
Bosnia and Herzegovina :  1.0000000000000004
Botswana
Brazil
Brunei
Bulgaria
Burkina Faso
Burundi
Burundi :  1.0000000000000002
Cambodia
Cameroon
Cameroon :  1.0
Canada
Canada :  1.0
Cape Verde
Cape Verde :  1.0000000000000002
Central African Republic
Central African Republic :  1.0000000000000002
Chad
Chile
China
Colombia
Colombia :  1.0
Comoros
Congo
Costa Rica
Cote d'Ivoire
Croatia
Croatia :  1.0
Cuba
Cuba :  1.0000000000000002
Cyprus
Cyprus :  1.0000000000000002
Czech Republic
Czech Republic :  1.0000000000000002
Democratic Republic of Congo
De

KeyError: 1

In [129]:
case_weights

{'Afghanistan': {'C1_School closing': {0: 0.0,
   1: -0.029345944623924366,
   2: -0.055665742561588294,
   3: -0.07941240196395348},
  'C2_Workplace closing': {0: 0.0,
   1: -0.03411692916040185,
   2: -0.06525088171015503,
   3: -0.09383699373208651},
  'C3_Cancel public events': {0: 0.0,
   1: -0.007430874975464715,
   2: -0.014221515788969298},
  'C4_Restrictions on gatherings': {0: 0.0,
   1: -0.011775713666435655,
   2: -0.022451688286314404,
   3: -0.03216122342993968,
   4: -0.041013121692067715},
  'C5_Close public transport': {0: 0.0,
   1: -0.018335332305910737,
   2: -0.034421498740048885},
  'C6_Stay at home requirements': {0: 0.0,
   1: -0.009719723109245649,
   2: -0.018292791702269537,
   3: -0.02588031699556849},
  'C7_Restrictions on internal movement': {0: 0.0,
   1: -0.010405729899273538,
   2: -0.019975486096627176},
  'C8_International travel controls': {0: 0.0,
   1: -0.014264104468060897,
   2: -0.026828040106034564,
   3: -0.03793662054890491,
   4: -0.04779227

In [43]:
OxCGRT_latest.columns

Index(['CountryName', 'CountryCode', 'RegionName', 'RegionCode',
       'Jurisdiction', 'Date', 'C1_School closing', 'C1_Flag',
       'C2_Workplace closing', 'C2_Flag', 'C3_Cancel public events', 'C3_Flag',
       'C4_Restrictions on gatherings', 'C4_Flag', 'C5_Close public transport',
       'C5_Flag', 'C6_Stay at home requirements', 'C6_Flag',
       'C7_Restrictions on internal movement', 'C7_Flag',
       'C8_International travel controls', 'E1_Income support', 'E1_Flag',
       'E2_Debt/contract relief', 'E3_Fiscal measures',
       'E4_International support', 'H1_Public information campaigns',
       'H1_Flag', 'H2_Testing policy', 'H3_Contact tracing',
       'H4_Emergency investment in healthcare', 'H5_Investment in vaccines',
       'H6_Facial Coverings', 'H6_Flag', 'H7_Vaccination policy', 'H7_Flag',
       'M1_Wildcard', 'ConfirmedCases', 'ConfirmedDeaths', 'StringencyIndex',
       'StringencyIndexForDisplay', 'StringencyLegacyIndex',
       'StringencyLegacyIndexForDispla

In [47]:
b = pd.read_csv( 'https://raw.githubusercontent.com/OxCGRT/covid-policy-tracker/master/data/OxCGRT_latest.csv')

  interactivity=interactivity, compiler=compiler, result=result)


In [71]:
b.to_csv('OxCGRT_latest.csv', index = False)

In [76]:
GeoIDs.delete('United States / Virgin Islands')

AttributeError: 'numpy.ndarray' object has no attribute 'delete'

In [80]:
GeoIDs = list(GeoIDs)

In [81]:
GeoIDs

In [48]:
b[(b['CountryName'] == 'United Kingdom') & (b['RegionName'] == 'Wales')][['Date','ConfirmedCases']]

Unnamed: 0,Date,ConfirmedCases
42000,20200101,
42001,20200102,
42002,20200103,
42003,20200104,
42004,20200105,
...,...,...
42395,20210130,193261.0
42396,20210131,193525.0
42397,20210201,193526.0
42398,20210202,


In [46]:
OxCGRT_latest[(OxCGRT_latest['CountryName'] == 'United Kingdom') & (OxCGRT_latest['RegionName'] == 'Wales')][['Date','ConfirmedCases']]

Unnamed: 0,Date,ConfirmedCases
39585,2020-01-01,
39586,2020-01-02,
39587,2020-01-03,
39588,2020-01-04,
39589,2020-01-05,
...,...,...
39957,2021-01-07,170088.0
39958,2021-01-08,171403.0
39959,2021-01-09,171547.0
39960,2021-01-10,


In [55]:
start_date = '2021-01-08'
end_date = '2021-05-05'
path_to_prior_ips_file = '/Users/chang/workplace/covid_xprize/covid-xprize/covid_xprize/standard_predictor/data/OxCGRT_latest.csv' 
path_to_cost_file = 'prescriptions/test.csv'
output_file_path = 'covid_xprize/validation/data/uniform_random_costs.csv'

In [82]:
GeoIDs = list(stringency_weight_df["GeoID"].values)
countries = list(stringency_weight_df["CountryName"].values)
regions = list(stringency_weight_df["RegionName"].values)

In [83]:
GeoIDs.remove('United States / Virgin Islands')

In [84]:
GeoIDs

['Afghanistan',
 'Albania',
 'Algeria',
 'Andorra',
 'Angola',
 'Argentina',
 'Aruba',
 'Australia',
 'Austria',
 'Azerbaijan',
 'Bahamas',
 'Bahrain',
 'Bangladesh',
 'Barbados',
 'Belarus',
 'Belgium',
 'Belize',
 'Benin',
 'Bermuda',
 'Bhutan',
 'Bolivia',
 'Bosnia and Herzegovina',
 'Botswana',
 'Brazil',
 'Brunei',
 'Bulgaria',
 'Burkina Faso',
 'Burundi',
 'Cambodia',
 'Cameroon',
 'Canada',
 'Cape Verde',
 'Central African Republic',
 'Chad',
 'Chile',
 'China',
 'Colombia',
 'Comoros',
 'Congo',
 'Costa Rica',
 "Cote d'Ivoire",
 'Croatia',
 'Cuba',
 'Cyprus',
 'Czech Republic',
 'Democratic Republic of Congo',
 'Denmark',
 'Djibouti',
 'Dominica',
 'Dominican Republic',
 'Ecuador',
 'Egypt',
 'El Salvador',
 'Eritrea',
 'Estonia',
 'Eswatini',
 'Ethiopia',
 'Faeroe Islands',
 'Fiji',
 'Finland',
 'France',
 'Gabon',
 'Gambia',
 'Georgia',
 'Germany',
 'Ghana',
 'Greece',
 'Greenland',
 'Guam',
 'Guatemala',
 'Guinea',
 'Guyana',
 'Haiti',
 'Honduras',
 'Hong Kong',
 'Hungary',


In [87]:
stringency_weight_df["GeoID"].unique()

array(['Afghanistan', 'Albania', 'Algeria', 'Andorra', 'Angola',
       'Argentina', 'Aruba', 'Australia', 'Austria', 'Azerbaijan',
       'Bahamas', 'Bahrain', 'Bangladesh', 'Barbados', 'Belarus',
       'Belgium', 'Belize', 'Benin', 'Bermuda', 'Bhutan', 'Bolivia',
       'Bosnia and Herzegovina', 'Botswana', 'Brazil', 'Brunei',
       'Bulgaria', 'Burkina Faso', 'Burundi', 'Cambodia', 'Cameroon',
       'Canada', 'Cape Verde', 'Central African Republic', 'Chad',
       'Chile', 'China', 'Colombia', 'Comoros', 'Congo', 'Costa Rica',
       "Cote d'Ivoire", 'Croatia', 'Cuba', 'Cyprus', 'Czech Republic',
       'Democratic Republic of Congo', 'Denmark', 'Djibouti', 'Dominica',
       'Dominican Republic', 'Ecuador', 'Egypt', 'El Salvador', 'Eritrea',
       'Estonia', 'Eswatini', 'Ethiopia', 'Faeroe Islands', 'Fiji',
       'Finland', 'France', 'Gabon', 'Gambia', 'Georgia', 'Germany',
       'Ghana', 'Greece', 'Greenland', 'Guam', 'Guatemala', 'Guinea',
       'Guyana', 'Haiti', 'Hondur

In [89]:
start_date = np.datetime64('2020-02-15')
#     if 
start_date > np.datetime64('2020-01-08')

True

In [90]:
'United Kingdom / Wales'.split(' / ')

['United Kingdom', 'Wales']

In [91]:
'United Kingdom'.split(' / ')

['United Kingdom']

In [None]:
prior_ip_file = pd.read_csv(path_to_prior_ips_file,
                                parse_dates=['Date'],
                                encoding="ISO-8859-1",
                                dtype={"RegionName": str,
                                    "RegionCode": str},
                                error_bad_lines=False)
prior_ip_file["GeoID"] = np.where(prior_ip_file["RegionName"].isnull(),
                                    prior_ip_file["CountryName"],
                                    prior_ip_file["CountryName"] + ' / ' + prior_ip_file["RegionName"])

GeoIDs = list(prior_ip_file["GeoID"].unique())
print('number of GeoID ', len(GeoIDs))
countries = list(prior_ip_file["CountryName"].unique())
regions = list(prior_ip_file["RegionName"].unique())

In [70]:
start_date = np.datetime64(start_date)
end_date = np.datetime64(end_date)

# get weights
case_weights_dict = {}

with open("weights/weights_reformat.pickle", "rb") as file:
    case_weights_dict['case_weights_1'] = pickle.load(file)

case_weights_names = ['case_weights_1']

# get stringency
stringency_weight_df = pd.read_csv(path_to_cost_file)
stringency_weight_df["GeoID"] = np.where(stringency_weight_df["RegionName"].isnull(),
                                    stringency_weight_df["CountryName"],
                                    stringency_weight_df["CountryName"] + ' / ' + stringency_weight_df["RegionName"])

GeoIDs = list(stringency_weight_df["GeoID"].values)
countries = list(stringency_weight_df["CountryName"].values)
regions = list(stringency_weight_df["RegionName"].values)

# process stringency
stringency_weight = defaultdict(lambda : defaultdict(lambda : defaultdict(np.float64)))
for index, row in stringency_weight_df.iterrows():
    row_sum = 0
    for col in NPI_COLS_NAMES:
        row_sum +=  row[col]*sum(NPI_dict[col])
#     row_sum = row[NPI_COLS_NAMES].sum()
    if row_sum == 0:
        row_sum = 1
    for col in NPI_COLS_NAMES:
        for j in range(len(NPI_dict[col])):
#             tmp.append((stringency_weight_tmp[i] * (NPI_values[i][j]))/stringency_weight_sum)
            stringency_weight[row['GeoID']][col][j] = (row[col] * NPI_dict[col][j])/row_sum

stringency_weight = default_to_regular(stringency_weight)

# set global variables
data_url = path_to_prior_ips_file
OxCGRT_latest = pd.read_csv(data_url,
                                parse_dates=['Date'],
                                encoding="ISO-8859-1",
                                dtype={"RegionName": str,
                                    "RegionCode": str},
                                error_bad_lines=False)
OxCGRT_latest["GeoID"] = np.where(OxCGRT_latest["RegionName"].isnull(),
                                    OxCGRT_latest["CountryName"],
                                    OxCGRT_latest["CountryName"] + ' / ' + OxCGRT_latest["RegionName"])
OxCGRT_latest['NewCases'] = OxCGRT_latest.groupby('GeoID').ConfirmedCases.diff().fillna(0)
initial_day_cases = OxCGRT_latest[OxCGRT_latest['Date'] == '2020-07-31'].set_index('GeoID')[['NewCases']]
initial_day_cases['NewCases'] = initial_day_cases['NewCases'].replace(0, 1)

predictor = XPrizePredictor()

ip_file_path = 'prescriptions/'
preds_file_path = 'predictions/'


day_count = (end_date - start_date)/np.timedelta64(1, 'D')

previous_day_cases = initial_day_cases
for w in case_weights_names:

    prescriptions_total_df = pd.DataFrame(columns = col_names)
    predictions_total = pd.DataFrame()
    
    cur_ip_file_path = ip_file_path + 'prescriptions_initial_' + w + '.csv'
    for i in tqdm(range(int(day_count) + 1)):
        cur_date = start_date + np.timedelta64(i,'D')
        print('prescribing for day ' + str(cur_date))

        prescriptions_total = []
        for geo, country, RegionName in tqdm(zip(GeoIDs, countries, regions)):
            cur_case_weight = case_weights_dict[w][geo]
            cur_stringency_weight = stringency_weight[geo]

            ip_solution = run_opt(previous_day_cases, cur_case_weight, cur_stringency_weight, geo, initial_day_cases)
            prescriptions_total.append([0] + [country] + [RegionName] + [str(cur_date)] + ip_solution)


        prescriptions_df = pd.DataFrame()
        prescriptions_df = prescriptions_df.append(pd.DataFrame(prescriptions_total))
        prescriptions_df.columns = col_names

        prescriptions_total_df = prescriptions_total_df.append(prescriptions_df)

        prescriptions_total_df.to_csv(cur_ip_file_path)

        # predict for next day for all geo
        print('predicting for day ' + str(cur_date))
        previous_day_cases = predictor.predict(start_date, cur_date, cur_ip_file_path)
        previous_day_cases = previous_day_cases[previous_day_cases['Date'] == cur_date]
        previous_day_cases["GeoID"] = np.where(previous_day_cases["RegionName"].isnull(),
                                            previous_day_cases["CountryName"],
                                            previous_day_cases["CountryName"] + ' / ' + previous_day_cases["RegionName"])
        previous_day_cases = previous_day_cases.set_index('GeoID')[['PredictedDailyNewCases']]
        previous_day_cases.rename(columns={'PredictedDailyNewCases':'NewCases'}, inplace=True)

        predictions_total[cur_date] = previous_day_cases['NewCases']


#     final_ip_path = output_file_path
#     prescriptions_total_df.to_csv(final_ip_path)
    final_pred_path = preds_file_path + 'final_predictions_' + w + '_'+ str(cur_date) + '.csv'
    predictions_total.to_csv(final_pred_path)


  0%|          | 0/118 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
6it [00:00, 54.88it/s][A

prescribing for day 2021-01-08



12it [00:00, 54.75it/s][A
18it [00:00, 54.40it/s][A
24it [00:00, 54.96it/s][A
30it [00:00, 55.21it/s][A
36it [00:00, 55.78it/s][A
42it [00:00, 53.83it/s][A
48it [00:00, 53.92it/s][A
54it [00:00, 54.96it/s][A
60it [00:01, 54.75it/s][A
66it [00:01, 55.46it/s][A
72it [00:01, 56.24it/s][A
78it [00:01, 56.86it/s][A
84it [00:01, 57.08it/s][A
90it [00:01, 56.88it/s][A
96it [00:01, 57.29it/s][A
102it [00:01, 57.36it/s][A
108it [00:01, 57.64it/s][A
114it [00:02, 57.67it/s][A
120it [00:02, 57.15it/s][A
126it [00:02, 57.50it/s][A
132it [00:02, 57.91it/s][A
138it [00:02, 57.78it/s][A
144it [00:02, 57.41it/s][A
150it [00:02, 57.31it/s][A
156it [00:02, 57.37it/s][A
162it [00:02, 57.24it/s][A
168it [00:02, 57.22it/s][A
174it [00:03, 55.99it/s][A
180it [00:03, 55.31it/s][A
186it [00:03, 55.61it/s][A
192it [00:03, 56.27it/s][A
198it [00:03, 56.58it/s][A
204it [00:03, 56.30it/s][A
210it [00:03, 56.33it/s][A
216it [00:03, 56.88it/s][A
222it [00:03, 56.91it/s][A
228it 

predicting for day 2021-01-08


  1%|          | 1/118 [00:08<16:36,  8.52s/it]
0it [00:00, ?it/s][A
6it [00:00, 55.65it/s][A

prescribing for day 2021-01-09



12it [00:00, 56.13it/s][A
18it [00:00, 55.86it/s][A
24it [00:00, 55.88it/s][A
30it [00:00, 54.79it/s][A
36it [00:00, 55.30it/s][A
42it [00:00, 54.70it/s][A
48it [00:00, 52.77it/s][A
54it [00:00, 53.96it/s][A
60it [00:01, 54.76it/s][A
66it [00:01, 54.47it/s][A
72it [00:01, 54.83it/s][A
78it [00:01, 55.21it/s][A
84it [00:01, 55.56it/s][A
90it [00:01, 54.66it/s][A
96it [00:01, 54.62it/s][A
102it [00:01, 55.19it/s][A
108it [00:01, 55.89it/s][A
114it [00:02, 56.27it/s][A
120it [00:02, 56.50it/s][A
126it [00:02, 56.44it/s][A
132it [00:02, 55.75it/s][A
138it [00:02, 54.87it/s][A
144it [00:02, 53.91it/s][A
150it [00:02, 54.33it/s][A
156it [00:02, 54.47it/s][A
162it [00:02, 54.22it/s][A
168it [00:03, 54.86it/s][A
174it [00:03, 55.71it/s][A
180it [00:03, 55.02it/s][A
186it [00:03, 55.26it/s][A
192it [00:03, 55.78it/s][A
198it [00:03, 55.34it/s][A
204it [00:03, 55.99it/s][A
210it [00:03, 56.13it/s][A
216it [00:03, 56.39it/s][A
222it [00:04, 56.62it/s][A
228it 

predicting for day 2021-01-09


  2%|▏         | 2/118 [00:17<16:35,  8.58s/it]
0it [00:00, ?it/s][A
6it [00:00, 52.92it/s][A

prescribing for day 2021-01-10



12it [00:00, 53.42it/s][A
18it [00:00, 53.89it/s][A
24it [00:00, 54.97it/s][A
30it [00:00, 55.48it/s][A
36it [00:00, 55.84it/s][A
42it [00:00, 56.17it/s][A
48it [00:00, 55.78it/s][A
54it [00:00, 56.35it/s][A
60it [00:01, 56.83it/s][A
66it [00:01, 56.73it/s][A
72it [00:01, 56.67it/s][A
78it [00:01, 56.85it/s][A
84it [00:01, 56.22it/s][A
90it [00:01, 56.41it/s][A
96it [00:01, 56.39it/s][A
102it [00:01, 55.52it/s][A
108it [00:01, 55.84it/s][A
114it [00:02, 55.99it/s][A
120it [00:02, 56.31it/s][A
126it [00:02, 55.72it/s][A
132it [00:02, 56.26it/s][A
138it [00:02, 56.77it/s][A
144it [00:02, 56.75it/s][A
150it [00:02, 56.54it/s][A
156it [00:02, 56.39it/s][A
162it [00:02, 55.43it/s][A
168it [00:02, 55.32it/s][A
174it [00:03, 55.12it/s][A
180it [00:03, 55.42it/s][A
186it [00:03, 54.90it/s][A
192it [00:03, 55.52it/s][A
198it [00:03, 56.16it/s][A
204it [00:03, 56.07it/s][A
210it [00:03, 55.94it/s][A
216it [00:03, 53.74it/s][A
222it [00:03, 54.87it/s][A
228it 

predicting for day 2021-01-10


  3%|▎         | 3/118 [00:26<16:44,  8.74s/it]
0it [00:00, ?it/s][A
6it [00:00, 57.21it/s][A

prescribing for day 2021-01-11



12it [00:00, 56.97it/s][A
18it [00:00, 56.61it/s][A
24it [00:00, 56.61it/s][A
30it [00:00, 56.39it/s][A
36it [00:00, 55.89it/s][A
42it [00:00, 56.20it/s][A
48it [00:00, 56.60it/s][A
54it [00:00, 56.65it/s][A
60it [00:01, 56.52it/s][A
66it [00:01, 56.73it/s][A
72it [00:01, 56.47it/s][A
78it [00:01, 56.84it/s][A
84it [00:01, 56.75it/s][A
90it [00:01, 56.72it/s][A
96it [00:01, 56.38it/s][A
102it [00:01, 55.97it/s][A
108it [00:01, 55.68it/s][A
114it [00:02, 56.65it/s][A
120it [00:02, 56.46it/s][A
126it [00:02, 56.62it/s][A
132it [00:02, 56.37it/s][A
138it [00:02, 56.89it/s][A
144it [00:02, 56.91it/s][A
150it [00:02, 56.94it/s][A
156it [00:02, 55.45it/s][A
162it [00:02, 55.88it/s][A
168it [00:02, 56.29it/s][A
174it [00:03, 56.07it/s][A
180it [00:03, 56.39it/s][A
186it [00:03, 56.27it/s][A
192it [00:03, 56.85it/s][A
198it [00:03, 56.59it/s][A
204it [00:03, 56.90it/s][A
210it [00:03, 55.89it/s][A
216it [00:03, 56.26it/s][A
222it [00:03, 56.53it/s][A
228it 

predicting for day 2021-01-11


  3%|▎         | 4/118 [00:35<17:04,  8.98s/it]
0it [00:00, ?it/s][A
6it [00:00, 57.53it/s][A

prescribing for day 2021-01-12



12it [00:00, 56.45it/s][A
18it [00:00, 56.36it/s][A
24it [00:00, 56.15it/s][A
30it [00:00, 56.40it/s][A
36it [00:00, 56.13it/s][A
42it [00:00, 56.42it/s][A
48it [00:00, 56.46it/s][A
54it [00:00, 56.74it/s][A
60it [00:01, 56.73it/s][A
66it [00:01, 53.95it/s][A
72it [00:01, 54.87it/s][A
78it [00:01, 55.54it/s][A
84it [00:01, 55.60it/s][A
90it [00:01, 55.61it/s][A
96it [00:01, 55.61it/s][A
102it [00:01, 55.86it/s][A
108it [00:01, 56.41it/s][A
114it [00:02, 56.35it/s][A
120it [00:02, 55.89it/s][A
126it [00:02, 55.07it/s][A
132it [00:02, 55.70it/s][A
138it [00:02, 55.39it/s][A
144it [00:02, 55.42it/s][A
150it [00:02, 55.71it/s][A
156it [00:02, 56.04it/s][A
162it [00:02, 55.88it/s][A
168it [00:03, 55.91it/s][A
174it [00:03, 55.96it/s][A
180it [00:03, 55.46it/s][A
186it [00:03, 55.45it/s][A
192it [00:03, 56.03it/s][A
198it [00:03, 56.35it/s][A
204it [00:03, 56.18it/s][A
210it [00:03, 55.90it/s][A
216it [00:03, 56.01it/s][A
222it [00:03, 55.97it/s][A
228it 

predicting for day 2021-01-12


  4%|▍         | 5/118 [00:46<17:39,  9.37s/it]
0it [00:00, ?it/s][A
5it [00:00, 49.16it/s][A

prescribing for day 2021-01-13



11it [00:00, 50.90it/s][A
17it [00:00, 51.93it/s][A
23it [00:00, 52.80it/s][A
29it [00:00, 53.52it/s][A
35it [00:00, 54.34it/s][A
41it [00:00, 54.93it/s][A
47it [00:00, 54.06it/s][A
53it [00:00, 54.97it/s][A
59it [00:01, 55.84it/s][A
65it [00:01, 56.11it/s][A
71it [00:01, 56.17it/s][A
77it [00:01, 56.23it/s][A
83it [00:01, 56.61it/s][A
89it [00:01, 56.85it/s][A
95it [00:01, 56.62it/s][A
101it [00:01, 56.84it/s][A
107it [00:01, 56.18it/s][A
113it [00:02, 56.58it/s][A
119it [00:02, 56.95it/s][A
125it [00:02, 56.76it/s][A
131it [00:02, 56.85it/s][A
137it [00:02, 56.77it/s][A
143it [00:02, 56.38it/s][A
149it [00:02, 56.45it/s][A
155it [00:02, 56.56it/s][A
161it [00:02, 55.50it/s][A
167it [00:02, 55.74it/s][A
173it [00:03, 55.76it/s][A
179it [00:03, 55.14it/s][A
185it [00:03, 55.38it/s][A
191it [00:03, 54.95it/s][A
197it [00:03, 55.22it/s][A
203it [00:03, 55.18it/s][A
209it [00:03, 55.02it/s][A
215it [00:03, 53.76it/s][A
221it [00:03, 53.68it/s][A
227it 

predicting for day 2021-01-13


  5%|▌         | 6/118 [00:57<18:22,  9.84s/it]
0it [00:00, ?it/s][A
6it [00:00, 53.17it/s][A

prescribing for day 2021-01-14



12it [00:00, 54.12it/s][A
18it [00:00, 54.83it/s][A
24it [00:00, 55.02it/s][A
30it [00:00, 54.93it/s][A
36it [00:00, 55.10it/s][A
42it [00:00, 55.26it/s][A
48it [00:00, 53.99it/s][A
54it [00:00, 54.41it/s][A
60it [00:01, 54.09it/s][A
66it [00:01, 55.15it/s][A
72it [00:01, 55.75it/s][A
78it [00:01, 51.00it/s][A
84it [00:01, 45.86it/s][A
90it [00:01, 47.41it/s][A
96it [00:01, 48.60it/s][A
101it [00:01, 48.70it/s][A
107it [00:02, 50.00it/s][A
113it [00:02, 50.43it/s][A
119it [00:02, 50.69it/s][A
125it [00:02, 50.74it/s][A
131it [00:02, 49.83it/s][A
136it [00:02, 42.99it/s][A
141it [00:02, 42.72it/s][A
146it [00:02, 44.08it/s][A
152it [00:03, 46.33it/s][A
158it [00:03, 48.27it/s][A
164it [00:03, 50.37it/s][A
170it [00:03, 49.92it/s][A
176it [00:03, 49.41it/s][A
181it [00:03, 48.71it/s][A
186it [00:03, 47.42it/s][A
192it [00:03, 48.92it/s][A
198it [00:03, 49.94it/s][A
204it [00:04, 50.22it/s][A
210it [00:04, 51.33it/s][A
216it [00:04, 49.01it/s][A
222it 

predicting for day 2021-01-14


  6%|▌         | 7/118 [01:08<18:58, 10.26s/it]
0it [00:00, ?it/s][A
6it [00:00, 58.15it/s][A

prescribing for day 2021-01-15



12it [00:00, 57.44it/s][A
18it [00:00, 57.13it/s][A
23it [00:00, 54.21it/s][A
29it [00:00, 54.19it/s][A
35it [00:00, 54.79it/s][A
41it [00:00, 54.53it/s][A
47it [00:00, 55.28it/s][A
53it [00:00, 55.63it/s][A
59it [00:01, 56.16it/s][A
65it [00:01, 56.55it/s][A
71it [00:01, 57.07it/s][A
77it [00:01, 57.17it/s][A
83it [00:01, 56.88it/s][A
89it [00:01, 57.12it/s][A
95it [00:01, 56.65it/s][A
101it [00:01, 56.36it/s][A
107it [00:01, 55.34it/s][A
113it [00:02, 55.80it/s][A
119it [00:02, 56.29it/s][A
125it [00:02, 56.53it/s][A
131it [00:02, 56.68it/s][A
137it [00:02, 56.83it/s][A
143it [00:02, 56.84it/s][A
149it [00:02, 56.89it/s][A
155it [00:02, 56.60it/s][A
161it [00:02, 56.93it/s][A
167it [00:02, 57.36it/s][A
173it [00:03, 57.06it/s][A
179it [00:03, 57.12it/s][A
185it [00:03, 57.29it/s][A
191it [00:03, 57.51it/s][A
197it [00:03, 57.59it/s][A
203it [00:03, 57.51it/s][A
209it [00:03, 56.88it/s][A
215it [00:03, 56.75it/s][A
221it [00:03, 56.94it/s][A
227it 

predicting for day 2021-01-15


  7%|▋         | 8/118 [01:19<19:11, 10.47s/it]
0it [00:00, ?it/s][A
6it [00:00, 57.86it/s][A

prescribing for day 2021-01-16



12it [00:00, 57.78it/s][A
18it [00:00, 57.64it/s][A
24it [00:00, 57.50it/s][A
30it [00:00, 57.63it/s][A
36it [00:00, 57.96it/s][A
42it [00:00, 57.45it/s][A
48it [00:00, 56.95it/s][A
54it [00:00, 56.63it/s][A
60it [00:01, 56.84it/s][A
66it [00:01, 57.08it/s][A
72it [00:01, 56.80it/s][A
78it [00:01, 57.34it/s][A
84it [00:01, 57.50it/s][A
90it [00:01, 57.59it/s][A
96it [00:01, 57.51it/s][A
102it [00:01, 56.45it/s][A
108it [00:01, 56.42it/s][A
114it [00:01, 56.71it/s][A
120it [00:02, 56.63it/s][A
126it [00:02, 56.32it/s][A
132it [00:02, 50.69it/s][A
138it [00:02, 51.59it/s][A
144it [00:02, 53.05it/s][A
150it [00:02, 54.36it/s][A
156it [00:02, 54.25it/s][A
162it [00:02, 55.03it/s][A
168it [00:03, 55.35it/s][A
174it [00:03, 55.48it/s][A
180it [00:03, 45.18it/s][A
185it [00:03, 34.83it/s][A
191it [00:03, 39.06it/s][A
197it [00:03, 42.09it/s][A
203it [00:03, 45.74it/s][A
209it [00:03, 48.50it/s][A
215it [00:04, 50.67it/s][A
221it [00:04, 52.40it/s][A
227it 

predicting for day 2021-01-16


  8%|▊         | 9/118 [01:31<19:49, 10.91s/it]
0it [00:00, ?it/s][A
6it [00:00, 53.91it/s][A

prescribing for day 2021-01-17



12it [00:00, 53.64it/s][A
18it [00:00, 54.05it/s][A
24it [00:00, 54.61it/s][A
30it [00:00, 54.85it/s][A
36it [00:00, 54.89it/s][A
42it [00:00, 55.31it/s][A
48it [00:00, 54.59it/s][A
54it [00:00, 55.05it/s][A
60it [00:01, 55.50it/s][A
66it [00:01, 55.38it/s][A
72it [00:01, 55.73it/s][A
78it [00:01, 55.76it/s][A
84it [00:01, 55.52it/s][A
90it [00:01, 55.58it/s][A
96it [00:01, 55.97it/s][A
102it [00:01, 55.01it/s][A
108it [00:01, 55.38it/s][A
114it [00:02, 55.77it/s][A
120it [00:02, 55.76it/s][A
126it [00:02, 56.16it/s][A
132it [00:02, 56.25it/s][A
138it [00:02, 56.24it/s][A
144it [00:02, 56.01it/s][A
150it [00:02, 56.33it/s][A
156it [00:02, 55.33it/s][A
162it [00:02, 55.85it/s][A
168it [00:03, 56.23it/s][A
174it [00:03, 56.14it/s][A
180it [00:03, 56.49it/s][A
186it [00:03, 56.58it/s][A
192it [00:03, 56.41it/s][A
198it [00:03, 56.25it/s][A
204it [00:03, 56.14it/s][A
210it [00:03, 56.01it/s][A
216it [00:03, 55.40it/s][A
222it [00:03, 55.60it/s][A
228it 

predicting for day 2021-01-17


  8%|▊         | 10/118 [01:43<20:09, 11.20s/it]
0it [00:00, ?it/s][A
6it [00:00, 56.32it/s][A

prescribing for day 2021-01-18



12it [00:00, 56.15it/s][A
18it [00:00, 55.91it/s][A
24it [00:00, 56.00it/s][A
30it [00:00, 56.22it/s][A
36it [00:00, 56.27it/s][A
42it [00:00, 56.37it/s][A
48it [00:00, 56.26it/s][A
54it [00:00, 55.28it/s][A
60it [00:01, 55.67it/s][A
66it [00:01, 55.94it/s][A
72it [00:01, 55.98it/s][A
78it [00:01, 55.68it/s][A
84it [00:01, 55.85it/s][A
90it [00:01, 55.32it/s][A
96it [00:01, 55.55it/s][A
102it [00:01, 56.09it/s][A
108it [00:01, 55.51it/s][A
114it [00:02, 55.27it/s][A
120it [00:02, 55.83it/s][A
126it [00:02, 56.06it/s][A
132it [00:02, 56.14it/s][A
138it [00:02, 56.37it/s][A
144it [00:02, 56.76it/s][A
150it [00:02, 56.04it/s][A
156it [00:02, 56.19it/s][A
162it [00:02, 55.88it/s][A
168it [00:03, 55.26it/s][A
174it [00:03, 55.59it/s][A
180it [00:03, 55.62it/s][A
186it [00:03, 55.56it/s][A
192it [00:03, 55.67it/s][A
198it [00:03, 55.89it/s][A
204it [00:03, 55.85it/s][A
210it [00:03, 55.90it/s][A
216it [00:03, 56.22it/s][A
222it [00:03, 55.56it/s][A
228it 

predicting for day 2021-01-18


  9%|▉         | 11/118 [01:55<20:39, 11.58s/it]
0it [00:00, ?it/s][A
6it [00:00, 56.48it/s][A

prescribing for day 2021-01-19



12it [00:00, 56.53it/s][A
18it [00:00, 56.16it/s][A
24it [00:00, 55.83it/s][A
30it [00:00, 54.77it/s][A
36it [00:00, 54.68it/s][A
42it [00:00, 54.60it/s][A
48it [00:00, 54.36it/s][A
54it [00:00, 54.21it/s][A
60it [00:01, 54.96it/s][A
66it [00:01, 54.82it/s][A
72it [00:01, 55.11it/s][A
78it [00:01, 55.16it/s][A
84it [00:01, 54.39it/s][A
90it [00:01, 54.97it/s][A
96it [00:01, 55.04it/s][A
102it [00:01, 55.46it/s][A
108it [00:01, 55.27it/s][A
114it [00:02, 55.44it/s][A
120it [00:02, 55.70it/s][A
126it [00:02, 55.88it/s][A
132it [00:02, 56.11it/s][A
138it [00:02, 53.59it/s][A
144it [00:02, 54.51it/s][A
150it [00:02, 54.91it/s][A
156it [00:02, 55.47it/s][A
162it [00:02, 55.55it/s][A
168it [00:03, 56.09it/s][A
174it [00:03, 56.44it/s][A
180it [00:03, 56.38it/s][A
186it [00:03, 56.48it/s][A
192it [00:03, 55.50it/s][A
198it [00:03, 55.73it/s][A
204it [00:03, 55.47it/s][A
210it [00:03, 55.80it/s][A
216it [00:03, 56.00it/s][A
222it [00:04, 55.90it/s][A
228it 

predicting for day 2021-01-19


 10%|█         | 12/118 [02:08<21:16, 12.04s/it]
0it [00:00, ?it/s][A
6it [00:00, 57.11it/s][A

prescribing for day 2021-01-20



12it [00:00, 56.59it/s][A
18it [00:00, 56.36it/s][A
24it [00:00, 55.35it/s][A
30it [00:00, 55.56it/s][A
36it [00:00, 55.98it/s][A
42it [00:00, 56.33it/s][A
48it [00:00, 56.36it/s][A
54it [00:00, 56.43it/s][A
60it [00:01, 56.42it/s][A
66it [00:01, 56.63it/s][A
72it [00:01, 56.55it/s][A
78it [00:01, 54.54it/s][A
84it [00:01, 55.21it/s][A
90it [00:01, 55.14it/s][A
96it [00:01, 54.91it/s][A
102it [00:01, 55.58it/s][A
108it [00:01, 56.37it/s][A
114it [00:02, 56.03it/s][A
120it [00:02, 55.59it/s][A
126it [00:02, 56.07it/s][A
132it [00:02, 55.61it/s][A
138it [00:02, 56.16it/s][A
144it [00:02, 56.43it/s][A
150it [00:02, 56.23it/s][A
156it [00:02, 56.39it/s][A
162it [00:02, 56.14it/s][A
168it [00:03, 56.23it/s][A
174it [00:03, 56.33it/s][A
180it [00:03, 56.11it/s][A
186it [00:03, 56.65it/s][A
192it [00:03, 55.36it/s][A
198it [00:03, 56.16it/s][A
204it [00:03, 56.09it/s][A
210it [00:03, 56.35it/s][A
216it [00:03, 55.65it/s][A
222it [00:03, 55.75it/s][A
228it 

predicting for day 2021-01-20


 10%|█         | 12/118 [02:17<20:10, 11.42s/it]


KeyboardInterrupt: 

In [62]:
start_date

numpy.datetime64('2021-01-08')

In [63]:
cur_date

numpy.datetime64('2021-01-10')

In [69]:
prescriptions_total_df

Unnamed: 0,PrescriptionIndex,CountryName,RegionName,Date,C1_School closing,C2_Workplace closing,C3_Cancel public events,C4_Restrictions on gatherings,C5_Close public transport,C6_Stay at home requirements,C7_Restrictions on internal movement,C8_International travel controls,H1_Public information campaigns,H2_Testing policy,H3_Contact tracing,H6_Facial Coverings
0,0,Afghanistan,,2021-01-08,0,3,2,0,2,3,2,4,2,3,2,4
1,0,Albania,,2021-01-08,0,3,0,4,0,3,2,0,2,3,2,4
2,0,Algeria,,2021-01-08,0,0,2,4,2,3,2,4,2,0,2,4
3,0,Andorra,,2021-01-08,0,3,2,0,0,3,2,4,0,3,2,4
4,0,Angola,,2021-01-08,0,0,2,4,2,3,2,4,0,3,2,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
231,0,Venezuela,,2021-01-11,0,3,2,4,0,3,2,0,2,3,2,4
232,0,Vietnam,,2021-01-11,0,0,2,4,2,3,0,4,2,3,0,4
233,0,Yemen,,2021-01-11,3,0,2,4,0,3,2,0,2,3,2,4
234,0,Zambia,,2021-01-11,3,3,2,4,2,0,2,0,0,3,0,4


In [68]:
pd.read_csv(cur_ip_file_path)

Unnamed: 0.1,Unnamed: 0,PrescriptionIndex,CountryName,RegionName,Date,C1_School closing,C2_Workplace closing,C3_Cancel public events,C4_Restrictions on gatherings,C5_Close public transport,C6_Stay at home requirements,C7_Restrictions on internal movement,C8_International travel controls,H1_Public information campaigns,H2_Testing policy,H3_Contact tracing,H6_Facial Coverings
0,0,0,Afghanistan,,2021-01-09,0,3,2,0,2,3,2,4,2,3,2,4
1,1,0,Albania,,2021-01-09,0,3,0,4,0,3,2,0,2,3,2,4
2,2,0,Algeria,,2021-01-09,0,0,2,4,2,3,2,4,2,0,2,4
3,3,0,Andorra,,2021-01-09,3,3,2,4,2,3,2,4,2,3,2,4
4,4,0,Angola,,2021-01-09,0,0,2,4,2,3,2,4,0,3,2,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
231,231,0,Venezuela,,2021-01-09,0,3,2,4,0,3,2,0,2,3,2,4
232,232,0,Vietnam,,2021-01-09,0,0,2,4,2,3,0,4,2,3,0,4
233,233,0,Yemen,,2021-01-09,3,0,2,4,0,3,2,0,2,3,2,4
234,234,0,Zambia,,2021-01-09,3,3,2,4,2,0,2,0,0,3,0,4


In [61]:
predictor.predict(start_date, cur_date, cur_ip_file_path)

Unnamed: 0_level_0,2021-01-08,2021-01-09
GeoID,Unnamed: 1_level_1,Unnamed: 2_level_1
Afghanistan,170.492377,
Albania,94.485507,
Algeria,332.875935,
Andorra,65.561760,
Angola,33.126745,
...,...,...
Venezuela,378.697679,
Vietnam,12.603948,
Yemen,2.378473,
Zambia,680.659270,


In [58]:
predictions_total

Unnamed: 0_level_0,2021-01-12
GeoID,Unnamed: 1_level_1
Afghanistan,
Albania,
Algeria,
Andorra,
Angola,
...,...
Venezuela,
Vietnam,
Yemen,
Zambia,


In [40]:
prescriptions_total_df

Unnamed: 0,PrescriptionIndex,CountryName,RegionName,Date,C1_School closing,C2_Workplace closing,C3_Cancel public events,C4_Restrictions on gatherings,C5_Close public transport,C6_Stay at home requirements,C7_Restrictions on internal movement,C8_International travel controls,H1_Public information campaigns,H2_Testing policy,H3_Contact tracing,H6_Facial Coverings
0,0,Afghanistan,,2021-01-08,0,3,2,0,2,3,2,4,2,3,2,4
1,0,Albania,,2021-01-08,0,3,0,4,0,3,2,0,2,3,2,4
2,0,Algeria,,2021-01-08,0,0,2,4,2,3,2,4,2,0,2,4
3,0,Andorra,,2021-01-08,0,3,2,0,0,3,2,4,0,3,2,4
4,0,Angola,,2021-01-08,0,0,2,4,2,3,2,4,0,3,2,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
231,0,Venezuela,,2021-01-11,0,3,2,4,0,3,2,0,2,3,2,4
232,0,Vietnam,,2021-01-11,0,0,2,4,2,3,0,4,2,3,0,4
233,0,Yemen,,2021-01-11,3,0,2,4,0,3,2,0,2,3,2,4
234,0,Zambia,,2021-01-11,3,3,2,4,2,0,2,0,0,3,0,4


In [33]:
tmp = pd.read_csv(cur_ip_file_path)
tmp[(tmp['CountryName'] == 'United Kingdom')]

Unnamed: 0.1,Unnamed: 0,PrescriptionIndex,CountryName,RegionName,Date,C1_School closing,C2_Workplace closing,C3_Cancel public events,C4_Restrictions on gatherings,C5_Close public transport,C6_Stay at home requirements,C7_Restrictions on internal movement,C8_International travel controls,H1_Public information campaigns,H2_Testing policy,H3_Contact tracing,H6_Facial Coverings
170,170,0,United Kingdom,,2021-01-11,3,3,2,4,0,3,2,4,0,3,2,4
171,171,0,United Kingdom,England,2021-01-11,3,3,2,4,2,3,0,4,2,3,2,4
172,172,0,United Kingdom,Northern Ireland,2021-01-11,3,0,2,4,0,3,2,4,2,3,2,4
173,173,0,United Kingdom,Scotland,2021-01-11,3,3,2,4,2,3,2,0,2,3,2,0
174,174,0,United Kingdom,Wales,2021-01-11,0,0,2,4,2,3,2,4,2,3,2,4


In [53]:
cur_ip_file_path = ip_file_path + 'prescriptions_initial_case_weights_1_2021-01-11.csv'

tmp2 = predictor.predict('2021-01-11', '2021-01-11', cur_ip_file_path)

In [54]:
tmp2

Unnamed: 0,CountryName,RegionName,Date,PredictedDailyNewCases
0,Afghanistan,,2021-01-11,1681.014872
1,Albania,,2021-01-11,336.611741
2,Algeria,,2021-01-11,285.852256
3,Andorra,,2021-01-11,57.186560
4,Angola,,2021-01-11,71.255222
...,...,...,...,...
231,Venezuela,,2021-01-11,290.566163
232,Vietnam,,2021-01-11,3.731618
233,Yemen,,2021-01-11,0.144612
234,Zambia,,2021-01-11,1156.393040


In [52]:
tmp2

Unnamed: 0,CountryName,RegionName,Date,PredictedDailyNewCases
0,Afghanistan,,2021-01-01,176.163194
1,Albania,,2021-01-01,609.968496
2,Algeria,,2021-01-01,518.532788
3,Andorra,,2021-01-01,57.126983
4,Angola,,2021-01-01,92.236357
...,...,...,...,...
231,Venezuela,,2021-01-01,437.935459
232,Vietnam,,2021-01-01,9.074572
233,Yemen,,2021-01-01,0.567098
234,Zambia,,2021-01-01,474.140673


In [35]:
tmp2[tmp2['CountryName'] == 'United Kingdom']

Unnamed: 0,CountryName,RegionName,Date,PredictedDailyNewCases
170,United Kingdom,,2021-01-11,71006.887865
171,United Kingdom,England,2021-01-11,72088.7195
172,United Kingdom,Northern Ireland,2021-01-11,2394.348731
173,United Kingdom,Scotland,2021-01-11,2857.256491


In [12]:
prescriptions_df

Unnamed: 0,PrescriptionIndex,CountryName,RegionName,Date,C1_School closing,C2_Workplace closing,C3_Cancel public events,C4_Restrictions on gatherings,C5_Close public transport,C6_Stay at home requirements,C7_Restrictions on internal movement,C8_International travel controls,H1_Public information campaigns,H2_Testing policy,H3_Contact tracing,H6_Facial Coverings
0,0,Afghanistan,,2021-02-03,0,3,2,0,2,3,2,4,2,3,2,4
1,0,Albania,,2021-02-03,0,3,0,4,0,3,2,0,2,3,2,4
2,0,Algeria,,2021-02-03,0,0,2,4,2,3,2,4,2,0,2,4
3,0,Andorra,,2021-02-03,0,3,2,0,0,3,2,4,0,3,2,4
4,0,Angola,,2021-02-03,0,0,2,4,2,3,2,4,0,3,2,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
231,0,Venezuela,,2021-02-03,0,3,2,4,0,3,2,0,2,3,2,4
232,0,Vietnam,,2021-02-03,0,0,2,4,2,3,0,4,2,3,0,4
233,0,Yemen,,2021-02-03,3,0,2,4,0,3,2,0,2,3,2,4
234,0,Zambia,,2021-02-03,3,3,2,4,2,0,2,0,0,3,0,4


In [11]:
previous_day_cases

Unnamed: 0_level_0,NewCases
GeoID,Unnamed: 1_level_1


In [7]:
prescribe('2021-02-03', '2021-05-05', '/Users/chang/workplace/covid_xprize/covid-xprize/covid_xprize/standard_predictor/data/OxCGRT_latest.csv', 'prescriptions/test.csv', 'covid_xprize/validation/data/uniform_random_costs.csv')

  0%|          | 0/92 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
6it [00:00, 52.77it/s][A

prescribing for day 2021-02-03



12it [00:00, 53.85it/s][A
18it [00:00, 53.55it/s][A
24it [00:00, 53.89it/s][A
30it [00:00, 54.82it/s][A
36it [00:00, 55.37it/s][A
42it [00:00, 55.47it/s][A
48it [00:00, 55.47it/s][A
54it [00:00, 56.57it/s][A
60it [00:01, 56.65it/s][A
66it [00:01, 56.78it/s][A
72it [00:01, 56.98it/s][A
78it [00:01, 56.80it/s][A
84it [00:01, 57.24it/s][A
90it [00:01, 56.83it/s][A
96it [00:01, 56.42it/s][A
102it [00:01, 56.39it/s][A
108it [00:01, 56.40it/s][A
114it [00:02, 55.52it/s][A
120it [00:02, 55.40it/s][A
126it [00:02, 55.34it/s][A
132it [00:02, 56.11it/s][A
138it [00:02, 55.68it/s][A
144it [00:02, 56.12it/s][A
150it [00:02, 56.21it/s][A
156it [00:02, 56.50it/s][A
162it [00:02, 57.12it/s][A
168it [00:02, 56.97it/s][A
174it [00:03, 56.01it/s][A
180it [00:03, 56.05it/s][A
186it [00:03, 56.35it/s][A
192it [00:03, 56.56it/s][A
198it [00:03, 56.94it/s][A
204it [00:03, 56.95it/s][A
210it [00:03, 57.21it/s][A
216it [00:03, 57.09it/s][A
222it [00:03, 57.30it/s][A
228it 

predicting for day 2021-02-03



  1%|          | 1/92 [00:08<12:45,  8.41s/it]
0it [00:00, ?it/s][A
  1%|          | 1/92 [00:08<12:46,  8.42s/it]

prescribing for day 2021-02-04





KeyError: 'Afghanistan'