# Count snvs in wastewater samples from ShoRAH output

In [None]:
from IPython.display import display, HTML

display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
import numpy as np
import pandas as pd
import os
import glob
import re
import gzip
import csv
import strictyaml
from Bio import SeqIO
from BCBio import GFF
import subprocess
from termcolor import colored
from tqdm.notebook import tqdm, trange
import pysam

## Globals

A few general variable about where to find stuff. Adapt to your own needs.

In [None]:
# Inputs
voc_dir = "../voc/"  # where COJAC stores its variants' YAMLs
vpipe_working = "working"  # V-pipe's working directory
ww_samples_tsv = f"{vpipe_working}/samples.wastewateronly.tsv"  # samples TSV file listing the waste water samples

# ww_samples_tsv = f"{vpipe_working}/samples.wastewateronly.lastweek.tsv"
genes_gff = f"{vpipe_working}/references/gffs/Genes_NC_045512.2.GFF3"  # genes table

# Outputs
muttable_tsv = "mutlist.txt"
tables_dir = "snv_tables"
if not os.path.isdir(tables_dir):
    try:
        os.mkdir(tables_dir, mode=0o775)
    except FileExistsError:
        pass

## Define mutations to look for

This is now done by parsing the variant's YAML file and the mutation position listed therein

In [None]:
rxmutdec = re.compile(
    "^(?:(?:(?:(?P<ref>[ATCG]+)\>)?(?P<mut>[ATCG]+))|(?P<del>[\-]+)|(?:[\+](?P<ins>[ATGC]+)))$"
)

In [None]:
vartable = pd.DataFrame(data={"position": [], "reference": [], "variant": []}).astype(
    {"position": "int"}
)
for yp in glob.glob(os.path.join(voc_dir, "*.yaml")):  # auto skips .hidden
    print(yp)
    with open(yp, "r") as yf:
        yam = strictyaml.dirty_load(yf.read(), allow_flow_style=True).data
    muts = pd.DataFrame(
        data={
            "position": [],
            "reference": [],
            "variant": [],
            yam["variant"]["short"]: [],
        }
    ).astype({"position": "int"})
    for c in ["mut", "extra", "shared", "subset"]:
        # all categories (we don't care, we will compare accross samples)
        if c in yam:
            for pos, mutstr in yam[c].items():
                if not (res := rxmutdec.match(mutstr)):
                    print(f"{yp}:{pos} cannot parse {mutstr}")
                    continue
                match = res.groupdict()
                if match["ins"]:
                    print(f"{yp}:{pos} insertions not supported (yet): {match['ins']}")
                    continue
                elif match["mut"]:
                    for i in range(len(match["mut"])):
                        muts = pd.concat(
                            [
                                muts,
                                pd.DataFrame.from_records(
                                    [
                                        {
                                            "position": int(pos) + i,
                                            "reference": (
                                                match["ref"][i]
                                                if match["ref"]
                                                and i < len(match["ref"])
                                                else "N"
                                            ),
                                            "variant": match["mut"][i],
                                            yam["variant"]["short"]: c,
                                        }
                                    ]
                                ),
                            ]
                        )
                elif match["del"]:
                    # TODO this is wrong and will be fixed in ShoRAH
                    for i in range(len(match["del"])):
                        muts = pd.concat(
                            [
                                muts,
                                pd.DataFrame.from_records(
                                    [
                                        {
                                            "position": int(pos) + i,
                                            "reference": (
                                                match["ref"][i]
                                                if match["ref"]
                                                and i < len(match["ref"])
                                                else "N"
                                            ),
                                            "variant": "-",
                                            yam["variant"]["short"]: c,
                                        }
                                    ]
                                ),
                            ]
                        )

    vartable = vartable.merge(
        how="outer", right=muts, copy=False, sort=True
    )  # .fillna('')
with pd.option_context("display.max_rows", None):  # , 'display.max_columns', None):
    display(vartable)  # .sort_values('position'))

## Add genes

In [None]:
if not "gene" in vartable.columns:
    vartable.insert(3, "gene", [""] * len(vartable.index))

In [None]:
vartable["gene"] = ""
if genes_gff:
    with open(genes_gff) as gf:
        for record in GFF.parse(gf):
            for feature in record.features:
                if feature.type == "gene":
                    mask = (int(feature.location.end) >= vartable["position"]) & (
                        vartable["position"] >= int(feature.location.start)
                    )
                    vartable.loc[mask, "gene"] = feature.qualifiers.get(
                        "Name", [feature.id]
                    )[0]
