In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
plotwidth=40

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
import yaml
from tqdm.notebook import tqdm, trange
import time
import os
import sys

from lollipop import *


# Globals

In [None]:
# Source of inspiration from covariatns, see:
# https://github.com/hodcroftlab/covariants/blob/master/web/data/clusters.json
#
# Keep in sync with covspectrum, see:
# https://github.com/cevo-public/cov-spectrum-website/blob/develop/src/models/wasteWater/constants.ts
color_map = {
  'B.1.1.7': '#D16666',
  'B.1.351': '#FF6665',
  'P.1': '#FFB3B3',
  'B.1.617.1': '#66C265',
  'B.1.617.2': '#66A366',
  'BA.1': '#A366A3',
  'BA.2': '#cfafcf',
  'BA.4': '#8a66ff',
  'BA.5': '#585eff',
  'BA.2.12.1': '#0400e0',
  'BA.2.75': '#008fe0',
  'BA.2.75.2': '#208fe0', # improv
  'BQ.1.1': '#8fe000', # improv
  'undetermined': '#969696',
}

In [None]:
# Load data needed by Lollipop
# temporary, globals

# zst needs python's Zstandard 
tally_data = "./work-vp-test/variants/tallymut.tsv.zst" # "./tallymut_line.tsv"
vpipe_deconv_data = "./work-vp-test/variants/deconvoluted.tsv.zst"
vpipe_kdec_config = "work-vp-test/deconv_linear_logit_quasi_strat.yaml" # "work-vp-test/deconv_linear_wald.yaml"
out_dir = (
    "./out"
)
plots_dir = out_dir
dump_csv = "linear_deconv_logit_quasi_strat.csv.zst"


# var dates
with open("work-vp-test/var_dates.yaml", "r") as file:
    var_dates = yaml.load(file, Loader=yaml.SafeLoader)


# load variants configuations
with open("work-vp-test/variant_config.yaml", "r") as file:
    conf_yaml = yaml.load(file, Loader=yaml.SafeLoader)
variants_list = conf_yaml["variants_list"]
variants_pangolin = conf_yaml["variants_pangolin"]
variants_not_reported = conf_yaml["variants_not_reported"]
start_date = conf_yaml["start_date"]
end_date = conf_yaml.get("end_date") # optionnal, usually absent in ongoing surveillance, and present in articles with subset of historical data

to_drop = conf_yaml["to_drop"]
cities_list = conf_yaml["locations_list"]

# display the current config
conf_yaml

In [None]:
### Outputs
plots_dir='deconv_plots'
if not os.path.isdir(plots_dir):
    try:
        os.mkdir(plots_dir, mode=0o775)
    except FileExistsError:
        pass
update_data_lolli_file = os.path.join('.', 'ww_update_data_lollipop.json')

# Load and preprocess data

In [None]:
df_tally = pd.read_csv(tally_data, sep="\t", parse_dates = [ "date" ], dtype={"location_code": "str"})#.drop(columns=['proto'])
# df_tally = df_tally[
#     df_tally.

preproc = DataPreprocesser(df_tally)
preproc = preproc.general_preprocess(
    variants_list=variants_list,
    variants_pangolin=variants_pangolin,
    variants_not_reported=variants_not_reported,
    to_drop=to_drop,
    start_date=start_date,
    end_date=end_date,
    remove_deletions=True,
)
preproc = preproc.filter_mutations()
preproc.df_tally.head()

In [None]:
preproc.df_tally["sample"].unique().shape

## Have a look at the design matrix / variants list

In [None]:
des_matrix = preproc.df_tally[variants_list + ["undetermined", "mutations"]].drop_duplicates("mutations").set_index("mutations")
des_matrix_mut = des_matrix[~des_matrix.index.str.startswith("-")]
des_matrix_wt = des_matrix[des_matrix.index.str.startswith("-")]


In [None]:
plt.figure(figsize=(plotwidth,plotwidth/4)) # (20,5))
sns.heatmap(des_matrix.T, square=False, cmap="viridis")

