In [1]:
import os
import pandas as pd
import numpy as np

# Load data set

In [2]:
LATEST_DATA_URL = 'https://raw.githubusercontent.com/OxCGRT/covid-policy-tracker/master/data/OxCGRT_latest.csv'

In [3]:
def load_dataset(url):
    latest_df = pd.read_csv(url,
                            parse_dates=['Date'],
                            encoding="ISO-8859-1",
                            dtype={"RegionName": str,
                                   "RegionCode": str},
                            error_bad_lines=False)
    latest_df["RegionName"] = latest_df["RegionName"].fillna("")
    return latest_df

In [4]:
latest_df = load_dataset(LATEST_DATA_URL)

In [5]:
latest_df.sample(3)

Unnamed: 0,CountryName,CountryCode,RegionName,RegionCode,Jurisdiction,Date,C1_School closing,C1_Flag,C2_Workplace closing,C2_Flag,...,StringencyIndex,StringencyIndexForDisplay,StringencyLegacyIndex,StringencyLegacyIndexForDisplay,GovernmentResponseIndex,GovernmentResponseIndexForDisplay,ContainmentHealthIndex,ContainmentHealthIndexForDisplay,EconomicSupportIndex,EconomicSupportIndexForDisplay
77992,United States,USA,Arkansas,US_AR,STATE_TOTAL,2020-07-19,3.0,1.0,2.0,1.0,...,61.11,61.11,70.95,70.95,59.17,59.17,58.65,58.65,62.5,62.5
97155,Yemen,YEM,,,NAT_TOTAL,2020-01-04,0.0,,0.0,,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
80340,United States,USA,Florida,US_FL,STATE_TOTAL,2020-03-25,3.0,1.0,3.0,0.0,...,77.31,77.31,80.48,80.48,55.28,55.28,59.94,59.94,25.0,25.0


# Get NPIs

In [6]:
NPI_COLUMNS = ['C1_School closing',
               'C2_Workplace closing',
               'C3_Cancel public events',
               'C4_Restrictions on gatherings',
               'C5_Close public transport',
               'C6_Stay at home requirements',
               'C7_Restrictions on internal movement',
               'C8_International travel controls',
               'H1_Public information campaigns',
               'H2_Testing policy',
               'H3_Contact tracing',
               'H6_Facial Coverings']

In [7]:
npis_df = latest_df[["CountryName", "RegionName", "Date"] + NPI_COLUMNS]

In [8]:
npis_df.sample(3)

Unnamed: 0,CountryName,RegionName,Date,C1_School closing,C2_Workplace closing,C3_Cancel public events,C4_Restrictions on gatherings,C5_Close public transport,C6_Stay at home requirements,C7_Restrictions on internal movement,C8_International travel controls,H1_Public information campaigns,H2_Testing policy,H3_Contact tracing,H6_Facial Coverings
37220,United Kingdom,Wales,2020-09-17,1.0,2.0,2.0,4.0,1.0,0.0,2.0,2.0,2.0,2.0,2.0,2.0
36302,United Kingdom,Northern Ireland,2020-02-16,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,1.0,1.0,0.0
90179,United States,Oregon,2020-03-08,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3.0,1.0,1.0,1.0,0.0


# Dates

In [316]:
start_date_str = "2020-11-01"
end_date_str = "2020-11-08"

In [317]:
start_date = pd.to_datetime(start_date_str, format='%Y-%m-%d')
end_date = pd.to_datetime(end_date_str, format='%Y-%m-%d')

In [318]:
actual_npis_df = npis_df[(npis_df.Date >= start_date) & (npis_df.Date <= end_date)]
actual_npis_df.sample(3)

Unnamed: 0,CountryName,RegionName,Date,C1_School closing,C2_Workplace closing,C3_Cancel public events,C4_Restrictions on gatherings,C5_Close public transport,C6_Stay at home requirements,C7_Restrictions on internal movement,C8_International travel controls,H1_Public information campaigns,H2_Testing policy,H3_Contact tracing,H6_Facial Coverings
78801,United States,California,2020-11-01,3.0,2.0,2.0,3.0,0.0,1.0,1.0,3.0,2.0,3.0,2.0,3.0
73877,Tonga,,2020-11-05,,,,,,,,,,,,
29170,Cyprus,,2020-11-02,1.0,2.0,2.0,4.0,1.0,2.0,0.0,3.0,2.0,3.0,2.0,3.0


