# Figure for testing

- with SMC fit for single region
- import data from data files used for filtering
- plot SMC diagnostics
- list parameter estimates

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import re
import scipy.stats as sts
import xml.etree.ElementTree as ET
import warnings
import pickle
import copy
import csv
import datetime
import json

from matplotlib.patches import Patch
from matplotlib.lines import Line2D
import matplotlib.ticker as ticker

from mpl_toolkits.axes_grid1.inset_locator import inset_axes


import sys, importlib
sys.path.append("..")
from evpytools import evplot
from evpytools import pftools
from evpytools import auxiliary as aux
from evpytools import definitions as defn
for mod in [evplot, pftools, aux, defn]:
    importlib.reload(mod)

In [None]:
plt.rcParams.update({'font.size': 18})

### Import data files

In [None]:
#fdata_file = "../data/in/sars2-seq-death-week-United_Kingdom-B.1.1.7.tsv"
#fdata_file = "../data/in/sars2-seq-death-week-Netherlands-B.1.1.7.tsv"
#fdata_file = "../data/in/sars2-seq-death-week-Netherlands-B.1.351.tsv"
fdata_file = "../data/in/sars2-seq-death-week-Japan-R.1.tsv"

fdatadicts = pftools.import_filter_data(fdata_file)

### Import filter results

In [None]:
pfout_file = "../data/out/ipf_result-sars_model.xml"
#pfout_file = "../data/out/ipf_result-sars_model_Netherlands_B.1.351.xml"
#pfout_file = "../data/out/ipf_result-sars_model_United_Kingdom_B.1.1.7.xml"
#pfout_file = "../data/out/ipf_result-sars_model_Netherlands_B.1.1.7.xml"
#pfout_file = "../data/out/ipf_result-sars_model_Japan_R.1.xml"

pf_data = pftools.extract_pfilter_data(pfout_file)

### Create the figure

In [None]:
def plot_data(axs, dds):
    # deaths
    ax = axs[1]
    ws = [row["t"] for row in dds if row["deaths_cc"] == defn.uncensored_code]
    Ds = [row["deaths"] for row in dds if row["deaths_cc"] == defn.uncensored_code]
    ax.scatter(ws, Ds, color='k', edgecolor='k', zorder=4, label='data', s=20)    
    # mutant freq
    ax = axs[2]
    ts = [row["t"] for row in dds if row["Ntot"] > 0]
    Fms = [row["Nmut"] / row["Ntot"] for row in dds if row["Ntot"] > 0]
    ## CIs for mutant frequency
    lFms = [sts.beta.ppf(0.025, row["Nmut"]+0.5, row["Ntot"] - row["Nmut"]+0.5) 
            for row in dds if row["Ntot"] > 0]
    uFms = [sts.beta.ppf(0.975, row["Nmut"]+0.5, row["Ntot"] - row["Nmut"]+0.5) 
            for row in dds if row["Ntot"] > 0]
    for t, l, u in zip(ts, lFms, uFms):
        ax.plot([t,t], [l,u], color='k', alpha=0.3)
    ax.scatter(ts, Fms, color='k', edgecolor='k', zorder=4, label='data', 
               s=20,  marker='_')

    
def plot_trajectories(axs, pf_data, date0, xlim_inset=None, ylim_inset=None,
                      color_wt="tab:orange", color_mut="tab:blue"):
    ID = pf_data["pfIDs"][0] ## select single ID
    ## latent paths
    varnames = ["Iw", "Im"]
    pretty_varnames = ["$I_{\\rm wt}$", "$I_{\\rm mt}$"]
    trajcolors = [color_wt, color_mut]
    alpha_traj = 0.7
    ## trajectories for model predictions
    obsvarnames = ['D', 'Fm']
    trajcolor = ["pink", "deepskyblue"]

    ax = axs[0]
    if xlim_inset is not None:
        axins = inset_axes(ax, width="15%", height="35%", loc=1)
    else:
        axins = None
    for j, path in enumerate(pf_data["paths"][ID]):
        ## extract timeseries
        xs = path.findall("state")
        ts = [float(x.attrib["t"]) for x in xs]
        for color, X, lab in zip(trajcolors, varnames, pretty_varnames):
            Xs = [float(x.find(f"var_vec[@name='{X}']/var").attrib["val"]) for x in xs]
            ## plot
            kwargs = {"label" : lab} if j == 0 else {}
            ax.plot(ts, Xs, color=color, alpha=alpha_traj, linewidth=0.5, zorder=1, **kwargs)
            if xlim_inset is not None:
                axins.plot(ts, Xs, color=color, alpha=alpha_traj, linewidth=0.5, zorder=1, **kwargs)
                ## restricy limits of axins
                axins.set_xlim(*xlim_inset)
                if ylim_inset is not None:
                    axins.set_ylim(*ylim_inset)
                #axins.yaxis.set_label_position("right")
                #axins.yaxis.tick_right()
                axins.tick_params(axis='both', which='major', labelsize='xx-small')
                ## dates as xticklabels
                xmin, xmax = xlim_inset
                xticks = range(xmin+1, xmax, 4)
                xtickdates = [date0 + datetime.timedelta(days=x) for x in xticks]
                xticklabels = [d.strftime("%b %d") for d in xtickdates]
                axins.set_xticks(xticks)
                axins.set_xticklabels(xticklabels, rotation=45, ha='right')
    
    ## re-format ticklabels for population sizes
    ax.yaxis.set_major_formatter(ticker.FuncFormatter(evplot.y_fmt))
    ax.tick_params(axis="y", labelsize=12)
    
    ## model predictions of the data
    for path in pf_data["paths"][ID]:
        for i, X in enumerate(obsvarnames):
            ## extract timeseries
            xs = path.findall("state")
            ts = [float(x.attrib["t"]) for x in xs]
            Xs = [float(x.find(f"var_vec[@name='{X}']/var").attrib["val"]) for x in xs]
            ## plot
            ax = axs[i+1]
            ax.plot(ts, Xs, color=trajcolor[i], alpha=alpha_traj, linewidth=0.5, zorder=1)
    return axins

            
