# Combined figure with SMC fit and CIs for $t \leq t_{\max}$


In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import re
import scipy.stats as sts
import xml.etree.ElementTree as ET
import warnings
import pickle
import copy
import csv
import datetime
import json
from scipy.interpolate import UnivariateSpline
from scipy.optimize import minimize_scalar, root_scalar

from matplotlib.patches import Patch
from matplotlib.lines import Line2D
import matplotlib.ticker as ticker

from mpl_toolkits.axes_grid1.inset_locator import inset_axes


import sys, importlib
sys.path.append("..")
from evpytools import evplot
from evpytools import pftools
from evpytools import auxiliary as aux
from evpytools import definitions as defn
for mod in [evplot, pftools, aux, defn]:
    importlib.reload(mod)

In [None]:
plt.rcParams.update({'font.size': 18})

### Import data files

In [None]:
def import_filter_data(filename):
    with open(filename) as f:
        table = [row.split() for row in f.read().split('\n') if row != '']
    ## build list of dicts
    data_dicts = [
        {
            "region" : row[0],
            "t" : int(row[1]),
            "event" : row[2],
            "deaths" : int(row[5]),
            "deaths_cc" : int(row[6]),
            "Nmut" : int(row[7]),
            "Ntot" : int(row[9]),
        } 
        for row in table
    ]
    return data_dicts

In [None]:
fdata_file = "../data/in/sars2-seq-death-week-Netherlands-B.1.351.tsv"

fdatadicts = import_filter_data(fdata_file)

### Import filter results

In [None]:
pfout_file = "../data/out/ipf_result-sars_model_Netherlands_B.1.351.xml"

idx = -1 ## select one of the PF iterations

pf_data = pftools.extract_pfilter_data(pfout_file)

### Import profile likelihood results

In [None]:
## export result to a json file for use in other data notebook...
with open("../data/out/profile-lik-tmax-results_Netherlands_B.1.351.json", 'r') as f:
    proflik_result_dict = json.load(f)


### Functions for creating the figure

In [None]:
def plot_data(axs, dds):
    # deaths
    ax = axs[0]
    ws = [row["t"] for row in dds if row["deaths_cc"] == defn.uncensored_code]
    Ds = [row["deaths"] for row in dds if row["deaths_cc"] == defn.uncensored_code]
    ax.scatter(ws, Ds, color='k', edgecolor='k', zorder=4, label='data', s=20)    
    # mutant freq
    ax = axs[1]
    ts = [row["t"] for row in dds if row["Ntot"] > 0]
    Fms = [row["Nmut"] / row["Ntot"] for row in dds if row["Ntot"] > 0]
    ## CIs for mutant frequency
    lFms = [sts.beta.ppf(0.025, row["Nmut"]+0.5, row["Ntot"] - row["Nmut"]+0.5) 
            for row in dds if row["Ntot"] > 0]
    uFms = [sts.beta.ppf(0.975, row["Nmut"]+0.5, row["Ntot"] - row["Nmut"]+0.5) 
            for row in dds if row["Ntot"] > 0]
    for t, l, u in zip(ts, lFms, uFms):
        ax.plot([t,t], [l,u], color='k', alpha=0.3)
    ax.scatter(ts, Fms, color='k', edgecolor='k', zorder=4, label='data', 
               s=20,  marker='_')

    
def plot_trajectories(ax, pf_data, varname, date0, color="tab:blue", pretty_varname=None):
    ID = pf_data["pfIDs"][0] ## select single ID
    ## latent paths
    trajcolor = color
    alpha_traj = 0.7
    if pretty_varname is None:
        pretty_varname = varname
    
    ## model predictions of the data
    for path in pf_data["paths"][ID]:
        ## extract timeseries
        xs = path.findall("state")
        ts = [float(x.attrib["t"]) for x in xs]
        Xs = [float(x.find(f"var_vec[@name='{varname}']/var").attrib["val"]) for x in xs]
        ## plot
        ax.plot(ts, Xs, color=trajcolor, alpha=alpha_traj, linewidth=0.5, zorder=1)

            
def plot_predictions(axs, pf_data, dds):
    dt = 1
    varcolor = ['purple', 'tab:blue']
    obsvarnames = ['D', 'Fm']
    ID = pf_data["pfIDs"][0] ## select single ID
    ts = [float(x.attrib["t"]) for x in pf_data["pred_medians"][ID]]
    for i, X in enumerate(obsvarnames):
        ws = [row["t"] for row in dds if row["deaths_cc"] == defn.uncensored_code]
        mask = [False if t in ws else True for t in ts]
        ax = axs[i]
        rans = pf_data["ranges"][ID]
        Xs_ran = [[float(x.find(f"var_vec[@name='{X}']/var").attrib["val"]) for x in ran]
                  for ran in rans]
        Xs_pred = [float(x.find(f"var_vec[@name='{X}']/var").attrib["val"])
                   for x in pf_data["pred_medians"][ID]]
        Xs_filt = [float(x.find(f"var_vec[@name='{X}']/var").attrib["val"])
               for x in pf_data["filter_medians"][ID]]
        evplot.pfilter_boxplot(ax, ts, Xs_ran, Xs_pred, Xs_filt, mask=mask,
                               color=varcolor[i], dt=dt)

        
