In [1]:
# Intended to be `%run lib.ipynb` from another notebook to initialize the database `db` and common options.
# Optionally set `INTERACTIVE_PLOTS = True` before to get `%matplotlib notebook` behavior (the default is `inline`).

import csv
import datetime
import math
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

from IPython.display import display, Markdown, Latex
from collections import defaultdict
from operator import itemgetter
from scipy.optimize import curve_fit
from scipy.special import exp10

# Ideally, this would use `%matplotlib notebook` as that makes the charts interactive in the notebook,
# but it also means they notebook must be rendered in the browser as it depends on JavaScript execution.
# That makes it much more difficult to automate data updates. This setting also requires that
# `plt.tight_layout()` is used instead `plt.show()`, and the plot sizes need to be different...
#
# The default is `%matplotlib inline`, which makes just simple plots, but is friendly to CLI rendering.

try:
    if INTERACTIVE_PLOTS:
        %matplotlib notebook
except NameError:
    INTERACTIVE_PLOTS = False

mpl.rcParams['savefig.dpi'] = 80
mpl.rcParams['figure.dpi'] = 80
mpl.rcParams['figure.max_open_warning'] = 0


# Returns a plt, ax pair with a size that works well. For `%matplotlib notebook`, `(12, 8)` is a good option.
def get_plot():
    return plt.subplots(figsize=(12, 8)) if INTERACTIVE_PLOTS else plt.subplots(figsize=(14, 9))


def show_plot(plt):
    if INTERACTIVE_PLOTS:
        plt.tight_layout()
    else:
        plt.show()


# CSV keys.
KEY_COUNTRY = 1
KEY_DATES_START = 4
# Number of days to extend projections by.
DAYS_EXTEND = 7
# `plot` kwargs to use for data lines to make them stand out vs. projections.
DATALINE_KWARGS = {'marker': 'x', 'markersize': 7, 'linewidth': 3}


# Returns a proper `datetime` from `date_str` in a format like
# '1/22/20', which represents January 22nd, 2020.
def make_date(date_str):
    parts = date_str.split('/')
    assert len(parts) == 3
    return datetime.date(2000 + int(parts[2]), int(parts[0]), int(parts[1]))


def delta_series(lst):
    if len(lst) < 2:
        return []
    return [x-y for x,y in zip(lst[1:], lst)]


def trailing_average_series(lst, k):
    s = 0
    ret = []
    for i, x in enumerate(lst):
        s += x
        if i >= k:
            s -= lst[i-k]
        ret.append(s / k)
    return ret


class Country(object):
    def __init__(self, name, population, continent, series_data):
        self.name = name
        self.population = population
        self.continent = continent
        self.timeseries = series_data
        self.timeseries['active'] = (
            [c-a-d for c, a, d in zip(self.timeseries['confirmed'], self.timeseries['recovered'], self.timeseries['dead'])])
        # Add 0 to keep it aligned with `db.dates`.
        daily_dead = delta_series([0] + self.timeseries['dead'])
        self.timeseries['daily dead (7-day average)'] = trailing_average_series(daily_dead, 7)

    def get(self, series, perm=False):
        factor = 1000.0/self.population if perm else 1.0
        return [v*factor for v in self.timeseries[series]]
    
    # Return new cases per day in the series, starting with the second date, up to the latest date.
    def get_new(self, series):
        return delta_series(self.timeseries[series])
            

def read_one_data_series(series):
    country_data = dict()
    with open(f'data/{series}.csv') as fin:
        reader = csv.reader(fin, delimiter=',', quotechar='"')
        header = next(reader)
        dates = [make_date(date_str) for date_str in header[KEY_DATES_START:]]
        for line in reader:
            name = line[KEY_COUNTRY]
            vals = [int(v) for v in line[KEY_DATES_START:]]
            assert len(vals) == len(dates)
            if name not in country_data:
                country_data[name] = vals
            else:
                # Another province/state for the same country.
                # We just sum this up.
                cur = country_data[name]
                assert len(cur) == len(vals)
                for i in range(len(cur)):
                    cur[i] += vals[i]
    return dates, country_data


