# Count snvs in wastewater samples from shorah output

In [None]:
import numpy as np
import pandas as pd
import os
import re
import gzip
import csv
from Bio import SeqIO
import subprocess
from IPython.core.display import display, HTML
from termcolor import colored
from tqdm.notebook import tqdm, trange
import pysam

## Functions

In [None]:
def extract_range(filename):
    '''extract the window range from a shorah window filename:'''
    match = re.search('([0-9]+)\-([0-9]+).reads', filename)
    return (int(match.group(1)), int(match.group(2)))

In [None]:
def count_snvs(filename, shorah_table):
    '''Function to produce a n_local_haplot X p_snv_falling_in_the_local_haplo_window table of snv counts
    Parameters:
        filename: str of the name of the fasta.gz file for the shorah window
        shorah_table: table outputted by shorah containing positions and 
    Return:
        df_out: pd.DataFrame of snv counts with local haplos in the rows and snv's in the columns
    '''
    # extract range of window from filename
    seqstart, seqstop = extract_range(filename)
    # subset rows of shorah table for snv's falling in that range
    shorah_table_subset = shorah_table[(shorah_table["position"] >= seqstart) & (shorah_table["position"] <= seqstop)]
    # stop there and return None if no snv's fall in that range
    if shorah_table_subset.shape[0] == 0:
        return None
    else:
        with gzip.open(filename, 'rt') as f:
            window_lst = [] 
            window_names = []
            # iterate through local haplos 
            for record in SeqIO.parse(f, "fasta"):
                # keep seq name
                window_names.append(record.description)
                snv_lst = []
                # iterate through snvs falling in the window
                for i in range(shorah_table_subset.shape[0]):
                    # test if the snv is present in this local haplo
                    snv_lst.append((record.seq[shorah_table_subset["position"].values[i]-seqstart] == 
                                    shorah_table_subset["variant"].values[i]))
                window_lst.append(snv_lst)    
                
        haplos_array = np.array(window_lst) * 1
        snv_names = shorah_table_subset["reference"] + \
            shorah_table_subset["position"].astype('str') + \
            shorah_table_subset["variant"]
        df_out = pd.DataFrame(haplos_array, columns=snv_names, index=window_names)
            
        return df_out   

## Define mutations to look for

In [None]:
rxmut=re.compile('^(?P<reference>[NATGC])(?P<position>[0-9]+)(?P<variant>[ATGC-])$')

In [None]:
UK_varlist = [
"C3267T",
"C5388A",
"T6954C",
"N11288-",
"N11289-",
"N11290-",
"N11291-",
"N11292-",
"N11293-",
"N11294-",
"N11295-",
"N11296-",
"N21765-",
"N21766-",
"N21767-",
"N21768-",
"N21769-",
"N21770-",
"N21991-",
"N21992-",
"N21993-",
"A23063T",
"C23271A",
"C23604A",
"C23709T",
"T24506G",
"G24914C",
"C27972T",
"G28048T",
"A28111G",
#"28280 GAT->CTA",
"G28280C",
"A28281T",
"T28282A",
"C28977T"]
UK_vartable = pd.DataFrame(data=[rxmut.match(i).groupdict() for i in UK_varlist])
UK_vartable["position"] = UK_vartable["position"].astype('int')
UK_vartable

In [None]:
SA_varlist = [
"C1059T",
"G5230T",
"A10323G",
"A21801C",
"G22813T",
"G23012A",
"A23063T",
"C23664T",
"G25563T",
"C25904T",
"C26456T",
"C28887T"]
SA_vartable = pd.DataFrame(data=[rxmut.match(i).groupdict() for i in SA_varlist])
SA_vartable["position"] = SA_vartable["position"].astype('int')
SA_vartable

## List all wastewater samples

In [None]:
t='ww.tsv'
with open(t,'rt',encoding='utf-8') as tf:	# this file has the same content as the original experiment
    ww_sampledirs = [f"working/samples/{r['sample']}/{r['batch']}/variants/SNVs/REGION_1/support/" for r in csv.DictReader(tf, dialect='excel-tab')]
ww_sampledirs


In [None]:
temp_dirlist = [ww_sampledirs[0] + i for i in os.listdir(ww_sampledirs[0])]
temp_dirlist


In [None]:
try:
    UK_vartable["helo"]
except KeyError:
    print("NO")

## do it for one mutation

In [None]:
min_posterior = 0.8
mut_number = 0
temp_dirlist = [ww_sampledirs[0] + i for i in os.listdir(ww_sampledirs[0])]
vartable = UK_vartable

# find all snv tables for one mutation
tmp_snvcounts = []
mut_name = vartable.iloc[mut_number]["reference"] + \
    str(vartable.iloc[mut_number]["position"]) + \
    vartable.iloc[mut_number]["variant"]
candidate_windows = 0 # keep track of candidate windows
for win in temp_dirlist:
    strt, stp = extract_range(win)
    if strt <= vartable["position"][mut_number] <= stp:
        candidate_windows += 1
        try:
            snv_tab = count_snvs(win, vartable)[mut_name]
        except KeyError:
            snv_tab = None
        if snv_tab is not None:
            tmp_snvcounts.append(snv_tab)

# sum haplos in each window and take the average
ave_reads_full_lst = []
for win in range(len(tmp_snvcounts)):
    ave_reads_lst = []
    for haplo in range(tmp_snvcounts[win].shape[0]):
        haplo_name = tmp_snvcounts[win].index[haplo]
        posterior = float(re.search("posterior=([0-1][\.]{0,1}[0-9]{0,})", haplo_name).group(1))
        ave_reads = float(re.search("ave_reads=([0-9]+[\.]{0,1}[0-9]{0,})", haplo_name).group(1))
        if posterior > min_posterior:
            if tmp_snvcounts[win][haplo] == 1:
                ave_reads_lst.append(ave_reads)
    ave_reads_tmp = sum(ave_reads_lst)
    ave_reads_full_lst.append(ave_reads_tmp)
