In [17]:
import pandas as pd
import numpy as np

from matplotlib import pyplot as plt
from matplotlib.dates import date2num, num2date
from matplotlib import dates as mdates
from matplotlib import ticker
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch

from scipy import stats as sps
from scipy.interpolate import interp1d

from IPython.display import clear_output

In [18]:
import plotly.express as px

In [19]:
#from plotly.offline import download_plotlyjs, init_notebook_mode, iplot
#import plotly.graph_objects as go
#init_notebook_mode(connected=True)

In [20]:
k = np.array([20, 40, 55, 90])

# We create an array for every possible value of Rt
R_T_MAX = 12
r_t_range = np.linspace(0, R_T_MAX, R_T_MAX*100+1)

# Gamma is 1/serial interval
# https://wwwnc.cdc.gov/eid/article/26/7/20-0282_article
# https://www.nejm.org/doi/full/10.1056/NEJMoa2001316
GAMMA = 1/7

In [21]:
def highest_density_interval(pmf, p=.9, debug=False):
    # If we pass a DataFrame, just call this recursively on the columns
    if(isinstance(pmf, pd.DataFrame)):
        return pd.DataFrame([highest_density_interval(pmf[col], p=p) for col in pmf],
                            index=pmf.columns)
    
    cumsum = np.cumsum(pmf.values)
    
    # N x N matrix of total probability mass for each low, high
    total_p = cumsum - cumsum[:, None]
    
    # Return all indices with total_p > p
    lows, highs = (total_p > p).nonzero()
    
    # Find the smallest range (highest density)
    best = (highs - lows).argmin()
    
    low = pmf.index[lows[best]]
    high = pmf.index[highs[best]]
    
    return pd.Series([low, high],
                     index=[f'Low_{p*100:.0f}',
                            f'High_{p*100:.0f}'])

In [22]:
state_key = {
'EC':'Eastern Cape',
'FS':'Free State',
'GP':'Gauteng',
'KZN':'Kwazulu Natal',
'LP':'Limpopo',
'MP':'Mpumalanga',
'NC':'Northern Cape',
'NW':'North-West',
'WC':'Western Cape'
}
state_filter = list(state_key.keys())

In [23]:
url = 'https://raw.githubusercontent.com/dsfsi/covid19za/master/data/covid19za_provincial_cumulative_timeline_confirmed.csv'
states = pd.read_csv(url,
                     parse_dates=['date'], dayfirst=True,
                     squeeze=True)
#index_col=0 .sort_index()

state_name = 'total'

states.tail()

Unnamed: 0,date,YYYYMMDD,EC,FS,GP,KZN,LP,MP,NC,NW,WC,UNKNOWN,total,source
80,2020-05-26,20200526,2864.0,206.0,3043.0,1927.0,132.0,103.0,45.0,115.0,15829.0,0.0,24264,https://twitter.com/nicd_sa/status/12653816780...
81,2020-05-27,20200527,3047.0,221.0,3167.0,2186.0,141.0,106.0,48.0,128.0,16893.0,0.0,25937,https://twitter.com/nicd_sa/status/12657387939...
82,2020-05-28,20200528,3306.0,225.0,3329.0,2349.0,144.0,111.0,51.0,134.0,17754.0,0.0,27403,https://twitter.com/COVID_19_ZA/status/1266131...
83,2020-05-29,20200529,3583.0,231.0,3583.0,2428.0,170.0,112.0,52.0,143.0,18906.0,32.0,29240,https://twitter.com/nicd_sa/status/12664583320...
84,2020-05-30,20200530,3759.0,261.0,3773.0,2476.0,173.0,113.0,57.0,162.0,20160.0,33.0,30967,https://twitter.com/nicd_sa/status/12667977983...


In [41]:
districts_gp = {
'date':'date',
'Ekurhuleni\tCases':'Ekurhuleni',
'Johannesburg\tCases':'Johannesburg',
'Sedibeng\tCases':'Sedibeng',
'Tshwane\tCases':'Tshwane',
'West Rand\tCases':'West Rand',
'West Rand\tCases':'West Rand',
'GP Unallocated\tCases':'Unknown'
}
districts = districts_gp.keys()

In [42]:
#districts = ['date','CT','CW','CK','GR','OB','WC','UNKNOWN']

