In [2]:
import pandas as pd
import numpy as np
import json
import shutil
import sys
import os

from matplotlib import pyplot as plt
from statsforecast import StatsForecast
from statsforecast.models import AutoETS
from glob import glob
from ast import literal_eval
from copy import deepcopy

# Add the model directory to Python path to import moved scripts
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), 'model'))

import MeaslesDataLoader as md
import MeaslesModelEval as mm
import EpiPreprocessor as ep

from IPython.display import clear_output

import seaborn as sb
sb.set()


%matplotlib inline

In [3]:
alphaTable = pd.read_csv('input/alpha_model_by_country.csv')
alphaTable

Unnamed: 0,ID,country,ROW_ID,run_folder,validation,predictor,environmentalArg,model,method,Tier_Sim,Tier_Obs
0,AGO,AGO,1460,s3://metabiota-modeling-internal/BMGF_measles/...,FAILED,{'birth_per_1k': 0},"{'mean_precip_mm_per_day': 3, 'mean_max_temp': 3}",Random Forest,Scikit-learn generic: Random Forest regressor,A,C
1,ARE,ARE,1423,s3://metabiota-modeling-internal/BMGF_measles/...,FAILED,{'incoming_air_passengers': 0},"{'mean_precip_mm_per_day': 3, 'mean_max_temp': 3}",Bagging regressor,Scikit-learn generic: Bagging regressor,,B
2,BEL,BEL,1373,s3://metabiota-modeling-internal/BMGF_measles/...,FAILED,{'MCV2_Cov_RecAge': 0},"{'mean_precip_mm_per_day': 3, 'mean_max_temp': 3}",CatBoost,Scikit-learn generic: CatBoostRegressor,,C
3,CIV,CIV,1136,s3://metabiota-modeling-internal/BMGF_measles/...,FAILED,{'MCV2': 0},"{'mean_precip_mm_per_day': 3, 'mean_max_temp': 3}",gradient boosting,Scikit-learn gradient boosted regression,C,B
4,CMR,CMR,1072,s3://metabiota-modeling-internal/BMGF_measles/...,FAILED,{'MCV2': 0},"{'mean_precip_mm_per_day': 3, 'mean_max_temp': 3}",CatBoost,Scikit-learn generic: CatBoostRegressor,A,C
...,...,...,...,...,...,...,...,...,...,...,...
96,UZB,unicef_region:ECAR,53679,s3://metabiota-modeling-internal/BMGF_measles/...,PASSED,"{'MCV2': 0, 'mnths_since_outbreak_2_per_M': 0,...","{'mean_precip_mm_per_day': 3, 'mean_max_temp': 3}",XGBRegressor,Scikit-learn generic: XGBRegressor,A,B
97,AFG,unicef_region:ROSA,72453,s3://metabiota-modeling-internal/BMGF_measles/...,PASSED,"{'MCV1': 0, 'incoming_air_passengers': 0, 'pas...","{'mean_precip_mm_per_day': 3, 'mean_max_temp': 3}",CatBoost,Scikit-learn generic: CatBoostRegressor,A,A
98,GNQ,unicef_region:WCAR,81971,s3://metabiota-modeling-internal/BMGF_measles/...,PASSED,"{'migrations_per_1k': 0, 'MCV1': 0, 'passenger...","{'mean_precip_mm_per_day': 3, 'mean_max_temp': 3}",XGBRegressor,Scikit-learn generic: XGBRegressor,A,B
99,MLI,unicef_region:WCAR,81962,s3://metabiota-modeling-internal/BMGF_measles/...,PASSED,"{'migrations_per_1k': 0, 'birth_per_1k': 0, 'M...","{'mean_precip_mm_per_day': 3, 'mean_max_temp': 3}",XGBRegressor,Scikit-learn generic: XGBRegressor,C,C


In [4]:
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import BaggingRegressor
from catboost import CatBoostRegressor
from xgboost.sklearn import XGBRegressor