In [None]:
fig, axes = plt.subplots(1,3, figsize=(plotwidth,plotwidth/4)) # figsize=(20,5))

common_mut = des_matrix_mut.T.dot(des_matrix_mut)
sns.heatmap(common_mut, square=True, cmap="viridis", annot=common_mut, ax=axes[0])
axes[0].set_title("common mutations")

corr_mut = (des_matrix_mut).corr()
sns.heatmap(corr_mut, square=True, cmap="viridis", annot=corr_mut, ax=axes[1], fmt=".1g")
axes[1].set_title("correlation")

from sklearn.metrics.pairwise import pairwise_distances
jac_sim = 1 - pairwise_distances(des_matrix_mut.T, metric = "hamming")
jac_sim = pd.DataFrame(jac_sim, index=des_matrix_mut.columns, columns=des_matrix_mut.columns)
sns.heatmap(jac_sim, square=True, cmap="viridis", annot=jac_sim, ax=axes[2])
axes[2].set_title("jaccard similarity ((A∩B)/(A∪B))")

fig.show()

# Multiple-choice time:

## CHOICE 1: Load V-pipe's deconvolution 

Use this choice to load the TSV dump of the deconvolution done by running Lollipop inside V-pipe

In [None]:
print("reusing %s last modified: %s" % (vpipe_deconv_data, time.ctime(os.path.getmtime(vpipe_deconv_data))))

In [None]:
linear_deconv2_quasi2_df_flat = pd.read_csv(vpipe_deconv_data, sep="\t", parse_dates = [ 0 ], index_col=0)
linear_deconv2_quasi2_df_flat.index.names = ["date"]
linear_deconv2_quasi2_df_flat

In [None]:
print("vpipe's %s last modified: %s" % (vpipe_kdec_config, time.ctime(os.path.getmtime(vpipe_kdec_config))))

In [None]:
with open(vpipe_kdec_config, "r") as file:
    kdconf_yaml = yaml.load(file, Loader=yaml.SafeLoader)
kdconf_yaml

## CHOICE 2: Run deconvolution

Use this choise if something has gone wrong with V-pipe and you need to tweak the model.
Once you're happy with your fixes: Do not forget to update the kernel deconvultion configuration used by V-pipe!

### Logit parametrised confidence intervals with quasilikelihood, stratified

In [None]:
# build the intervals pairs
d = list(var_dates["var_dates"].keys())
date_intervals = list(zip(d, d[1:] + [None]))
for mindate, maxdate in date_intervals:
    if maxdate:
        assert (
            mindate < maxdate
        ), f"out of order dates: {mindate} >= {maxdate}. Please fix the content of {variants_date}"
        print(f"from {mindate} to {maxdate}: {var_dates['var_dates'][mindate]}")
    else:
        print(f"from {mindate} onward: {var_dates['var_dates'][mindate]}")


In [None]:
%%time

np.random.seed(42)
linear_deconv2_quasi2 = []

for city in tqdm(cities_list):
    #print(city)
    temp_dfb = preproc.df_tally[preproc.df_tally["location"] == city]

    for mindate, maxdate in tqdm(date_intervals, desc = city):
        if maxdate is not None:
            temp_df2 = temp_dfb[
                #temp_dfb.date.between(mindate, maxdate, inclusive="left")
                (temp_dfb.date >= mindate) & (temp_dfb.date < maxdate)
            ]
        else:
            temp_df2 = temp_dfb[temp_dfb.date >= mindate]
        if temp_df2.size == 0:
            continue

        t_kdec = KernelDeconv(
            temp_df2[var_dates['var_dates'][mindate] + ["undetermined"]],
            temp_df2["frac"],
            temp_df2["date"],
    #         weights=temp_df2["resample_value"],
            kernel=GaussianKernel(30),
            reg=NnlsReg(),
            confint=WaldConfint(
                scale="logit",
                pseudofrac=0.01,
                quasi=True,
                method="strat"
            )
        )
        t_kdec = t_kdec.deconv_all(min_tol=1e-3)

        # save results
        res = t_kdec.fitted.copy()
        res["location"] = city
        res["estimate"] = "MSE"
        linear_deconv2_quasi2.append(res)

        res_lower = t_kdec.conf_bands["lower"].copy()
        res_lower["location"] = city
        res_lower["estimate"] = "Wald_lower"
        linear_deconv2_quasi2.append(res_lower)

        res_upper = t_kdec.conf_bands["upper"].copy()
        res_upper["location"] = city
        res_upper["estimate"] = "Wald_upper"
        linear_deconv2_quasi2.append(res_upper)
    