display(vartable)

In [None]:
vartable.to_csv("mutlist.txt", sep="\t", index=False, na_rep="NA")

## Functions

In [None]:
def extract_range(filename):
    """extract the window range from a shorah window filename:"""
    match = re.search("([0-9]+)\-([0-9]+).reads", filename)
    return (int(match.group(1)), int(match.group(2)))

In [None]:
def count_snvs(filename, shorah_table):
    """Function to produce a n_local_haplot X p_snv_falling_in_the_local_haplo_window table of snv counts
    Parameters:
        filename: str of the name of the fasta.gz file for the shorah window
        shorah_table: table outputted by shorah containing positions and
    Return:
        df_out: pd.DataFrame of snv counts with local haplos in the rows and snv's in the columns
    """
    # extract range of window from filename
    seqstart, seqstop = extract_range(filename)
    # subset rows of shorah table for snv's falling in that range
    shorah_table_subset = shorah_table[
        (shorah_table["position"] >= seqstart) & (shorah_table["position"] <= seqstop)
    ]
    # stop there and return None if no snv's fall in that range
    if shorah_table_subset.shape[0] == 0:
        return None
    else:
        with gzip.open(filename, "rt") as f:
            window_lst = []
            window_names = []
            # iterate through local haplos
            for record in SeqIO.parse(f, "fasta"):
                # keep seq name
                window_names.append(record.description)
                snv_lst = []
                # iterate through snvs falling in the window
                for i in range(shorah_table_subset.shape[0]):
                    # test if the snv is present in this local haplo
                    snv_lst.append(
                        (
                            record.seq[
                                shorah_table_subset["position"].values[i] - seqstart
                            ]
                            == shorah_table_subset["variant"].values[i]
                        )
                    )
                window_lst.append(snv_lst)

        haplos_array = np.array(window_lst) * 1
        snv_names = (
            shorah_table_subset["reference"]
            + shorah_table_subset["position"].astype("str")
            + shorah_table_subset["variant"]
        )
        df_out = pd.DataFrame(haplos_array, columns=snv_names, index=window_names)

        return df_out

## List all wastewater samples

In [None]:
with open(
    ww_samples_tsv, "rt", encoding="utf-8", newline=""
) as tf:  # this file has the same content as the original experiment
    ww_sampledirs = [
        d
        for d in [
            f"working/samples/{sample}/{batch}/variants/SNVs/REGION_1/support/"
            for (sample, batch, *r) in csv.reader(tf, delimiter="\t")
        ]
        if os.path.isdir(d)
    ]
ww_sampledirs

In [None]:
temp_dirlist = [ww_sampledirs[0] + i for i in os.listdir(ww_sampledirs[0])]
temp_dirlist

In [None]:
try:
    vartable["helo"]
except KeyError:
    print("NO")

## do it for one mutation

In [None]:
min_posterior = 0.8
mut_number = 0
temp_dirlist = [ww_sampledirs[0] + i for i in os.listdir(ww_sampledirs[0])]

# find all snv tables for one mutation
tmp_snvcounts = []
mut_name = (
    vartable.iloc[mut_number]["reference"]
    + str(vartable.iloc[mut_number]["position"])
    + vartable.iloc[mut_number]["variant"]
)
candidate_windows = 0  # keep track of candidate windows
for win in temp_dirlist:
    strt, stp = extract_range(win)
    if strt <= vartable["position"][mut_number] <= stp:
        candidate_windows += 1
        try:
            snv_tab = count_snvs(win, vartable)[mut_name]
        except KeyError:
            snv_tab = None
        if snv_tab is not None:
            tmp_snvcounts.append(snv_tab)

# sum haplos in each window and take the average
ave_reads_full_lst = []
for win in range(len(tmp_snvcounts)):
    ave_reads_lst = []
    for haplo in range(tmp_snvcounts[win].shape[0]):
        haplo_name = tmp_snvcounts[win].index[haplo]
        posterior = float(
            re.search("posterior=([0-1][\.]{0,1}[0-9]{0,})", haplo_name).group(1)
        )
        ave_reads = float(
            re.search("ave_reads=([0-9]+[\.]{0,1}[0-9]{0,})", haplo_name).group(1)
        )
        if posterior > min_posterior:
            if tmp_snvcounts[win][haplo] == 1:
                ave_reads_lst.append(ave_reads)
    ave_reads_tmp = sum(ave_reads_lst)
    ave_reads_full_lst.append(ave_reads_tmp)