In [43]:
## Debugging dataset
file_name = 'provincial_' + 'gp' + '_cumulative.csv'  
url = 'https://raw.githubusercontent.com/dsfsi/covid19za/master/data/district_data/' + file_name
states_district = pd.read_csv(url,
                     parse_dates=['date'], dayfirst=True,
                     squeeze=True).sort_index()

states_district_filter = states_district[districts]
col_tol = states_district_filter.sum(axis=1, numeric_only=True)
pd.options.mode.chained_assignment = None
states_district_filter['Total'] = col_tol

#states = states_filter_wp
states_district_filter.tail()

Unnamed: 0,date,Ekurhuleni\tCases,Johannesburg\tCases,Sedibeng\tCases,Tshwane\tCases,West Rand\tCases,GP Unallocated\tCases,Total
45,2020-05-26,642.0,1467.0,64.0,412.0,286.0,172.0,3043.0
46,2020-05-27,667.0,1496.0,68.0,420.0,290.0,226.0,3167.0
47,2020-05-28,693.0,1556.0,75.0,436.0,306.0,263.0,3329.0
48,2020-05-29,732.0,1667.0,92.0,457.0,312.0,323.0,3583.0
49,2020-05-30,770.0,1762.0,93.0,474.0,323.0,351.0,3773.0


In [26]:
## for total analysis
state_filter.insert(0,'date')
#state_filter.append('total')
#state_filter

In [27]:
state_plot = states[state_filter]
state_plot

Unnamed: 0,date,EC,FS,GP,KZN,LP,MP,NC,NW,WC
0,2020-03-05,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
1,2020-03-07,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0
2,2020-03-08,0.0,0.0,1.0,2.0,0.0,0.0,0.0,0.0,0.0
3,2020-03-09,0.0,0.0,1.0,6.0,0.0,0.0,0.0,0.0,0.0
4,2020-03-11,0.0,0.0,5.0,7.0,0.0,0.0,0.0,0.0,1.0
...,...,...,...,...,...,...,...,...,...,...
80,2020-05-26,2864.0,206.0,3043.0,1927.0,132.0,103.0,45.0,115.0,15829.0
81,2020-05-27,3047.0,221.0,3167.0,2186.0,141.0,106.0,48.0,128.0,16893.0
82,2020-05-28,3306.0,225.0,3329.0,2349.0,144.0,111.0,51.0,134.0,17754.0
83,2020-05-29,3583.0,231.0,3583.0,2428.0,170.0,112.0,52.0,143.0,18906.0


In [28]:
state_plotly = state_plot.melt(id_vars='date')
state_plotly

Unnamed: 0,date,variable,value
0,2020-03-05,EC,0.0
1,2020-03-07,EC,0.0
2,2020-03-08,EC,0.0
3,2020-03-09,EC,0.0
4,2020-03-11,EC,0.0
...,...,...,...
760,2020-05-26,WC,15829.0
761,2020-05-27,WC,16893.0
762,2020-05-28,WC,17754.0
763,2020-05-29,WC,18906.0


In [29]:
px.bar(state_plotly, x='date', y='value', color='variable')

In [30]:
states_daily = states['total'].diff()
states_daily['date'] = states['date']
states_daily

0                                                     NaN
1                                                       1
2                                                       1
3                                                       4
4                                                       6
                              ...                        
81                                                   1673
82                                                   1466
83                                                   1837
84                                                   1727
date    0    2020-03-05
1    2020-03-07
2    2020-03-0...
Name: total, Length: 86, dtype: object

In [31]:
px.bar(states_daily, x='total', y='total')

In [32]:
states['total'].diff().iplot(kind='bar', title='Daily Cases per Province', xTitle='Dates', yTitle='Cases', colors=colorscale)

AttributeError: 'Series' object has no attribute 'iplot'

In [None]:
states['WC'].diff().iplot(kind='bar', title='Daily Cases in Western Cape', xTitle='Dates', yTitle='Cases', colors=colorscale)

In [None]:
url = 'https://raw.githubusercontent.com/dsfsi/covid19za/master/data/covid19za_provincial_cumulative_timeline_deaths.csv'
states_all_deaths = pd.read_csv(url,
                     parse_dates=['date'], dayfirst=True,
                     squeeze=True,index_col=0).sort_index()
states_all_deaths.tail()

In [None]:
state_deaths = states_all_deaths[state_filter]
state_deaths.tail()

