In [None]:
from IPython.display import display, HTML

display(HTML("<style>.container { width:100% !important; }</style>"))
plotwidth = 40

In [None]:
from WwDec.main import *

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap

# Globals

In [None]:
# Source of inspiration from covariatns, see:
# https://github.com/hodcroftlab/covariants/blob/master/web/data/clusters.json
#
# Keep in sync with covspectrum, see:
# https://github.com/GenSpectrum/cov-spectrum-website/blob/develop/src/models/wasteWater/constants.tsv
color_map = {
    "B.1.1.7": "#D16666",
    "B.1.351": "#FF6665",
    "P.1": "#FFB3B3",
    "B.1.617.1": "#66C265",
    "B.1.617.2": "#66A366",
    "BA.1": "#A366A3",
    "BA.2": "#cfafcf",
    "BA.4": "#8a66ff",
    "BA.5": "#585eff",
    "BA.2.12.1": "#0400e0",
    "BA.2.75": "#008fe0",
    "BA.2.75.2": "#208fe0",  # improv
    "BQ.1.1": "#8fe000",  # improv
    "XBB": "#dd6bff",
    "undetermined": "#969696",
}

In [None]:
# Overwrite globals set by WwDec.main:
# temporary, globals

# zst needs python's Zstandard
tally_data = "./work-vp-test/variants/tallymut.tsv.zst"  # zst needs python's Zstandard # "./tallymut_line.tsv"
out_dir = "./out"

import yaml

# load variants configuations
with open("work-vp-test/variant_config.yaml", "r") as file:
    conf_yaml = yaml.load(file, Loader=yaml.SafeLoader)
variants_list = conf_yaml.get("variants_list", None)
variants_pangolin = conf_yaml["variants_pangolin"]
variants_not_reported = conf_yaml.get("variants_not_reported", [])
start_date = conf_yaml.get("start_date", None)
end_date = conf_yaml.get(
    "end_date"
)  # optionnal, usually absent in ongoing surveillance, and present in articles with subset of historical data

to_drop = conf_yaml.get("to_drop", [])
locations_list = conf_yaml.get("locations_list", None)

# var dates
with open("work-vp-test/var_dates.yaml", "r") as file:
    conf_yaml.update(yaml.load(file, Loader=yaml.SafeLoader))

if not variants_list:
    # build list of all variants from var_dates (if we did lack one)
    conf_yaml["variants_list"] = variants_list = sorted(
        list(set([var for lst in conf_yaml["var_dates"].values() for var in lst]))
    )


# display the current config
conf_yaml

# Load and preprocess data

In [None]:
df_tally = pd.read_csv(
    tally_data, sep="\t", parse_dates=["date"], dtype={"location_code": "str"}
)  # .drop(columns=['proto'])
df_tally.head()

In [None]:
if not locations_list:
    # remember to remove empty cells: nan or empty cells
    conf_yaml["locations_list"] = locations_list = list(
        set(df_tally["location"].unique()) - {"", np.nan}
    )
    display(locations_list)

In [None]:
(
    set(df_tally.columns)
    - set(variants_pangolin.keys())
    - {
        "base",
        "batch",
        "cov",
        "date",
        "frac",
        "gene",
        "location_code",
        "location",
        "pos",
        "proto",
        "sample",
        "var",
    }
)

In [None]:
preproc = DataPreprocesser(df_tally)
preproc = preproc.general_preprocess(
    variants_list=variants_list,
    variants_pangolin=variants_pangolin,
    variants_not_reported=variants_not_reported,
    to_drop=["subset"],
    start_date=start_date,
    remove_deletions=True,
)
t_df_tally = preproc.df_tally
# split into v41 and not v41, filter mutations and join
df_tally_v41 = preproc.df_tally[preproc.df_tally.proto == "v41"]
print(df_tally_v41.shape)
preproc.df_tally = preproc.df_tally[preproc.df_tally.proto != "v41"]
preproc = preproc.filter_mutations()
print(preproc.df_tally.shape)

preproc.df_tally = pd.concat([preproc.df_tally, df_tally_v41])
print(preproc.df_tally.shape)
# preproc.df_tally['']

# Look at design of mutations

In [None]:
des_matrix = (
    preproc.df_tally[variants_list + ["undetermined", "mutations"]]
    .drop_duplicates("mutations")
    .set_index("mutations")
)
des_matrix_mut = des_matrix[~des_matrix.index.str.startswith("-")]
des_matrix_wt = des_matrix[des_matrix.index.str.startswith("-")]

In [None]:
fig, axes = plt.subplots(ncols=1, nrows=2, figsize=(plotwidth * 0.5, plotwidth / 8))
cmap_binary = ListedColormap(["white", "red"])
# sns.heatmap(des_matrix.T, square=True, cmap=cmap_binary, cbar=False)