effective_windows = len(ave_reads_full_lst)
ave_r = np.average(ave_reads_full_lst)

In [None]:
def search_one_mut(temp_dirlist, vartable, mut_number, min_posterior=0.9):
    """Look for mutation number (mut_number) of (vartable) in (temp_dirlist)"""
    # find all snv tables for one mutation
    tmp_snvcounts = []
    mut_name = (
        vartable.iloc[mut_number]["reference"]
        + str(vartable.iloc[mut_number]["position"])
        + vartable.iloc[mut_number]["variant"]
    )
    candidate_windows = 0  # keep track of candidate windows
    for win in temp_dirlist:
        strt, stp = extract_range(win)
        if strt <= vartable["position"][mut_number] <= stp:
            candidate_windows += 1
            try:
                snv_tab = count_snvs(win, vartable)[mut_name]
            except KeyError:
                snv_tab = None
            if snv_tab is not None:
                tmp_snvcounts.append(snv_tab)

    # sum haplos in each window and take the average
    ave_reads_full_lst = []
    for win in range(len(tmp_snvcounts)):
        ave_reads_lst = []
        for haplo in range(tmp_snvcounts[win].shape[0]):
            haplo_name = tmp_snvcounts[win].index[haplo]
            posterior = float(
                re.search("posterior=([0-1][\.]{0,1}[0-9]{0,})", haplo_name).group(1)
            )
            ave_reads = float(
                re.search("ave_reads=([0-9]+[\.]{0,1}[0-9]{0,})", haplo_name).group(1)
            )
            if posterior > min_posterior:
                if tmp_snvcounts[win][haplo] == 1:
                    ave_reads_lst.append(ave_reads)
        ave_reads_tmp = sum(ave_reads_lst)
        ave_reads_full_lst.append(ave_reads_tmp)
    effective_windows = sum([i > 0 for i in ave_reads_full_lst])

    # compute average
    ave_r = np.average(ave_reads_full_lst) if len(ave_reads_full_lst) else 0
    if not len(ave_reads_full_lst):
        warnname = os.sep.join(str(temp_dirlist[0]).split(os.sep)[:-2])
        print(f"Warning! Can't average in {warnname}")

    return (candidate_windows, effective_windows, ave_r)

In [None]:
def search_all_mut(temp_dirlist, vartable, min_posterior=0.9):
    arr1 = np.array(
        [
            list(search_one_mut(temp_dirlist, vartable, i, min_posterior))
            for i in range(vartable.shape[0])
        ]
    )
    temp_df = pd.DataFrame(
        arr1, columns=["candidate_windows", "effective_windows", "ave_reads"]
    )
    temp_df = pd.concat([vartable, temp_df], axis=1)
    return temp_df

## Make all mutations outputs

In [None]:
all_mut_dfs = []
for sample in tqdm(ww_sampledirs):
    # check if ShoRAH did output windows there
    if not os.path.isdir(sample):
        print(f"Warning! No windows in {sample}!!!")
        continue

    temp_dirlist = [sample + i for i in os.listdir(sample)]
    if 0 == len(temp_dirlist):
        print(f"Warning! No windows in {sample}!!!")
        continue

    mut_df = search_all_mut(temp_dirlist, vartable, min_posterior=0.9)
    all_mut_dfs.append(mut_df)
    spl = sample.split(os.sep)
    mut_df.to_csv(os.path.join(tables_dir, f"{spl[2]}-{spl[3]}_snv.csv"), na_rep="NA")

# Double checking code snippet

In [None]:
ww_sampledirs[1]

### variants

In [None]:
for dfnum in range(len(all_mut_dfs)):
    target_dir = ww_sampledirs[dfnum]
    for i in range(all_mut_dfs[dfnum].shape[0]):
        if pd.isna(all_mut_dfs[dfnum]["ave_reads"][i]):
            pos_to_check = all_dfs[dfnum]["position"][i]
            print(pos_to_check)
            lst1 = [ww_sampledirs[dfnum] + i for i in os.listdir(ww_sampledirs[dfnum])]
            for d in lst1:
                strt, stop = extract_range(d)
                if strt <= pos_to_check <= stop:
                    print("PROBLEM")