def plot_predictions(axs, pf_data, dds):
    dt = 1
    varcolor = ['purple', 'tab:blue']
    obsvarnames = ['D', 'Fm']
    ID = pf_data["pfIDs"][0] ## select single ID
    ts = [float(x.attrib["t"]) for x in pf_data["pred_medians"][ID]]
    for i, X in enumerate(obsvarnames):
        ws = [row["t"] for row in dds if row["deaths_cc"] == defn.uncensored_code]
        mask = [False if t in ws else True for t in ts]
        ax = axs[i+1]
        rans = pf_data["ranges"][ID]
        Xs_ran = [[float(x.find(f"var_vec[@name='{X}']/var").attrib["val"]) for x in ran]
                  for ran in rans]
        Xs_pred = [float(x.find(f"var_vec[@name='{X}']/var").attrib["val"])
                   for x in pf_data["pred_medians"][ID]]
        Xs_filt = [float(x.find(f"var_vec[@name='{X}']/var").attrib["val"])
               for x in pf_data["filter_medians"][ID]]
        evplot.pfilter_boxplot(ax, ts, Xs_ran, Xs_pred, Xs_filt, mask=mask,
                               color=varcolor[i], dt=dt)


In [None]:
data_markers = ['o', '|']
legend_locs = [1, 2]
data_colors = ['w', 'lightgray']
trajcolor = ["pink", "deepskyblue"]
varcolor = ['purple', 'tab:blue']

date0 = datetime.datetime.strptime("01/01/2020", "%m/%d/%Y")

fig, axs = plt.subplots(3,1, figsize=(7,10), sharex=True)

plot_data(axs, fdatadicts)
plot_trajectories(axs, pf_data, date0, color_wt='tab:blue', color_mut='tab:green')
plot_predictions(axs, pf_data, fdatadicts)

## dates in x-axis

days = [dd["t"] for dd in fdatadicts]
dates = [date0 + datetime.timedelta(days=d) for d in days]
xticks = days[::2] ## every 2 weeks
xticklabels = [d.strftime("%b %d") for d in dates[::2]]

axs[-1].set_xlabel("date")
axs[-1].set_xticks(xticks)
axs[-1].set_xticklabels(xticklabels, fontsize='x-small', rotation=45, ha='right')
#axs[-1].set_ylim(-0.05, 1.05)
   
    
## add legends
leg = axs[0].legend(ncol=1, loc=1, fontsize='x-small')
for lh in leg.legendHandles: 
    lh.set_alpha(1)
    lh.set_linewidth(1)
    
for i, ax in enumerate(axs[1:]):
    ## Legend
    legend_elements = [
        Line2D([0], [0], marker=data_markers[i], color=data_colors[i], label='data', 
               markerfacecolor='k', markeredgecolor='k', markersize=7),
        Line2D([0], [0], color=varcolor[i], label='model'),
    ]
    ax.legend(handles=legend_elements, ncol=1, fontsize='x-small', loc=legend_locs[i])  

# y-labels
ylabs = ["population size", "death incidence", "mutant frequency"]
for ax, ylab in zip(axs, ylabs):
    ax.set_ylabel(ylab)
    
fig.align_ylabels(axs)
    
## add labels
dx_num = -0.15
subplot_labels = "ABC"