# Get actual cases between these dates

In [319]:
NUM_PREV_DAYS_TO_INCLUDE = 6
WINDOW_SIZE = 7

In [320]:
def get_actual_cases(df, start_date, end_date):
    # 1 day earlier to compute the daily diff
    start_date_for_diff = start_date - pd.offsets.Day(WINDOW_SIZE)
    actual_df = df[["CountryName", "RegionName", "Date", "ConfirmedCases"]]
    # Filter out the data set to include all the data needed to compute the diff
    actual_df = actual_df[(actual_df.Date >= start_date_for_diff) & (actual_df.Date <= end_date)]
    # Add GeoID column that combines CountryName and RegionName for easier manipulation of data
    # np.where usage: if A then B else C
    actual_df["GeoID"] = np.where(actual_df["RegionName"].isnull(),
                                  actual_df["CountryName"],
                                  actual_df["CountryName"] + ' / ' + actual_df["RegionName"])
    actual_df.sort_values(by=["GeoID","Date"], inplace=True)
    # Compute the diff
    actual_df["ActualDailyNewCases"] = actual_df.groupby("GeoID")["ConfirmedCases"].diff().fillna(0)
    # Compute the 7 day moving average
    actual_df["ActualDailyNewCases7DMA"] = actual_df.groupby(
        "GeoID")['ActualDailyNewCases'].rolling(
        WINDOW_SIZE, center=False).mean().reset_index(0, drop=True)
    return actual_df

In [321]:
actual_df = get_actual_cases(latest_df, start_date, end_date)

In [322]:
actual_df.head(12)

Unnamed: 0,CountryName,RegionName,Date,ConfirmedCases,GeoID,ActualDailyNewCases,ActualDailyNewCases7DMA
650,Afghanistan,,2020-10-25,40833.0,Afghanistan /,0.0,
651,Afghanistan,,2020-10-26,40937.0,Afghanistan /,104.0,
652,Afghanistan,,2020-10-27,41032.0,Afghanistan /,95.0,
653,Afghanistan,,2020-10-28,41145.0,Afghanistan /,113.0,
654,Afghanistan,,2020-10-29,41268.0,Afghanistan /,123.0,
655,Afghanistan,,2020-10-30,41334.0,Afghanistan /,66.0,
656,Afghanistan,,2020-10-31,41425.0,Afghanistan /,91.0,84.571429
657,Afghanistan,,2020-11-01,41501.0,Afghanistan /,76.0,95.428571
658,Afghanistan,,2020-11-02,41633.0,Afghanistan /,132.0,99.428571
659,Afghanistan,,2020-11-03,41728.0,Afghanistan /,95.0,99.428571


# Get historical data for 7 days moving average calculation
In order to compute the 7 days moving average, we need to get the historical true new cases for the last 7 days before start date

In [323]:
ma_df = actual_df[actual_df["Date"] < start_date]
ma_df = ma_df[["CountryName", "RegionName", "Date", "ActualDailyNewCases"]]
ma_df = ma_df.rename(columns={"ActualDailyNewCases": "PredictedDailyNewCases"})
ma_df.head()

Unnamed: 0,CountryName,RegionName,Date,PredictedDailyNewCases
650,Afghanistan,,2020-10-25,0.0
651,Afghanistan,,2020-10-26,104.0
652,Afghanistan,,2020-10-27,95.0
653,Afghanistan,,2020-10-28,113.0
654,Afghanistan,,2020-10-29,123.0


# Run the predictions
Evaluate some example submissions.  
__NOTE: Please run the corresponding example notebooks first in order to train the models that are used in this section.__

In [324]:
IP_FILE = "covid_xprize/validation/data/2020-12-16_historical_ip.csv"
predictions = {}

## Linear

In [325]:
# Check a model has been trained
if not os.path.isfile("covid_xprize/examples/predictors/linear/models/model.pkl"):
    print("ERROR: Please run the notebook in 'covid_xprize/examples/predictors/linear' in order to train a model!")