effective_windows = len(ave_reads_full_lst)
ave_r = np.average(ave_reads_full_lst)



In [None]:
def search_one_mut(temp_dirlist, vartable, mut_number, min_posterior=0.9):
    '''Look for mutation number (mut_number) of (vartable) in (temp_dirlist)'''
    # find all snv tables for one mutation
    tmp_snvcounts = []
    mut_name = vartable.iloc[mut_number]["reference"] + \
        str(vartable.iloc[mut_number]["position"]) + \
        vartable.iloc[mut_number]["variant"]
    candidate_windows = 0 # keep track of candidate windows
    for win in temp_dirlist:
        strt, stp = extract_range(win)
        if strt <= vartable["position"][mut_number] <= stp:
            candidate_windows += 1
            try:
                snv_tab = count_snvs(win, vartable)[mut_name]
            except KeyError:
                snv_tab = None
            if snv_tab is not None:
                tmp_snvcounts.append(snv_tab)

    # sum haplos in each window and take the average
    ave_reads_full_lst = []
    for win in range(len(tmp_snvcounts)):
        ave_reads_lst = []
        for haplo in range(tmp_snvcounts[win].shape[0]):
            haplo_name = tmp_snvcounts[win].index[haplo]
            posterior = float(re.search("posterior=([0-1][\.]{0,1}[0-9]{0,})", haplo_name).group(1))
            ave_reads = float(re.search("ave_reads=([0-9]+[\.]{0,1}[0-9]{0,})", haplo_name).group(1))
            if posterior > min_posterior:
                if tmp_snvcounts[win][haplo] == 1:
                    ave_reads_lst.append(ave_reads)
        ave_reads_tmp = sum(ave_reads_lst)
        ave_reads_full_lst.append(ave_reads_tmp)
    effective_windows = sum([i>0 for i in ave_reads_full_lst])

    # compute average 
    ave_r = np.average(ave_reads_full_lst) if len(ave_reads_full_lst) else 0
    if not len(ave_reads_full_lst):
        warnname=os.sep.join(str(temp_dirlist[0]).split(os.sep)[:-2])
        print(f"Warning! Can't average in {warnname}")

    return (candidate_windows, effective_windows, ave_r)
    

In [None]:
def search_all_mut(temp_dirlist, vartable, min_posterior=0.9):
    arr1 = np.array([list(search_one_mut(temp_dirlist, vartable, i, min_posterior)) for i in range(vartable.shape[0])])
    temp_df = pd.DataFrame(arr1, columns=["candidate_windows", "effective_windows", "ave_reads"])
    temp_df = pd.concat([vartable, temp_df], axis=1)
    return temp_df
    

## Make all UK outputs

In [None]:
all_UK_dfs = []

for sample in tqdm(ww_sampledirs):
    # check if ShoRAH did output windows there
    if not os.path.isdir(sample):
        print(f"Warning! No windows in {sample}!!!")
        continue

    temp_dirlist = [sample + i for i in os.listdir(sample)]
    mut_df_UK = search_all_mut(temp_dirlist, UK_vartable, min_posterior=0.9)
    all_UK_dfs.append(mut_df_UK)
    spl=sample.split(os.sep)
    mut_df_UK.to_csv(os.path.join('uk_snv_tables', f"{spl[2]}-{spl[3]}_uk_snv.csv"), na_rep="NA")


## Run all SA outputs

In [None]:
all_SA_dfs = []

for sample in tqdm(ww_sampledirs):
    # check if ShoRAH did output windows there
    if not os.path.isdir(sample):
        print(f"Warning! No windows in {sample}!!!")
        continue

    temp_dirlist = [sample + i for i in os.listdir(sample)]
    mut_df_SA = search_all_mut(temp_dirlist, SA_vartable, min_posterior=0.9)
    all_SA_dfs.append(mut_df_SA)
    spl=sample.split(os.sep)
    mut_df_SA.to_csv(os.path.join('sa_snv_tables/', f"{spl[2]}-{spl[3]}_sa_snv.csv"), na_rep="NA")


# Double checking code snippet

In [None]:
ww_sampledirs[1]

### UK

In [None]:
for dfnum in range(len(all_UK_dfs)):
    target_dir = ww_sampledirs[dfnum]
    for i in range(all_UK_dfs[dfnum].shape[0]):
        if pd.isna(all_UK_dfs[dfnum]["ave_reads"][i]):
            pos_to_check = all_UK_dfs[dfnum]["position"][i]
            print(pos_to_check)
            lst1 = [ww_sampledirs[dfnum] + i for i in os.listdir(ww_sampledirs[dfnum])]
            for d in lst1:
                strt, stop = extract_range(d)
                if strt <= pos_to_check <= stop:
                    print("PROBLEM")
        
        

### SA

In [None]:
for dfnum in range(len(all_SA_dfs)):
    target_dir = ww_sampledirs[dfnum]
    for i in range(all_SA_dfs[dfnum].shape[0]):
        if pd.isna(all_SA_dfs[dfnum]["ave_reads"][i]):
            pos_to_check = all_SA_dfs[dfnum]["position"][i]
            print(pos_to_check)
            lst1 = [ww_sampledirs[dfnum] + i for i in os.listdir(ww_sampledirs[dfnum])]
            for d in lst1:
                strt, stop = extract_range(d)
                if strt <= pos_to_check <= stop:
                    print("PROBLEM")

In [None]:
UK_varlist