# Optimize Initial Conditions
## Yabox

In [None]:
import numpy as np
from datetime import datetime,timedelta,date
import pandas as pd
from yabox import DE

# Initialize Ray

In [None]:
from environs import Env
env = Env()
env.str("CUDA_DEVICE_ORDER",'PCI_BUS_ID')
env.int("CUDA_VISIBLE_DEVICES",1)
env.int("NUMBA_ENABLE_CUDASIM",1)
env.bool("OMPI_MCA_opal_cuda_support",True)

import os
import ray
MB=1024*1024
GB=MB*1024
ray.shutdown()
ray.init(object_store_memory=1*GB,memory=220*GB,
         lru_evict=True,
         driver_object_store_memory=500*MB,num_gpus=5,num_cpus=1,
         ignore_reinit_error=True) # , include_webui=False)

@ray.remote(num_gpus=1)
def use_gpu():
    print("ray.get_gpu_ids(): {}".format(ray.get_gpu_ids()[0]))
    print("CUDA_VISIBLE_DEVICES: {}".format(os.environ["CUDA_VISIBLE_DEVICES"]))

use_gpu.remote()

# Load New and Process Data from website data.brasil.io

In [None]:
%reload_ext autoreload
%autoreload 2
import get_data
LoadData=True

if LoadData:
    get_data.get_data()

# Functions to Load Processed Data

In [None]:
def load_confirmed(country,start_date=None,end_date=None):
    df = pd.read_csv('data/time_series_19-covid-Confirmed-country.csv')
    country_df = df[df['Country/Region'] == country]
    if start_date==None:
        return country_df.iloc[0]
    else: 
        return country_df.iloc[0].loc[start_date:end_date]

def load_recovered(country,start_date=None,end_date=None):
    df = pd.read_csv('data/time_series_19-covid-Recovered-country.csv')
    country_df = df[df['Country/Region'] == country]
    if start_date==None:
        return country_df.iloc[0]
    else: 
        return country_df.iloc[0].loc[start_date:end_date]

def load_dead(country,start_date=None,end_date=None):
    df = pd.read_csv('data/time_series_19-covid-Deaths-country.csv')
    country_df = df[df['Country/Region'] == country]
    if start_date==None:
        return country_df.iloc[0]
    else: 
        return country_df.iloc[0].loc[start_date:end_date]

# Load solver

In [None]:
%reload_ext autoreload
%autoreload 2
import LearnerICRayNoLoadBH_v3NewModel as L 

# Data for Countries

In [None]:
modelHist = "YaboxAndBasinHopping"
dfparam = pd.read_csv("data/param_optimized_"+modelHist+"_HistMin.csv")
countries=dfparam.country
popEst = pd.read_csv("data/WPP2019_TotalPopulationBySex.csv")
popEst['popTotal']=pd.to_numeric(popEst.PopTotal, errors='coerce')

for country in countries:
    if country=="US":
        country2="United States of America"    
    else:
        country2=country
    dfparam.loc[dfparam.country==country,'popTotal']=popEst.loc[popEst.Location==country2].loc[popEst.Time==2020].iloc[0,8]*1000
#     dfparam.loc[dfparam.country==country,'s0']=popEst.loc[popEst.Location==country2].loc[popEst.Time==2020].iloc[0,8]*1000
    
display(dfparam)
    

# Functions for Optimization

In [None]:
from scipy.integrate import odeint
import sys
import io
import gc

def create_f(country,e0,a0,date, end_dateFirstWave, wcases, wrec, wdth, predict_range, version):
                
    def fobjective(point):
        
        dead=  load_dead(country,date, end_dateFirstWave)
        recovered = load_recovered(country,date, end_dateFirstWave)
        data = load_confirmed(country,date, end_dateFirstWave)-recovered-dead
        cleanRecovered=False
        s0, deltaDate, i0, d0, r0, startNCases  = point
        end_date=datetime.strptime(date, "%m/%d/%y") + timedelta(days=deltaDate)
        f=L.Learner.remote(country, end_date.strftime("%m/%d/%y"), predict_range,\
                           s0, e0, a0, i0, r0, d0, startNCases, wcases, wrec, wdth,\
                           cleanRecovered, version, data, dead, recovered, savedata=False)
        result = f.train.remote() 
        result = ray.get(result) 

        del end_date,cleanRecovered, data, dead, point,f         

        gc.collect()

        return result
    return fobjective

In [None]:
@ray.remote(memory=50 * 1024 * 1024, max_calls=1)
def opt(country,s0,i0,e0,a0,r0,d0,date,end_date,startNCases, wcases, wrec, wdth,
        predict_range, version):

    bounds=[(s0/1.5,s0*1.5),(0,0),(i0/1.2,i0*1.2),(r0/1.2,r0*1.2),(d0/1.2,d0*1.2),(startNCases/1.2,startNCases*1.2)]
    f=create_f(country,e0,a0,date,end_date, wcases, wrec, wdth, predict_range, version)
    x0 = [s0, 0, i0,r0,d0,startNCases]
    maxiterations=500
    de = DE(f, bounds, maxiters=maxiterations)
    for step in de.geniterator():
        try:
            idx = step.best_idx
            norm_vector = step.population[idx]
            best_params = de.denormalize([norm_vector])
            del norm_vector, idx
        except:
            print("error in function evaluation")
    p=best_params[0]
    del f, bounds, data, dead,best_params
    
    return optimal.x

# Main Code

In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines){
    return true;}

In [None]:
flagFirstWave=True
finalDate=date.today()+ timedelta(days=-1)
finalDateStr= datetime.strftime(finalDate, '%-m/%-d/%y')

if flagFirstWave:
    #'10/1/20' Brazil
    #'10/25/20' US, India
    firstWave=[finalDateStr,'6/1/20','8/1/20',finalDateStr,finalDateStr]
else:
    firstWave=[finalDateStr,finalDateStr,finalDateStr,finalDateStr,finalDateStr]

In [None]:
countries=dfparam.country
display(countries)
version="115"
gc.enable()

optimal=[]
i=0

for country in countries:
    #remove previous history file
    strFile='./results/history_'+country+version+'.csv'
    if os.path.isfile(strFile):
        os.remove(strFile)
    query = dfparam.query('country == "{}"'.format(country)).reset_index()
    parameters = np.array(query.iloc[:, 2:])[0]
    endDate = datetime.strptime(firstWave[i], '%m/%d/%y')
    end_dateStr= datetime.strftime(endDate, '%-m/%-d/%y')
    date,predict_range,s0,e0,a0,i0,r0,d0,startNCases,wcases,wrec,wdth, pop = parameters
    dateD = datetime.strptime(date, '%m/%d/%y')
    dateStr= datetime.strftime(dateD, '%-m/%-d/%y')
    optimal.append(opt.remote(country,s0,i0,e0,a0,r0,d0,dateStr,end_dateStr,startNCases, wcases, wrec, wdth,
                                        predict_range, version)) 
    i+=1        

In [None]:
optimal=ray.get(optimal)


In [None]:
for i in range(0,len(countries)):    

    #s0, deltaDate, i0, d0, r0, startNCases  = point
    # deltaDate not used at all
    j = query['index'].values[0]
    dfparam.at[j, "s0"] = optimal[i][0]
    dfparam.at[j, "i0"] = optimal[i][2]
    dfparam.at[j, "r0"] = optimal[i][3]
    dfparam.at[j, "d0"] = optimal[i][4]
    dfparam.at[j, "startNCases"] = optimal[i][5]

    dfparam.to_csv("data/param_optimized_FineTune.csv", sep=",", index=False)
    display(dfparam)
    