In [None]:
linear_output_file = "covid_xprize/examples/predictors/linear/predictions/val_4_days.csv"

In [None]:
!python covid_xprize/examples/predictors/linear/predict.py -s {start_date_str} -e {end_date_str} -ip {IP_FILE} -o {linear_output_file}

In [None]:
predictions["Linear"] = linear_output_file

## LSTM

In [276]:
# Check a model has been trained
if not os.path.isfile("covid_xprize/examples/predictors/lstm/models/test_robojudge_2.h5"):
    print("ERROR: Please run the notebook in 'covid_xprize/examples/predictors/lstm' in order to train a model!")

In [326]:
lstm_output_file = "covid_xprize/examples/predictors/lstm/predictions/val_4_days.csv"

In [377]:
!python covid_xprize/examples/predictors/lstm/predict.py -s {start_date_str} -e {end_date_str} -ip {IP_FILE} -o {lstm_output_file}

Generating predictions from 2020-11-01 to 2020-11-08...
Saved predictions to covid_xprize/examples/predictors/lstm/predictions/val_4_days.csv
Done!


In [378]:
predictions["LSTM"] = lstm_output_file

# Get predictions from submissions

In [379]:
def get_predictions_from_file(predictor_name, predictions_file, ma_df):
    preds_df = pd.read_csv(predictions_file,
                           parse_dates=['Date'],
                           encoding="ISO-8859-1",
                           error_bad_lines=False)
    preds_df["RegionName"] = preds_df["RegionName"].fillna("")
    preds_df["PredictorName"] = predictor_name
    preds_df["Prediction"] = True
    
    # Append the true number of cases before start date
    ma_df["PredictorName"] = predictor_name
    ma_df["Prediction"] = False
    preds_df = ma_df.append(preds_df, ignore_index=True)

    # Add GeoID column that combines CountryName and RegionName for easier manipulation of data
    # np.where usage: if A then B else C
    preds_df["GeoID"] = np.where(preds_df["RegionName"].isnull(),
                                 preds_df["CountryName"],
                                 preds_df["CountryName"] + ' / ' + preds_df["RegionName"])
    # Sort
    preds_df.sort_values(by=["GeoID","Date"], inplace=True)
    # Compute the 7 days moving average for PredictedDailyNewCases
    preds_df["PredictedDailyNewCases7DMA"] = preds_df.groupby(
        "GeoID")['PredictedDailyNewCases'].rolling(
        WINDOW_SIZE, center=False).mean().reset_index(0, drop=True)

    # Put PredictorName first
    preds_df = preds_df[["PredictorName"] + [col for col in preds_df.columns if col != "PredictorName"]]
    return preds_df

In [380]:
test_predictor_name = "LSTM"
temp_df = get_predictions_from_file(test_predictor_name, predictions[test_predictor_name], ma_df.copy())
temp_df.head(12)

Unnamed: 0,PredictorName,CountryName,RegionName,Date,PredictedDailyNewCases,Prediction,GeoID,PredictedDailyNewCases7DMA
0,LSTM,Afghanistan,,2020-10-25,0.0,False,Afghanistan /,
1,LSTM,Afghanistan,,2020-10-26,104.0,False,Afghanistan /,
2,LSTM,Afghanistan,,2020-10-27,95.0,False,Afghanistan /,
3,LSTM,Afghanistan,,2020-10-28,113.0,False,Afghanistan /,
4,LSTM,Afghanistan,,2020-10-29,123.0,False,Afghanistan /,
5,LSTM,Afghanistan,,2020-10-30,66.0,False,Afghanistan /,
6,LSTM,Afghanistan,,2020-10-31,91.0,False,Afghanistan /,84.571429
1968,LSTM,Afghanistan,,2020-11-01,122.624376,True,Afghanistan /,102.089197
1969,LSTM,Afghanistan,,2020-11-02,239.655492,True,Afghanistan /,121.468553
1970,LSTM,Afghanistan,,2020-11-03,81.613632,True,Afghanistan /,119.556214


In [381]:
actual_df.head(8)

