# Profile Likelihood for the SARS-CoV-2 variant model

Compute the profile likelihood with data restricted such that $t \leq t_{\max}$.

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import re
import scipy.stats as sts
import xml.etree.ElementTree as ET
import pickle
import glob
import os
from scipy.interpolate import UnivariateSpline
from scipy.optimize import minimize_scalar, root_scalar
import json


import sys, importlib
sys.path.append("..")
from evpytools import evplot
from evpytools import auxiliary as aux
from evpytools import definitions as defn
for mod in [evplot, aux, defn]:
    importlib.reload(mod)

In [None]:
plt.rcParams.update({'font.size' : 18})

In [None]:
dupl = 5 ## number of repeated LL estimates at end of IPF

sigmas = np.linspace(-0.1, 0.8, 19)

#tmaxs = np.linspace(355, 439, 7)
#filename_base = "../data/out/ipf_result-sars_model_Netherlands-B.1.351"

tmaxs = np.linspace(278, 362, 7)
filename_base = "../data/out/ipf_result-sars_model_United_Kingdom-B.1.1.7"

#tmaxs = np.linspace(326, 438, 9)
#filename_base = "../data/out/ipf_result-sars_model_Netherlands-B.1.1.7"


filenames = [[f"{filename_base}_tmax={tmax:g}_sigma={sigma:g}.xml" 
              for sigma in sigmas] for tmax in tmaxs]

## check that all files exist

for files in filenames:
    for file in files:
        if not os.path.isfile(file):
            print(file)

In [None]:
## extract final loglikes
LLss = []
LLvalidss = []

for files in filenames:
    LLs = []
    LLvalids = []
    for file in files:        
        tree = ET.parse(file)
        root = tree.getroot()
        ## extract IPF steps
        iterf_steps = root.findall("iterated_filtering_step")
        ## get log-like traces
        ll_dicts = [xs.find("log_lik").attrib for xs in iterf_steps]
        ll_vals = [float(d["val"]) for d in ll_dicts]
        ll_valids = [True if d["finite"] == 'true' else False for d in ll_dicts]
        ## get final LL
        final_lls = ll_vals[-dupl:]
        final_lls_valid = ll_valids[-dupl:]
        ## add final LL to list
        LLs.append(final_lls)
        LLvalids.append(final_lls_valid)    
    LLss.append(LLs)
    LLvalidss.append(LLvalids)

In [None]:
## export result to a json file for use in other data notebook...
result_dict = {
    "LLss" : LLss,
    "tmaxs" : list(tmaxs),
    "sigmas" : list(sigmas)
}


#output_file = "../data/out/profile-lik-tmax-results_Netherlands_B.1.351.json"
output_file = "../data/out/profile-lik-tmax-results_United_Kingdom-B.1.1.7.json"
#output_file = "../data/out/profile-lik-tmax-results_Netherlands-B.1.1.7.json"

with open(output_file, 'w') as f:
    json.dump(result_dict, f)

## Compute CIs for each tmax

In [None]:
## test figure: inspect raw results

fig, ax = plt.subplots(1, 1, figsize=(7,4), sharex=True)


## compute and plot plot CIs

max_diff = 10
DL = sts.chi2.ppf(0.95,1)/2

for i, LLs in enumerate(LLss):
    ## compute means
    meanLLs = np.mean(LLs, axis=1)
    
    ## remove very small LLs
    sigs, lls = aux.unzip([(s, l) for s, l in zip(sigmas, meanLLs) 
                           if l >= np.max(LLs)-max_diff])
   
    bounds = (sigs[0], sigs[-1])

    cs = UnivariateSpline(sigs, lls, s=10, ext='raise')
    xs = np.linspace(*bounds, 250)

    ## find max of spline and CI
    res = minimize_scalar(lambda x: -cs(x), bounds=bounds, method='bounded')
    max_LL = -res.fun
    sigma_opt = res.x

    ax.plot(cs(xs)-max_LL+tmaxs[i]+DL, xs, label='spline', color='k', linewidth=2)
    
    print(f"s_opt = {sigma_opt:0.2f}")
    print(f"max LL = {max_LL:0.2f}")
    
    try:
        lres = root_scalar(lambda x: cs(x)-max_LL + DL, bracket=[sigs[0], sigma_opt])
        rres = root_scalar(lambda x: cs(x)-max_LL + DL, bracket=[sigma_opt, sigs[-1]])
        
        lCI = lres.root
        rCI = rres.root

        print(f"95% CI = [{lCI:0.2f}, {rCI:0.2f}]")

        #ax.axhspan(lCI, rCI, color='k', alpha=0.2, linewidth=0)
    except:
        print("unable to compute CI!")
        
    ## plot dots
    ax.scatter(np.array(lls)-max_LL+tmaxs[i]+DL, sigs, color='k', s=5)
    ax.axvline(x=tmaxs[i], color='k', alpha=0.4)

    
ax.set_ylabel("$s$")
ax.set_xlabel("max time (days)")

fig.savefig("../data/prof-lik-maxt.pdf")