# Figure for manuscript

- with SMC fit for 2 regions
- import data from data files used for filtering

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import re
import scipy.stats as sts
import xml.etree.ElementTree as ET
import warnings
import pickle
import copy
import csv
import datetime
import json

from matplotlib.patches import Patch
from matplotlib.lines import Line2D
import matplotlib.ticker as ticker

from mpl_toolkits.axes_grid1.inset_locator import inset_axes


import sys, importlib
sys.path.append("..")
from evpytools import evplot
from evpytools import pftools
from evpytools import auxiliary as aux
from evpytools import definitions as defn
for mod in [evplot, pftools, aux, defn]:
    importlib.reload(mod)

In [None]:
plt.rcParams.update({'font.size': 18})

### Import data files

In [None]:
def import_filter_data(filename):
    with open(filename) as f:
        table = [row.split() for row in f.read().split('\n') if row != '']
    ## build list of dicts
    data_dicts = [
        {
            "region" : row[0],
            "t" : int(row[1]),
            "event" : row[2],
            "deaths" : int(row[5]),
            "deaths_cc" : int(row[6]),
            "Nmut" : int(row[7]),
            "Ntot" : int(row[9]),
        } 
        for row in table
    ]
    return data_dicts

In [None]:
## D614G
#fdata_file_UK = "../data/in/sars2-seq-death-week-United_Kingdom-614.tsv"
#fdata_file_NL = "../data/in/sars2-seq-death-week-Netherlands-614.tsv"

## N501Y
fdata_file_UK = "../data/in/sars2-seq-death-week-United_Kingdom-501.tsv"
fdata_file_NL = "../data/in/sars2-seq-death-week-Netherlands-501.tsv"



fdatadicts_UK = import_filter_data(fdata_file_UK)
fdatadicts_NL = import_filter_data(fdata_file_NL)

### Import filter results

In [None]:
## D614G
#pfout_file_UK = "../data/out/wk-seq/UK/D614G/ipf_result-sars_model_UK-614-wk.xml"
#pfout_file_NL = "../data/out/wk-seq/NL/D614G/ipf_result-sars_model_NL-614-wk.xml"

## N501Y
pfout_file_UK = "../data/out/wk-seq/UK/B117/ipf_result-sars_model_UK-501-long-wk.xml"
pfout_file_NL = "../data/out/wk-seq/NL/B117/ipf_result-sars_model_NL-501-long-wk.xml"


idx = -1 ## select one of the PF iterations

pf_data_UK = pftools.extract_pfilter_data(pfout_file_UK)
pf_data_NL = pftools.extract_pfilter_data(pfout_file_NL)

### Create the figure

In [None]:
def plot_data(axs, dds):
    # deaths
    ax = axs[1]
    ws = [row["t"] for row in dds if row["deaths_cc"] == defn.uncensored_code]
    Ds = [row["deaths"] for row in dds if row["deaths_cc"] == defn.uncensored_code]
    ax.scatter(ws, Ds, color='k', edgecolor='k', zorder=4, label='data', s=20)    
    # mutant freq
    ax = axs[2]
    ts = [row["t"] for row in dds if row["Ntot"] > 0]
    Fms = [row["Nmut"] / n for row in dds if (n := row["Ntot"]) > 0]
    ## CIs for mutant frequency
    lFms = [sts.beta.ppf(0.025, row["Nmut"]+0.5, n - row["Nmut"]+0.5) 
            for row in dds if (n := row["Ntot"]) > 0]
    uFms = [sts.beta.ppf(0.975, row["Nmut"]+0.5, n - row["Nmut"]+0.5) 
            for row in dds if (n := row["Ntot"]) > 0]
    for t, l, u in zip(ts, lFms, uFms):
        ax.plot([t,t], [l,u], color='k', alpha=0.3)
    ax.scatter(ts, Fms, color='k', edgecolor='k', zorder=4, label='data', 
               s=20,  marker='_')

    
def plot_trajectories(axs, pf_data, date0, xlim_inset=None, ylim_inset=None,
                      color_wt="tab:orange", color_mut="tab:blue"):
    ID = pf_data["pfIDs"][0] ## select single ID
    ## latent paths
    varnames = ["Iw", "Im"]
    pretty_varnames = ["$I_{\\rm wt}$", "$I_{\\rm mt}$"]
    trajcolors = [color_wt, color_mut]
    alpha_traj = 0.7
    ## trajectories for model predictions
    obsvarnames = ['D', 'Fm']
    trajcolor = ["pink", "deepskyblue"]

    ax = axs[0]
    if xlim_inset is not None:
        axins = inset_axes(ax, width="15%", height="35%", loc=1)
    else:
        axins = None
    for j, path in enumerate(pf_data["paths"][ID]):
        ## extract timeseries
        xs = path.findall("state")
        ts = [float(x.attrib["t"]) for x in xs]
        for color, X, lab in zip(trajcolors, varnames, pretty_varnames):
            Xs = [float(x.find(f"var_vec[@name='{X}']/var").attrib["val"]) for x in xs]
            ## plot
            kwargs = {"label" : lab} if j == 0 else {}
            ax.plot(ts, Xs, color=color, alpha=alpha_traj, linewidth=0.5, zorder=1, **kwargs)
            if xlim_inset is not None:
                axins.plot(ts, Xs, color=color, alpha=alpha_traj, linewidth=0.5, zorder=1, **kwargs)
                ## restricy limits of axins
                axins.set_xlim(*xlim_inset)
                if ylim_inset is not None:
                    axins.set_ylim(*ylim_inset)
                #axins.yaxis.set_label_position("right")
                #axins.yaxis.tick_right()
                axins.tick_params(axis='both', which='major', labelsize='xx-small')
                ## dates as xticklabels
                xmin, xmax = xlim_inset
                xticks = range(xmin+1, xmax, 4)
                xtickdates = [date0 + datetime.timedelta(days=x) for x in xticks]
                xticklabels = [d.strftime("%b %d") for d in xtickdates]
                axins.set_xticks(xticks)
                axins.set_xticklabels(xticklabels, rotation=45, ha='right')
    
    ## re-format ticklabels for population sizes
    ax.yaxis.set_major_formatter(ticker.FuncFormatter(evplot.y_fmt))
    ax.tick_params(axis="y", labelsize=12)
    
    ## model predictions of the data
    for path in pf_data["paths"][ID]:
        for i, X in enumerate(obsvarnames):
            ## extract timeseries
            xs = path.findall("state")
            ts = [float(x.attrib["t"]) for x in xs]
            Xs = [float(x.find(f"var_vec[@name='{X}']/var").attrib["val"]) for x in xs]
            ## plot
            ax = axs[i+1]
            ax.plot(ts, Xs, color=trajcolor[i], alpha=alpha_traj, linewidth=0.5, zorder=1)
    return axins

            
