# Combined figure with SMC fit and CIs for $t \leq t_{\max}$

with multiple regions in one figure

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import re
import scipy.stats as sts
import xml.etree.ElementTree as ET
import warnings
import pickle
import copy
import csv
import datetime
import json
import string
from scipy.interpolate import UnivariateSpline
from scipy.optimize import minimize_scalar, root_scalar

from matplotlib.patches import Patch
from matplotlib.lines import Line2D
import matplotlib.ticker as ticker

from mpl_toolkits.axes_grid1.inset_locator import inset_axes


import sys, importlib
sys.path.append("..")
from evpytools import evplot
from evpytools import pftools
from evpytools import auxiliary as aux
from evpytools import definitions as defn
for mod in [evplot, pftools, aux, defn]:
    importlib.reload(mod)

In [None]:
plt.rcParams.update({'font.size': 18})

### Import data files

In [None]:
fdata_files = [
    "../data/in/sars2-seq-death-week-Netherlands-B.1.351.tsv",
    "../data/in/sars2-seq-death-week-Japan-R.1.tsv"
]

#fdata_files = [
#    "../data/in/sars2-seq-death-week-United_Kingdom-B.1.1.7.tsv",
#    "../data/in/sars2-seq-death-week-Netherlands-B.1.1.7.tsv"
#]

#fdata_files = [
#    "../data/in/sars2-seq-death-week-United_Kingdom-D614G.tsv",
#    "../data/in/sars2-seq-death-week-Netherlands-D614G.tsv"
#]


fdatadicts = [
    pftools.import_filter_data(fdata_file)
    for fdata_file in fdata_files
]

### Import filter results

In [None]:
pfout_files = [
    "../data/out/ipf_result-sars_model_Netherlands_B.1.351.xml",
    "../data/out/ipf_result-sars_model_Japan_R.1.xml"
]

#pfout_files = [
#    "../data/out/ipf_result-sars_model_United_Kingdom_B.1.1.7.xml",
#    "../data/out/ipf_result-sars_model_Netherlands_B.1.1.7.xml"
#]

#pfout_files = [
#    "../data/out/ipf_result-sars_model_UK-614-wk.xml",
#    "../data/out/ipf_result-sars_model_NL-614-wk.xml"
#]

idx = -1 ## select one of the PF iterations

pf_datas = [
    pftools.extract_pfilter_data(pfout_file)
    for pfout_file in pfout_files
]

### Import profile likelihood results

In [None]:
prof_lik_files = [
    "../data/out/profile-lik-tmax-results_Netherlands_B.1.351.json",
    "../data/out/profile-lik-tmax-results_Japan-R.1.json"
]

#prof_lik_files = [
#    "../data/out/profile-lik-tmax-results_United_Kingdom-B.1.1.7.json",
#    "../data/out/profile-lik-tmax-results_Netherlands-B.1.1.7.json"
#]

#prof_lik_files = []

proflik_result_dicts = []
for prof_lik_file in prof_lik_files:
    with open(prof_lik_file, 'r') as f:
        proflik_result_dict = json.load(f)
        proflik_result_dicts.append(proflik_result_dict)


### Functions for creating the figure

In [None]:
def plot_data(axs, dds):
    # deaths
    ax = axs[0]
    ws = [row["t"] for row in dds if row["deaths_cc"] == defn.uncensored_code]
    Ds = [row["deaths"] for row in dds if row["deaths_cc"] == defn.uncensored_code]
    ax.scatter(ws, Ds, color='k', edgecolor='k', zorder=4, label='data', s=20)    
    # mutant freq
    ax = axs[1]
    ts = [row["t"] for row in dds if row["Ntot"] > 0]
    Fms = [row["Nmut"] / row["Ntot"] for row in dds if row["Ntot"] > 0]
    ## CIs for mutant frequency
    lFms = [sts.beta.ppf(0.025, row["Nmut"]+0.5, row["Ntot"] - row["Nmut"]+0.5) 
            for row in dds if row["Ntot"] > 0]
    uFms = [sts.beta.ppf(0.975, row["Nmut"]+0.5, row["Ntot"] - row["Nmut"]+0.5) 
            for row in dds if row["Ntot"] > 0]
    for t, l, u in zip(ts, lFms, uFms):
        ax.plot([t,t], [l,u], color='k', alpha=1)
    ax.scatter(ts, Fms, color='k', edgecolor='k', zorder=4, label='data', 
               s=40,  marker='_')

    