regressors =  {'Scikit-learn generic: Random Forest regressor':{'model': RandomForestRegressor, 'modelName': 'RandomForestRegressor'},
               'Scikit-learn generic: Bagging regressor':{'model': BaggingRegressor, 'modelName': 'BaggingRegressor'},
               'Scikit-learn gradient boosted regression':mm.sklGradientBoostingRegression,
               'Scikit-learn generic: CatBoostRegressor':{'model': CatBoostRegressor, 'modelName': 'CatBoostRegressor'},
               'Scikit-learn generic: XGBRegressor':{'model': XGBRegressor, 'modelName': 'XGBRegressor'}}

In [5]:
def runWrapped(model,
               geography,
               depVar,
               indepVars,
               metric,
               r=5,
               monthsForward = 0,
               seed = 1337):
               
    seeds = [seed*i for i in range(1,r+1)]
    runs = []
    for seed in seeds:    
        if True:
            if type(model) is not dict:
                mlRun = model(geography,
                              depVar,
                              indepVars = indepVars,
                              randomState = seed,
                              useCache = True,
                              binaryLabelMetric = metric,
                              monthsForward = monthsForward)
            
            elif type(model) is dict:
                if model['modelName'] == 'XGBRegressor':
                    initialized = True
                    model['model'] = XGBRegressor(random_state=seed)
                else:
                    initialized = False

                
                mlRun = mm.sklGeneric(geography,
                                      depVar,
                                      indepVars = indepVars,
                                      modelArgs = model,
                                      randomState = seed,
                                      useCache = True,
                                      binaryLabelMetric = metric,
                                      monthsForward = monthsForward,
                                      initialized = initialized)
            mlRun.train()
            runs.append(mlRun)
        else:
            pass
    
    if r == 1:
        return mlRun
    else:
        return runs

In [None]:
def runAlphaTable(ref = 'input/alpha_model_by_country.csv',
                 monthsForward = 36,
                 targetRuns=25,
                 runType = 'regressors'):
    
    alphaTable = pd.read_csv(ref)
    loadDicts = lambda x: x.str.replace("'",'"').apply(json.loads)
    alphaTable['predictor'] = loadDicts(alphaTable['predictor'])
    alphaTable['environmentalArg'] = loadDicts(alphaTable['environmentalArg'])

    if runType == 'regressors':
        depVar = 'cases_1M'
        binaryLabelMetric = lambda x: x >= 5

    collected = []
    
    for index, row in alphaTable.iterrows():
        geography = row['country']
        country = row['ID']
        indepVars = row['predictor']
        indepVars.update(row['environmentalArg'])
        model = regressors[row['method']]
        method = row['method']

        if geography == 'all':
            geoType = 'global'
        elif geography.startswith('cluster'):
            geoType = 'cluster'
        else:
            geoType = 'single'

            
        dfs = []
        goodRuns = 0
        attemptedRuns = 0
        
        while goodRuns != targetRuns and not (attemptedRuns > targetRuns*2 and dfs == []):
            try:
                run = runWrapped(model,
                                 geography,
                                 depVar,
                                 indepVars,
                                 binaryLabelMetric,
                                 r = 1,
                                 monthsForward = monthsForward,
                                 seed = attemptedRuns)

                df = pd.DataFrame(run.evaluate().loc[country,:]).T
                mm.exportTables(run)
                df['hash'] = run.hash
                selectColumns = [col for col in df.columns if not col.startswith("ID_")]
                df = df[selectColumns]
                dfs.append(df)
                goodRuns += 1

            except:
                print('Failed',goodRuns,attemptedRuns)
                
            attemptedRuns += 1
            

        if dfs != []:
            merged = pd.concat(dfs)
            merged['Geography'] = geography
            merged['Geo type'] = geoType
            merged['ISO3'] = merged.index
            merged['table'] = 'output/projections/' + merged['hash'] + '_' + merged['ISO3'] + '_Projection.csv'
    
            collected.append(merged)
        clear_output()
    #return run
    
    metaData = pd.concat(collected)
    metaData.to_csv("output/projections/MetaData.csv")

    for index, row in metaData.iterrows():
        tableOut = row['table']
        tableIn = tableOut.replace('output/projections/','output/tables/')
        shutil.copy(tableIn,tableOut)
        
    
    return metaData

temp = runAlphaTable()
temp

Dropped [('country', 'BDI', 'Country dropped due to missing var'), ('country', 'COM', 'Country dropped due to missing var'), ('country', 'GMB', 'Country dropped due to missing var'), ('country', 'SYC', 'Country dropped due to missing var')]