def plot_CIs(ax, LLss, tmaxs, sigmas, max_diff=10):
    DL = sts.chi2.ppf(0.95,1)/2

    for i, LLs in enumerate(LLss):
        ## compute means
        meanLLs = np.mean(LLs, axis=1)
    
        ## remove very small LLs
        sigs, lls = aux.unzip([(s, l) for s, l in zip(sigmas, meanLLs) if l >= np.max(LLs)-max_diff])
   
        bounds = (sigs[0], sigs[-1])

        cs = UnivariateSpline(sigs, lls, s=10, ext='raise')
        xs = np.linspace(*bounds, 250)

        ## find max of spline and CI
        res = minimize_scalar(lambda x: -cs(x), bounds=bounds, method='bounded')
        max_LL = -res.fun
        sigma_opt = res.x

        ax.plot(cs(xs)-max_LL+tmaxs[i]+DL, xs, label='spline', color='k', linewidth=2)
    
        print(f"s_opt = {sigma_opt:0.2f}")
        print(f"max LL = {max_LL:0.2f}")
    
        try:
            lres = root_scalar(lambda x: cs(x)-max_LL + DL, bracket=[sigs[0], sigma_opt])
            rres = root_scalar(lambda x: cs(x)-max_LL + DL, bracket=[sigma_opt, sigs[-1]])
        
            lCI = lres.root
            rCI = rres.root

            print(f"95% CI = [{lCI:0.2f}, {rCI:0.2f}]")
        except:
            print("unable to compute CI!")
        
        ## plot dots
        ax.scatter(np.array(lls)-max_LL+tmaxs[i]+DL, sigs, color='k', s=5)
        ax.axvline(x=tmaxs[i], color='k', alpha=0.4)

In [None]:
data_markers = ['o', '|']
legend_locs = [1, 2]
data_colors = ['w', 'lightgray']
trajcolor = ["pink", "deepskyblue"]
varcolor = ['purple', 'tab:blue']
varnames = ["D", "Fm"]

date0 = datetime.datetime.strptime("01/01/2020", "%m/%d/%Y")

fig, axs = plt.subplots(3, 1, figsize=(7,10), sharex=True)

plot_data(axs, fdatadicts)
for i, varname in enumerate(varnames):
    plot_trajectories(axs[i], pf_data, varname, date0, color=trajcolor[i])
plot_predictions(axs, pf_data, fdatadicts)

## dates in x-axis

days = [dd["t"] for dd in fdatadicts]
dates = [date0 + datetime.timedelta(days=d) for d in days]
xticks = days[::2] ## every 2 weeks
xticklabels = [d.strftime("%b %d") for d in dates[::2]]

axs[-1].set_xlabel("date")
axs[-1].set_xticks(xticks)
axs[-1].set_xticklabels(xticklabels, fontsize='x-small', rotation=45, ha='right')
#axs[-1].set_ylim(-0.05, 1.05)
   
    
## add legends
leg = axs[0].legend(ncol=1, loc=1, fontsize='x-small')
for lh in leg.legendHandles: 
    lh.set_alpha(1)
    lh.set_linewidth(1)
    
for i, ax in enumerate(axs[:2]):
    ## Legend
    legend_elements = [
        Line2D([0], [0], marker=data_markers[i], color=data_colors[i], label='data', 
               markerfacecolor='k', markeredgecolor='k', markersize=7),
        Line2D([0], [0], color=varcolor[i], label='model'),
    ]
    ax.legend(handles=legend_elements, ncol=1, fontsize='x-small', loc=legend_locs[i])  

    
## profile likelihoods

LLss = proflik_result_dict["LLss"]
tmaxs = proflik_result_dict["tmaxs"]
sigmas = proflik_result_dict["sigmas"]
plot_CIs(axs[2], LLss, tmaxs, sigmas)
    
    
# y-labels
ylabs = ["death incidence", "mutant frequency", "selection ($s$)"]
for ax, ylab in zip(axs, ylabs):
    ax.set_ylabel(ylab)
    
fig.align_ylabels(axs)
    
## add labels
dx_num = -0.2
subplot_labels = "ABC"

for i, ax in enumerate(axs):
    ax.text(dx_num, 1.05, subplot_labels[i], fontsize=24, transform=ax.transAxes)
    
fig.savefig("../data/out/SMCFitTmax.pdf", bbox_inches='tight')