In [1]:
import ipywidgets as widgets
import qgrid

from scipy.stats import binom
import matplotlib.pyplot as plt
from ipywidgets import interact, Dropdown
import pandas as pd
import csv
import numpy as np
import datetime
import re
import warnings
from IPython.display import HTML, display
import requests
import sys
from scipy.optimize import curve_fit
from math import pi
from scipy import stats
from scipy.stats import lognorm

from ipywidgets import Layout, Button, VBox, HBox, Box, FloatText, Textarea, Dropdown, Label, IntSlider, FloatSlider

warnings.filterwarnings('ignore')
#plt.ioff()

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

In [2]:

def logistic(x, a, b, c):
    return a / (np.exp(-c * x**3) + b)

#def exponential(x, a, b, c):
#    return a * np.exp(b * x) + c
        

def obs_pred_rsquare(obs, pred):
    return 1 - sum((obs - pred) ** 2) / sum((obs - np.mean(obs)) ** 2)


def get_logistic(obs_x, obs_y, ForecastDays):

    obs_x = np.array(obs_x)
    for i, val in enumerate(obs_y):
        if val == 0:
            try:
                obs_y[i] = obs_y[i-1]
            except:
                pass
    
    obs_y = np.array(obs_y)
    
    try:
        popt, pcov = curve_fit(logistic, obs_x, obs_y)
        pred_y = logistic(obs_x, *popt)
        forecasted_x = np.array(list(range(max(obs_x) + ForecastDays)))
        forecasted_y = logistic(forecasted_x, *popt)
        
    except:
        print('Logistic failed to fit. Using 3rd degree polynomial.')
        forecasted_y, forecasted_x, pred_y = get_polynomial(obs_x, obs_y, ForecastDays)
        
    return forecasted_y, forecasted_x, pred_y



def get_exponential(obs_x, obs_y, ForecastDays):
    
    obs_x = np.array(obs_x)
    
    for i, val in enumerate(obs_y):
        if val == 0:
            try:
                obs_y[i] = obs_y[i-1]
            except:
                pass       
    
    slope, intercept, r_value, p_value, std_err = stats.linregress(obs_x, np.log(obs_y))
    obs_y = np.array(obs_y)
    
    pred_y = np.exp(intercept + slope*obs_x)
    forecasted_x = np.array(list(range(max(obs_x) + ForecastDays)))
    forecasted_y = np.exp(intercept + slope*forecasted_x)
    
    return forecasted_y, forecasted_x, pred_y
        


def get_polynomial(obs_x, obs_y, ForecastDays):
    
    obs_x = np.array(obs_x)
    for i, val in enumerate(obs_y):
        if val == 0:
            try:
                obs_y[i] = obs_y[i-1]
            except:
                pass       
    
    obs_y = np.array(obs_y)
    forecasted_y = np.zeros(len(obs_y))
    try:
        z = np.polyfit(obs_x, obs_y, 2)
        p = np.poly1d(z)
        pred_y = p(obs_x)
            
        forecasted_x = np.array(list(range(max(obs_x) + ForecastDays)))
        forecasted_y = p(forecasted_x)
    except:
        pass
    
    return forecasted_y, forecasted_x, pred_y



def fit_curve(obs_x, obs_y, model, df_sub, ForecastDays):

    obs_x = list(range(len(obs_y)))
    obs_x = np.array(obs_x)
    obs_y = np.array(obs_y)
    
    best_loc = str()
    
    if model == 'logistic':
        forecasted_y, forecasted_x, pred_y = get_logistic(obs_x, obs_y, ForecastDays)
        obs_pred_r2 = obs_pred_rsquare(obs_y, pred_y)
    elif model == 'exponential':
        forecasted_y, forecasted_x, pred_y = get_exponential(obs_x, obs_y, ForecastDays)
        obs_pred_r2 = obs_pred_rsquare(obs_y, pred_y)
    elif model == 'polynomial':
        forecasted_y, forecasted_x, pred_y = get_polynomial(obs_x, obs_y, ForecastDays)
        obs_pred_r2 = obs_pred_rsquare(obs_y, pred_y)
        
    return obs_pred_r2, model, best_loc, obs_x, pred_y, forecasted_x, forecasted_y