class DB(object):
    def __init__(self):
        self.dates = None
        # {country name -> {series name -> timeseries numbers}}
        per_country = defaultdict(dict)
        for series in ['confirmed', 'recovered', 'dead']:
            dates, country_data = read_one_data_series(series)
            if self.dates is None:
                self.dates = dates
            else:
                # We're assuming all the dates are present in all the datasets.
                assert self.dates == dates
            if len(per_country) > 0:
                assert set(per_country.keys()) == set(country_data.keys())
            for country, vals in country_data.items():
                per_country[country][series] = vals

        # {country name -> population}
        populations = dict()
        with open('data/country_population.txt') as fin:
            for line in fin:
                parts = line.strip().split(' ')
                populations[' '.join(parts[:-1])] = int(parts[-1])

        # {country name -> continent}
        continents = dict()
        with open('data/country_continents.txt') as fin:
            for line in fin:
                parts = line.strip().split('\t')
                assert len(parts) == 2, (line, parts)
                continents[parts[1]] = parts[0]

        # {country name -> Country}
        self.countries = dict()
        for name, series_data in per_country.items():
            try:
                pop = populations[name]
            except KeyError:
                # print('Skipping country (no population data):', name)
                continue
            try:
                continent = continents[name]
            except KeyError:
                # print('Skipping country (no continent data):', name)
                continue
            self.countries[name] = Country(name, pop, continent, series_data)

    def country(self, country_name):
        return self.countries[country_name]


def identity(lst):
    return lst


def log(lst):
    return [None if x <= 0 else math.log10(x) for x in lst]


def latex_scientific(numstr):
    p = numstr.find('e')
    if p == -1:
        return numstr
    k = numstr[:p]
    e = int(numstr[p+1:])
    return fr'{k} \times 10^{{{e}}}'


def exponential(t, a, b):
    return a * exp10(b * t)


def display_exponential(prefix, a, b):
    doubling_days = math.log10(2) / b
    latex_a = latex_scientific(f'{a:.3g}')
    display(Latex(fr'{prefix}: \({latex_a} \times 10^{{{b:.3f}t}}\) (doubling rate \({doubling_days:.1f}\) days)'))


def sigmoid(t, a, b, c):
    return a / (1.0 + exp10(-b * (t - c)))


def display_sigmoid(prefix, a, b, c):
    display(Latex(fr'{prefix}: \(\dfrac{{{a:,.1f}}}{{1 + 10^{{-{b:.3f} (t - {c:.1f})}}}}\) (asimptote \({a:,.1f}\))'))


# Returns dates and data starting from the first day where the value is at least
# `start_val_perm` per thousand in the given country, regardless of whether the data
# needs to be returned per thousand or not (which is controlled with `perm`).
def get_start_and_align(country_name, series, start_val_perm, perm):
    country = db.country(country_name)
    data = country.get(series, perm=perm)
    start_val = start_val_perm if perm else start_val_perm * country.population / 1000.0
    for i, (t, val) in enumerate(zip(db.dates, data)):
        if val >= start_val:
            return db.dates[i:], data[i:]
    print(f'{country_name} doesn\'t yet have {start_val:.4f} {series} per thousand.')
    return [], []


def get_fit_params(fn, data, **kwargs):
    params, _ = curve_fit(fn, np.arange(1, len(data) + 1), data, **kwargs)
    return params


def should_fit_lines(data):
    return len(data) >= 5 and data[-1] >= 10


