# Bar chart animation for COVID-19 statistics worldwide

## Load tools and data

This script makes use of the following libraries:
 - pandas and numpy for the data analysis; (currently tested with pandas-v1.1.4 and numpy-v1.19.2)
 - matplotlib for plotting; (currently tested with matplotlib-v3.3.2)
 - matplotlib.animation and IPython.display for producing animations
 - iso3166 for standardized geographic data of countries

In [None]:
import os, sys
import pandas as pd
import numpy as np
import seaborn as sns
from iso3166 import Country
import flag
import itertools
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.animation as animation
from IPython.display import HTML
import tqdm

matplotlib.rcParams['animation.embed_limit'] = 2**10
matplotlib.rcParams['markers.fillstyle'] = 'full'

In [None]:
datacvs = 'owid-covid-data.csv'

# Load useful data from CVS file as pandas dataframe
df = pd.read_csv(datacvs, usecols=['continent','location','date','new_cases_smoothed_per_million','total_deaths','new_deaths','total_deaths_per_million','new_deaths_per_million','new_deaths_smoothed_per_million','reproduction_rate','new_tests','new_tests_per_thousand','tests_per_case'], parse_dates=[2])

# Shorten dataframe to minimal set of columns and pick out single day
dfshort=df[['continent','location','date','total_deaths_per_million','new_deaths_smoothed_per_million']]
df_now = df[df["date"]=='20220106'].sort_values(by="total_deaths_per_million", ascending=False).reset_index(drop=True)

# List locations (countries) and map colours
nocountry_mask = pd.isna(df.continent.values)
exclude_locs = set(df[nocountry_mask]['location'].values)
# exclude_loc = set(['World', 'Africa', 'Asia', 'Oceania', 'Europe', 'European Union', 'South America', 'High income', 'International', 'Upper middle income', 'Low income', 'Lower middle income', 'North America'])
locations=set(df["location"].to_dict().values()) - exclude_locs
loclist = list(locations)
loclist.sort()

cols = plt.cm.Pastel1.colors + plt.cm.Pastel2.colors + plt.cm.Paired.colors + plt.cm.Accent.colors + plt.cm.Set1.colors + plt.cm.Set2.colors + plt.cm.Dark2.colors + plt.cm.tab20c.colors + plt.cm.Set3.colors + plt.cm.tab20.colors + plt.cm.tab20b.colors 
color_dict = dict(zip(loclist, [cols[k] for k in np.arange(len(loclist)) % len(cols)]))

## Transitions