In [None]:
state_deaths.iplot(kind='bar', barmode='stack')

In [None]:
url = 'https://raw.githubusercontent.com/dsfsi/covid19za/master/data/covid19za_provincial_cumulative_timeline_recoveries.csv'
states_all_recover = pd.read_csv(url,
                     parse_dates=['date'], dayfirst=True,
                     squeeze=True,index_col=0).sort_index()
states_all_recover.tail()

In [None]:
states_recover = states_all_recover[state_filter]
states_recover.tail()

In [None]:
states_recover.iplot(kind='bar', barmode='stack')

In [None]:
states_series = pd.Series(states['total'].values, index=states.index.values, name='Cases')
states_series

In [None]:
deaths_series = pd.Series(states_all_deaths['total'].values, index=states_all_deaths.index, name='Deaths')
recover_series = pd.Series(states_all_recover['total'].values, index=states_all_recover.index, name='Recovered')

In [None]:
states_combine = pd.concat([states_series, recover_series, deaths_series], axis=1)
states_combine

In [None]:
states_master = states_combine.ffill(axis=0)

In [None]:
states_master.iplot()

In [None]:
states_changed = states_master[['Recovered','Deaths']].sum(axis=1)

In [None]:
active_all = states_master['Cases'].sub(states_changed)
active_all

In [None]:
states_master['Active'] = active_all

In [None]:
states_master.iplot(title='Combined Stats', xTitle='Dates', yTitle='Cases')

In [None]:
url = 'https://raw.githubusercontent.com/dsfsi/covid19za/master/data/covid19za_provincial_cumulative_timeline_testing.csv'
states_all_tests = pd.read_csv(url,
                     parse_dates=['date'], dayfirst=True,
                     squeeze=True,index_col=0).sort_index()
states_all_tests

In [None]:
states_tests = states_all_tests[state_filter]

In [None]:
states_tests.iplot(kind='bar', barmode='stack')

In [None]:
url = 'https://raw.githubusercontent.com/dsfsi/covid19za/master/data/covid19za_provincial_cumulative_timeline_confirmed.csv'
states = pd.read_csv(url,
                     parse_dates=['date'], dayfirst=True,
                     squeeze=True).sort_index()

state_name = 'total'

states

In [None]:
def prepare_cases(cases, cutoff=25):
    new_cases = cases.diff()

    smoothed = new_cases.rolling(7,
        win_type='gaussian',
        min_periods=1,
        center=True).mean(std=2).round()
    
    idx_start = np.searchsorted(smoothed, cutoff)
    
    smoothed = smoothed.iloc[idx_start:]
    original = new_cases.loc[smoothed.index]
    
    return original, smoothed

cases = pd.Series(states[state_name].values, index=states['date'])

original, smoothed = prepare_cases(cases, cutoff=25)

original.plot(title=f"{state_name} New Cases per Day",
               c='k',
               linestyle=':',
               alpha=.5,
               label='Actual',
               legend=True,
             figsize=(500/72, 300/72))

ax = smoothed.plot(label='Smoothed',
                   legend=True)

ax.get_figure().set_facecolor('w')

In [None]:
single_cases_df = pd.concat([original,smoothed], axis=1)
single_cases_df = single_cases_df.rename(columns={0:'Actual'})
single_cases_df = single_cases_df.rename(columns={1:'Smoothed'})

Layout = {'plot_bgcolor':'#fff',
          'xaxis':{
              'color':'#000',
              'title':{
                      'font':{
                          'color':'#000'
                      }
              }
          }
         }
fig = single_cases_df.iplot(asFigure = True, xTitle = 'Date', yTitle = 'Daily change in confirmed cases', title = f"{state_name} New Cases per Day", layout=Layout)
fig.show()

