# Evaluation code for competition

## Load libraries

In [1]:
import os
import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error

## Set variables

In [2]:
types = ['confirmed','deaths','recovered']

data_dir='Track1_Submitted/'

## Load in data files

### Submitter (predicted)

In [3]:
def get_submitter_dict(data_dir):
    unis = [x for x in os.listdir(data_dir) if '.DS' not in x]
    submitter_dict = {}
    for uni in unis:
        uni_dict = {}
        files = [x for x in os.listdir(data_dir+uni) if '.DS' not in x]
        file_dict = {}
        for type_ in types:
            f_list = [x for x in files if type_ in x]
            file_dict[type_] = f_list
            if len(f_list)>0:
                uni_dict[type_] = pd.read_csv(data_dir+uni+'/'+f_list[0])
        submitter_dict[uni] = uni_dict
    return submitter_dict

submitter_dict = get_submitter_dict(data_dir)

### JHU (true)

In [4]:
true_dte='5/1/20'
filt_dte='4/23/20'
prefix = 'data/time_series_covid19_'
suffix = '_global.csv'

def get_jhu_target(true_dte, filt_dte, fname):
    target_df = \
        (pd.
         read_csv(fname).
         rename(columns = {true_dte : 'true'}).
         rename(columns = {filt_dte : 'filt'}).
         loc[:,['Province/State','Country/Region','true', 'filt']].
         #only view countries/states with filt counts>100
         query('filt>100')
        )
    return target_df

def get_jhu_dict(true_dte, filt_dte, prefix, types, suffix):
    jhu_dict = {}
    for type_ in types:
        jhu_dict[type_] =  get_jhu_target(true_dte, filt_dte,  prefix+type_+suffix)
    return jhu_dict   

jhu_dict = get_jhu_dict(true_dte, filt_dte, prefix, types, suffix)

## Compare data files

In [5]:
pred_dte = '5/1/20'
results_dict = {}
compare_dfs = {}


            
def get_compare_df(df_s, df_j):            
    print(df_s.shape)
    print(df_j.shape)
    compare_df = \
        (df_j[['true']].
        join(df_s[[pred_dte]].
        rename(columns = {pred_dte : 'pred'}).
        assign(pred = lambda x: x.pred.where(x.pred.ge(0))),
        how='left')
        )
    print(compare_df.shape)
    return compare_df

def get_rmse(y_true, y_pred):
    return mean_squared_error(y_true,y_pred) ** 0.5

def get_log(col, compare_df):
    return (compare_df[col].
                      replace([-0,0],1).
                      replace([np.inf, -np.inf], np.nan).
                      replace([np.nan, -np.nan],1).
                      apply(np.log).
                      values)

def get_results_df(submitter_dict, jhu_dict ):
    for uni in submitter_dict.keys():
        if uni!='sl4646':
            print(uni)
            uni_dict = {}
            compare_dfs[uni] = {}
            for type_ in jhu_dict.keys():
                print(type_)
                compare_df = get_compare_df((submitter_dict[uni][type_].
                                             set_index(['Province/State','Country/Region'])),
                                            (jhu_dict[type_].
                                             set_index(['Province/State','Country/Region']))                )
                compare_dfs[uni][type_] = compare_df
                uni_dict[type_] = get_rmse(get_log('pred',compare_df), get_log('true',compare_df))
            results_dict[uni] = uni_dict
    return pd.DataFrame(results_dict).T

results_df = get_results_df(submitter_dict , jhu_dict )

lcb2165
confirmed
(264, 2)
(183, 2)
(183, 2)
deaths
(264, 2)
(50, 2)
(50, 2)
recovered
(250, 2)
(141, 2)
(141, 2)
manasi.sharma
confirmed
(264, 3)
(183, 2)
(183, 2)
deaths
(264, 3)
(50, 2)
(50, 2)
recovered
(250, 3)
(141, 2)
(141, 2)
dlm2202
confirmed
(264, 96)
(183, 2)
(183, 2)
deaths
(264, 96)
(50, 2)
(50, 2)
recovered
(250, 96)
(141, 2)
(141, 2)
mcg2208
confirmed
(264, 3)
(183, 2)
(183, 2)
deaths
(264, 3)
(50, 2)
(50, 2)
recovered
(250, 3)
(141, 2)
(141, 2)
a.saakyan
confirmed
(264, 1)
(183, 2)
(183, 2)
deaths
(264, 1)
(50, 2)
(50, 2)
recovered
(250, 1)
(141, 2)
(141, 2)
by2287
confirmed
(264, 3)
(183, 2)
(183, 2)
deaths
(264, 3)
(50, 2)
(50, 2)
recovered
(250, 3)
(141, 2)
(141, 2)


## Display RMSE ranking

In [6]:
display(results_df.sort_values('confirmed'))

Unnamed: 0,confirmed,deaths,recovered
dlm2202,0.133013,0.131129,0.243079
mcg2208,0.139514,0.162693,0.307704
manasi.sharma,0.24049,0.160922,3.263254
lcb2165,0.326818,0.218046,0.42149
a.saakyan,0.377213,0.245567,0.883555
by2287,0.702877,0.401629,1.509048


In [7]:
display(results_df.sort_values('deaths'))

Unnamed: 0,confirmed,deaths,recovered
dlm2202,0.133013,0.131129,0.243079
manasi.sharma,0.24049,0.160922,3.263254
mcg2208,0.139514,0.162693,0.307704
lcb2165,0.326818,0.218046,0.42149
a.saakyan,0.377213,0.245567,0.883555
by2287,0.702877,0.401629,1.509048


In [8]:
display(results_df.sort_values('recovered'))

Unnamed: 0,confirmed,deaths,recovered
dlm2202,0.133013,0.131129,0.243079
mcg2208,0.139514,0.162693,0.307704
lcb2165,0.326818,0.218046,0.42149
a.saakyan,0.377213,0.245567,0.883555
by2287,0.702877,0.401629,1.509048
manasi.sharma,0.24049,0.160922,3.263254


In [9]:
compare_dfs['dlm2202']['confirmed']

Unnamed: 0_level_0,Unnamed: 1_level_0,true,pred
Province/State,Country/Region,Unnamed: 2_level_1,Unnamed: 3_level_1
,Afghanistan,2335,1891
,Albania,782,824
,Algeria,4154,3830
,Andorra,745,781
,Argentina,4532,4535
,Armenia,2148,1944
Australian Capital Territory,Australia,106,108
New South Wales,Australia,3030,3059
Queensland,Australia,1034,1059
South Australia,Australia,438,452


In [10]:
compare_dfs['dlm2202']['deaths']

Unnamed: 0_level_0,Unnamed: 1_level_0,true,pred
Province/State,Country/Region,Unnamed: 2_level_1,Unnamed: 3_level_1
,Algeria,453,450
,Argentina,225,237
,Austria,589,648
,Bangladesh,170,223
,Belgium,7703,8144
,Brazil,6412,5422
Ontario,Canada,1265,1345
Quebec,Canada,2022,2090
,Chile,234,259
Hubei,China,4512,5049


In [11]:
compare_dfs['dlm2202']['recovered']

Unnamed: 0_level_0,Unnamed: 1_level_0,true,pred
Province/State,Country/Region,Unnamed: 2_level_1,Unnamed: 3_level_1
,Afghanistan,310,460
,Albania,488,537
,Algeria,1821,1979
,Andorra,468,489
,Argentina,1292,1425
,Armenia,977,1204
New South Wales,Australia,2293,1730
Queensland,Australia,965,980
South Australia,Australia,422,410
Victoria,Australia,1300,1334