def plot_trajectories(ax, pf_data, varname, date0, color="tab:blue", 
                      pretty_varname=None):
    ID = pf_data["pfIDs"][0] ## select single ID
    ## latent paths
    trajcolor = color
    alpha_traj = 0.7
    if pretty_varname is None:
        pretty_varname = varname
    
    ## model predictions of the data
    for j, path in enumerate(pf_data["paths"][ID]):
        lab = None if j > 0 else pretty_varname
        ## extract timeseries
        xs = path.findall("state")
        ts = [float(x.attrib["t"]) for x in xs]
        Xs = [float(x.find(f"var_vec[@name='{varname}']/var").attrib["val"]) 
              for x in xs]
        ## plot
        ax.plot(ts, Xs, color=trajcolor, alpha=alpha_traj, 
                linewidth=0.5, zorder=1, label=lab)

            
def plot_predictions(axs, pf_data, dds):
    dt = 1
    varcolor = ['purple', 'tab:blue']
    obsvarnames = ['D', 'Fm']
    ID = pf_data["pfIDs"][0] ## select single ID
    ts = [float(x.attrib["t"]) for x in pf_data["pred_medians"][ID]]
    for i, X in enumerate(obsvarnames):
        ws = [row["t"] for row in dds if row["deaths_cc"] == defn.uncensored_code]
        mask = [False if t in ws else True for t in ts]
        ax = axs[i]
        rans = pf_data["ranges"][ID]
        Xs_ran = [[float(x.find(f"var_vec[@name='{X}']/var").attrib["val"]) 
                   for x in ran] for ran in rans]
        Xs_pred = [float(x.find(f"var_vec[@name='{X}']/var").attrib["val"])
                   for x in pf_data["pred_medians"][ID]]
        Xs_filt = [float(x.find(f"var_vec[@name='{X}']/var").attrib["val"])
               for x in pf_data["filter_medians"][ID]]
        evplot.pfilter_boxplot(ax, ts, Xs_ran, Xs_pred, Xs_filt, mask=mask,
                               color=varcolor[i], dt=dt)

        
def plot_CIs(ax, LLss, tmaxs, sigmas, max_diff=11):
    DL = sts.chi2.ppf(0.95,1)/2

    for i, LLs in enumerate(LLss):
        ## compute means
        meanLLs = np.mean(LLs, axis=1)
    
        ## remove very small LLs
        sigs, lls = aux.unzip([(s, l) for s, l in zip(sigmas, meanLLs) if l >= np.max(LLs)-max_diff])
   
        bounds = (sigs[0], sigs[-1])

        cs = UnivariateSpline(sigs, lls, s=10, ext='raise')
        xs = np.linspace(*bounds, 250)

        ## find max of spline and CI
        res = minimize_scalar(lambda x: -cs(x), bounds=bounds, method='bounded')
        max_LL = -res.fun
        sigma_opt = res.x

        sign = 0 < bounds[0] or cs(0) < max_LL-DL
        
        ax.plot(cs(xs)-max_LL+tmaxs[i]+DL, xs, label='spline', color='k', linewidth=2)
                
        print(f"s_opt = {sigma_opt:0.2f}")
        print(f"max LL = {max_LL:0.2f}")
    
        try:
            lres = root_scalar(lambda x: cs(x)-max_LL + DL, bracket=[sigs[0], sigma_opt])
            lCI = lres.root
        except:
            print("unable to compute lower bound CI!")
            lCI = np.nan
        try:
            rres = root_scalar(lambda x: cs(x)-max_LL + DL, bracket=[sigma_opt, sigs[-1]])        
            rCI = rres.root
        except:
            print("unable to compute upper bound CI!")
            rCI = np.nan
            
        print(f"95% CI = [{lCI:0.2f}, {rCI:0.2f}]")
        
        if not np.isnan(lCI) and lCI > 0.0:
            ax.text(tmaxs[i], 1.005, "*", fontsize=18, ha='center',
                    transform=evplot.hybrid_trans(ax))
            
            
        ## plot dots
        ax.scatter(np.array(lls)-max_LL+tmaxs[i]+DL, sigs, color='k', s=5)
        ax.axvline(x=tmaxs[i], color='k', alpha=0.4)
        
        


In [None]:
data_markers = ['o', '|']
#legend_locs = [1, 1, 4] ## D614G
legend_locs = [1, 1, 2] ## others
data_colors = ['w', 'k']
trajcolor = ["pink", "deepskyblue"]
varcolor = ['purple', 'tab:blue']
varnames = ["D", "Fm"]

#regions = ["United Kingdom D614G", "Netherlands D614G"]
#regions = ["United Kingdom B.1.1.7", "Netherlands B.1.1.7"]
regions = ["Netherlands B.1.351", "Japan R.1"]

## insets only used for D614G
xlim_insets = [(65,75), (58,68)]
ylim_insets = [(0,10000), (0,1000)]

## plot profile-likelihood results?
#plot_prof_lik = False
plot_prof_lik = True

## plot an inset with a close-up of the population sizes?
plot_inset = False
#plot_inset = True

## scale the y-axis limits to [0,1]?
#unit_freq_limits = True
unit_freq_limits = False