In [None]:
def get_posteriors(sr, sigma=0.15):

    # (1) Calculate Lambda
    lam = sr[:-1].values * np.exp(GAMMA * (r_t_range[:, None] - 1))

    
    # (2) Calculate each day's likelihood
    likelihoods = pd.DataFrame(
        data = sps.poisson.pmf(sr[1:].values, lam),
        index = r_t_range,
        columns = sr.index[1:])
    
    # (3) Create the Gaussian Matrix
    process_matrix = sps.norm(loc=r_t_range,
                              scale=sigma
                             ).pdf(r_t_range[:, None]) 

    # (3a) Normalize all rows to sum to 1
    process_matrix /= process_matrix.sum(axis=0)
    
    # (4) Calculate the initial prior
    #prior0 = sps.gamma(a=4).pdf(r_t_range)
    prior0 = np.ones_like(r_t_range)/len(r_t_range)
    prior0 /= prior0.sum()

    # Create a DataFrame that will hold our posteriors for each day
    # Insert our prior as the first posterior.
    posteriors = pd.DataFrame(
        index=r_t_range,
        columns=sr.index,
        data={sr.index[0]: prior0}
    )
    
    # We said we'd keep track of the sum of the log of the probability
    # of the data for maximum likelihood calculation.
    log_likelihood = 0.0

    # (5) Iteratively apply Bayes' rule
    for previous_day, current_day in zip(sr.index[:-1], sr.index[1:]):

        #(5a) Calculate the new prior
        current_prior = process_matrix @ posteriors[previous_day]
        
        #(5b) Calculate the numerator of Bayes' Rule: P(k|R_t)P(R_t)
        numerator = likelihoods[current_day] * current_prior
        
        #(5c) Calcluate the denominator of Bayes' Rule P(k)
        denominator = np.sum(numerator)
        
        # Execute full Bayes' Rule
        posteriors[current_day] = numerator/denominator
        
        # Add to the running sum of log likelihoods
        log_likelihood += np.log(denominator)
    
    return posteriors, log_likelihood

# Note that we're fixing sigma to a value just for the example
posteriors, log_likelihood = get_posteriors(smoothed, sigma=.25)

In [None]:
# Note that this takes a while to execute - it's not the most efficient algorithm
hdis = highest_density_interval(posteriors, p=.9)

most_likely = posteriors.idxmax().rename('ML')

# Look into why you shift -1
result = pd.concat([most_likely, hdis], axis=1)

# US: Since we now use a uniform prior, the first datapoint is pretty bogus, so just truncating it here
# ZA: rename to single_result to add to final province plots again
single_result = result.drop(result.index[0])
single_result.tail()