Whenever one bar overtakes another, the single-frame transition is difficult for the eye to follow.
The following functions define a smoother transition by storing the current positions and velocities of all bars and updating the next frame based on the target positions.
As a simple motion model, the velocity of a given bar undergoing transition can be constant so that the bar reaches its target position within a predefined amount of time. A possible improvment is to decelerate smoothly by slowing down in the final few frames.
Additional features may include: 
 - a visual indicator for the up or down motion of a bar (currently change in color)
 - a z-index giving depth to the overlapping bars (foreground for ascending and v.v., tried maginfying by 1+exp(-(y-y0)^2) but doesn't look too good)

In [None]:
# Transition functions

def initData(initFrame=None, N=len(loclist)):
    # Initialize bar positions and velocities
    global Y_NOW, V_NOW, RANK_NOW
    if initFrame is None:
        Y_NOW = N - 1.0*np.arange(N)
    else:
        sdf = df_interp.groupby('date').get_group(datelist[initFrame]).sort_values(by='total_deaths_per_million', ascending=False).reset_index(drop=True)
        target_dict = dict(zip(sdf['location'].to_numpy(), np.arange(len(sdf))))

        # Read off target locations in loclist order and update 
        Y_NOW = np.array([1.0*target_dict[k] for k in loclist])
        
    V_NOW = 1.0*np.zeros(N)
    RANK_NOW = Y_NOW.copy()
    
    return

def V(delta_y, sec_per_transition=0.7):
    '''Vertical velocity of bar coordinate as a function of distance to true location'''
    # [TODO: needed in vertical index units per second]
    c = 2./sec_per_transition
    v = c*delta_y
    
    # Slow down
    #   if (abs(delta_y) < 0.2):
    #       v = 0.05*delta_y/abs(delta_y)
        
    return v
        
        

def update_positions(y_target, v_mask):
    ''' Calculates the interpolated coordinate. 
        All arrays are sorted as loclist.
     Input:
      - y_target :   array with target coordinate values
      - v_mask :     mask array where ones indicate a desired change in V
      # - y_now    :   the current coordinate value
     Output (in place):
      - global Y_NOW    :   array with current coordinates [y units]
      - global V_NOW    :   array with current velocities [y units/sec]
    '''

    global Y_NOW, V_NOW
    
    dt = 1.0/fps
    dy = y_target - Y_NOW
    
    # Change velocity where necessary
    V_NOW += v_mask*(V(dy) - V_NOW)
    Y_NOW += V_NOW*dt
    
    reached = np.where(abs(V_NOW*dt) > abs(dy))[0]
    if reached.size > 0: 
        Y_NOW[reached] = y_target[reached]
        V_NOW[reached] = 0.0
        
    return
    
    
def data_interp(datelist, df):
    '''Interpolates data for new list of timestamps'''
    
    # Create DataFrame object
    df_interp = pd.DataFrame(columns=['continent','location','date','total_deaths_per_million','new_deaths_smoothed_per_million'])
    
    # group by country
    dfl = df.groupby('location')
    
    # iterate through countries
    for l in dfl.groups.keys():
        dg = dfl.get_group(l)
        x = [k.timestamp() for k in dg['date'].to_list()]
        y1 = dg['total_deaths_per_million']
        y2 = dg['new_deaths_smoothed_per_million']

        l_interp_1 = np.interp([k.timestamp() for k in datelist.to_list()], x, y1)
        l_interp_2 = np.interp([k.timestamp() for k in datelist.to_list()], x, y2)
        l_df = pd.DataFrame({'continent':[dg['continent'].values[0]]*len(datelist), 'location':[l]*len(datelist), 'date':datelist, 'total_deaths_per_million':l_interp_1, 'new_deaths_smoothed_per_million':l_interp_2})
        df_interp = pd.concat([df_interp, l_df])
    
    return df_interp

## Set up animation

In [None]:
duration = 200.0                # animation duration in seconds
sec_per_transition = 1.0        # UNUSED how many seconds does a transition last
fps = 24                      # frames per second
nframes = int(fps*duration)        # total number of frames
fpt = fps/sec_per_transition  # UNUSED frames per transition

focus_country = Country.gr
color_dict[focus_country.english_short_name] = (0.2,0.4,0.7)

startdate='03/24/2020'        # start animation from date
enddate='03/20/2022'          # end animation at date

# create array of timestamps by dividing date range with uniform steps
datelist = pd.date_range(start=startdate, end=enddate, periods=nframes)

# interpolate data for datelist
df_interp = data_interp(datelist, dfshort)
df_interp_bydate = df_interp.groupby('date')


## Draw Barchart

The basic idea for the barchart follows the example by @jburnmurdoch [link].
On top of that, several additions have been made, especially for making the animation smoother (interpolation), adapting to COVID-19 specific data and a few additional features (transitions, focus location, aesthetics, etc.).

In [None]:
def focus_color_shift(focus_idx):
    '''Return color shift of focus bar'''
    if not focus_idx:
        return None
    
    v = V_NOW[focus_idx]
    k = min(np.abs(v)/10.0,1.0)
    if v > 0:
        col = (1-k, 1.0, 1-k, 0.9)
    elif v < 0:
        col = (1.0, 1-k, 1-k, 0.9)
    else:
        col = None
    
    return col

# --- WIP! ---
def update_data(date):
    
    # The following global variables as well as Y_target and update_mask are all sorted as loclist
    global Y_NOW, V_NOW, RANK_NOW
    
    # Select data on date
    dff = df_interp.groupby('date').get_group(date)
    # Remove unwanted location entries (continents, income, etc.)
    #     dff = dff.groupby()
    # sort by column of interest
    dff = dff.sort_values(by='total_deaths_per_million', ascending=False).copy(deep=True)
    dff.reset_index(drop=True, inplace=True)
    
    # Dictionary of target ranks over locations
    target_dict = dict(zip(dff['location'], np.arange(len(dff))+1))
    
    # Read off target ranks in loclist order and compare with previous rank list
    Y_target = np.array([1.0*target_dict[k] for k in loclist])
    update_mask = (RANK_NOW != Y_target).astype(float)
    
    # Move bars based 
    update_positions(Y_target, update_mask)
    RANK_NOW = 1.0*Y_target.copy()
    
    # put new positions in location dict
    y_dict = dict(zip(loclist, Y_NOW))
    # dff['y_new'] = dff['location'].map(y_dict)

    return y_dict


# --- WIP! ---
def init_barchart(date, y_dict):

    bar_colors = [color_dict[k] for k in y_dict.keys()]
    tick_labels = [str(k) for k in list(dfs.index.to_numpy())]
    bars = ax.barh(y_dict, dfs['total_deaths_per_million'], color=colors, alpha=0.9, tick_label=tick_labels)
    
    return


# --- WIP! ---
def redraw_barchart(date, y_dict, focus=focus_country.english_short_name, nplot=11, magnify=False):
    ''' Draw the bar chart
    In: 
        - date  : date entry in datelist
        - focus : Focus country to follow
        - nplot : Number of positions to plot from top or around focus country
    Out:
        - ax    : Plot axis filled with the barchart
    '''
    

    # Select range of indices to plot
    nplotpp = nplot + 10 

    # if no focus country is fiven, plot top N
    if focus not in locations:
        focus = None
        print("No focus!")
        imin = 0
        imaxpp = nplotpp
    else: 
        focus_index = dff[dff['location']==focus].index.to_numpy()[0]
        # focus_indices.append(focus_index)
        imin = np.max([focus_index - (nplotpp//2),0])
        imaxpp = imin + nplotpp

    #     idx_plot = dff.iloc[list(dff['total_deaths_per_million'] <= bar_num)].iloc[i]
    idx_plot = np.arange(imin,imaxpp)
    dfs = dff.iloc[idx_plot]
    #     dfrest = dff.iloc[np.array(set(np.arange(len(dff))) - set(idx_plot))]
    
    # Plot horizontal bars
    ax.clear()
    y_pos = nplot - np.arange(nplot)
    
    fcol = focus_color_shift(loclist.index(focus))

    #     ax.set_yticks(y_pos, labels=list(dfs.index.to_numpy()))
    dx = dfs['total_deaths_per_million'].max() / 200
    
    #     bars = ax.barh(imin - idx_rest, dfs['total_deaths_per_million'], color=[color_dict[k] for k in dfs['location'].to_numpy()], alpha=0.9, tick_label=[str(k) for k in list(dfrest.index.to_numpy())])
    
    
    y_pos = np.array([y_dict[k] for k in dfs['location'].to_numpy()])
    foc_pos = y_dict[focus]

    # update positions for moving bars
    for i, (bar, location) in enumerate(zip(bars, dfs['location'])):
        bar.set_y(imaxpp - y_pos[i] - bar.get_height()/2)

        width = bar.get_width()

        fw = 'light'
        if location == focus:
            foc_y = bar.get_y()
            fw = 'bold'
            if fcol:
                focfc = 0.7*np.array(bar.get_fc())
                focfc += 0.3*np.array(fcol)
                bar.set_fc(focfc)

        ax.text(width + dx, bar.get_y() + bar.get_height() / 2, location, size=14, weight=fw, ha='left', va='center')
        ax.annotate(f'{width:.0F}',
                    xy = (width , bar.get_y() + bar.get_height() / 2),
                    xytext = (-25, 0),
                    textcoords = "offset points",
                    fontsize = 'x-large',
                    fontweight = fw,
                    ha = 'right',
                    va = 'center')    
        
    if focus and magnify:
        for bar in bars:
            magfactor = (0.8 + 0.4*np.exp(-0.5/0.5*(bar.get_y()-foc_y)**2))
            bar.set_height(bar.get_height()*magfactor)
            bar.set_y(bar.get_y() - (magfactor-1)*bar.get_height()/2.0)
#             bar.set_width(bar.get_width()*magfactor)
                
        
    ax.text(1, 0.2, date.strftime("%d %B, %Y"), transform=ax.transAxes, color='#AAAAAA', size=24, ha='right', weight=800)
    ax.text(0, 1.06, 'Deaths per million (total)', transform=ax.transAxes, size=12, color='#AAAAAA')
    
    ax.xaxis.set_major_formatter(ticker.StrMethodFormatter('{x:,.0f}'))
    ax.xaxis.set_ticks_position('top')
    ax.tick_params(axis='x', color='#777777', labelsize=12)
    ax.tick_params(axis='y', color='#777777', labelsize=16)

    #     ax.tick_params(labelsize = 'medium')    
    #     ax.set_yticks([str(k) for k in list(dfs.index.to_numpy())])
    ax.grid(True, axis = 'x')
    if focus:
        ax.set_xlim(0, 2.0*dff.iloc[focus_index]['total_deaths_per_million'])
        dy = foc_pos - focus_index
        ax.set_ylim(ax.get_ylim()[0]-dy, ax.get_ylim()[1]-dy)
   
    ax.margins(0, 0.01)
    ax.grid(which='major', axis='x', linestyle='-')
    
    ax.set_axisbelow(True)
    
    ax.text(0, 1.15, 'World ranking in COVID-19 deaths per million population', transform=ax.transAxes, size=24, weight=600, ha='left', va='top')
    ax.text(1, 0, '@magathos', transform=ax.transAxes, color='#777777', ha='right', va='bottom', bbox=dict(facecolor='white', alpha=0.8, edgecolor='white'))
    ax.set_frame_on(False)
    
    return

    # -- WIP --
def update_plot(date, focus=focus_country.english_short_name, nplot=None):
    
    update_data(date)
    update_positions(date)
    redraw_barchart(date, focus, nplot)
    return

def draw_barchart(date, focus=focus_country.english_short_name, nplot=11, magnify=False):
    ''' Draw the bar chart
    In: 
        - date  : date entry in datelist
        - focus : Focus country to follow
        - nplot : Number of positions to plot from top or around focus country
    Out:
        - ax    : Plot axis filled with the barchart
    '''
    
    # The following global variables as well as Y_target and update_mask are all sorted as loclist
    global Y_NOW, V_NOW, RANK_NOW
    
    # Select data on date
    dff = df_interp_bydate.get_group(date)
    # Remove unwanted location entries (continents, income, etc.)
    dff = dff[~pd.isna(dff.continent.values)]
    # sort by column of interest
    dff = dff.sort_values(by='total_deaths_per_million', ascending=False).copy(deep=True)
    dff.reset_index(drop=True, inplace=True)
    
    # Dictionary of target ranks over locations
    target_dict = dict(zip(dff['location'].to_numpy(), np.arange(len(dff))+1))
    
    # Read off target ranks in loclist order and compare with previous rank list
    Y_target = np.array([1.0*target_dict[k] for k in loclist])
    update_mask = (RANK_NOW != Y_target).astype(float)
    
    # Move bars based 
    update_positions(Y_target, update_mask)
    RANK_NOW = 1.0*Y_target.copy()
    
    # put new positions in location dict
    y_dict = dict(zip(loclist, Y_NOW))
    # dff['y_new'] = dff['location'].map(y_dict)


    # Select range of indices to plot
    # focus_indices = []

    # if no focus country is fiven, plot top N
    if focus not in locations:
        focus = None
        print("No focus!")
        imin = 0
        imaxpp = nplot
    else: 
        focus_index = dff[dff['location']==focus].index.to_numpy()[0]
        # focus_indices.append(focus_index)
        imin = np.max([focus_index - (nplot//2),0])
        imaxpp = np.min([imin + nplot, len(loclist)])

    #     idx_plot = dff.iloc[list(dff['total_deaths_per_million'] <= bar_num)].iloc[i]
    idx_plot = np.arange(imin,imaxpp)
    dfs = dff.iloc[idx_plot]
    #     dfrest = dff.iloc[np.array(set(np.arange(len(dff))) - set(idx_plot))]
    
    # Plot horizontal bars
    ax.clear()
    y_pos = nplot - np.arange(nplot)
    
    fcol = focus_color_shift(loclist.index(focus))
    bar_colors = [color_dict[k] for k in dfs['location'].to_numpy()]
    tick_labels = [str(k) for k in list(dfs.index.to_numpy())]

    bars = ax.barh(y_pos, dfs['total_deaths_per_million'], color=bar_colors, alpha=0.9, tick_label=tick_labels)

    #     ax.set_yticks(y_pos, labels=list(dfs.index.to_numpy()))
    dx = dfs['total_deaths_per_million'].max() / 200
    
    #     bars = ax.barh(imin - idx_rest, dfs['total_deaths_per_million'], color=[color_dict[k] for k in dfs['location'].to_numpy()], alpha=0.9, tick_label=[str(k) for k in list(dfrest.index.to_numpy())])
    
    
    y_pos = np.array([y_dict[k] for k in dfs['location'].to_numpy()])
    foc_pos = y_dict[focus]

    # update positions for moving bars
    for i, (bar, location) in enumerate(zip(bars, dfs['location'])):
        bar.set_y(imaxpp - y_pos[i] - bar.get_height()/2)

        width = bar.get_width()

        fw = 'light'
        fc = 'white'
        if location == focus:
            foc_y = bar.get_y()
            fw = 'bold'
            if fcol:
                focfc = 0.7*np.array(bar.get_fc())
                focfc += 0.3*np.array(fcol)
                fc = fcol
                # bar.set_fc(focfc)

        ax.text(width + 3*dx, bar.get_y() + bar.get_height() / 2, location, size=14, weight=fw, color=fc, ha='left', va='center')
        ax.annotate(f'{width:.0F}',
                    xy = (width , bar.get_y() + bar.get_height() / 2),
                    xytext = (-25, 0),
                    textcoords = "offset points",
                    fontsize = 'x-large',
                    fontweight = fw,
                    color = fc,
                    ha = 'right',
                    va = 'center')    
        
    if focus and magnify:
        for bar in bars:
            magfactor = (0.8 + 0.4*np.exp(-0.5/0.5*(bar.get_y()-foc_y)**2))
            bar.set_height(bar.get_height()*magfactor)
            bar.set_y(bar.get_y() - (magfactor-1)*bar.get_height()/2.0)
#             bar.set_width(bar.get_width()*magfactor)
                
        
    ax.text(1, 0.2, date.strftime("%d %B, %Y"), transform=ax.transAxes, color='#AAAAAA', size=24, ha='right', weight=800)
    ax.text(0, 1.06, 'Deaths per million (total)', transform=ax.transAxes, size=12, color='#AAAAAA')
    
    ax.xaxis.set_major_formatter(ticker.StrMethodFormatter('{x:,.0f}'))
    ax.xaxis.set_ticks_position('top')
    ax.tick_params(axis='x', color='#777777', labelsize=12)
    ax.tick_params(axis='y', color='#777777', labelsize=16, width=3, length=5)

    #     ax.tick_params(labelsize = 'medium')    
    #     ax.set_yticks([str(k) for k in list(dfs.index.to_numpy())])
    ax.grid(True, axis = 'x')
    if focus:
        ax.set_xlim(0, 2.0*dff.iloc[focus_index]['total_deaths_per_million'])
        dy = foc_pos - focus_index
        ax.set_ylim(ax.get_ylim()[0]-dy, ax.get_ylim()[1]-dy)
   
    ax.margins(0, 0.01)
    ax.grid(which='major', axis='x', linestyle='-')
    
    ax.set_axisbelow(True)
    
    ax.text(0, 1.15, 'World ranking in COVID-19 deaths per million population', transform=ax.transAxes, size=24, weight=600, ha='left', va='top')
    ax.text(1, 0, '@magathos', transform=ax.transAxes, color='#777777', ha='right', va='bottom', bbox=dict(facecolor='white', alpha=0.8, edgecolor='white'))
    ax.set_frame_on(False)
    
    return

def draw_barchart_padded(date, focus=focus_country.english_short_name, nplot=11, magnify=False):
    ''' Draw the bar chart
    In: 
        - date  : date entry in datelist
        - focus : Focus country to follow
        - nplot : Number of positions to plot from top or around focus country
    Out:
        - ax    : Plot axis filled with the barchart
    '''
    
    # The following global variables as well as Y_target and update_mask are all sorted as loclist
    global Y_NOW, V_NOW, RANK_NOW
    
    # Select data on date
    dff = df_interp_bydate.get_group(date)
    # Remove unwanted location entries (continents, income, etc.)
    dff = dff[~pd.isna(dff.continent.values)]
    # sort by column of interest
    dff = dff.sort_values(by='total_deaths_per_million', ascending=False).copy(deep=True)
    dff.reset_index(drop=True, inplace=True)
    
    # Dictionary of target ranks over locations
    target_dict = dict(zip(dff['location'].to_numpy(), np.arange(len(dff))+1))
    
    # Read off target ranks in loclist order and compare with previous rank list
    Y_target = np.array([1.0*target_dict[k] for k in loclist])
    update_mask = (RANK_NOW != Y_target).astype(float)
    
    # Move bars based 
    update_positions(Y_target, update_mask)
    RANK_NOW = 1.0*Y_target.copy()
    
    # put new positions in location dict
    y_dict = dict(zip(loclist, Y_NOW))
    # dff['y_new'] = dff['location'].map(y_dict)


    # Select range of indices to plot
    nplotpp = nplot + 10 

    # if no focus country is fiven, plot top N
    if focus not in locations:
        focus = None
        print("No focus!")
        imin = 0
        imaxpp = nplot
    else: 
        focus_index = dff[dff['location']==focus].index.to_numpy()[0]
        # focus_indices.append(focus_index)
        imin = np.max([focus_index - (nplotpp//2),0])
        imaxpp = np.min([imin + nplotpp, len(loclist)])

    #     idx_plot = dff.iloc[list(dff['total_deaths_per_million'] <= bar_num)].iloc[i]
    idx_plot = np.arange(imin,imaxpp)
    dfs = dff.iloc[idx_plot]
    #     dfrest = dff.iloc[np.array(set(np.arange(len(dff))) - set(idx_plot))]
    
    # Plot horizontal bars
    ax.clear()
    y_pos = nplotpp - np.arange(nplotpp) 
    
    fcol = focus_color_shift(loclist.index(focus))
    bar_colors = [color_dict[k] for k in dfs['location'].to_numpy()]
    tick_labels = [str(k) for k in list(dfs.index.to_numpy())]

    bars = ax.barh(y_pos, dfs['total_deaths_per_million'], color=bar_colors, alpha=0.9, tick_label=tick_labels)

    #     ax.set_yticks(y_pos, labels=list(dfs.index.to_numpy()))
    dx = dfs['total_deaths_per_million'].max() / 200
    
    #     bars = ax.barh(imin - idx_rest, dfs['total_deaths_per_million'], color=[color_dict[k] for k in dfs['location'].to_numpy()], alpha=0.9, tick_label=[str(k) for k in list(dfrest.index.to_numpy())])
    
    
    y_pos = np.array([y_dict[k] for k in dfs['location'].to_numpy()])
    foc_pos = y_dict[focus]

    if focus:
        dy = foc_pos - focus_index
        bottom = ax.get_ylim()[0] - dy + (nplotpp - nplot)/2 
        top = ax.get_ylim()[1] - dy - (nplotpp - nplot)/2
    
    # update positions for moving bars
    for i, (bar, location) in enumerate(zip(bars, dfs['location'])):
        bar.set_y(imaxpp - y_pos[i] - bar.get_height()/2)

        width = bar.get_width()

        fw = 'light'
        fc = 'white'
        if location == focus:
            foc_y = bar.get_y()
            fw = 'bold'
            if fcol:
                focfc = 0.7*np.array(bar.get_fc())
                focfc += 0.3*np.array(fcol)
                fc = fcol
                # bar.set_fc(focfc)
        if bottom < bar.get_y() < top:
            ax.text(width + 3*dx, bar.get_y() + bar.get_height() / 2, location, size=14, weight=fw, color=fc, ha='left', va='center')
            ax.annotate(f'{width:.0F}',
                        xy = (width , bar.get_y() + bar.get_height() / 2),
                        xytext = (-25, 0),
                        textcoords = "offset points",
                        fontsize = 'x-large',
                        fontweight = fw,
                        color = fc,
                        ha = 'right',
                        va = 'center')    
        
    if focus and magnify:
        for bar in bars:
            magfactor = (0.8 + 0.4*np.exp(-0.5/0.5*(bar.get_y()-foc_y)**2))
            bar.set_height(bar.get_height()*magfactor)
            bar.set_y(bar.get_y() - (magfactor-1)*bar.get_height()/2.0)
#             bar.set_width(bar.get_width()*magfactor)
                
        
    ax.text(1, 0.2, date.strftime("%d %B, %Y"), transform=ax.transAxes, color='#AAAAAA', size=24, ha='right', weight=800)
    ax.text(0, 1.06, 'Deaths per million (total)', transform=ax.transAxes, size=12, color='#AAAAAA')
    
    ax.xaxis.set_major_formatter(ticker.StrMethodFormatter('{x:,.0f}'))
    ax.xaxis.set_ticks_position('top')
    ax.tick_params(axis='x', color='#777777', labelsize=12)
    ax.tick_params(axis='y', color='#777777', labelsize=16, width=3, length=5)

    #     ax.tick_params(labelsize = 'medium')    
    #     ax.set_yticks([str(k) for k in list(dfs.index.to_numpy())])
    ax.grid(True, axis = 'x')
    
    if focus:
        ax.set_xlim(0, 2.0*dff.iloc[focus_index]['total_deaths_per_million'])
        ax.set_ylim(bottom, top)

   
    ax.margins(0, 0.01)
    ax.grid(which='major', axis='x', linestyle='-')
    
    ax.set_axisbelow(True)
    
    ax.text(0, 1.15, 'World ranking in COVID-19 deaths per million population', transform=ax.transAxes, size=24, weight=600, ha='left', va='top')
    ax.text(1, 0, '@magathos', transform=ax.transAxes, color='#777777', ha='right', va='bottom', bbox=dict(facecolor='white', alpha=0.8, edgecolor='white'))
    ax.set_frame_on(False)
    
    return

In [None]:
#plt.style.available
plt.style.use('dark_background')
N_plot = 11
frameno = 1
initData(initFrame=(frameno-1))
fig = plt.figure(figsize=(14, (N_plot+3)/2.))
ax = fig.add_subplot(111)

draw_barchart(datelist[frameno], focus=focus_country.english_short_name, nplot=N_plot) 


## Animate

To animate, we will use [`FuncAnimation`][FuncAnimation] from `matplotlib.animation`.

[`FuncAnimation`][FuncAnimation] makes an animation by repeatedly calling a function (that draws on canvas). 
In our case, it'll be `draw_barchart`.

`frames` arguments accepts on what values you want to run `draw_barchart`.

Run cell below.

[FuncAnimation]: https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.animation.FuncAnimation.html

In [None]:
# Single plot animator
initData(0)
N_plot=11 # USE update_plot() wrapper as animation callable
plt.style.use('dark_background')

fig_single, ax = plt.subplots(figsize=(14,10))
animator = animation.FuncAnimation(fig_single, draw_barchart_padded, frames=datelist[:], interval=1000./fps)
html_animator = HTML(animator.to_jshtml())

In [None]:
# Joint plot animator

def initStemplot(date_idx):
    x = df_interp_bydate.get_group(datelist[date_idx]).groupby('continent').get_group(mycontinent)
    # ax_all.hlines()
    ax_stems = ax_new.twinx()

    bbars = ax_new.bar(np.arange(len(x)), list(x['new_deaths_smoothed_per_million']), color='g', alpha=0.6)
    stems = ax_stems.stem(x['location'].replace({'Bosnia and Herzegovina': 'Bosnia & Herzegovina'}), x['total_deaths_per_million'], basefmt='None', use_line_collection=False) #  

    ax_new.set_xlabel(mycontinent, fontsize=16)
    ax_new.xaxis.set_label_coords(0.5,-.75)
    ax_new.set_ylim(0,20)
    ax_new.set_ylabel('Daily deaths/million', fontsize=12)
    ax_new.spines['top'].set_visible(False)
    ax_new.spines['right'].set_visible(False)

    stemcolor = stems.markerline.get_color()
    ax_stems.set_ylim(auto=True)
    ax_stems.spines['top'].set_visible(False)
    ax_stems.spines['bottom'].set_visible(False)
    ax_stems.spines['left'].set_visible(False)
    ax_stems.spines['right'].set_bounds((0, ax_stems.get_ylim()[1]))
    ax_stems.spines['right'].set_position(('outward', -20))
    ax_stems.tick_params(axis='y', color=stemcolor, labelcolor=stemcolor)
    ax_stems.spines['right'].set_color(stemcolor)
    ax_stems.set_ylabel('Total', color=stemcolor, fontsize=10)

    ax_new.spines['left'].set_bounds((0, ax_new.get_ylim()[1]))
    ax_new.spines['left'].set_position(('outward', -20))
    ax_new.spines['bottom'].set_bounds((0, len(x)-1))
    ax_new.spines['bottom'].set_position(('outward', 10))

    ppp = plt.setp(ax_new.get_xticklabels(), rotation=60, ha="right", rotation_mode="anchor")
    
    return bbars, stems

def draw_stemgraph(date, continent='Europe'):
    
    dff = df_interp_bydate.get_group(date)
    dfcont = dff.groupby('continent').get_group(continent).sort_values(by='location')
    dfcont.reset_index(drop=True, inplace=True)
    y = dfcont['new_deaths_smoothed_per_million']
    y2 = dfcont['total_deaths_per_million']
    for i,line in enumerate(stems.stemlines):
        line.set_ydata((0,y2[i]))
    for i, b in enumerate(bbars):
        b.set(height=y1[i])
    stems.markerline.set_ydata(y)
    
    return stems.stemlines, stems.markerline, bbars

def draw_joint_plot(date):
    ax = draw_barchart_padded(date)
    lines, markers, bars = draw_stemgraph(date)
    
    return lines + [markers] + bars.patches


In [None]:
    
initData(0)
N_plot=11 # USE update_plot() wrapper as animation callable
plt.style.use('dark_background')

fig, (ax, ax_new) = plt.subplots(nrows=2, figsize=(14,12), gridspec_kw={'height_ratios':[4, 1]})

mycontinent='Europe'
bbars, stems = initStemplot(0)

animator_joint = animation.FuncAnimation(fig, draw_joint_plot, frames=datelist, blit=True, interval=1000./fps)
# html_animator_joint = HTML(animator_joint.to_jshtml())

In [None]:
html5_video = animator.to_html5_video()

In [None]:
f_anim = os.path.join(os.getcwd(), 'animation.mp4')
f_gif = os.path.join(os.getcwd(), 'animation.gif')
writervideo = animation.FFMpegWriter(fps=fps) 

In [None]:
animator_joint.save(f_anim, writer=writervideo)

# Stem graph for all Europe and more

In [None]:
def draw_stemgraph(date, continent='Europe'):
    
    dff = df_interp_bydate.get_group(date)
    dfcont = dff.groupby('continent').get_group(continent).sort_values(by='location')
    dfcont.reset_index(drop=True, inplace=True)
    y = dfcont['new_deaths_smoothed_per_million']
    y2 = dfcont['total_deaths_per_million']
    for i,line in enumerate(stems.stemlines):
        line.set_ydata((0,y2[i]))
    for i, b in enumerate(bbars):
        b.set(height=y1[i])
    stems.markerline.set_ydata(y)
#     ax_all.stem(dfcont['location'], , basefmt='None', use_line_collection=False)
#     ax_all.set_ylim(auto=True)
#     plt.setp(ax_all.get_xticklabels(), rotation='vertical', ha="right", rotation_mode="anchor")
    
    return stems.stemlines, stems.markerline, bbars


## Set up stem plot

In [None]:
mycontinent='Europe'
ax_new.clear()

x = df_interp_bydate.get_group(datelist[1000]).groupby('continent').get_group(mycontinent)
# ax_all.hlines()
ax_stems = ax_new.twinx()
ax_stems.clear()

bbars = ax_new.bar(np.arange(len(x)), list(x['new_deaths_smoothed_per_million']), color='g', alpha=0.6)
stems = ax_stems.stem(x['location'].replace({'Bosnia and Herzegovina': 'Bosnia & Herzegovina'}), x['total_deaths_per_million'], basefmt='None', use_line_collection=False) #  

ax_new.set_xlabel(mycontinent, fontsize=16)
ax_new.set_ylim(0,20)
ax_new.set_ylabel('Daily deaths/million', fontsize=12)
ax_new.spines['top'].set_visible(False)
ax_new.spines['right'].set_visible(False)

stemcolor = stems.markerline.get_color()
ax_stems.set_ylim(auto=True)
ax_stems.spines['top'].set_visible(False)
ax_stems.spines['bottom'].set_visible(False)
ax_stems.spines['left'].set_visible(False)
ax_stems.spines['right'].set_bounds((0, ax_stems.get_ylim()[1]))
ax_stems.spines['right'].set_position(('outward', -20))
ax_stems.tick_params(axis='y', color=stemcolor, labelcolor=stemcolor)
ax_stems.spines['right'].set_color(stemcolor)
ax_stems.set_ylabel('Total', color=stemcolor, fontsize=10)

ax_new.spines['left'].set_bounds((0, ax_new.get_ylim()[1]))
ax_new.spines['left'].set_position(('outward', -20))
ax_new.spines['bottom'].set_bounds((0, len(x)-1))
ax_new.spines['bottom'].set_position(('outward', 10))

ppp = plt.setp(ax_new.get_xticklabels(), rotation=60, ha="right", rotation_mode="anchor")
fig

In [None]:
animate_stem = animation.FuncAnimation(fig_test, draw_stemgraph, frames=datelist[:100], interval=1000./fps)
# frames=tqdm(range(10), initial=1, position=0)
html_animstem = HTML(animate_stem.to_jshtml())
# or use animator.to_html5_video() or animator.save() 
html_animstem

In [None]:
# ax_all.hlines(df1eu['location'], 0, df1eu['total_deaths_per_million'])
# change the style of the axis spines
# my_range=list(range(1,len(df1eu.index)+1))
# ax_all.spines['left'].set_bounds((1, len(my_range)))
# ax_all.set_xlim(0,25)
# add some space between the axis and the plot

## List of references
- twin y-scales in single plot https://matplotlib.org/3.5.1/gallery/subplots_axes_and_figures/two_scales.html
- beautify bar/stem plot https://scentellegher.github.io/visualization/2018/10/10/beautiful-bar-plots-matplotlib.html
- gradient bars https://matplotlib.org/stable/gallery/lines_bars_and_markers/gradient_bar.html