Unnamed: 0,CountryName,RegionName,Date,ConfirmedCases,GeoID,ActualDailyNewCases,ActualDailyNewCases7DMA
650,Afghanistan,,2020-10-25,40833.0,Afghanistan /,0.0,
651,Afghanistan,,2020-10-26,40937.0,Afghanistan /,104.0,
652,Afghanistan,,2020-10-27,41032.0,Afghanistan /,95.0,
653,Afghanistan,,2020-10-28,41145.0,Afghanistan /,113.0,
654,Afghanistan,,2020-10-29,41268.0,Afghanistan /,123.0,
655,Afghanistan,,2020-10-30,41334.0,Afghanistan /,66.0,
656,Afghanistan,,2020-10-31,41425.0,Afghanistan /,91.0,84.571429
657,Afghanistan,,2020-11-01,41501.0,Afghanistan /,76.0,95.428571


In [382]:
from covid_xprize.validation.predictor_validation import validate_submission

ranking_df = pd.DataFrame()
for predictor_name, predictions_file in predictions.items():
    print(f"Getting {predictor_name}'s predictions from: {predictions_file}")
    errors = validate_submission(start_date_str, end_date_str, IP_FILE, predictions_file)
    if not errors:
        preds_df = get_predictions_from_file(predictor_name, predictions_file, ma_df)
        merged_df = actual_df.merge(preds_df, on=['CountryName', 'RegionName', 'Date', 'GeoID'], how='left')
        ranking_df = ranking_df.append(merged_df)
    else:
        print(f"Predictor {predictor_name} did not submit valid predictions! Please check its errors:")
        print(errors)

Getting LSTM's predictions from: covid_xprize/examples/predictors/lstm/predictions/val_4_days.csv


In [383]:
ranking_df['DiffDaily'] = (ranking_df["ActualDailyNewCases"] - ranking_df["PredictedDailyNewCases"]).abs()

In [384]:
ranking_df['Diff7DMA'] = (ranking_df["ActualDailyNewCases7DMA"] - ranking_df["PredictedDailyNewCases7DMA"]).abs()

In [385]:
# Compute the cumulative sum of 7DMA errors
ranking_df['CumulDiff7DMA'] = ranking_df.groupby(["GeoID", "PredictorName"])['Diff7DMA'].cumsum()

In [386]:
# Keep only predictions (either Prediction == True) or on or after start_date
ranking_df = ranking_df[ranking_df["Date"] >= start_date]

In [387]:
# Sort by 7 days moving average diff
ranking_df.sort_values(by=["CountryName","RegionName","Date","Diff7DMA"], inplace=True)

In [388]:
ranking_df.head(4*2)

Unnamed: 0,CountryName,RegionName,Date,ConfirmedCases,GeoID,ActualDailyNewCases,ActualDailyNewCases7DMA,PredictorName,PredictedDailyNewCases,Prediction,PredictedDailyNewCases7DMA,DiffDaily,Diff7DMA,CumulDiff7DMA
7,Afghanistan,,2020-11-01,41501.0,Afghanistan /,76.0,95.428571,LSTM,122.624376,True,102.089197,46.624376,6.660625,6.660625
8,Afghanistan,,2020-11-02,41633.0,Afghanistan /,132.0,99.428571,LSTM,239.655492,True,121.468553,107.655492,22.039981,28.700606
9,Afghanistan,,2020-11-03,41728.0,Afghanistan /,95.0,99.428571,LSTM,81.613632,True,119.556214,13.386368,20.127643,48.828249
10,Afghanistan,,2020-11-04,41814.0,Afghanistan /,86.0,95.571429,LSTM,128.538539,True,121.776006,42.538539,26.204577,75.032826
11,Afghanistan,,2020-11-05,41935.0,Afghanistan /,121.0,95.285714,LSTM,315.317907,True,149.249992,194.317907,53.964278,128.997104
12,Afghanistan,,2020-11-06,41975.0,Afghanistan /,40.0,91.571429,LSTM,176.998277,True,165.106889,136.998277,73.53546,202.532564
13,Afghanistan,,2020-11-07,42033.0,Afghanistan /,58.0,86.857143,LSTM,253.901353,True,188.378511,195.901353,101.521368,304.053932
14,Afghanistan,,2020-11-08,42159.0,Afghanistan /,126.0,94.0,LSTM,135.901567,True,190.275252,9.901567,96.275252,400.329185