def plot_predictions(axs, pf_data, dds):
    dt = 1
    varcolor = ['purple', 'tab:blue']
    obsvarnames = ['D', 'Fm']
    ID = pf_data["pfIDs"][0] ## select single ID
    ts = [float(x.attrib["t"]) for x in pf_data["pred_medians"][ID]]
    for i, X in enumerate(obsvarnames):
        ws = [row["t"] for row in dds if row["deaths_cc"] == defn.uncensored_code]
        mask = [False if t in ws else True for t in ts]
        ax = axs[i+1]
        rans = pf_data["ranges"][ID]
        Xs_ran = [[float(x.find(f"var_vec[@name='{X}']/var").attrib["val"]) for x in ran]
                  for ran in rans]
        Xs_pred = [float(x.find(f"var_vec[@name='{X}']/var").attrib["val"])
                   for x in pf_data["pred_medians"][ID]]
        Xs_filt = [float(x.find(f"var_vec[@name='{X}']/var").attrib["val"])
               for x in pf_data["filter_medians"][ID]]
        evplot.pfilter_boxplot(ax, ts, Xs_ran, Xs_pred, Xs_filt, mask=mask,
                               color=varcolor[i], dt=dt)


In [None]:
data_markers = ['o', '|']
## D614G
#legend_locs = [1, 4]
## N501Y
legend_locs = [2, 2]
data_colors = ['w', 'lightgray']
trajcolor = ["pink", "deepskyblue"]
varcolor = ['purple', 'tab:blue']

xlim_inset_UK = (65,75)
xlim_inset_NL = (55,65)

date0 = datetime.datetime.strptime("01/01/2020", "%m/%d/%Y")

fig, axs = plt.subplots(3,2, figsize=(14,10), sharex=True)

## UK
plot_data(axs[:,0], fdatadicts_UK)
## D614G: add inset
#axins_UK = plot_trajectories(axs[:,0], pf_data_UK, date0, xlim_inset=xlim_inset_UK, ylim_inset=(0,20000))
## N501Y: no inset
axins_UK = plot_trajectories(axs[:,0], pf_data_UK, date0, color_wt='tab:blue', color_mut='tab:green')
plot_predictions(axs[:,0], pf_data_UK, fdatadicts_UK)

## NL
plot_data(axs[:,1], fdatadicts_NL)
## D614G: add inset
#axins_NL = plot_trajectories(axs[:,1], pf_data_NL, date0, xlim_inset=xlim_inset_NL, ylim_inset=(0,1000))
## N501Y: no inset
axins_NL = plot_trajectories(axs[:,1], pf_data_NL, date0, color_wt='tab:blue', color_mut='tab:green')
plot_predictions(axs[:,1], pf_data_NL, fdatadicts_NL)

axs[0,0].set_title("United Kingdom")
axs[0,1].set_title("Netherlands")

## dates in x-axis

days = [dd["t"] for dd in fdatadicts_UK]
dates = [date0 + datetime.timedelta(days=d) for d in days]
xticks = days[::2] ## every 2 weeks
xticklabels = [d.strftime("%b %d") for d in dates[::2]]

for i in range(2):
    axs[-1,i].set_xlabel("date")
    axs[-1,i].set_xticks(xticks)
    axs[-1,i].set_xticklabels(xticklabels, fontsize='x-small', rotation=45, ha='right')
    axs[-1,i].set_ylim(-0.05, 1.05)

    
    
## add legends
leg = axs[0,0].legend(ncol=1, loc=2, fontsize='x-small')
for lh in leg.legendHandles: 
    lh.set_alpha(1)
    lh.set_linewidth(1)
    
for i, ax in enumerate(axs[1:,0]):
    ## Legend
    legend_elements = [
        Line2D([0], [0], marker=data_markers[i], color=data_colors[i], label='data', 
               markerfacecolor='k', markeredgecolor='k', markersize=7),
        Line2D([0], [0], color=varcolor[i], label='model'),
    ]
    ax.legend(handles=legend_elements, ncol=1, fontsize='x-small', loc=legend_locs[i])  

# y-labels
ylabs = ["population size", "death incidence", "mutant frequency"]
for ax, ylab in zip(axs[:,0], ylabs):
    ax.set_ylabel(ylab)
    
fig.align_ylabels(axs[:,0])
    
## add labels
dx_num = -0.15
subplot_labels = "ABCDEF"

for i, ax in enumerate(axs.flatten()):
    ax.text(dx_num, 1.05, subplot_labels[i], fontsize=24, transform=ax.transAxes)

    
fig.savefig("../data/out/figures/Fig2RegionsFit.pdf", bbox_inches='tight')