# The exponential fit is commented out as it doesn't make sense anymore and causes "artifacts"
# with multiple infection waves.
def backpredict_exp_vs_sigmoid(ax, dates_all, data_all, backpredict_days):
    got_any_sigmoid = False
    ts_extended = np.arange(1, len(data_all) + 1 + DAYS_EXTEND)
    dates_extended = dates_all[::] + [dates_all[-1] + datetime.timedelta(days=d) for d in range(1, DAYS_EXTEND + 1)]
    
    backpredict_days = min(backpredict_days, len(data_all) - 2)
    for backd in range(0, backpredict_days + 1):
        data = data_all if backd == 0 else data_all[:-backd]
        if not should_fit_lines(data):
            break
        # aexp, bexp = get_fit_params(exponential, data, p0=(1, 0.1))
        try:
            asig, bsig, csig = get_fit_params(sigmoid, data, p0=(data[-1], 0.2, 21))
        except:
            asig = None

        min_alpha = 0.3
        alpha = (backpredict_days + 1 - backd) / (backpredict_days + 1) * (1.0 - min_alpha) + min_alpha
        lw = alpha * 2.0
        style = '-' if backd == 0 else '--'
        # ax.plot(
        #     dates_extended,
        #     exponential(ts_extended, aexp, bexp),
        #     'C1',
        #     linestyle=style,
        #     linewidth=lw,
        #     alpha=alpha,
        #     label='Best fit exponential' if backd == 0 else None)
        if asig is not None:
            ax.plot(
                dates_extended,
                sigmoid(ts_extended, asig, bsig, csig),
                'C2',
                linestyle=style,
                linewidth=lw,
                alpha=alpha,
                label='Best fit sigmoid' if not got_any_sigmoid else None)
            got_any_sigmoid = True
    return got_any_sigmoid
    
    
# `start_info` must be a pair where the 0th element is the number of cases per thousand where the
# analysis starts, and the 1st element is a textual description of that condition.
# `backpredict_days` is the number of the most recent days for which exponential and sigmoid fits are
# plotted, in addition to the latest day.
#
# The exponential fit is commented out as it doesn't make sense anymore and causes "artifacts"
# with multiple infection waves.
def analyze_country(country, series, start_info, backpredict_days=0, fit_lines=True):
    dates, data = get_start_and_align(country, series, start_val_perm=start_info[0], perm=False)
    if not data:
        return
    if not should_fit_lines(data):
        fit_lines = False
    display(Markdown(f'### {series.title()}'))
    display(Markdown(f'Start date {dates[0]} (1st day with {start_info[1]})'))
    display(Markdown(fr'Latest number \\({data[-1]:,.0f}\\) on {dates[-1]}'))
    
    if fit_lines:
        # This computation is duplicated in `backpredict_exp_vs_sigmoid` a few lines down, but we want
        # to write out the formulas above the chart, and this seems the easiest way to do that.
        #
        # aexp, bexp = get_fit_params(exponential, data, p0=(1, 0.1))
        # display_exponential('Best fit exponential', aexp, bexp)
        try:
            asig, bsig, csig = get_fit_params(sigmoid, data, p0=(data[-1], 0.2, 21)) 
            display_sigmoid('Best fit sigmoid', asig, bsig, csig)
        except:
            pass
  
    _, ax = get_plot()
    if fit_lines:
        got_any_sigmoid = backpredict_exp_vs_sigmoid(ax, dates, data, backpredict_days)
    ax.plot(dates, data, 'C0', label=f'{country} {series}', **DATALINE_KWARGS)
    
    if fit_lines:
        # ax.set_title(f'Exponential{" and sigmoid" if got_any_sigmoid else ""} fit for {series} in {country},'
        #              + f' starting from {start_info[1]}')
        if got_any_sigmoid:
            ax.set_title(f'Sigmoid fit for {series} in {country}, starting from {start_info[1]}')
    else:
        ax.set_title(f'{series.title()} for {country}, starting from {start_info[1]}')
    ax.legend()
    show_plot(plt)
    
    if len(data) > 1:
        _, ax = get_plot()
        ax.set_title(f'New {series} per day ({country})')
        ax.bar(dates[1:], delta_series(data))
        show_plot(plt)


def country_deep_dive(country_name, backpredict_days=5):
    display(Markdown(f'## {country_name}'))
    display(Markdown(fr'Population \\({db.country(country_name).population:,d}\\)'))
    # Line fitting is disabled as it doesn't make sense anymore.
    analyze_country(country_name, 'confirmed', start_info=(0.001, '1 confirmed per million'), backpredict_days=backpredict_days, fit_lines=False)
    analyze_country(country_name, 'dead', start_info=(0.0001, '0.1 dead per million'), backpredict_days=backpredict_days, fit_lines=False)
    # Number of active cases analyses are no longer useful as nobody is reporting recovery.
    # analyze_country(country_name, 'active', start_info=(0.001, '1 active per million'), fit_lines=False)
    