Seed set to 0


Learning rate set to 0.050999
0:	learn: 64.2708525	total: 3.92ms	remaining: 3.91s
1:	learn: 63.2507537	total: 5.17ms	remaining: 2.58s
2:	learn: 62.2476128	total: 6.29ms	remaining: 2.09s
3:	learn: 61.9173676	total: 7.38ms	remaining: 1.84s
4:	learn: 60.9412663	total: 8.52ms	remaining: 1.7s
5:	learn: 60.6383222	total: 9.57ms	remaining: 1.58s
6:	learn: 60.3935594	total: 10.7ms	remaining: 1.51s
7:	learn: 59.1239706	total: 11.7ms	remaining: 1.46s
8:	learn: 57.7893999	total: 12.9ms	remaining: 1.42s
9:	learn: 57.5149839	total: 13.9ms	remaining: 1.38s
10:	learn: 57.2460901	total: 15ms	remaining: 1.35s
11:	learn: 56.0594077	total: 16.1ms	remaining: 1.32s
12:	learn: 54.9096178	total: 17.2ms	remaining: 1.31s
13:	learn: 54.0083360	total: 18.3ms	remaining: 1.29s
14:	learn: 53.7670556	total: 19.4ms	remaining: 1.27s
15:	learn: 52.6837916	total: 20.5ms	remaining: 1.26s
16:	learn: 52.4518276	total: 21.6ms	remaining: 1.25s
17:	learn: 51.6083175	total: 22.7ms	remaining: 1.24s
18:	learn: 50.5849287	total: 

Seed set to 1


Learning rate set to 0.050999
0:	learn: 65.1465992	total: 1.34ms	remaining: 1.33s
1:	learn: 64.8885145	total: 2.58ms	remaining: 1.28s
2:	learn: 63.8510226	total: 3.72ms	remaining: 1.24s
3:	learn: 62.8784587	total: 4.8ms	remaining: 1.19s
4:	learn: 61.5457189	total: 5.83ms	remaining: 1.16s
5:	learn: 60.4943011	total: 6.88ms	remaining: 1.14s
6:	learn: 59.5472151	total: 7.88ms	remaining: 1.12s
7:	learn: 58.3034157	total: 8.9ms	remaining: 1.1s
8:	learn: 57.0821006	total: 9.9ms	remaining: 1.09s
9:	learn: 56.1335557	total: 11ms	remaining: 1.08s
10:	learn: 55.5252663	total: 12.1ms	remaining: 1.08s
11:	learn: 54.6855968	total: 13.1ms	remaining: 1.08s
12:	learn: 54.0379154	total: 14.2ms	remaining: 1.08s
13:	learn: 53.7177306	total: 15.3ms	remaining: 1.08s
14:	learn: 53.4576764	total: 16.3ms	remaining: 1.07s
15:	learn: 53.2132440	total: 17.3ms	remaining: 1.06s
16:	learn: 52.1449731	total: 18.4ms	remaining: 1.06s
17:	learn: 51.1116145	total: 19.4ms	remaining: 1.06s
18:	learn: 50.8755158	total: 20.

Seed set to 2


Learning rate set to 0.050999
0:	learn: 64.2706067	total: 1.28ms	remaining: 1.28s
1:	learn: 63.1611894	total: 2.38ms	remaining: 1.19s
2:	learn: 62.8345796	total: 3.47ms	remaining: 1.15s
3:	learn: 61.7558937	total: 4.55ms	remaining: 1.13s
4:	learn: 60.7718213	total: 5.6ms	remaining: 1.11s
5:	learn: 60.4659961	total: 6.63ms	remaining: 1.1s
6:	learn: 59.1960545	total: 7.66ms	remaining: 1.09s
7:	learn: 58.2052269	total: 8.72ms	remaining: 1.08s
8:	learn: 56.9962405	total: 9.8ms	remaining: 1.08s
9:	learn: 55.8255420	total: 10.9ms	remaining: 1.07s
10:	learn: 54.6863659	total: 11.9ms	remaining: 1.07s
11:	learn: 54.4233492	total: 13.1ms	remaining: 1.07s
12:	learn: 54.1573760	total: 14.1ms	remaining: 1.07s
13:	learn: 53.0625474	total: 15.1ms	remaining: 1.06s
14:	learn: 52.8197769	total: 16.1ms	remaining: 1.06s
15:	learn: 52.0353956	total: 17.2ms	remaining: 1.06s
16:	learn: 51.8031042	total: 18.2ms	remaining: 1.05s
17:	learn: 50.7845746	total: 19.2ms	remaining: 1.05s
18:	learn: 49.7934066	total: 