In [None]:
def plot_rt(result, ax, state_name):
    
    ax.set_title(f"{state_name}")
    
    # Colors
    ABOVE = [1,0,0]
    MIDDLE = [1,1,1]
    BELOW = [0,0,0]
    cmap = ListedColormap(np.r_[
        np.linspace(BELOW,MIDDLE,25),
        np.linspace(MIDDLE,ABOVE,25)
    ])
    color_mapped = lambda y: np.clip(y, .5, 1.5)-.5
    
    index = result['ML'].index.get_level_values('date')
    values = result['ML'].values
    
    # Plot dots and line
    ax.plot(index, values, c='k', zorder=1, alpha=.25)
    ax.scatter(index,
               values,
               s=40,
               lw=.5,
               c=cmap(color_mapped(values)),
               edgecolors='k', zorder=2)
    
    # Aesthetically, extrapolate credible interval by 1 day either side
    lowfn = interp1d(date2num(index),
                     result['Low_90'].values,
                     bounds_error=False,
                     fill_value='extrapolate')
    
    highfn = interp1d(date2num(index),
                      result['High_90'].values,
                      bounds_error=False,
                      fill_value='extrapolate')
    
    extended = pd.date_range(start=pd.Timestamp('2020-03-01'),
                             end=index[-1]+pd.Timedelta(days=1))
    
    ax.fill_between(extended,
                    lowfn(date2num(extended)),
                    highfn(date2num(extended)),
                    color='k',
                    alpha=.1,
                    lw=0,
                    zorder=3)

    ax.axhline(1.0, c='k', lw=1, label='$R_t=1.0$', alpha=.25);
    
    # Formatting
    ax.xaxis.set_major_locator(mdates.MonthLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b'))
    ax.xaxis.set_minor_locator(mdates.DayLocator())
    
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_formatter(ticker.StrMethodFormatter("{x:.1f}"))
    ax.yaxis.tick_right()
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.margins(0)
    ax.grid(which='major', axis='y', c='k', alpha=.1, zorder=-2)
    ax.margins(0)
    ax.set_ylim(0.0, 5.0)
    ax.set_xlim(pd.Timestamp('2020-03-06'), result.index.get_level_values('date')[-1]+pd.Timedelta(days=1))
    #fig.set_facecolor('w')
    
    return ax

    
fig, ax = plt.subplots(figsize=(600/72,400/72))

plot_rt(single_result, ax, state_name)
ax.set_title(f'Real-time $R_t$ for {state_name}')
ax.xaxis.set_major_locator(mdates.WeekdayLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %d'))

In [None]:
sigmas = np.linspace(1/20, 1, 20)

# ZA: only consider the official 9 provinces
states_to_process = list(states.columns.values[2:11])
# ZA: do not think the total RSA sigma needs to be included to find max later
# states_to_process.append('Total RSA') 

results = {}

for state_name in states_to_process:
    
    print(state_name)
    
    # --> ZA prepare data
    # ZA: Rt is very small for some provinces
    cases = pd.Series(states[state_name].values,index=states['date'])
    new, smoothed = prepare_cases(cases, cutoff=10)
    cut = 10
    
    # Rt for ZA is very small for some provinces
    # set threshold for smoothed data length at 3 to ensure posteriors can be calculated
    if len(smoothed) < 3:
        new, smoothed = prepare_cases(cases, cutoff=5)
        cut = 5
        if len(smoothed) < 3:
            new, smoothed = prepare_cases(cases, cutoff=2)
            cut = 2
            
            ## ignore Rt further for slow growth provinces
            if len(smoothed) < 3:
                print('BREAK')
                clear_output(wait=True)
                continue
            
    print(cut)
    ## <-- ZA prepare data
    
    result = {}
    
    # Holds all posteriors with every given value of sigma
    result['posteriors'] = []
    
    # Holds the log likelihood across all k for each value of sigma
    result['log_likelihoods'] = []
    
    for sigma in sigmas:
        posteriors, log_likelihood = get_posteriors(smoothed, sigma=sigma)
        result['posteriors'].append(posteriors)
        result['log_likelihoods'].append(log_likelihood)
    
    # Store all results keyed off of state name
    results[state_name] = result
    clear_output(wait=True)

print('Done.')

In [None]:
# Each index of this array holds the total of the log likelihoods for
# the corresponding index of the sigmas array.
total_log_likelihoods = np.zeros_like(sigmas)

# Loop through each state's results and add the log likelihoods to the running total.
for state_name, result in results.items():
    total_log_likelihoods += result['log_likelihoods']

# Select the index with the largest log likelihood total
max_likelihood_index = total_log_likelihoods.argmax()

# Select the value that has the highest log likelihood
sigma = sigmas[max_likelihood_index]


In [None]:
final_results = None

for state_name, result in results.items():
    print(state_name)
    posteriors = result['posteriors'][max_likelihood_index]
    hdis_90 = highest_density_interval(posteriors, p=.9)
    hdis_50 = highest_density_interval(posteriors, p=.5)
    most_likely = posteriors.idxmax().rename('ML')
    result = pd.concat([most_likely, hdis_90, hdis_50], axis=1)
    
    # ZA: add province index
    result.index = pd.MultiIndex.from_product([[state_name], result.index], names=['state','date'])
     
    if final_results is None:
        final_results = result
    else:
        final_results = pd.concat([final_results, result])
    clear_output(wait=True)

print('Done.')

In [None]:
# US: This can be moved before the plots
# Since we now use a uniform prior, the first datapoint is pretty bogus, so just truncating it here
final_results = final_results.groupby('state').apply(lambda x: x.iloc[1:].droplevel(0))

# Rt Data Import

In [None]:
#dsfsi
url = 'https://raw.githubusercontent.com/dsfsi/covid19za/master/data/calc/calculated_rt_sa_provincial_cumulative.csv'
states_raw = pd.read_csv(url,
                     parse_dates=['date'], dayfirst=True,
                     squeeze=True, index_col=[0,1])

In [None]:
state_single = states_raw.filter(like='Total RSA', axis=0)
state_single.tail()

In [None]:
last = state_single.groupby(level=0).last().iloc[0]['ML']
last

In [None]:
latestdate = state_single.index.get_level_values('date')[-1]
latestdate.strftime("%d %B %Y")

In [None]:
fig = state_single.iplot(asFigure = True)
fig.show()

In [None]:
states_raw.groupby('state').last()['ML']

In [None]:
len(states_raw.groupby('state'))

In [None]:
states_filter = states_raw.loc[['EC','FS','GP','KZN','LP','MP','NC','NW','WC']]

In [None]:
def all_plot(final_results):
    state_groups = final_results.groupby('state')
    
    ncols = 3
    nrows = int(np.ceil(len(state_groups) / ncols))

    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, nrows*3))

    for i, (state_name, result) in enumerate(state_groups):
        axes.flat[i] = plot_rt(result, axes.flat[i], state_name)

    fig.tight_layout()
    fig.set_facecolor('w')
    
    fig.suptitle(f'COVID-19 Confirmed Cases: Real-time $R_t$ for South African Provinces', size=14)
    fig.subplots_adjust(top=0.92)
    