for i, ax in enumerate(axs):
    ax.text(dx_num, 1.05, subplot_labels[i], fontsize=24, transform=ax.transAxes)
    
fig.savefig("../data/out/SMCFit.pdf", bbox_inches='tight')

## Diagnosis and parameter estimates

In [None]:
parnames = pf_data["parnames"]
fit_dicts = {}  

for r, ID in enumerate(pf_data["pfIDs"]):
    ## make a dictionary
    fit_dicts[ID] = {
        "params" : pftools.makeParDict(r, parnames, pf_data["param_medians"], 
                                       pf_data["param_ranges"])
        ## TODO: other elements?
    }
    
## TODO: extract hyper parameters

In [None]:
stat_dicts = {}

for r, ID in enumerate(pf_data["pfIDs"]):
    ts = [float(pf.attrib["t"]) for pf in pf_data["particle_filters"][ID]]
    Jeffs = [int(pf.attrib["J_eff"]) for pf in pf_data["particle_filters"][ID]]
    Jcofs = [int(pf.attrib["J_coffin"]) for pf in pf_data["particle_filters"][ID]]
    Jsims = [float(pf.attrib["J_inv_simpson"]) for pf in pf_data["particle_filters"][ID]]
    cLLs = [float(pf.attrib["cond_LL_hat"]) for pf in pf_data["particle_filters"][ID]]
    stat_dicts[ID] = {
        "ts" : ts,
        "Jeffs" : Jeffs,
        "Jsims" : Jsims,
        "Jcofs" : Jcofs,
        "cLLs" : cLLs,
    }

In [None]:
## plot PF statistics

R = len(pf_data["pfIDs"])

fig, axs = plt.subplots(R, 1, figsize=(10,3*R))

if R == 1:
    axs = np.array([axs])

bxs = [ax.twinx() for ax in axs]

for i, ID in enumerate(pf_data["pfIDs"]):
    stat_dict = stat_dicts[ID]
    ax = axs[i]
    ax.plot(stat_dict["ts"], stat_dict["cLLs"], color='k')
    ## plot PF statistics
    bx = bxs[i]
    bx.plot(stat_dict["ts"], stat_dict["Jeffs"], color='tab:red', alpha=0.7)
    #bx.plot(stat_dict["ts"], stat_dict["Jsims"], color='tab:purple')
    ax.set_ylabel("conditional\nlog-likelihood")
    bx.set_ylabel("effective\nswarm size", color='red')
    
## share axes after the fact
bxs[0].get_shared_y_axes().join(*bxs)
bxs[0].autoscale(axis='y')

In [None]:
sel_parnames = parnames

fig, axs = plt.subplots(len(sel_parnames)+1, 1, 
                        figsize=(14,2*len(sel_parnames)), 
                        sharex=True)

## FIXME: plot a single run for the evolution of a parameter

for ID in pf_data["pfIDs"]:
    paramss = fit_dicts[ID]["params"]
    for i, pn in enumerate(sel_parnames):
        ax = axs[i]
        meds = paramss[pn]["median"]
        rans = paramss[pn]["range"]
        ms = range(len(meds))
        evplot.range_plots(ax, ms, *aux.unzip(rans), dt=0.1, zorder=1)
        ax.scatter(ms, meds, color='r', marker='_', zorder=2) 
        ax.set_ylabel(pn)
    
    
## plot log-likelihood

ll_dicts = [xs.find("log_lik").attrib for xs in pf_data["iterf_steps"]]
ll_vals = [float(d["val"]) for d in ll_dicts]
ll_colors = ['k' if d["finite"] == 'true' else 'red' for d in ll_dicts]
ms = range(len(ll_dicts))
axs[-1].scatter(ms, ll_vals, color=ll_colors)
axs[-1].set_ylabel("LL")
axs[-1].set_xlabel("iteration")
#axs[-1].set_ylim(-500, -300)

from scipy.interpolate import UnivariateSpline

if len(ms) > 2:
    lb = -np.infty
    ## filter out some mistakes
    fms = [m for ll, m in zip(ll_vals, ms) if ll > lb]
    flls = [ll for ll, m in zip(ll_vals, ms) if ll > lb]
    cs = UnivariateSpline(fms, flls, s=1e4)
    xs = np.linspace(ms[0], ms[-1], 1000)
    axs[-1].plot(xs, cs(xs), label='spline', color='red', linewidth=2)

print("final LL:", ll_vals[-1])

In [None]:
## Print some parameter estimates...

par_ests = {
    ID: {
        k : v["median"][-1]
        for k, v in fit_dicts[ID]["params"].items()
    } 
    for ID in pf_data["pfIDs"]
}

for ID in par_ests.keys():
    for k, v in par_ests[ID].items():
        print(f"{ID} -> {k}: {v:0.3g}")
    print('-------')