# Profile Likelihood for the SARS-CoV-2 variant model

Compute the profile likelihood with data restricted such that $t \leq t_{\max}$.

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import re
import scipy.stats as sts
import xml.etree.ElementTree as ET
import scipy.stats as sts
import pickle
import glob
import os
from scipy.interpolate import UnivariateSpline
from scipy.optimize import minimize_scalar, root_scalar


import sys, importlib
sys.path.append("..")
from evpytools import evplot
from evpytools import auxiliary as aux
from evpytools import definitions as defn
for mod in [evplot, aux, defn]:
    importlib.reload(mod)

In [None]:
plt.rcParams.update({'font.size' : 18})

In [None]:
dupl = 2 ## number of repeated LL estimates at end of IPF

sigmas = np.linspace(0.0, 0.7, 36)
tmaxs = np.linspace(355, 383, 3)
filenames = [[f"../data/out/ipf_result-sars_model_Netherlands-B.1.351_tmax={tmax:g}_sigma={sigma:g}.xml" 
              for sigma in sigmas] for tmax in tmaxs]

## check that all files exist

for files in filenames:
    for file in files:
        if not os.path.isfile(file):
            print(file)

In [None]:
## extract final loglikes
LLss = []
LLvalidss = []

for files in filenames:
    LLs = []
    LLvalids = []
    for file in files:        
        tree = ET.parse(file)
        root = tree.getroot()
        ## extract IPF steps
        iterf_steps = root.findall("iterated_filtering_step")
        ## get log-like traces
        ll_dicts = [xs.find("log_lik").attrib for xs in iterf_steps]
        ll_vals = [float(d["val"]) for d in ll_dicts]
        ll_valids = [True if d["finite"] == 'true' else False for d in ll_dicts]
        ## get final LL
        final_lls = ll_vals[-dupl:]
        final_lls_valid = ll_valids[-dupl:]
        ## add final LL to list
        LLs.append(final_lls)
        LLvalids.append(final_lls_valid)    
    LLss.append(LLs)
    LLvalidss.append(LLvalids)

In [None]:
## export the LL values to a file
with open("../data/out/prof-lik-tmax-test.tsv", 'w') as f:
    for s, xs, bs in zip(sigmas, LLs, LLvalids):
        for x, b in zip(xs, bs):
            f.write(f"{s}\t{x}\t{b}\n")

## Compute CIs for each tmax

In [None]:
## test figure: inspect raw results

fig, ax = plt.subplots(1, 1, figsize=(7,4))

for i, LLs in enumerate(LLss):
    ## compute medians
    medianLLs = np.median(LLs, axis=1)
    maxLL = np.max(medianLLs)
    ax.scatter(medianLLs-maxLL+tmaxs[i], sigmas, color='k', s=2)
    ax.axvline(x=tmaxs[i]-2, color='k', alpha=0.4)