linear_deconv2_quasi2_df = pd.concat(linear_deconv2_quasi2).sort
# linear_deconv_df = linear_deconv2_df.fillna(0)

found_var = list(set(variants_list) & set(linear_deconv2_quasi2_df.columns))


linear_deconv2_quasi2_df_flat = linear_deconv2_quasi2_df.melt(
    id_vars=["location", "estimate"],
    value_vars=found_var + ["undetermined"],
    var_name="variant",
    value_name="frac",
    ignore_index=False,
)
linear_deconv2_quasi2_df_flat.index.names = ["date"]
linear_deconv2_quasi2_df_flat.head()


# backup data
linear_deconv2_quasi2_df_flat.to_csv(dump_csv, index_label="date")
deconv_backup_df = linear_deconv2_quasi2_df_flat

In [None]:
linear_deconv2_quasi2_df_flat

In [None]:
linear_deconv2_quasi2_df_flat.dtypes

### reload the dump

In [None]:
print("dump %s last modified: %s" % (dump_csv, time.ctime(os.path.getmtime(dump_csv))))

In [None]:
linear_deconv2_quasi2_df_flat = pd.read_csv(dump_csv, index_col = 0)
linear_deconv2_quasi2_df_flat.index.names = ["date"]
linear_deconv2_quasi2_df_flat

# End of multiple choices

Now move onto plots, etc.

## Plots

In [None]:
fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(plotwidth, plotwidth/2), sharex=True)
axes = axes.flatten()

def logit_inv(x):
    return np.exp(x)/(1+np.exp(x))

for i, city in enumerate(linear_deconv2_quasi2_df_flat.location.unique()):
    axes[i].set_title(city)
    
    for var in linear_deconv2_quasi2_df_flat["variant"].unique():
        tt_df = linear_deconv2_quasi2_df_flat[
            (linear_deconv2_quasi2_df_flat["variant"] == var) &
            (linear_deconv2_quasi2_df_flat["location"] == city) 
        ].reset_index().pivot(
            index=["date", "location", "variant"],
            columns="estimate"
        ).reset_index().sort_values(by='date')

        sns.lineplot(
            x=tt_df["date"], 
            y=tt_df["frac"]["MSE"], 
            hue=tt_df["variant"],
            ax = axes[i], 
            palette = color_map
        )
        axes[i].fill_between(
            x=tt_df["date"], 
            y1=logit_inv(np.clip(tt_df["frac"]["Wald_lower"], -100, 100)), 
            y2=logit_inv(np.clip(tt_df["frac"]["Wald_upper"], -100, 100)),
            alpha = 0.2,
            #color="grey"
            color=color_map[var],
        )
handles, labels = axes[i].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=len(labels), bbox_to_anchor=(0.5,0.05))
fig.suptitle(f'Gaussian Kernel Deconvolution ($k=10$)')
# plt.savefig(os.path.join(plots_dir, f"combined-linear.pdf"))

In [None]:
# df["Date of Birth"] = pd.to_datetime(df["Date of Birth"])

In [None]:
tt_df.dtypes

In [None]:
fix = tt_df
fix["date"] = pd.to_datetime(fix["date"])
fix.dtypes

In [None]:
fix=# df["Date of Birth"] = pd.to_datetime(df["Date of Birth"])