def get_top_countries_by_relative_deaths(continent, num=8):
    top_dead = []
    for country in db.countries.values():
        if (country.name == 'Diamond Princess'
          or (continent is not None and country.continent != continent)
          or country.population < 100000):
            continue
        if country.get('confirmed')[-1] < 300:
            continue
        top_dead.append((country.name, country.get('daily dead (7-day average)', perm=True)[-1]))
    return list(c[0] for c in sorted(top_dead, key=itemgetter(1), reverse=True))[:num]

    
# Returns a list of (name, %-recovered) pairs for all "recovering" countries, where a country is defined
# as "recovering" if:
#   - of the last 7 days, the number of active cases has decreased on at least 5 (this is to account
#     for some weird reporting artifacts that can make it appear some country is "recovering" but
#     actually isn't),
#   - the number of active cases has been lower than its all-time-high number of active cases for at least
#     the last 5 days,
#   - the all-time-high was at least 300 cases (there are some countries which never had more than a few
#     cases, so they are not very interestring to look at).
#
# The countries are ordered by decreasing relative "recoverdness".
def get_recovering_countries_info(continent=None):
    recovered = []
    for country in db.countries.values():
        if country.name == 'Diamond Princess' or (continent is not None and country.continent != continent):
            continue
        active = country.get('active')
        if len(active) < 8 or sum(1 if d<0 else 0 for d in delta_series(active[-8:])) < 5:
            continue
        max_active = max(active)
        if max_active < 300:
            continue
        if all(active[-i] < max_active for i in range(1, 6)):
            recovered.append((country.name, (1 - active[-1] / max_active) * 100))
    return [r for r in sorted(recovered, key=itemgetter(1), reverse=True)]


MARKERS_SEQUENCE = 'xovspPD*X13hH<>^|_+'
# Common implementation for cross-country comparisons, whether by relative or by absolute dates.
def compare_countries_impl(
        # Name of the series, e.g. 'confirmed'.
        series,
        # A function returning parallel lists of dates and data points for the given country/series.
        get_dates_and_data_fn,
        # Given the dates and data, must return the x-axis sequence for the chart.
        get_ts_from_dates_and_data_fn,
        # Given the country name and dates, must return the line-label for that country.
        get_data_label_fn,
        # Must return the title for the linear chart.
        get_idn_title_fn,
        # Must return the title for the log chart.
        get_log_title_fn,
        # Argument to pass to the `get_dates_and_data_fn`, to configure how it picks the start date.
        # It is safe to pass `None` here if all the functions can handle it.
        start_info):
    display(Markdown(f'## {series.title()}'))

    # Arguments for the identity (i.e. linear) and log plots.
    # We do this weird thing here instead of using subplots because
    # `%matplotlib notebook` works poorly with subplots (to be able to enable notebook mode more easily).
    idn_plot_args = []
    log_plot_args = []
    plot_args = [idn_plot_args, log_plot_args]
    for cidx, country_name in enumerate(countries_to_plot):
        dates, data = get_dates_and_data_fn(country_name, series, start_info)
        if not data:
            continue
        data_label = get_data_label_fn(country_name, dates)
        
        ts = get_ts_from_dates_and_data_fn(dates, data)
        dataline_kwargs = dict(DATALINE_KWARGS)
        dataline_kwargs['marker'] = MARKERS_SEQUENCE[cidx % len(MARKERS_SEQUENCE)]
        for fnidx, fn in enumerate([identity, log]):
            plot_args[fnidx].append((ts, fn(data), f'C{cidx}', data_label, dataline_kwargs))

    def plot_one(args, title):
        _, ax = get_plot()
        for ts, data, color, label, kwargs in args:
            ax.plot(ts, data, color=color, label=label, **kwargs)
        ax.set_title(title)
        ax.legend()
        show_plot(plt)

    plot_one(idn_plot_args, get_idn_title_fn(series, start_info))
    # Log comparisons are no longer useful.
    # plot_one(log_plot_args, get_log_title_fn(series, start_info))