Seed set to 3


Learning rate set to 0.050999
0:	learn: 64.8607990	total: 1.26ms	remaining: 1.26s
1:	learn: 63.8179059	total: 2.42ms	remaining: 1.21s
2:	learn: 63.3464404	total: 3.5ms	remaining: 1.16s
3:	learn: 62.3396897	total: 4.58ms	remaining: 1.14s
4:	learn: 61.0133787	total: 5.63ms	remaining: 1.12s
5:	learn: 60.5259952	total: 6.7ms	remaining: 1.11s
6:	learn: 59.5012228	total: 7.8ms	remaining: 1.11s
7:	learn: 58.5809220	total: 8.85ms	remaining: 1.1s
8:	learn: 57.6021411	total: 9.9ms	remaining: 1.09s
9:	learn: 56.4087445	total: 11ms	remaining: 1.09s
10:	learn: 55.2534063	total: 12ms	remaining: 1.08s
11:	learn: 54.1295712	total: 13ms	remaining: 1.07s
12:	learn: 53.7990950	total: 14ms	remaining: 1.06s
13:	learn: 52.7232435	total: 15ms	remaining: 1.06s
14:	learn: 51.6796250	total: 15.9ms	remaining: 1.04s
15:	learn: 50.8485066	total: 16.9ms	remaining: 1.04s
16:	learn: 50.5440744	total: 17.8ms	remaining: 1.03s
17:	learn: 50.2999172	total: 18.8ms	remaining: 1.02s
18:	learn: 49.3232895	total: 19.8ms	remai

Seed set to 4


Learning rate set to 0.050999
0:	learn: 64.2732374	total: 1.21ms	remaining: 1.21s
1:	learn: 63.9279873	total: 2.35ms	remaining: 1.17s
2:	learn: 62.8073210	total: 3.49ms	remaining: 1.16s
3:	learn: 61.4723486	total: 4.55ms	remaining: 1.13s
4:	learn: 60.7114429	total: 5.59ms	remaining: 1.11s
5:	learn: 60.4033539	total: 6.7ms	remaining: 1.11s
6:	learn: 59.1337907	total: 7.74ms	remaining: 1.1s
7:	learn: 58.1373945	total: 8.78ms	remaining: 1.09s
8:	learn: 56.9261866	total: 9.82ms	remaining: 1.08s
9:	learn: 56.0538073	total: 10.9ms	remaining: 1.08s
10:	learn: 55.1974063	total: 12.8ms	remaining: 1.15s
11:	learn: 54.2933245	total: 14ms	remaining: 1.16s
12:	learn: 53.9789335	total: 15.2ms	remaining: 1.15s
13:	learn: 53.7250439	total: 16.3ms	remaining: 1.15s
14:	learn: 52.6355384	total: 17.4ms	remaining: 1.14s
15:	learn: 52.3891711	total: 18.4ms	remaining: 1.13s
16:	learn: 51.3450852	total: 19.5ms	remaining: 1.13s
17:	learn: 50.3382407	total: 20.5ms	remaining: 1.12s
18:	learn: 50.1117711	total: 2

Seed set to 5