In [389]:
ranking_df[(ranking_df.CountryName == "United States") &
           (ranking_df.Date == '2020-08-02')]

Unnamed: 0,CountryName,RegionName,Date,ConfirmedCases,GeoID,ActualDailyNewCases,ActualDailyNewCases7DMA,PredictorName,PredictedDailyNewCases,Prediction,PredictedDailyNewCases7DMA,DiffDaily,Diff7DMA,CumulDiff7DMA


In [390]:
# Save to file
# ranking_df.to_csv("/Users/m_754337/workspace/esp-demo/xprize/tests/fixtures/ranking.csv", index=False)

# Ranking

## Global

In [391]:
ranking_df.groupby('PredictorName').Diff7DMA.sum().sort_values()

PredictorName
LSTM    3.159673e+06
Name: Diff7DMA, dtype: float64

## Countries

In [407]:
countries_ranking_df = ranking_df.groupby(["CountryName", "RegionName", "PredictorName"])[["CountryName", "RegionName", "PredictorName", "Diff7DMA"]].sum().sort_values(by=["CountryName", "RegionName", "Diff7DMA"])


TypeError: sort_values() got an unexpected keyword argument 'order'

In [408]:
countries_ranking_df.head(12)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Diff7DMA
CountryName,RegionName,PredictorName,Unnamed: 3_level_1
Afghanistan,,LSTM,400.329185
Albania,,LSTM,1956.599309
Algeria,,LSTM,871.060786
Andorra,,LSTM,269.223044
Angola,,LSTM,673.367764
Argentina,,LSTM,25239.564822
Aruba,,LSTM,15.577464
Australia,,LSTM,8.606308
Austria,,LSTM,9840.764692
Azerbaijan,,LSTM,5759.714286


## Specific country

In [394]:
cr_df = countries_ranking_df.reset_index()

In [395]:
cr_df[(cr_df.CountryName == "Italy") & (cr_df.RegionName == "")]

Unnamed: 0,CountryName,RegionName,PredictorName,Diff7DMA
110,Italy,,LSTM,49300.489917


In [396]:
ranking_df[ranking_df.CountryName == "Italy"]

Unnamed: 0,CountryName,RegionName,Date,ConfirmedCases,GeoID,ActualDailyNewCases,ActualDailyNewCases7DMA,PredictorName,PredictedDailyNewCases,Prediction,PredictedDailyNewCases7DMA,DiffDaily,Diff7DMA,CumulDiff7DMA
1852,Italy,,2020-11-01,709335.0,Italy /,29905.0,26221.857143,LSTM,17191.875064,True,24405.696438,12713.124936,1816.160705,1816.160705
1853,Italy,,2020-11-02,731588.0,Italy /,22253.0,26971.285714,LSTM,20943.170315,True,24968.006483,1309.829685,2003.279232,3819.439937
1854,Italy,,2020-11-03,759829.0,Italy /,28241.0,27864.428571,LSTM,23620.984313,True,25201.147099,4620.015687,2663.281473,6482.721409
1855,Italy,,2020-11-04,790377.0,Italy /,30548.0,28658.714286,LSTM,23370.343355,True,24970.053292,7177.656645,3688.660993,10171.382403
1856,Italy,,2020-11-05,824879.0,Italy /,34502.0,29754.857143,LSTM,21065.118696,True,24146.641678,13436.881304,5608.215465,15779.597868
1857,Italy,,2020-11-06,862681.0,Italy /,37802.0,30715.285714,LSTM,14623.299379,True,21795.827303,23178.700621,8919.458411,24699.056279
1858,Italy,,2020-11-07,902490.0,Italy /,39809.0,31865.714286,LSTM,16496.84682,True,19615.948277,23312.15318,12249.766008,36948.822288
1859,Italy,,2020-11-08,935104.0,Italy /,32614.0,32252.714286,LSTM,19187.563717,True,19901.046656,13426.436283,12351.667629,49300.489917


## Specific country (group by)

In [397]:
ranking_df[(ranking_df.CountryName == "United States") & (ranking_df.RegionName == "")].groupby(["PredictorName"]).Diff7DMA.sum().sort_values()

PredictorName
LSTM    436962.330824
Name: Diff7DMA, dtype: float64