In [3]:
class App_GetFits:
    
    def __init__(self, df):
        
        # model: 'logistic'; 'exponential'; 'polynomial';
        # query: Any location available within the dataframe
        # refer: Any location available within the dataframe
        
        self._df = df
        available_indicators2 = list(set(self._df['Province/State']))
        available_indicators2.sort()
        
        self._1_dropdown = self._create_dropdown(['logistic', 'exponential', 'polynomial'], 1, label = 'Choose a model to fit:')
        self._2_dropdown = self._create_dropdown(available_indicators2, 16, label = 'Choose a location:')
        self._3_floattext = self._create_floattext(label = '% Visiting your hospital:', 
                                                   val=10, minv=0, maxv=100, boxw='33%', desw='70%')
        self._4_floattext = self._create_floattext(label = '% Admitted to your hospital:', 
                                                   val=50, minv=0, maxv=100, boxw='33%', desw='70%')
        self._5_floattext = self._create_floattext(label = '% Admitted to critical care:', 
                                                   val=20, minv=0, maxv=100, boxw='33%', desw='70%')
        self._6_floattext = self._create_floattext(label = 'LOS (non-critical care):', 
                                                   val=5, minv=1, maxv=180, boxw='33%', desw='70%')
        self._7_floattext = self._create_floattext(label = 'LOS (critical care):', 
                                                   val=10, minv=1, maxv=180, boxw='33%', desw='70%')
        self._8_floattext = self._create_floattext(label = '% of ICU on vent:',
                                                   val=30, minv=0, maxv=100, boxw='33%', desw='70%')
        self._9_toggle = self._create_toggle()
        
        self._10_floattext = self._create_floattext(label = 'GLOVE SURGICAL', 
                                                    val=5, minv=0, maxv=1000, boxw='33%', desw='70%')
        self._11_floattext = self._create_floattext(label = 'GLOVE EXAM NITRILE', 
                                                    val=5, minv=0, maxv=1000, boxw='33%', desw='70%')
        self._12_floattext = self._create_floattext(label = 'GLOVE EXAM VINYL', 
                                                    val=5, minv=0, maxv=1000, boxw='33%', desw='70%')
        self._13_floattext = self._create_floattext(label = 'MASK FACE PROC. ANTI FOG', 
                                                    val=2, minv=0, maxv=1000, boxw='33%', desw='70%')
        self._14_floattext = self._create_floattext(label = 'MASK PROC. FLUID RESISTANT', 
                                                    val=2, minv=0, maxv=1000, boxw='33%', desw='70%')
        self._15_floattext = self._create_floattext(label = 'GOWN ISOLATION XL YELLOW', 
                                                    val=2, minv=0, maxv=1000, boxw='33%', desw='70%')
        self._16_floattext = self._create_floattext(label = 'MASK SURG. ANTI FOG W/FILM', 
                                                    val=1, minv=0, maxv=1000, boxw='33%', desw='70%')
        self._17_floattext = self._create_floattext(label = 'SHIELD FACE FULL ANTI FOG', 
                                                    val=3, minv=0, maxv=1000, boxw='33%', desw='70%')
        self._18_floattext = self._create_floattext(label = 'RESP. PART. FILTER REG', 
                                                    val=1, minv=0, maxv=1000, boxw='33%', desw='70%')
        self._19_floattext = self._create_floattext(label = 'Forecast length (days)', 
                                                    val=10, minv=0, maxv=14, boxw='33%', desw='70%')
        
        self._plot_container = widgets.Output()
        
        _app_container = widgets.VBox(
            [widgets.VBox([widgets.HBox([self._9_toggle, self._1_dropdown, self._2_dropdown], 
                             layout=widgets.Layout(align_items='flex-start', flex='0 auto auto', width='100%')),
                           
                           widgets.HBox([self._3_floattext, self._4_floattext, self._5_floattext],
                             layout=widgets.Layout(align_items='flex-start', flex='0 0 auto', width='100%')),
                           
                           widgets.HBox([self._6_floattext, self._7_floattext, self._8_floattext],
                             layout=widgets.Layout(align_items='flex-start', flex='0 0 auto', width='100%')),
                          
                           widgets.HBox([self._10_floattext, self._11_floattext, self._12_floattext],
                             layout=widgets.Layout(align_items='flex-start', flex='0 0 auto', width='100%')),
                           
                           widgets.HBox([self._13_floattext, self._14_floattext, self._15_floattext],
                             layout=widgets.Layout(align_items='flex-start', flex='0 0 auto', width='100%')),
                           
                           widgets.HBox([self._16_floattext, self._17_floattext, self._18_floattext],
                             layout=widgets.Layout(align_items='flex-start', flex='0 0 auto', width='100%')),
                          
                           widgets.HBox([self._19_floattext],
                             layout=widgets.Layout(align_items='flex-start', flex='0 0 auto', width='100%'))],
                           
                           layout=widgets.Layout(display='flex', flex_flow='column', border='solid 1px', 
                                        align_items='stretch', width='100%')),
                           self._plot_container], layout=widgets.Layout(display='flex', flex_flow='column', 
                                        border='solid 2px', align_items='stretch', width='100%'))
                
        # 'flex-start', 'flex-end', 'center', 'baseline', 'stretch', 'inherit', 'initial', 'unset'
        self.container = widgets.VBox([
            widgets.HBox([
                _app_container
            ])
        ], layout=widgets.Layout(flex='1 1 auto', margin='0 auto 0 auto', max_width='1024px'))
        self._update_app()
        
        
    @classmethod
    def from_url(cls):
        
        _URL = 'https://raw.githubusercontent.com/klocey/ResourceDemand/master/notebooks/COVID-CASES-DF.txt'
        df = pd.read_csv(_URL, sep='\t')  
        df = df[df['Country/Region'] == 'US']
        
        patternDel = ","
        filter = df['Province/State'].str.contains(patternDel)
        df = df[~filter]
        df = df.drop(df.columns[0], axis=1)
        
        df['sum'] = df.iloc[:, 6:].sum(axis=1)
        df = df[df['sum'] > 1]
        df = df.drop(['sum'], axis=1)
        #df = df.iloc[:, :-1]
        return cls(df)
        
        
    def _create_dropdown(self, indicators, initial_index, label):
        dropdown = widgets.Dropdown(options=indicators, 
                                    layout={'width': '60%'},
                                    style={'description_width': '49%'},
                                    value=indicators[initial_index],
                                   description=label)
        
        dropdown.observe(self._on_change, names=['value'])
        return dropdown
    
    def _create_floattext(self, label, val, minv, maxv, boxw, desw):
        obj = widgets.BoundedFloatText(
                    value=val,
                    min=minv,
                    max=maxv,
                    description=label,
                    disabled=False,
                    layout={'width': boxw},
                    style={'description_width': desw},
                )
        obj.observe(self._on_change, names=['value'])
        return obj
    
    
    
    def _create_toggle(self): 
        obj = widgets.ToggleButton(
                    value=False,
                    description='log-scale',
                    disabled=False,
                    button_style='', # 'success', 'info', 'warning', 'danger' or ''
                    tooltip='Description',
                    icon='check' # (FontAwesome names without the `fa-` prefix)
                )
        obj.observe(self._on_change, names=['value'])
        return obj
                
    
    
    def _on_change(self, _):
        self._update_app()

    def _update_app(self):
        
        
        model = self._1_dropdown.value
        focal_loc = self._2_dropdown.value
        per_loc  = self._3_floattext.value
        per_admit = self._4_floattext.value
        per_cc = self._5_floattext.value
        LOS_nc = self._6_floattext.value
        LOS_cc = self._7_floattext.value
        per_vent = self._8_floattext.value
        log_scl = self._9_toggle.value
        
        ppe_GLOVE_SURGICAL = self._10_floattext.value
        ppe_GLOVE_EXAM_NITRILE = self._11_floattext.value
        ppe_GLOVE_GLOVE_EXAM_VINYL = self._12_floattext.value
        ppe_MASK_FACE_PROCEDURE_ANTI_FOG= self._13_floattext.value
        ppe_MASK_PROCEDURE_FLUID_RESISTANT = self._14_floattext.value
        ppe_GOWN_ISOLATION_XLARGE_YELLOW= self._15_floattext.value
        ppe_MASK_SURGICAL_ANTI_FOG_W_FILM = self._16_floattext.value
        ppe_SHIELD_FACE_FULL_ANTI_FOG = self._17_floattext.value
        ppe_RESPIRATOR_PARTICULATE_FILTER_REG = self._18_floattext.value
        ForecastDays = self._19_floattext.value
        
        
        self._plot_container.clear_output(wait=True)
        with self._plot_container:
            self._get_fit(model, focal_loc, per_loc, per_admit, per_cc, LOS_cc, LOS_nc, per_vent, log_scl,
                         ppe_GLOVE_SURGICAL, ppe_GLOVE_EXAM_NITRILE, ppe_GLOVE_GLOVE_EXAM_VINYL,
                         ppe_MASK_FACE_PROCEDURE_ANTI_FOG, ppe_MASK_PROCEDURE_FLUID_RESISTANT, 
                         ppe_GOWN_ISOLATION_XLARGE_YELLOW, ppe_MASK_SURGICAL_ANTI_FOG_W_FILM,
                         ppe_SHIELD_FACE_FULL_ANTI_FOG, ppe_RESPIRATOR_PARTICULATE_FILTER_REG,
                         ForecastDays)
            
            plt.show()
            
            
    def _get_fit(self, model, focal_loc, per_loc, per_admit, per_cc, LOS_cc, LOS_nc, per_vent, log_scl,
                        ppe_GLOVE_SURGICAL, ppe_GLOVE_EXAM_NITRILE, ppe_GLOVE_GLOVE_EXAM_VINYL,
                        ppe_MASK_FACE_PROCEDURE_ANTI_FOG, ppe_MASK_PROCEDURE_FLUID_RESISTANT, 
                        ppe_GOWN_ISOLATION_XLARGE_YELLOW, ppe_MASK_SURGICAL_ANTI_FOG_W_FILM,
                        ppe_SHIELD_FACE_FULL_ANTI_FOG, ppe_RESPIRATOR_PARTICULATE_FILTER_REG,
                        ForecastDays):
        
        ForecastDays = int(ForecastDays+1)
        
        fig = plt.figure(figsize=(11, 17))
        ax = plt.subplot2grid((6, 4), (0, 0), colspan=2, rowspan=2)
        
        
        df_sub = self._df[self._df['type'] == 'Confirmed']
        df_sub = df_sub[df_sub['Province/State'] == focal_loc]
        df_sub = df_sub.loc[:, (df_sub != 0).any(axis=0)]
        yi = list(df_sub)
        
        clrs =  ['mistyrose', 'pink', 'lightcoral', 'salmon', 'red']
        clrs2 = ['powderblue', 'lightskyblue', 'cornflowerblue', 'dodgerblue', 'blue']
        
        for i, j in enumerate([-4,-3,-2,-1, 0]):
            if j == 0:
                DATES = yi[6:]
                focal = df_sub.iloc[0,6:].values
            else:
                DATES = yi[6:j]
                focal = df_sub.iloc[0,6:j].values
            
            y = []
            dates = []
            for ii, val in enumerate(focal):
                if len(y) > 0 or val > 0:
                    y.append(val)
                    dates.append(DATES[ii])

            x = list(range(len(y)))

            obs_pred_r2_G, model_G, loc_G, obs_x_G, pred_y_G, forecasted_x_G, forecasted_y_G = fit_curve(x, y, model, df_sub, ForecastDays)
            obs_y_G = np.array(list(y))

            if obs_pred_r2_G < 0:
                obs_pred_r2_G = 0.0

            y = np.array(y)
            y[y < 0] = 0
            pred_y_G = np.array(pred_y_G)
            pred_y_G[pred_y_G < 0] = 0

            forecasted_y_G = np.array(forecasted_y_G)
            forecasted_y_G[forecasted_y_G < 0] = 0
            forecast_vals = np.copy(forecasted_y_G)

            numdays = len(forecasted_x_G)
            latest_date = pd.to_datetime(dates[-1])
            first_date = pd.to_datetime(dates[0])

            future_date = latest_date + datetime.timedelta(days = ForecastDays-1)
            fdates = pd.date_range(start=first_date, end=future_date)
            fdates = fdates.strftime('%m/%d')
            
            do = -j
            if do == 0:
                label='Current forecast'
                lw = 3
            else:
                label = str(do)+' day old forecast'
                lw = 3
            
            plt.plot(fdates, forecasted_y_G, c=clrs[i], linewidth=lw, label=label)
            
            latest_date = pd.to_datetime(dates[-1])
            first_date = pd.to_datetime(dates[0])
            dates = pd.date_range(start=first_date, end=latest_date)
            dates = dates.strftime('%m/%d')
            plt.plot(dates, pred_y_G, c=clrs2[i], linewidth=lw)
            plt.scatter(dates, y, c='0.2', s=100, alpha=0.8, linewidths=0.1)
            
            
        forecast_vals = forecast_vals.tolist()

        new_cases = []
        for i, val in enumerate(forecast_vals):
            if i > 0:
                new_cases.append(forecast_vals[i] - forecast_vals[i-1])
            if i == 0:
                new_cases.append(forecast_vals[i])


        leg = ax.legend(handlelength=0, handletextpad=0, fancybox=False,
                        loc='best', frameon=False, fontsize=14)

        for line,text in zip(leg.get_lines(), leg.get_texts()):
            text.set_color(line.get_color())

        for item in leg.legendHandles: 
            item.set_visible(False)

        plt.xticks(rotation=35, ha='center')
        plt.xlabel('Date', fontsize=14, fontweight='bold')
        plt.ylabel('Confirmed cases', fontsize=14, fontweight='bold')
        
        if log_scl == True:
            plt.yscale('log')

        if len(forecasted_x_G) < 10:
            i = 1
        elif len(forecasted_x_G) < 20:
            i = 4
        elif len(forecasted_x_G) < 40:
            i = 6
        else:
            i = 8

        for label in ax.xaxis.get_ticklabels()[::i]:
            label.set_visible(False)

        ax = plt.gca()
        temp = ax.xaxis.get_ticklabels()
        temp = list(set(temp) - set(temp[::i]))
        for label in temp:
            label.set_visible(False)
            
        plt.title('Model fitting, current ' + r'$r^{2}$' + ' = ' + str(np.round(obs_pred_r2_G, 2)), fontsize = 16, fontweight = 'bold')
        
        
        
        
        
        ax = plt.subplot2grid((6, 4), (0, 2), colspan=2, rowspan=2)
        ax.axis('off')
        #ax.axis('tight')

        loc = str(focal_loc)
        if len(loc) > 12:
            loc = loc[:12]
            loc = loc + '...'

        col_labels = ['Cumulative', 'New', 'Your hospital', 'Admitted']

        row_labels = fdates.tolist()
        new_cases = np.round(new_cases, 1)
        
        #print(len(row_labels))
        #print(len(new_cases))
        #print(len(forecasted_y_G)) # Must be the same length
        
        row_labels = row_labels[-(ForecastDays):]
        new_cases2 = new_cases[ -(ForecastDays):]
        sub_f = forecasted_y_G[ -(ForecastDays):]

        table_vals = []
        clr_vals = []
        
        for i in range(len(row_labels)):

            val = new_cases2[i]
            cell = [int(np.round(sub_f[i])), 
                    int(np.round(val)), 
                    int(np.round(val * (per_loc * 0.01))),
                    int(np.round((0.01 * per_admit) * val * (per_loc * 0.01)))]
            
            clr = ['0.4', '0.5', '0.6', '0.7']
            table_vals.append(cell)
            clr_vals.append(clr)

        ncol = 4
        the_table = plt.table(cellText=table_vals,
                        colWidths=[0.26] * ncol,
                        rowLabels=row_labels,
                        colLabels=col_labels,
                        cellLoc='center',
                        loc='upper center')#,
                        #cellColours=clr_vals)
        the_table.auto_set_font_size(False)
        the_table.set_fontsize(10)
        the_table.scale(1, 1.32)
        plt.title('Forecasted cases for '+ loc, fontsize = 16, fontweight = 'bold')
        
        
        
        
        ax = plt.subplot2grid((6, 4), (2, 0), colspan=2, rowspan=2)
        
        #### Construct arrays for critical care and non-critical care patients
        cc = (0.01 * per_cc) * (0.01 * per_admit) * (0.01 * per_loc) * np.array(new_cases)
        cc = cc.tolist()

        
        nc = (1 - (0.01 * per_cc)) * (0.01 * per_admit) * (0.01 * per_loc) * np.array(new_cases)
        nc = nc.tolist()
        
        # LOS for non critical care = 5 days
        # LOS for critical care = 10 days
        p = 0.5
        n_cc = LOS_cc*2
        n_nc = LOS_nc*2
        
        rv_nc = binom(n_nc, p)
        p_nc = rv_nc.cdf(np.array(range(1, len(fdates)+1)))
        
        rv_cc = binom(n_cc, p)
        p_cc = rv_cc.cdf(np.array(range(1, len(fdates)+1)))
        
        LOScc = np.zeros(len(fdates))
        LOScc[0] = new_cases[0] * (0.01 * per_cc) * (0.01 * per_admit) * (0.01 * per_loc)
        LOSnc = np.zeros(len(fdates))
        LOSnc[0] =  new_cases[0] * (1-(0.01 * per_cc)) * (0.01 * per_admit) * (0.01 * per_loc)
        
        
        total_nc = []
        total_cc = []
        
        
        for i, day in enumerate(fdates):
            LOScc = LOScc * (1 - p_cc)
            LOSnc = LOSnc * (1 - p_nc)
            
            LOScc = np.roll(LOScc, shift=1)
            LOSnc = np.roll(LOSnc, shift=1)
            
            LOScc[0] = new_cases[i] * (0.01 * per_cc) * (0.01 * per_admit) * (0.01 * per_loc)
            LOSnc[0] = new_cases[i] * (1 - (0.01 * per_cc)) * (0.01 * per_admit) * (0.01 * per_loc)
    
            total_nc.append(np.sum(LOSnc))
            total_cc.append(np.sum(LOScc))
            #print(day, '  :  ', np.round(LOScc, 0), '  :  ', sum(np.round(LOScc, 0)), '\n')
            
        
        plt.plot(fdates, total_cc, c='Crimson', label='Critical care', linewidth=3)
        plt.plot(fdates, total_nc, c='0.3', label='Non-critical care', linewidth=3)
        plt.title('Forecasted census', fontsize = 16, fontweight = 'bold')
        if log_scl == True:
            plt.yscale('log')
        
        for label in ax.xaxis.get_ticklabels()[::8]:
            label.set_visible(False)

        ax = plt.gca()
        temp = ax.xaxis.get_ticklabels()
        temp = list(set(temp) - set(temp[::8]))
        for label in temp:
            label.set_visible(False)
            
        leg = ax.legend(handlelength=0, handletextpad=0, fancybox=False,
                        loc='best', frameon=False, fontsize=14)

        for line,text in zip(leg.get_lines(), leg.get_texts()):
            text.set_color(line.get_color())

        for item in leg.legendHandles: 
            item.set_visible(False)
        
        plt.ylabel('COVID-19 patients', fontsize=14, fontweight='bold')
        plt.xlabel('Date', fontsize=14, fontweight='bold')
        
        
        
        
        
        ax = plt.subplot2grid((6, 4), (2, 2), colspan=2, rowspan=2)
        
        ax.axis('off')
        #ax.axis('tight')

        if len(loc) > 12:
            loc = loc[:12]
            loc = loc + '...'

        col_labels = ['All COVID', 'Non-ICU', 'ICU', 'Vent']

        
        row_labels = fdates.tolist()
        #print(row_labels)
        #sys.exit()
        
        row_labels = row_labels[-(ForecastDays):]
        total_nc_trunc = total_nc[-(ForecastDays):]
        total_cc_trunc = total_cc[-(ForecastDays):]
        
        table_vals = []
        clr_vals = []
        for i in range(len(row_labels)):

            cell = [int(np.round(total_nc_trunc[i] + total_cc_trunc[i])), 
                    int(np.round(total_nc_trunc[i])),
                    int(np.round(total_cc_trunc[i])), 
                    int(np.round(total_cc_trunc[i]*(0.01*per_vent)))]
            
            table_vals.append(cell)
            
        ncol = 4
        the_table = plt.table(cellText=table_vals,
                        colWidths=[0.255, 0.255, 0.255, 0.255],
                        rowLabels=row_labels,
                        colLabels=col_labels,
                        cellLoc='center',
                        loc='upper center')
        the_table.auto_set_font_size(False)
        the_table.set_fontsize(10)
        the_table.scale(1, 1.32)
        plt.title('Beds needed for COVID-19 cases', fontsize = 16, fontweight = 'bold')
        
        
        
        
        
        
        
        ####################### PPE ##################################
        ax = plt.subplot2grid((6, 4), (4, 0), colspan=2, rowspan=2)
        
        #### Construct arrays for critical care and non-critical care patients
        PUI_COVID = np.array(total_nc) + np.array(total_cc)
        
        glove_surgical = np.round(ppe_GLOVE_SURGICAL * PUI_COVID).astype('int')
        glove_nitrile = np.round(ppe_GLOVE_EXAM_NITRILE * PUI_COVID).astype('int')
        glove_vinyl = np.round(ppe_GLOVE_GLOVE_EXAM_VINYL * PUI_COVID).astype('int')
        face_mask = np.round(ppe_MASK_FACE_PROCEDURE_ANTI_FOG * PUI_COVID).astype('int')
        procedure_mask = np.round(ppe_MASK_PROCEDURE_FLUID_RESISTANT * PUI_COVID).astype('int')
        isolation_gown = np.round(ppe_GOWN_ISOLATION_XLARGE_YELLOW * PUI_COVID).astype('int')
        surgical_mask = np.round(ppe_MASK_SURGICAL_ANTI_FOG_W_FILM * PUI_COVID).astype('int')
        face_shield = np.round(ppe_SHIELD_FACE_FULL_ANTI_FOG * PUI_COVID).astype('int')
        respirator = np.round(ppe_RESPIRATOR_PARTICULATE_FILTER_REG * PUI_COVID).astype('int')
        
        #print(len(fdates), len(glove_surgical))
        #sys.exit()
        
        ppe_ls =[[glove_surgical, 'GLOVE SURGICAL', 'r'],
             [glove_nitrile, 'GLOVE EXAM NITRILE', 'orange'],
             [glove_vinyl, 'GLOVE EXAM VINYL', 'goldenrod'],
             [face_mask, 'MASK FACE PROCEDURE ANTI FOG', 'limegreen'],
             [procedure_mask, 'MASK PROCEDURE FLUID RESISTANT', 'green'],
             [isolation_gown, 'GOWN ISOLATION XLARGE YELLOW', 'cornflowerblue'],
             [surgical_mask, 'MASK SURGICAL ANTI FOG W/FILM', 'blue'],
             [face_shield, 'SHIELD FACE FULL ANTI FOG', 'plum'],
             [respirator, 'RESPIRATOR PARTICULATE FILTER REG', 'darkviolet']]
        
        linestyles = ['dashed', 'dotted', 'dashdot', 
                      'dashed', 'dotted', 'dashdot',
                      'dotted', 'dashed', 'dashdot']
        
        for i, ppe in enumerate(ppe_ls):
            plt.plot(fdates, ppe[0], c=ppe[2], label=ppe[1], linewidth=2, ls=linestyles[i])
    
        plt.title('Forecasted PPE needs', fontsize = 16, fontweight = 'bold')
        #if log_scl == True:
        #    plt.yscale('log')
        
        for label in ax.xaxis.get_ticklabels()[::8]:
            label.set_visible(False)

        ax = plt.gca()
        temp = ax.xaxis.get_ticklabels()
        temp = list(set(temp) - set(temp[::8]))
        for label in temp:
            label.set_visible(False)
            
        leg = ax.legend(handlelength=0, handletextpad=0, fancybox=True,
                        loc='best', frameon=True, fontsize=8)

        for line,text in zip(leg.get_lines(), leg.get_texts()):
            text.set_color(line.get_color())

        for item in leg.legendHandles: 
            item.set_visible(False)
        
        plt.ylabel('PPE Supplies', fontsize=14, fontweight='bold')
        plt.xlabel('Date', fontsize=14, fontweight='bold')
        if log_scl == True:
            plt.yscale('log')
        
        
        
        
        
        ax = plt.subplot2grid((6, 4), (4, 2), colspan=2, rowspan=2)
        ax.axis('off')
        #ax.axis('tight')
        
        #### Construct arrays for critical care and non-critical care patients
        PUI_COVID = np.array(total_nc) + np.array(total_cc)
        PUI_COVID = PUI_COVID[-(ForecastDays):]
        
        glove_surgical = np.round(ppe_GLOVE_SURGICAL * PUI_COVID).astype('int')
        glove_nitrile = np.round(ppe_GLOVE_EXAM_NITRILE * PUI_COVID).astype('int')
        glove_vinyl = np.round(ppe_GLOVE_GLOVE_EXAM_VINYL * PUI_COVID).astype('int')
        face_mask = np.round(ppe_MASK_FACE_PROCEDURE_ANTI_FOG * PUI_COVID).astype('int')
        procedure_mask = np.round(ppe_MASK_PROCEDURE_FLUID_RESISTANT * PUI_COVID).astype('int')
        isolation_gown = np.round(ppe_GOWN_ISOLATION_XLARGE_YELLOW * PUI_COVID).astype('int')
        surgical_mask = np.round(ppe_MASK_SURGICAL_ANTI_FOG_W_FILM * PUI_COVID).astype('int')
        face_shield = np.round(ppe_SHIELD_FACE_FULL_ANTI_FOG * PUI_COVID).astype('int')
        respirator = np.round(ppe_RESPIRATOR_PARTICULATE_FILTER_REG * PUI_COVID).astype('int')
        
        #print(len(fdates), len(glove_surgical))
        #sys.exit()
        
        ppe_ls =[[glove_surgical, 'GLOVE SURGICAL', 'r'],
             [glove_nitrile, 'GLOVE EXAM NITRILE', 'orange'],
             [glove_vinyl, 'GLOVE EXAM VINYL', 'goldenrod'],
             [face_mask, 'MASK FACE PROCEDURE ANTI FOG', 'limegreen'],
             [procedure_mask, 'MASK PROCEDURE FLUID RESISTANT', 'green'],
             [isolation_gown, 'GOWN ISOLATION XLARGE YELLOW', 'cornflowerblue'],
             [surgical_mask, 'MASK SURGICAL ANTI FOG W/FILM', 'blue'],
             [face_shield, 'SHIELD FACE FULL ANTI FOG', 'plum'],
             [respirator, 'RESPIRATOR PARTICULATE FILTER REG', 'darkviolet']]
        
        
        if len(loc) > 12:
            loc = loc[:12]
            loc = loc + '...'

        col_labels = [ppe_ls[0][1], ppe_ls[1][1], ppe_ls[2][1], 
                      ppe_ls[3][1], ppe_ls[4][1], ppe_ls[5][1],
                      ppe_ls[6][1], ppe_ls[7][1], ppe_ls[8][1]]

        row_labels = fdates.tolist()        
        row_labels = row_labels[-(ForecastDays):]
        
        #print(col_labels)
        #print(row_labels)
        #total_nc = total_nc[-(ForecastDays-1):]
        #total_cc = total_cc[-(ForecastDays-1):]
        
        table_vals = []
        clr_vals = []
        
        for i in range(len(row_labels)):

            cell = [ppe_ls[0][0][i], ppe_ls[1][0][i], ppe_ls[2][0][i], 
                      ppe_ls[3][0][i], ppe_ls[4][0][i], ppe_ls[5][0][i],
                      ppe_ls[6][0][i], ppe_ls[7][0][i], ppe_ls[8][0][i]]
            
            table_vals.append(cell)
            
        ncol = 9
        the_table = plt.table(cellText=table_vals,
                        colWidths=[0.15]*9,
                        rowLabels=row_labels,
                        colLabels=None,
                        cellLoc='center',
                        loc='upper center')
        the_table.auto_set_font_size(False)
        the_table.set_fontsize(10)
        the_table.scale(1, 1.32)
        
        for i in range(len(ppe_ls)):
            clr = ppe_ls[i][2]
            for j in range(len(row_labels)):
                the_table[(j, i)].get_text().set_color(clr)
        
        hoffset = -0.4 #find this number from trial and error
        voffset = 1.0 #find this number from trial and error
        col_width = [0.06, 0.09, 0.09, 0.12, 0.133, 0.138, 0.128, 0.135, 0.142]
        
        col_labels2 =[['GLOVE SURGICAL', 'r'],
             ['GLOVE EXAM NITRILE', 'orange'],
             ['GLOVE GLOVE EXAM VINYL', 'goldenrod'],
             ['MASK FACE PROC. A-FOG', 'limegreen'],
             ['MASK PROC. FLUID RES.', 'green'],
             ['GOWN ISO. XL YELLOW', 'cornflowerblue'],
             ['MASK SURG. ANTI FOG W/FILM', 'blue'],
             ['SHIELD FACE FULL ANTI FOG', 'plum'],
             ['RESP. PART. FILTER REG', 'darkviolet']]
        
        count=0
        for i, val in enumerate(col_labels2):
            ax.annotate('  '+val[0], xy=(hoffset + count * col_width[i], voffset),
            xycoords='axes fraction', ha='left', va='bottom', 
            rotation=-25, size=8, c=val[1])
            count+=1
        
        
        plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=1.1, hspace=1.1)
            
    


In [4]:
%config InlineBackend.figure_format = 'svg'

app1 = App_GetFits.from_url()

grid = widgets.GridspecLayout(1, 1, layout=widgets.Layout(justify_content='flex-start'))
grid[0, 0] = app1.container

app_contents = [grid]
app = widgets.VBox(app_contents)

display(app)

VBox(children=(GridspecLayout(children=(VBox(children=(HBox(children=(VBox(children=(VBox(children=(HBox(child…