all_plot(states_filter)

In [None]:
fig = states_raw.iplot(asFigure = True, subplots=(4,2), subplot_titles=True)
fig.show()

In [None]:
for i, (state_name, result) in enumerate(states_raw.groupby('state')):
    print(result.tail())

In [None]:
state = 'gp'
url = 'https://raw.githubusercontent.com/dsfsi/covid19za/master/data/calc/calculated_rt_' + state + '_district_cumulative.csv'
districts_raw = pd.read_csv(url,
                     parse_dates=['date'], dayfirst=True,
                     squeeze=True, index_col=[0,1])
all_plot(districts_raw)

### Plot All South African Provinces

In [None]:
ncols = 3
nrows = int(np.ceil(len(results) / ncols))

fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, nrows*3))

for i, (state_name, result) in enumerate(final_results.groupby('state')):
    plot_rt(result, axes.flat[i], state_name)
    
fig.tight_layout()
fig.suptitle(f'Real-time $R_t$ for South African Provinces', fontsize=14)
fig.subplots_adjust(top=0.92)
fig.set_facecolor('w')

### Standings

In [None]:
# ZA: South Arica lockdown level data as of 2020/05/03
no_lockdown = [

]
partial_lockdown = [

]
# add items as required -> 'Western Cape','WC',

FULL_COLOR = [.7,.7,.7]
NONE_COLOR = [179/255,35/255,14/255]
PARTIAL_COLOR = [.5,.5,.5]
ERROR_BAR_COLOR = [.3,.3,.3]

In [None]:
# ZA: df slighty different to US
mr = final_results.groupby(level=0)[['ML', 'High_90', 'Low_90']].last()
mr

In [None]:
def plot_standings(mr, figsize=None, title='Most Recent $R_t$ by Province'):
    if not figsize:
        figsize = ((15.9/50)*len(mr)+.1,2.5)
        
    fig, ax = plt.subplots(figsize=figsize)

    ax.set_title(title)
    err = mr[['Low_90', 'High_90']].sub(mr['ML'], axis=0).abs()
    bars = ax.bar(mr.index,
                  mr['ML'],
                  width=.825,
                  color=FULL_COLOR,
                  ecolor=ERROR_BAR_COLOR,
                  capsize=2,
                  error_kw={'alpha':.5, 'lw':1},
                  yerr=err.values.T)

    for bar, state_name in zip(bars, mr.index):
        if state_name in no_lockdown:
            bar.set_color(NONE_COLOR)
        if state_name in partial_lockdown:
            bar.set_color(PARTIAL_COLOR)

    labels = mr.index.to_series().replace({'District of Columbia':'DC'})
    ax.set_xticklabels(labels, rotation=90, fontsize=11)
    ax.margins(0)
    ax.set_ylim(0,2.)
    ax.axhline(1.0, linestyle=':', color='k', lw=1)

    leg = ax.legend(handles=[
                        Patch(label='Full', color=FULL_COLOR),
                        Patch(label='Partial', color=PARTIAL_COLOR),
                        Patch(label='None', color=NONE_COLOR)
                    ],
                    title='Lockdown',
                    ncol=3,
                    loc='upper left',
                    columnspacing=.75,
                    handletextpad=.5,
                    handlelength=1)

    leg._legend_box.align = "left"
    fig.set_facecolor('w')
    return fig, ax

mr.sort_values('ML', inplace=True)
plot_standings(mr);

In [None]:
mr.sort_values('High_90', inplace=True)
plot_standings(mr);

In [None]:
show = mr[mr.High_90.le(1)].sort_values('ML')
fig, ax = plot_standings(show, title='Likely Under Control');

In [None]:
show = mr[mr.Low_90.ge(1.0)].sort_values('Low_90')
fig, ax = plot_standings(show, title='Likely Not Under Control');
ax.get_legend().remove()