## Specific region

In [398]:
cr_df[(cr_df.CountryName == "United States") & (cr_df.RegionName == "California")]

Unnamed: 0,CountryName,RegionName,PredictorName,Diff7DMA
209,United States,California,LSTM,111181.400659


## Continent

In [399]:
NORTH_AMERICA = ["Canada", "United States", "Mexico"]

In [400]:
cr_df[(cr_df.CountryName.isin(NORTH_AMERICA)) & (cr_df.RegionName == "")].groupby('PredictorName').Diff7DMA.sum().sort_values().reset_index()

Unnamed: 0,PredictorName,Diff7DMA
0,LSTM,474954.610557


In [401]:
cr_df[(cr_df.CountryName.isin(NORTH_AMERICA)) & (cr_df.RegionName == "")]

Unnamed: 0,CountryName,RegionName,PredictorName,Diff7DMA
57,Canada,,LSTM,15655.181426
134,Mexico,,LSTM,22337.098308
204,United States,,LSTM,436962.330824


# Plots

In [402]:
ALL_GEO = "Overall"
DEFAULT_GEO = ALL_GEO

## Prediction vs actual

In [403]:
predictor_names = list(ranking_df.PredictorName.dropna().unique())
geoid_names = list(ranking_df.GeoID.unique())

## Filter by country

In [404]:
all_df = ranking_df.groupby(["PredictorName", "Date"])[["GeoID", "PredictorName", "PredictedDailyNewCases7DMA"]].sum(). \
    sort_values(by=["PredictorName", "Date"]).reset_index()
all_df

Unnamed: 0,PredictorName,Date,PredictedDailyNewCases7DMA
0,LSTM,2020-11-01,663422.498945
1,LSTM,2020-11-02,821837.749368
2,LSTM,2020-11-03,862053.300967
3,LSTM,2020-11-04,879613.547888
4,LSTM,2020-11-05,879821.609165
5,LSTM,2020-11-06,870208.410581
6,LSTM,2020-11-07,881150.726877
7,LSTM,2020-11-08,866514.862614


In [405]:
import plotly.graph_objects as go

fig = go.Figure(layout=dict(title=dict(text=f"{DEFAULT_GEO} Daily New Cases 7-day Average ",
                                       y=0.9,
                                       x=0.5,
                                       xanchor='center',
                                       yanchor='top'
                                       ),
                             plot_bgcolor='#f2f2f2',
                             xaxis_title="Date",
                             yaxis_title="Daily new cases 7-day average"
                             ))

# Keep track of trace visibility by geo ID name
geoid_plot_names = []

all_df = ranking_df.groupby(["PredictorName", "Date"])[["GeoID", "PredictorName", "PredictedDailyNewCases7DMA"]].sum(). \
    sort_values(by=["PredictorName", "Date"]).reset_index()

# Add 1 trace per predictor, for all geos
for predictor_name in predictor_names:
    all_geo_df = all_df[all_df.PredictorName == predictor_name]
    fig.add_trace(go.Scatter(x=all_geo_df.Date,
                             y=all_geo_df.PredictedDailyNewCases7DMA,
                             name=predictor_name,
                             visible=(ALL_GEO == DEFAULT_GEO))
                 )
    geoid_plot_names.append(ALL_GEO)

# Add 1 trace per predictor, per geo id
for predictor_name in predictor_names:
    for geoid_name in geoid_names:
        pred_geoid_df = ranking_df[(ranking_df.GeoID == geoid_name) &
                                   (ranking_df.PredictorName == predictor_name)]
        fig.add_trace(go.Scatter(x=pred_geoid_df.Date,
                                 y=pred_geoid_df.PredictedDailyNewCases7DMA,
                                 name=predictor_name,
                                 visible=(geoid_name == DEFAULT_GEO))
                     )
        geoid_plot_names.append(geoid_name)

# For each geo
# Add 1 trace for the true number of cases
for geoid_name in geoid_names:
    geo_actual_df = actual_df[(actual_df.GeoID == geoid_name) &
                                  (actual_df.Date >= start_date)]
    fig.add_trace(go.Scatter(x=geo_actual_df.Date,
                             y=geo_actual_df.ActualDailyNewCases7DMA,
                             name="Ground Truth",
                             visible= (geoid_name == DEFAULT_GEO),
                             line=dict(color='orange', width=4, dash='dash'))
                  )
    geoid_plot_names.append(geoid_name)
    