Learning rate set to 0.050999
0:	learn: 65.3456572	total: 1.28ms	remaining: 1.28s
1:	learn: 63.9382396	total: 2.44ms	remaining: 1.22s
2:	learn: 63.5979671	total: 3.49ms	remaining: 1.16s
3:	learn: 63.2809751	total: 4.57ms	remaining: 1.14s
4:	learn: 61.9308812	total: 5.62ms	remaining: 1.12s
5:	learn: 60.8653099	total: 6.64ms	remaining: 1.1s
6:	learn: 60.5524620	total: 7.62ms	remaining: 1.08s
7:	learn: 59.2765013	total: 8.65ms	remaining: 1.07s
8:	learn: 58.2748428	total: 9.74ms	remaining: 1.07s
9:	learn: 58.0028124	total: 10.8ms	remaining: 1.07s
10:	learn: 57.0334410	total: 12ms	remaining: 1.07s
11:	learn: 55.8573283	total: 13ms	remaining: 1.07s
12:	learn: 55.6026067	total: 14ms	remaining: 1.06s
13:	learn: 54.4671371	total: 15.1ms	remaining: 1.06s
14:	learn: 53.8202009	total: 16.2ms	remaining: 1.06s
15:	learn: 53.0104362	total: 17.2ms	remaining: 1.06s
16:	learn: 51.8675529	total: 18.3ms	remaining: 1.06s
17:	learn: 51.0444829	total: 19.4ms	remaining: 1.06s
18:	learn: 50.2383366	total: 20.5

Seed set to 6


Learning rate set to 0.050999
0:	learn: 65.4690857	total: 1.26ms	remaining: 1.26s
1:	learn: 64.0533714	total: 2.35ms	remaining: 1.17s
2:	learn: 62.9454956	total: 3.41ms	remaining: 1.13s
3:	learn: 62.6183477	total: 4.42ms	remaining: 1.1s
4:	learn: 61.2907029	total: 5.47ms	remaining: 1.09s
5:	learn: 59.9938855	total: 6.54ms	remaining: 1.08s
6:	learn: 59.4734522	total: 7.69ms	remaining: 1.09s
7:	learn: 58.2346531	total: 8.76ms	remaining: 1.09s
8:	learn: 57.9450145	total: 9.77ms	remaining: 1.07s
9:	learn: 56.9756387	total: 10.8ms	remaining: 1.07s
10:	learn: 55.8048839	total: 11.9ms	remaining: 1.07s
11:	learn: 54.8828918	total: 12.9ms	remaining: 1.06s
12:	learn: 54.6201022	total: 13.9ms	remaining: 1.06s
13:	learn: 54.3613124	total: 15ms	remaining: 1.05s
14:	learn: 54.1122917	total: 16ms	remaining: 1.05s
15:	learn: 53.2298667	total: 17ms	remaining: 1.05s
16:	learn: 52.1554472	total: 18.1ms	remaining: 1.05s
17:	learn: 51.1215658	total: 19.2ms	remaining: 1.04s
18:	learn: 50.0192943	total: 20.2

Seed set to 7


Learning rate set to 0.050999
0:	learn: 65.3386644	total: 1.33ms	remaining: 1.33s
1:	learn: 65.0045460	total: 2.49ms	remaining: 1.24s
2:	learn: 63.6066038	total: 3.61ms	remaining: 1.2s
3:	learn: 62.5094024	total: 4.66ms	remaining: 1.16s
4:	learn: 61.0838002	total: 5.74ms	remaining: 1.14s
5:	learn: 60.0455667	total: 6.8ms	remaining: 1.13s
6:	learn: 59.0959950	total: 7.87ms	remaining: 1.12s
7:	learn: 58.8053066	total: 8.93ms	remaining: 1.11s
8:	learn: 57.8922124	total: 9.98ms	remaining: 1.1s
9:	learn: 56.6930549	total: 11ms	remaining: 1.09s
10:	learn: 56.3966745	total: 12.1ms	remaining: 1.08s
11:	learn: 55.2342734	total: 13.2ms	remaining: 1.08s
12:	learn: 54.3308227	total: 14.2ms	remaining: 1.08s
13:	learn: 54.0280719	total: 15.3ms	remaining: 1.07s
14:	learn: 52.8620855	total: 16.3ms	remaining: 1.07s
15:	learn: 51.8005291	total: 17.4ms	remaining: 1.07s
16:	learn: 51.5597072	total: 18.4ms	remaining: 1.06s
17:	learn: 50.5317121	total: 19.3ms	remaining: 1.05s
18:	learn: 49.5436900	total: 20

In [None]:
temp.table[-1]

In [None]:
print(temp['model args'][-1])
pd.read_csv(temp.table[-1])

In [None]:
print(temp['model args'][-2])
pd.read_csv(temp.table[-2])

In [None]:
print(temp['model args'][-3])
pd.read_csv(temp.table[-3])