In [175]:
import data
import pandas as pd
import math
import plotly.graph_objects as go
import numpy as np
#from tools import write_pandas_to_gsheet

MERGE_NO_US, MERGED_CSBS_JHU, JHU_TIME, JHU_RECENT, DATE_MAPPER, CSBS, CENTROID_MAPPER = data.get_data() 

In [188]:
JHU_DF_AGG_COUNTRY = JHU_TIME.sort_values('confirmed')[::-1].groupby(['Date','country']).agg(
            {'lat': 'first', 'lon': 'first', 'confirmed': 'sum', 'deaths': 'sum'}).reset_index()

JHU_DF_AGG_PROVINCE = JHU_TIME[JHU_TIME['province']!=''].sort_values('confirmed')[::-1].groupby(['Date','province']).agg(
            {'lat': 'first', 'lon': 'first', 'confirmed': 'sum', 'deaths': 'sum'}).reset_index()

CSBS_DF_AGG_STATE = CSBS[CSBS['province'] != ''].sort_values('confirmed')[::-1].groupby(['Date','province']).agg(
            {'lat': 'first', 'lon': 'first', 'confirmed': 'sum', 'deaths': 'sum'}).reset_index().rename({'province': 'state'}, axis=1)

CSBS_DF_AGG_COUNTY = CSBS[CSBS['county'] != ''].sort_values('confirmed')[::-1].groupby(['Date','county']).agg(
            {'lat': 'first', 'lon': 'first', 'confirmed': 'sum', 'deaths': 'sum'}).reset_index()

In [463]:
##Lets find out exponential 
def plot_countries(countries,backtrack=7,log=True):
    fig = go.Figure()
    
    colors = plotly.colors.qualitative.Prism 
    max_number = 0
    for country_enum, country in enumerate(countries):
        full_report = JHU_DF_AGG_COUNTRY[JHU_DF_AGG_COUNTRY['country']==country].groupby('Date').sum().drop(['lat','lon'],axis=1)
        per_day = full_report.diff()
        plottable = full_report.join(per_day,lsuffix='_cum',rsuffix='_diff')
        plottable = plottable.fillna(0)

        xs = []
        ys = []
        dates = []
        indexes = plottable.index
        for indexer in range(1,len(indexes)):


            x = plottable.loc[indexes[indexer]]['confirmed_cum']

            if indexer > backtrack:
                y = plottable.loc[indexes[indexer-backtrack]:indexes[indexer]].sum()['confirmed_diff']
            else:  
                y = plottable.loc[:indexes[indexer]].sum()['confirmed_diff']

            if y < 100 or x < 100:
                continue

            if x > max_number:
                max_number = x
            if y > max_number:
                max_number = y
            xs.append(x)
            ys.append(y)
            
            dates.append(indexes[indexer].strftime('%m/%d/%Y'))
            #print(indexes[indexer].strftime('%m/%d/%Y'))
    
        

        fig.add_trace(
            go.Scatter(
                x=xs,
                y=ys,
                mode='lines',
                name=country,
                text=dates,
                showlegend=False,
                legendgroup=country,
                line=dict(shape='linear',color=colors[country_enum]),
                marker=dict(
                    symbol='circle-open',
                    size=5
                    ),
               hovertemplate="On %{text} <br> Total Cases: %{x}<br> Cummulative Cases Last Week %{y}"
            )
        )
        fig.add_trace(
            go.Scatter(
                x=[xs[-1]],
                y=[ys[-1]],
                mode='markers',
                name=country,
                text=[dates[-1]],
                legendgroup=country,
                hoverlabel=dict(align='left'),
                marker=dict(
                    symbol='circle',
                    size=14,
                    color=colors[country_enum]
                    ),
                hovertemplate="On %{text} <br> Total Cases: %{x}<br> Cummulative Cases Last Week %{y}"
            )
        )
    fig.add_trace(
        go.Scatter(
            x=[100,max_number],
            y=[100,max_number],
            mode='lines',
            name='Exponential',
            line=dict(color='grey',width=4,dash='dash')
        )
    )
    if log:
        fig.update_xaxes(type="log",dtick=1)
        fig.update_yaxes(type="log",dtick=1)
    fig.update_layout(
    xaxis=dict(
        showline=True,
        showgrid=False,
        showticklabels=True,
        linecolor='black',
        linewidth=3,
        title=dict(text='All Cases',font=dict(color='black',family='arial',size=20)),

        tickfont=dict(
            family='Arial',
            size=16,
            color='rgb(82, 82, 82)',
        ),
    ),
    yaxis=dict(
        showline=True,
        showgrid=False,
        showticklabels=True,
        linecolor='black',
        linewidth=3,
        title=dict(text='New Cases Past Week',font=dict(color='black',family='arial',size=20)),
        tickfont=dict(
            family='Arial',
            size=16,
            color='rgb(82, 82, 82)',
        ),
    ),
    plot_bgcolor='white'
    )
    annotations=[]
    annotations.append(dict(xref='paper', x=0.9, yref='paper',y=.95,
                                  #xanchor='center', yanchor='middle',
                                  text="Exponential Growth",
                                  font=dict(family='Arial',
                                            size=20,
                                           color='grey'),
                                  showarrow=True,
                                  arrowhead=2))

    fig.update_layout(legend=dict(title='Click to Toggle'),annotations=annotations)
  
    return fig
    
fig = plot_countries(['US','China','France','Mexico','Germany','Italy','Spain'],7,log=False)

fig.show()


In [439]:
fig.strftime('%m/%d/%Y')

'03/05/2020'