In [1]:
import plotly.graph_objects as go
import plotly
from plotly.subplots import make_subplots
import pandas as pd
import urllib.request
import os
import re
import glob
import datetime
import dateutil.relativedelta
import numpy as np
from ipywidgets import widgets

from IPython.display import display

import matplotlib.pyplot as plt

In [18]:
def row_sum(arow):
    i = 0
    for x in arow:
        if type(arow[x]) == float:
            i+=arow[x]
    return i


def refresh_data():

    
    url = 'https://github.com/CSSEGISandData/COVID-19/tree/master/csse_covid_19_data/csse_covid_19_daily_reports'
    response = urllib.request.urlopen(url)
    data = response.read()      # a `bytes` object
    text = data.decode('utf-8') # a `str`; this step can't be used if data is binary

    [os.system('rm {}'.format(i)) for i in glob.glob("daily_data/*csv*")]

    for x in set(re.findall("[0-9]+\-[0-9]+\-20[0-9]+", text)):
        os.system("wget https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_daily_reports/{}.csv  -P daily_data/".format(x))
        
    data_files = glob.glob('daily_data/*.csv')
    
    summaries = []
    deaths = []
    
    for entry in data_files:
        temp = {}
        death = {}
        x = pd.read_csv(entry)
        date = (entry.split('/')[1].split('.')[0]).split('-')
        date =  datetime.datetime.fromisoformat("-".join([date[2],date[0], date[1]]))
        x['date'] = date
        temp['date'] = date
        death['date'] = date
        
        # country coding varies as of 3-11-2020
        
        x['Country/Region'] = x['Country/Region'].apply(lambda i: "Iran" if "Iran" in i else i)
        x['Country/Region'] = x['Country/Region'].apply(lambda i: "South Korea" if "Republic of Korea" == i else i)
        
        for country in x['Country/Region'].unique():
            temp[country] = sum( x[x['Country/Region']==country]['Confirmed'])
            death[country] = sum(x[x['Country/Region']==country]['Deaths'])
        deaths.append(death)
        summaries.append(temp)

    death_table = pd.DataFrame(deaths).sort_values(by="date").fillna(0)
    summary_table = pd.DataFrame(summaries).sort_values(by="date").fillna(0)


    return death_table, summary_table


In [19]:
death_table, summary_table = refresh_data()

In [20]:
death_table['global'] = death_table.apply(lambda x: sum([i for i in x if type(i) == float]), axis=1)
summary_table['global'] = summary_table.apply(lambda x: sum([i for i in x if type(i) == float]), axis=1)

all_countries = ['global'] + sorted([i for i in death_table.columns if i not in ["date", 'global']])
col_entry = summary_table.tail(1).T.columns[0]
all_countries = summary_table[all_countries].tail(1).T.sort_values(by=col_entry, ascending=False).index.to_list()

In [21]:
# Create figure
fig = make_subplots(rows=2, cols=2, shared_xaxes=True, subplot_titles=["","South Korea - Linear","", "South Korea - Log"])


n_traces = 12

# Add surface trace
for country in all_countries:
    if country == 'global':
        viz = True
    else:
        viz = False
        
    ## core scatter lots
    
    fig.add_trace(go.Scatter(
        x=death_table["date"], y=death_table[country],
        name='death', visible=viz, line=dict(color="#3498DB"),
        mode='lines+markers', showlegend=False),
        row=1, col=1
    )

    fig.add_trace(go.Scatter(
        x=summary_table["date"], y=summary_table[country],
        name='confirmed', visible=viz, line=dict(color="#FF0000"),
        mode='lines+markers', showlegend=False),
        row=1, col=1
    )

    fig.add_trace(go.Scatter(
        x=death_table["date"], y=np.log10(death_table[country]),
        name='log10 death - {}'.format(country), visible=viz, line=dict(color="#3498DB"),
        mode='lines+markers', showlegend=False),
        row=2, col=1
    )

    fig.add_trace(go.Scatter(
        x=summary_table["date"], y=np.log10(summary_table[country]),
        name='log10 confirmed - {}'.format(country), visible=viz, line=dict(color="#FF0000"),
        mode='lines+markers', showlegend=False),
        row=2, col=1
    )
    
    ## additional analysis
    
    x_min = min(summary_table['date'])
    x_max = max(summary_table['date'])
        
    pct_05 = max(summary_table[country]) * 0.005
    pct_1 = max(summary_table[country]) * 0.01
    pct_3 = max(summary_table[country]) * 0.03
    pct_5 = max(summary_table[country]) * 0.05
    
    
    for pct, h_color, val in zip([pct_05, pct_1, pct_3, pct_5], 
                                 ["#D2B4DE", "#AED6F1", "#ABEBC6", "#F9E79F"],
                                 ['0.05%', '1%', '3%', '5%']
                                ):
        
        fig.add_trace(go.Scatter(x=[x_min, x_max], y=[pct, pct],
                                 mode='lines', name=val,  visible=viz, 
                                 line=dict(color=h_color, width=3), showlegend=False), row=1, col=1)
        
        fig.add_trace(go.Scatter(x=[x_min, x_max], y=[np.log10(pct), np.log10(pct)],
                                 mode='lines', name=val,  visible=viz, 
                                 line=dict(color=h_color, width=3), showlegend=False), row=2, col=1)