def compare_countries_absolute_dates(series, skip_days=0):
    def get_dates_and_data(country_name, series, start_info):
        return db.dates[skip_days:], db.country(country_name).get(series, perm=True)[skip_days:]
    
    compare_countries_impl(
        series,
        get_dates_and_data_fn=get_dates_and_data,
        get_data_label_fn=(lambda country_name, dates: country_name),
        get_ts_from_dates_and_data_fn=(lambda dates, data: dates),
        get_idn_title_fn=(lambda series, start_info: f'{series.title()} per thousand over absolute dates'),
        get_log_title_fn=(lambda series, start_info: f'Log10 {series} per thousand over absolute dates'),
        start_info=None
    )
    

def compare_countries_relative_dates(series, start_info):
    def get_dates_and_data(country_name, series, start_info):
        return get_start_and_align(country_name, series, start_info[0], perm=True)
    
    def get_idn_title(series, start_info):
        return f'{series.title()} per thousand over time (days), from the 1st day with {start_info[1]}'
    
    def get_log_title(series, start_info):
        return f'Log10 {series} per thousand over time (days), from the 1st day with {start_info[1]}'
    
    compare_countries_impl(
        series,
        get_dates_and_data_fn=get_dates_and_data,
        get_data_label_fn=(lambda country_name, dates: f'{country_name} (start date {dates[0]})'),
        get_ts_from_dates_and_data_fn=(lambda dates, data: np.arange(1, len(data) + 1)),
        get_idn_title_fn=get_idn_title,
        get_log_title_fn=get_log_title,
        start_info=start_info
    )

# 2020-01-22 is the first day in the dataset.
# `absolute_date_comparison_start_date`, if specified, is only used for the comparison over absolute dates.
# `backpredict_days` determines how many previous days we make predictions for in the per-country analyses.
def analyze_countries(absolute_date_comparison_start_date='2020-01-22', backpredict_days=5):
    if len(countries_to_plot) > 1:
        display(Markdown('# Cross-country comparison over absolute dates'))
        skip_days=db.dates.index(datetime.datetime.strptime(absolute_date_comparison_start_date, '%Y-%m-%d').date())
        # for series in ['confirmed', 'dead', 'daily dead (7-day average)', 'active']:
        for series in ['dead', 'daily dead (7-day average)', 'confirmed']:
            compare_countries_absolute_dates(series, skip_days=skip_days)
    
        display(Markdown('# Cross-country comparison with approximately aligned start days'))
        for series, start_info in [
                ('dead', (0.001, '1 death per million')),
                ('daily dead (7-day average)', (0.00001, '0.01 deaths per million')),
                ('confirmed', (0.01, '10 confirmed per million')),
                # ('active', (0.01, '10 active per million')),
        ]:
            compare_countries_relative_dates(series, start_info)
    
#     display(Markdown('''# Per-country analysis with exponential and sigmoidal projections, and new cases analysis
#
# <span style="color:red;">
# IMPORTANT: The projections are only accurate if the fit is good (it often isn't), and assuming nothing changes
# going forward. The sigmoid is omitted if a reasonable fit can't be computed, but this still doesn't mean that
# the fit is good if it is shown.
# </span>
#                 
# The dashed lines show best fit projections from a few previous days for comparison.'''))
    
    display(Markdown('# Per-country analysis'))
    for country in countries_to_plot:
        country_deep_dive(country, backpredict_days=backpredict_days)
    
    
def draw_recovering_countries(recovering_countries_info):
    display(Markdown('## List of all recovering countries (the top 4 not covered above are also analyzed below)'))
    for info in recovering_countries_info:
        print(f'{info[0]} recovered {info[1]:.0f}%')

    max_num_to_draw = 4
    for info in recovering_countries_info:
        if info[0] not in countries_to_plot:
            country_deep_dive(info[0])
            max_num_to_draw -= 1
            if max_num_to_draw == 0:
                break
    

db = DB()