# Add 1 trace for the overall ground truth
overall_actual_df = actual_df[actual_df.Date >= start_date].groupby(["Date"])[["GeoID", "ActualDailyNewCases7DMA"]].sum(). \
    sort_values(by=["Date"]).reset_index()
fig.add_trace(go.Scatter(x=overall_actual_df.Date,
                         y=overall_actual_df.ActualDailyNewCases7DMA,
                         name="Ground Truth",
                         visible= (ALL_GEO == DEFAULT_GEO),
                         line=dict(color='orange', width=4, dash='dash'))
                  )
geoid_plot_names.append(geoid_name)

# Format x axis
fig.update_xaxes(
dtick="D1",  # Means 1 day
tickformat="%d\n%b")

# Filter
buttons=[]
for geoid_name in ([ALL_GEO] + geoid_names):
    buttons.append(dict(method='update',
                        label=geoid_name,
                        args = [{'visible': [geoid_name==r for r in geoid_plot_names]},
                                {'title': f"{geoid_name} Daily New Cases 7-day Average "}]))
fig.update_layout(showlegend=True,
                  updatemenus=[{"buttons": buttons,
                                "direction": "down",
                                "active": ([ALL_GEO] + geoid_names).index(DEFAULT_GEO),
                                "showactive": True,
                                "x": 0.1,
                                "y": 1.15}])

fig.show()

## Rankings: by cumulative 7DMA error

In [406]:
ranking_fig = go.Figure(layout=dict(title=dict(text=f'{DEFAULT_GEO} submission rankings',
                                               y=0.9,
                                               x=0.5,
                                               xanchor='center',
                                               yanchor='top'
                                               ),
                                    plot_bgcolor='#f2f2f2',
                                    xaxis_title="Date",
                                    yaxis_title="Cumulative 7DMA error"
                                    ))

# Keep track of trace visibility by geo name
ranking_geoid_plot_names = []

all_df = ranking_df.groupby(["PredictorName", "Date"])[["GeoID", "PredictorName", "CumulDiff7DMA"]].sum(). \
    sort_values(by=["PredictorName", "Date"]).reset_index()

# Add 1 trace per predictor, for all geos
for predictor_name in predictor_names:
    ranking_geoid_df = all_df[all_df.PredictorName == predictor_name]
    ranking_fig.add_trace(go.Scatter(x=ranking_geoid_df.Date,
                             y=ranking_geoid_df.CumulDiff7DMA,
                             name=predictor_name,
                             visible=(ALL_GEO == DEFAULT_GEO))
                 )
    ranking_geoid_plot_names.append(ALL_GEO)


# Add 1 trace per predictor, per country
for predictor_name in predictor_names:
    for geoid_name in geoid_names:
        ranking_geoid_df = ranking_df[(ranking_df.GeoID == geoid_name) &
                                        (ranking_df.PredictorName == predictor_name)]
        ranking_fig.add_trace(go.Scatter(x=ranking_geoid_df.Date,
                                 y=ranking_geoid_df.CumulDiff7DMA,
                                 name=predictor_name,
                                 visible= (geoid_name == DEFAULT_GEO))
                     )
        ranking_geoid_plot_names.append(geoid_name)

# Format x axis
ranking_fig.update_xaxes(
dtick="D1",  # Means 1 day
tickformat="%d\n%b")

# Filter
buttons=[]
for geoid_name in ([ALL_GEO] + geoid_names):
    buttons.append(dict(method='update',
                        label=geoid_name,
                        args = [{'visible': [geoid_name==r for r in ranking_geoid_plot_names]},
                                {'title': f'{geoid_name} submission rankings'}]))
ranking_fig.update_layout(showlegend=True,
                          updatemenus=[{"buttons": buttons,
                                        "direction": "down",
                                        "active": ([ALL_GEO] + geoid_names).index(DEFAULT_GEO),
                                        "showactive": True,
                                        "x": 0.1,
                                        "y": 1.15}])

ranking_fig.show()