date0 = datetime.datetime.strptime("01/01/2020", "%m/%d/%Y")

numrows = 4 if plot_prof_lik else 3

fig, axs = plt.subplots(numrows, len(regions), figsize=(7*len(regions),10), sharex='col')

if len(regions) == 1:
    axs = np.array([axs]).T

for r, region in enumerate(regions):
    plot_data(axs[1:,r], fdatadicts[r])
    for i, varname in enumerate(varnames):
        plot_trajectories(axs[i+1,r], pf_datas[r], 
                          varname, date0, color=trajcolor[i])
    plot_predictions(axs[1:,r], pf_datas[r], fdatadicts[r])

    plot_trajectories(axs[0,r], pf_datas[r], "Iw", date0, color='tab:orange', 
                      pretty_varname="$I_{\\rm wt}$")
    plot_trajectories(axs[0,r], pf_datas[r], "Im", date0, color='tab:blue',
                      pretty_varname="$I_{\\rm mt}$")

    axs[0,r].legend()
    axs[0,r].yaxis.set_major_formatter(ticker.FuncFormatter(evplot.y_fmt))
    axs[0,r].tick_params(axis="y", labelsize=12)

    ## dates in x-axis

    days = [dd["t"] for dd in fdatadicts[r]]
    dates = [date0 + datetime.timedelta(days=d) for d in days]
    xticks = days[::2] ## every 2 weeks
    xticklabels = [d.strftime("%b %d") for d in dates[::2]]

    axs[-1,r].set_xlabel("date")
    axs[-1,r].set_xticks(xticks)
    axs[-1,r].set_xticklabels(xticklabels, fontsize='x-small', rotation=45, ha='right')
    
    ## add legends
    leg = axs[0,r].legend(ncol=1, loc=legend_locs[0], fontsize='x-small')
    for lh in leg.legendHandles: 
        lh.set_alpha(1)
        lh.set_linewidth(1)

    for i, ax in enumerate(axs[1:3,r]):
        ## Legend
        legend_elements = [
            Line2D([0], [0], marker=data_markers[i], color=data_colors[i], label='data', 
                   markerfacecolor='k', markeredgecolor='k', markersize=7),
            Line2D([0], [0], color=varcolor[i], label='model'),
        ]
        ax.legend(handles=legend_elements, ncol=1, fontsize='x-small', loc=legend_locs[i+1])  


    ## profile likelihoods
    if plot_prof_lik:
        proflik_result_dict = proflik_result_dicts[r]
        LLss = proflik_result_dict["LLss"]
        tmaxs = proflik_result_dict["tmaxs"]
        sigmas = proflik_result_dict["sigmas"]

        ## replace tmax with the largest observation time $\leq$ tmax
        tmaxs = [np.max([t for t in days if t <= tm])
                 for tm in tmaxs]

        plot_CIs(axs[-1,r], LLss, tmaxs, sigmas)

        axs[-1,r].axhline(y=0, color='red', linewidth=0.5)
        
    ## inset
    if plot_inset:
        axins = inset_axes(axs[0,r], width="20%", height="35%", loc=1,
                           bbox_to_anchor=(0,0,0.8,1), bbox_transform=axs[0,r].transAxes)
        plot_trajectories(axins, pf_datas[r], "Iw", date0, color='tab:orange')
        plot_trajectories(axins, pf_datas[r], "Im", date0, color='tab:blue')
        axins.set_xlim(*xlim_insets[r])
        axins.set_ylim(*ylim_insets[r])
        axins.tick_params(axis='both', which='major', labelsize='xx-small')
        ## dates as xticklabels
        xmin, xmax = xlim_insets[r]
        xticks = range(xmin+1, xmax, 4)
        xtickdates = [date0 + datetime.timedelta(days=x) for x in xticks]
        xticklabels = [d.strftime("%b %d") for d in xtickdates]
        axins.set_xticks(xticks)
        axins.set_xticklabels(xticklabels, rotation=45, ha='right')

        
    ## title
    axs[0,r].set_title(region)
    
# y-labels
ylabs = [
    "population\nsize",
    "death\nincidence", 
    "mutant\nfrequency", 
    "selection ($s$)"
]
for ax, ylab in zip(axs[:,0], ylabs):
    ax.set_ylabel(ylab, fontsize='small')

if unit_freq_limits:
    for ax in axs[2,:]:
        ax.set_ylim(-0.05, 1.05)

    
fig.align_ylabels(axs)
    
## add labels
subplot_labels = string.ascii_uppercase

for i, ax in enumerate(axs.flatten()):
    ax.text(-0.15, 1.02, subplot_labels[i], fontsize=22, transform=ax.transAxes)
    
fig.savefig("../data/out/SMCFitTmax.pdf", bbox_inches='tight')