<a href="https://colab.research.google.com/github/bobby-mclaughlinjr/covid/blob/master/Rt*.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import pandas as pd
import numpy as np
import datetime as d
from ipywidgets import interact, interactive, IntSlider, FloatSlider, fixed, HBox, VBox, Label, Output, Button, Dropdown
import matplotlib.pyplot as plt
from matplotlib import dates as mdates

In [0]:
URL = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv'
COLUMNS = {'Province/State': 'area', 'Country/Region': 'region', 'Lat': 'latitude', 'Long': 'longitude'}

In [0]:
# Default parameters
ASYMPTOMATIC = 0.75
R0 = 2.2
INCUBATION_DURATION = 5.2
INFECTION_DURATION = 2.9

In [0]:
class Rt(object):

    def __init__(self
                 , R0=R0
                 , incubation_duration=INCUBATION_DURATION
                 , infection_duration=INFECTION_DURATION
                 , asymptomatic=ASYMPTOMATIC
                 , data=None
                 ):

        self.R0 = R0
        self.incubation_duration = incubation_duration
        self.infection_duration = infection_duration
        self.asymptomatic = asymptomatic

        self.data = data
        self.transformed_data = None
        
        self.incubation_duration_slider = interactive(self.set_and_get_solution
                                                  , value=FloatSlider(min=0.1
                                                                      , max=10
                                                                      , step=0.1
                                                                      , value=INCUBATION_DURATION
                                                                      , continuous_update=False
                                                                      , description='Incubation Duration (Tinc)'
                                                                      , style={'description_width': 'initial'}
                                                                      )
                                                  , name=fixed('incubation_duration')
                                                  )
        
        self.infection_duration_slider = interactive(self.set_and_get_solution
                                                  , value=FloatSlider(min=0.1
                                                                     , max=10
                                                                     , step=0.1
                                                                     , value=INFECTION_DURATION
                                                                     , continuous_update=False
                                                                     , description='Infection Duration (Tinf)'
                                                                     , style={'description_width': 'initial'}
                                                                   )
                                                  , name=fixed('infection_duration')
                                                  )
        
        
        #self.reset_button = Button(description='Reset')
        #self.reset_button.on_click(self.reset)
        
        self.output = Output()
        self.interact = VBox([self.incubation_duration_slider, self.infection_duration_slider, self.output])

    @staticmethod
    def smoothing(X, window=7):
        return X.rolling(window=window).mean()
    
    def set_property(self, name, value=None):
        setattr(self, name, value)
    
    def set_and_get_solution(self, name=None, value=None):
        if name is not None:
            self.set_property(name=name, value=value)
        
        self.output.clear_output(wait=True)  
        with self.output as f:
            display(self.re_calculate())
            self.plot()
    
    def reset(self, _):
        self.incubation_duration = INCUBATION_DURATION
        self.infection_duration = INFECTION_DURATION

        self.set_and_get_solution()
            
    def get_data(self, source='CSSE'):
        df = pd.read_csv(URL).rename(columns=COLUMNS)
        data = pd.melt(df, id_vars=COLUMNS.values(), var_name='date', value_name='cases')
        data['date'] = [d.datetime.strptime(str(date), '%m/%d/%y') for date in data['date']]

        self.data = data.sort_index().groupby(['region', 'date'])['cases'].sum().loc[['US', 'Spain', 'Italy'], :]

        return self

    def assumed(self, infectious, Rt, shift=7):
        return (Rt / self.infection_duration) * infectious.shift(shift) * (1 - self.asymptomatic)

    def re_calculate(self, smoothing=7):
        return self.calculate().get_historical()

    def calculate(self, smoothing=7):
        if self.data is None:
            self.get_data()
            
        self.transformed_data = self.data.copy(deep=True).groupby(level=0).apply(self._calculate, smoothing=smoothing)
        return self

    def _calculate(self, data, smoothing=7):
        data = data.to_frame('cases').reset_index(level=0, drop=True)

        data['total_cases'] = data['cases'] / (1 - self.asymptomatic)
        data['asymptomatic_cases'] = data['total_cases'] - data['cases']

        data['new_cases'] = data['cases'].diff()
        data['new_total_cases'] = data['total_cases'].diff()
        data['new_total_cases_shift'] = data['new_total_cases'].shift(13)

        # Rt
        data['infectious'] = (data['new_total_cases'] - data['new_total_cases'].shift(3)).expanding().apply(lambda x: np.nansum(x))

        data['Rt^'] = self.smoothing(data['new_total_cases'].rolling(window=3).mean() * self.infection_duration / data['infectious'].rolling(window=3).mean().shift(8), window=smoothing)
        data['Rt*'] = self.smoothing(-(data['total_cases'] - data['new_total_cases_shift']) / ((data['total_cases'] - data['new_total_cases_shift']).shift(1)) * np.log(1 / self.incubation_duration) - 1, window=smoothing)
        data['Rt'] = data[['Rt^', 'Rt*']].mean(axis=1)

        data['ex_ante'] = ((data['Rt'] / self.infection_duration) * data['infectious'].shift(6) * (1 - self.asymptomatic)).shift()
        data['ex_post'] = data['new_cases'].shift(-1)

        return data

    def get_historical(self, date=None):
        if self.data is None:
            self.get_data().calculate()
            
        results = self.transformed_data.copy(deep=True).groupby(level=0).apply(self._get_historical, date=date)
        return pd.DataFrame(results.tolist(), index=results.index)

    @staticmethod
    def _get_historical(data, date=None):
        data.reset_index(level=0, drop=True, inplace=True)

        if date is None:
            current = data.iloc[-1]
            last = data.iloc[-2]
        else:
            loc = data.index.get_loc(date)
            current = data.iloc[loc]
            last = data.iloc[loc - 1]

        return {'Rt': round(current['Rt'], 2)
                , 'Last Rt': round(last['Rt'], 2)
                , 'Delta': round(current['Rt'] - last['Rt'], 3)
                , 'Projected': int(current['ex_ante'])
                , 'Last Projected': int(last['ex_ante'])
                , 'Ex Post': int(last['ex_post'])
                , 'Error': int(last['ex_post'] - last['ex_ante'])
                , 'Error %': round((last['ex_ante'] - last['ex_post']) / last['ex_post'] * 100, 1)
                }
    
    def plot(self):
        data = self.transformed_data.copy(deep=True).reset_index(level=0)[['region', 'ex_ante', 'ex_post']].loc['2020-03-01':]
            
        dropdown_widget = Dropdown(options=list(set(data['region']))
                                     , value='US'
                                     , disabled=False
                                     , description='Region:'
                                    )
        
        def plot_it(region):
            region_data = data[data['region'] == region].drop('region', axis=1).rename(columns={'ex_ante': 'Ex Ante', 'ex_post': 'Ex Post'})
            
            fig, ax = plt.subplots()
            ax.plot(region_data.index, region_data['Ex Ante'], color='b')
            ax.plot(region_data.index, region_data['Ex Post'], color='darkorange')
            ax.set_ylabel('New Cases')
            ax.set_xlabel('')
            
            ax2 = ax.twinx()
            ax2.bar(region_data.index, ((region_data['Ex Ante'] - region_data['Ex Post']) / region_data['Ex Post']), color='r', alpha=0.1)
            ax2.set_ylabel('Error %')
            ax2.set_yticklabels(['{:,.0%}'.format(y) for y in ax2.get_yticks()])
            
            ax2.xaxis.set_major_locator(mdates.MonthLocator())
            ax2.xaxis.set_major_formatter(mdates.DateFormatter('%b'))
            ax.xaxis.set_minor_locator(mdates.DayLocator())
            
            #ax2.set_ylim(left=-2)
            ax2.grid(which='major', axis='y', c='k', alpha=0.1, zorder=-2)
            ax2.axhline(0, linestyle=':', color='r', lw=1)
        
        display(interactive(plot_it, region=dropdown_widget))


In [9]:
obj = Rt()
obj.interact

VBox(children=(interactive(children=(FloatSlider(value=5.2, continuous_update=False, description='Incubation D…