In [9]:
# %load_ext autoreload
# %autoreload 2

import json
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import requests
import pandas as pd
import altair as alt

if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir('..')

plt.style.use('ggplot')
alt.themes.enable('fivethirtyeight')
CHARTS_DIR = Path('../covid19-analysis/layouts/partials/covid')

In [2]:
from fetch import fetch_timeseries, TS_URL

df = fetch_timeseries(TS_URL)
df_long = df.stack().rename('count').rename_axis(index={None: 'status'})

display(df.head())
display(df_long.head())

# Timeseries

Unnamed: 0_level_0,Unnamed: 1_level_0,confirmed,deaths,recovered
country,date,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Afghanistan,2020-01-22,0,0,0
Afghanistan,2020-01-23,0,0,0
Afghanistan,2020-01-24,0,0,0
Afghanistan,2020-01-25,0,0,0
Afghanistan,2020-01-26,0,0,0


country      date        status   
Afghanistan  2020-01-22  confirmed    0
                         deaths       0
                         recovered    0
             2020-01-23  confirmed    0
                         deaths       0
Name: count, dtype: int64

In [163]:
from IPython.display import display
from importlib import reload
import charts
import render
reload(charts)
reload(render)
from charts import *
from render import make_chart

# alt.data_transformers.enable('default', max_rows=None)
alt.data_transformers.enable('data_server')

data_long = make_data_long(df_long)
dod_long = make_dod(df_long).reset_index()

base_ts =  (alt.Chart(data_long).encode(x='date:T'))
selection_legend, selection_tooltip = make_ts_selections()
ts_chart = make_ts_chart(base_ts, sorted(dod_long.status.unique()), selection_legend, selection_tooltip)

map_data = make_map_data(data_long, countries)
map_chart = make_map(map_data, status_schemes)

dod_chart = make_dod_chart(dod_long)
chart = combine_map_ts(map_chart, ts_chart, dod_chart, selection_legend)

chart

In [161]:
selection_legend == 'confirmed'

False

## Attempt: status selection in heatmap 

In [136]:
map_data.set_index(['country', 'id', 'date', 'status']).drop('day', axis=1).unstack()['count'].reset_index()

status,country,id,date,confirmed,deaths,recovered
0,Afghanistan,4,2020-04-02,273.0,6.0,10.0
1,Albania,8,2020-04-02,277.0,16.0,76.0
2,Algeria,12,2020-04-02,986.0,86.0,61.0
3,Andorra,20,2020-04-02,428.0,15.0,10.0
4,Angola,24,2020-04-02,8.0,2.0,1.0
...,...,...,...,...,...,...
169,Uzbekistan,860,2020-04-02,205.0,2.0,25.0
170,Venezuela,862,2020-04-02,146.0,5.0,43.0
171,Vietnam,704,2020-04-02,233.0,,75.0
172,Zambia,894,2020-04-02,39.0,1.0,


## Attempt: day slider in heatmap

In [119]:
countries = alt.topo_feature(data.world_110m.url, 'countries')

china_data = map_data.query('country == "China"').query('status == "confirmed"')
min_day = china_data.day.min()
max_day = china_data.day.max()
# china_data = china_data.pivot(index='fips', columns='year', values='Pill_per_pop').reset_index()
china_data = china_data.set_index(['id', 'day'])['count'].unstack().reset_index()
china_data.columns = china_data.columns.map(str)
columns = list(china_data.columns.difference({'id'}))

slider = alt.binding_range(min=min_day,
                           max=max_day,
                           step=1)

select_day = alt.selection_single(name='day',
                                   fields=['day'],
                                   bind=slider,
                                #   on='none',
                                 init={'day': min_day}
                                 )

china_map = (alt
             .Chart(countries)
             .encode(
                 tooltip=['count:Q', 
                          'country:N', 
                          'day:Q'])
             .mark_geoshape()
             .encode(
                 color=alt.Color('count:Q', scale=alt.Scale(scheme='reds'))
             )
             .transform_lookup(
                 lookup='id',
                 from_=alt.LookupData(data=china_data,
                                      key='id', 
                                      fields=columns)
             ).transform_fold(
                 columns, as_=['day', 'count']
             ).add_selection(select_day)
             .transform_filter(select_day)
       )

china_map

In [109]:
china_data

day,id,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93
0,156,548.0,643.0,920.0,1406.0,2075.0,2877.0,5509.0,6087.0,8141.0,9802.0,11891.0,16630.0,19716.0,23707.0,27440.0,30587.0,34110.0,36814.0,39829.0,42354.0,44386.0,44759.0,59895.0,66358.0,68413.0,70513.0,72434.0,74211.0,74619.0,75077.0,75550.0,77001.0,77022.0,77241.0,77754.0,78166.0,78600.0,78928.0,79356.0,79932.0,80136.0,80261.0,80386.0,80537.0,80690.0,80770.0,80823.0,80860.0,80887.0,80921.0,80932.0,80945.0,80977.0,81003.0,81033.0,81058.0,81102.0,81156.0,81250.0,81305.0,81435.0,81498.0,81591.0,81661.0,81782.0,81897.0,81999.0,82122.0,82198.0,82279.0,82361.0,82432.0


In [116]:
columns

['22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92', '93']

In [117]:
china_data.columns

Index(['id', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32',
       '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44',
       '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56',
       '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68',
       '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80',
       '81', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92',
       '93'],
      dtype='object', name='day')

# Correlations

In [11]:
import seaborn as sns

df.groupby('country').apply(lambda f: f.fillna(method='ffill')).fillna(0).corr().pipe(sns.heatmap, annot=True)

<matplotlib.axes._subplots.AxesSubplot object at 0x7ff456b3cbe0>

In [12]:
df.groupby(['date']).sum().corr()

Unnamed: 0,confirmed,deaths,recovered
confirmed,1.0,0.997479,0.94617
deaths,0.997479,1.0,0.932249
recovered,0.94617,0.932249,1.0


In [14]:
def correlation_lags(df, column='deaths', max_lag=20, group=False):
    def series_corr(f):
        return pd.Series({t: f['confirmed'].shift(t).iloc[:-20].corr(f[column]) for t in range(max_lag)})
    if group:
        return df.groupby('country').apply(series_corr).idxmax(axis=1)
    else:
        return df.pipe(series_corr).idxmax(axis=0)

world_ts = df.groupby('date').sum()
days_to_death = correlation_lags(world_ts, 'deaths')
days_to_recov = correlation_lags(world_ts, 'recovered')

In [18]:
days_to_death, days_to_recov

(1, 19)

In [20]:
fatality_rates = (df.join(correlation_lags(df, group=True, column='deaths').rename('days_to_death').fillna(0).astype(int))
                  .groupby('country')
                  .apply(lambda g: g.deaths.div(g.confirmed.shift(g['days_to_death'].iloc[0])))
                  .reset_index(0, drop=True)
                 )
fatality_rates.groupby('date').mean().plot()

<matplotlib.axes._subplots.AxesSubplot object at 0x7f472639bd68>

In [23]:
world_ts.deaths.div(world_ts.confirmed.shift(days_to_death)).plot()

<matplotlib.axes._subplots.AxesSubplot object at 0x7f472639bd68>