sns.heatmap(des_matrix_mut.T, square=True, cmap=cmap_binary, cbar=False, ax=axes[0])
sns.heatmap(des_matrix_wt.T, square=True, cmap=cmap_binary, cbar=False, ax=axes[1])

# axes[0].tick_params(labelsize=9)


plt.show()

In [None]:
# np.linalg.cond(des_matrix_mut.drop('undetermined', axis=1))
print("condition number = {:.2f}".format(np.linalg.cond(des_matrix)))

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(22, 7))

common_mut = des_matrix_mut.T.dot(des_matrix_mut)
sns.heatmap(common_mut, square=True, cmap="viridis", annot=common_mut, ax=axes[0])
axes[0].set_title("common mutations")

corr_mut = (des_matrix_mut).corr()
sns.heatmap(
    corr_mut, square=True, cmap="viridis", annot=corr_mut, ax=axes[1], fmt=".1g"
)
axes[1].set_title("correlation")

from sklearn.metrics.pairwise import pairwise_distances

jac_sim = 1 - pairwise_distances(des_matrix_mut.T, metric="hamming")
jac_sim = pd.DataFrame(
    jac_sim, index=des_matrix_mut.columns, columns=des_matrix_mut.columns
)
sns.heatmap(jac_sim, square=True, cmap="viridis", annot=jac_sim, ax=axes[2])
axes[2].set_title("jaccard similarity ((A∩B)/(A∪B))")

fig.show()

In [None]:
locations_1 = locations_list
# locations_1 = ['Lugano (TI)',
#               'Zürich (ZH)',
#               'Genève (GE)',
#               'Chur (GR)',
#               'Altenrhein (SG)',
#               'Laupen (BE)',
#               'Lausanne (Vidy)',
#               'Sion (VS)',
#               'Porrentruy (JU)',
#               'Basel (catchment area ARA Basel)']
# print(set(locations_1)-set(cities_list))

In [None]:
all_conds_df = []
for proto in preproc.df_tally.proto.unique():
    for location in locations_1:
        t_df_tally_zh = preproc.df_tally[preproc.df_tally.location == location]
        t_df_tally_zh = t_df_tally_zh[
            (t_df_tally_zh.proto == proto) & (t_df_tally_zh["cov"] >= 5)
        ]

        conds = []
        for date in t_df_tally_zh.date.unique():
            des_matrix = (
                t_df_tally_zh[(t_df_tally_zh.date == date)][
                    variants_list + ["undetermined", "mutations"]
                ]
                .drop_duplicates("mutations")
                .set_index("mutations")
            )
            des_matrix_mut = des_matrix[~des_matrix.index.str.startswith("-")]
            des_matrix_wt = des_matrix[des_matrix.index.str.startswith("-")]

            #             print((location, date))

            jac_sim = 1 - pairwise_distances(des_matrix_mut.T, metric="hamming")
            jac_sim = pd.DataFrame(jac_sim)
            jac_arr = jac_sim.values
            np.fill_diagonal(jac_arr, np.nan)
            maxjac = np.nanmax(jac_arr)

            corr_mut = (des_matrix_mut).corr()
            corr_arr = corr_mut.values
            np.fill_diagonal(corr_arr, np.nan)
            maxcorr = np.nanmax(corr_arr)

            conds.append(
                {
                    "n_mut": des_matrix_mut.shape[0],
                    "cond_number": np.linalg.cond(des_matrix),
                    "cond_number_omicron": np.linalg.cond(
                        des_matrix[["BA.1", "BA.2", "BA.4", "BA.5"]]
                    ),
                    "max_jac": maxjac,
                    "max_corr": maxcorr,
                    "location": location,
                }
            )

        conds_df = pd.DataFrame(conds, index=t_df_tally_zh.date.unique())
        conds_df["proto"] = proto
        all_conds_df.append(conds_df)
        # print(np.linalg.cond(des_matrix_mut.drop('undetermined', axis=1)))

In [None]:
all_conds_df_conc = pd.concat(all_conds_df)
all_conds_df_conc = all_conds_df_conc.reset_index()
all_conds_df_conc.head()

In [None]:
fig, axes = plt.subplots(5, 2, figsize=(14, 12))
axes = axes.flatten()

for i, location in enumerate(all_conds_df_conc.location.unique()):
    tmp_df = all_conds_df_conc[all_conds_df_conc.location == location]
    h = sns.lineplot(
        x=tmp_df["index"], y=tmp_df["max_jac"], hue=tmp_df["proto"], ax=axes[i]
    )
    # h.set_ylim(top=20)
    h.set_xlim(left=np.datetime64("2021-12-01"))
    axes[i].set_title(location)
    axes[i].set_ylabel("max jaccard sim")
    axes[i].set_xlabel("")
    #     axes[i].set_xticks(rotation = 45) # Rotates X-Axis Ticks by 45-degrees

    for tick in axes[i].get_xticklabels():
        tick.set_rotation(45)

fig.tight_layout()  # Or equivalently,  "plt.tight_layout()"