fig.add_trace(go.Scatter(
    x=death_table["date"], y=death_table["South Korea"],
    name='death', visible=True, line=dict(color="#3498DB"),
    mode='lines+markers', showlegend=True),
    row=1, col=2
)

fig.add_trace(go.Scatter(
    x=summary_table["date"], y=summary_table["South Korea"],
    name='confirmed', visible=True, line=dict(color="#FF0000"),
    mode='lines+markers', showlegend=True),
    row=1, col=2
)

fig.add_trace(go.Scatter(
    x=death_table["date"], y=np.log10(death_table["South Korea"]),
    name='log10 death - SK', visible=True, line=dict(color="#3498DB"),
    mode='lines+markers', showlegend=False),
    row=2, col=2
)

fig.add_trace(go.Scatter(
    x=summary_table["date"], y=np.log10(summary_table["South Korea"]),
    name='log10 confirmed - SK', visible=True, line=dict(color="#FF0000"),
    mode='lines+markers', showlegend=False),
    row=2, col=2
)

## additional analysis - sk Specific    
pct_05 = max(summary_table["South Korea"]) * 0.005
pct_1 = max(summary_table["South Korea"]) * 0.01
pct_3 = max(summary_table["South Korea"]) * 0.03
pct_5 = max(summary_table["South Korea"]) * 0.05


for pct, h_color, val in zip([pct_05, pct_1, pct_3, pct_5], 
                             ["#D2B4DE", "#AED6F1", "#ABEBC6", "#F9E79F"],
                             ['0.5%', '1%', '3%', '5%']
                            ):

    fig.add_trace(go.Scatter(x=[x_min, x_max], y=[pct, pct],
                             mode='lines', name=val,  visible=True, 
                             line=dict(color=h_color, width=3), showlegend=True), row=1, col=2)

    fig.add_trace(go.Scatter(x=[x_min, x_max], y=[np.log10(pct), np.log10(pct)],
                             mode='lines', name=val,  visible=True, 
                             line=dict(color=h_color, width=3), showlegend=False), row=2, col=2)


# Update plot sizing
fig.update_layout(
    width=800,
    height=900,
    autosize=False,
    margin=dict(t=100, b=0, l=0, r=0),
    title_text="Global vs South Korea"
)

fig.update_yaxes(title_text="Linear infected and deceased", row=1, col=1)
fig.update_yaxes(title_text="Log10 infected and deceased", row=2, col=1)


fig.update_layout(
    updatemenus=[
        dict(
            buttons=[
                dict(
                    args=[{"visible": [i==x for i in all_countries for _ in range(n_traces)] + [True] * n_traces},
                          {"title": "{} vs South Korea".format(x)},
                          {"subplot_titles": [x, "", "SK", ""]}
                         ],
                    label=x,
                    method="update",
                   # country=x
                    
                ) for x in all_countries
            ],
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0,
            xanchor="left",
            y=1.2,
            yanchor="top"
        ),
    ]
)


fig.show()


divide by zero encountered in log10



In [None]:
plotly.offline.plot(fig, filename = 'test-covid.html', auto_open=False)

In [17]:
summary_table

Unnamed: 0,date,Mainland China,South Korea,Others,Italy,Japan,Iran,Singapore,Hong Kong,US,...,Ukraine,Palestine,Vatican City,Ivory Coast,Azerbaijan,North Ireland,St. Martin,Moldova,Republic of Ireland,global
28,2020-01-22,0.0,1.0,0.0,0.0,2.0,0.0,0.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,8.0
29,2020-01-23,0.0,1.0,0.0,0.0,1.0,0.0,1.0,2.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,14.0
39,2020-01-24,0.0,2.0,0.0,0.0,2.0,0.0,3.0,2.0,2.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,25.0
38,2020-01-25,1399.0,2.0,0.0,0.0,2.0,0.0,3.0,5.0,2.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1438.0
22,2020-01-26,2062.0,3.0,0.0,0.0,4.0,0.0,4.0,8.0,5.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2118.0
23,2020-01-27,2863.0,4.0,0.0,0.0,4.0,0.0,5.0,8.0,5.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,2927.0
34,2020-01-28,5494.0,4.0,0.0,0.0,7.0,0.0,7.0,8.0,5.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,5578.0
35,2020-01-29,6070.0,4.0,0.0,0.0,7.0,0.0,7.0,10.0,5.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,6165.0
17,2020-01-30,8124.0,4.0,0.0,0.0,11.0,0.0,10.0,10.0,5.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,8235.0
16,2020-01-31,9783.0,11.0,0.0,2.0,15.0,0.0,13.0,12.0,6.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,9925.0
