In [8]:
import numpy as np
import pandas as pd
from numba import jit, njit, prange, objmode
import threading
import concurrent.futures
import warnings
import pysam
from Bio import pairwise2
import copy
import pickle
from cigar import Cigar

%run EM_algo.py

In [9]:
vsf = pd.read_csv("galaxy/SEM-2903/Galaxy20-[VCFtoTab-delimited__on_data_18] (1).csv", sep="\t")
gisaid = pd.read_csv("result_nextstrain/csv/result_supstitution.csv", index_col=0)
bam = pysam.AlignmentFile("galaxy/SEM-2903/Galaxy24-[BAM_filter_on_data_16].bam", "rb")
with pysam.FastxFile("Galaxy4-[Galaxy2-[RefSeq_SARS-CoV_NC_45512.fasta].fasta].fasta") as fh:
    for entry in fh:
        seq_ref = str(entry.sequence)
variant_of_intresst = ["21A (Delta)", "20H (Beta, V2)", "20A", "20E (EU1)", "20B", "20I (Alpha, V1)", "20J (Gamma, V3)", "21D (Eta)"]

In [10]:
def get_all_mutations_to_check(vsf, variant_of_intresst, tol=0.005):  # get all mutaions in galaxy file + mutations of the variants in gisaid
    # Mutation of vsf
    mutations_vsf = np.array(vsf[vsf["AF"] > tol][["REF", "POS", "ALT"]])
    mutations_vsf_sub = [mut for mut in mutations_vsf if len(mut[0]) == len(mut[2])]
    ## and found in gisaid
    intersect = []
    for mut in mutations_vsf_sub:
        if (mut[0] + str(mut[1:-1][0]) + mut[-1]) in set(gisaid.index):
            intersect.append(mut)
    # Mutation in gisaid with variant_of_intresst
    mutations_gisaid = []
    for var in variant_of_intresst:
        for mut in gisaid[gisaid[var] > tol].index:
            mutations_gisaid.append([mut[0], int(mut[1:-1]), mut[-1]])
    mutations = list(intersect)  # + list(mutations_gisaid)
    mutations = np.array([list(x) for x in set(tuple(x) for x in mutations)])  # uniqness
    mutations = mutations[mutations[:, 1].astype(int).argsort()]  # sort
    mutations = [list(mutation) for mutation in mutations]
    for i in range(len(mutations)):
        mutations[i][1] = int(mutations[i][1])
    mutations = [tuple(mut) for mut in mutations]
    return mutations

In [11]:
dict_mut = dict()
def get_mutations_to_check(start_idx, length, all_mutations):
    global dict_mut
    if start_idx in dict_mut:
        return dict_mut[start_idx]
    else:
        val = set()
        for ref_i, mut in enumerate(all_mutations):
            if mut[1] >= start_idx + length:
                break
            if start_idx <= mut[1]:
                val.add((mut, ref_i))
        dict_mut[start_idx] = val
        return val

In [12]:
bam.get_index_statistics()

[IndexStats(contig='NC_045512.2', mapped=721290, unmapped=0, total=721290)]

In [13]:
len(seq_ref)

29903

In [14]:
def strat_parallele(start_index=0, stop_index=29903, nb_worker=8):
    longueur = min(100, (stop_index - start_index) // nb_worker)
    csv_writer_lock = threading.Lock()
    MAX_WORKERS = nb_worker  # to be changed in function of your machine proc power
    my_threads = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        tasks = {
            executor.submit(lambda p: start_align(*p), [start_index, min(start_index + longueur, stop_index)]): start_index
            for start_index in range(start_index, stop_index, longueur)
        }

In [15]:
# Just to verify
def get_result(all_mutations):
    # for mut in all_mutations:
    global data
    # for i in range(len(all_mutations_pos)):
    for pos in range(len(all_mutations)):
        # mut = all_mutations_pos[i]
        # print(mut[1])
        founds = np.sum(X[:, pos] == 1)
        not_found = np.sum(X[:, pos] == 0)
        data.append([all_mutations[pos], founds, not_found, founds + not_found, founds / (founds + not_found)])
        print(data[-1])

In [16]:
data = []
# get_result(all_mutations)

In [17]:
all_mutations = get_all_mutations_to_check(vsf, variant_of_intresst)
X = np.empty((int(bam.get_index_statistics()[0][3]), len(seq_ref)), dtype=np.dtype("i1"))
X.fill(-1)
all_mutations_pos = [mut[1] for mut in all_mutations]
def read_mut_calling(ref, read_start_idx, read_cigar, read_seq, all_mutations, read_idx):
    cig = Cigar(read_cigar)
    cigar_list = list(cig.items())
    pos_now = read_start_idx
    s = 0
    if cigar_list[0][1] == "S":
        s = 1
        pos_now += cigar_list[0][0]
    # calculat
    pos_ref = pos_now
    pos_seq = pos_now
    mut_to_checke = get_mutations_to_check(read_start_idx, len(cig) + 20, all_mutations)
    for c in cigar_list[s:]:
        if c[1] == "M":
            X[read_idx, pos_ref : pos_ref + c[0]] = 0
            for mut in mut_to_checke:
                if pos_ref <= mut[0][1] <= pos_ref + c[0] and read_seq[mut[0][1] - pos_seq - 1] == mut[0][2] and ref[mut[0][1] - 1] == mut[0][0]:
                    X[read_idx, mut[0][1] - 1] = 1
            pos_seq += c[0]
            pos_ref += c[0]
        elif c[1] == "I":
            pos_seq += c[0]
        elif c[1] == "D":
            X[read_idx, pos_ref : pos_ref + c[0]] = 0
            pos_ref += c[0]

In [18]:
nb_error = 0
marge = 10
decalee = 0
read_idx = 0
def start_align(start_index, stop_index):
    to_print = str(start_index) + " " + str(stop_index) + "\n"
    print(to_print)
    global X
    global nb_error
    # global read_idx
    global read_idx
    bam = pysam.AlignmentFile("galaxy/SEM-2903/Galaxy24-[BAM_filter_on_data_16].bam", "rb")
    # iter = bam.fetch()
    iter = bam.fetch("NC_045512.2", start_index, stop_index)
    for x in iter:
        read_data = str(x).split("	")  # read a line in the bam file and store it in a list
        read_start_idx = int(read_data[3])  # the index of start of the seq in the ref_seq
        read_cigar = read_data[5]
        read_seq = read_data[9]
        if read_start_idx - 1 > stop_index:
            break
        if read_start_idx >= start_index:
            if X[read_idx, read_start_idx] == 0:
                nb_error += 1
            read_mut_calling(seq_ref, read_start_idx, read_cigar, read_seq, all_mutations, read_idx)
            read_idx += 1

In [19]:
strat_parallele(nb_worker=8)
X = X[:, [mut[1] - 1 for mut in all_mutations]]#To keep mutation pos only

0 100
100 200
200 300



300 400

400 500

500 600

600 700

700 800

800 900

900 1000

1000 1100

1100 1200

1200 1300

1300 1400

1400 1500

1500 1600

1600 1700

1700 1800

1800 1900

1900 2000

2000 2100

2100 2200

2200 2300

2300 2400
2400 2500


2500 2600

2600 2700

2700 2800

2800 2900
2900 3000
3000 3100
3100 3200




3200 3300

3300 3400

3400 3500

3500 3600

3600 3700

3700 3800

3800 3900

3900 4000
4000 4100
4100 4200



4200 4300

4300 4400
4400 4500


4500 4600
4600 4700


4700 4800
4800 4900


4900 5000
5000 5100


5100 5200

5200 5300
5300 5400


5400 5500
5500 5600


5600 5700
5700 5800


5800 5900

5900 6000

6000 6100

6100 6200

6200 6300
6300 6400
6400 6500



6500 6600

6600 6700
6700 6800


6800 6900
6900 7000


7000 7100

7100 7200

7200 7300
7300 7400

7400 7500


7500 7600

7600 7700

7700 7800
7800 7900

7900 8000


8000 8100

8100 8200

8200 8300

8300 8400

8400 8500

8500 8600
8600 8700
8700 8800



8800 8900

8900 9000
9000 9100


9100 9200

9200 9300

In [20]:
df = pd.DataFrame(data)
df.to_csv("SEM-2903_vsf_mine.csv")
with open("X_SEM-2903.pickle", "wb") as handle:
    pickle.dump(X, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open("all_mutations_SEM-2903.pickle", "wb") as handle:
    pickle.dump(all_mutations, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [21]:
# Generate M
gisaid_file = 'result_nextstrain/csv/result_supstitution.csv'
M = pd.read_csv(gisaid_file, index_col=0)
M = M.loc[[mut[0] + str(mut[1]) + mut[2] for mut in all_mutations]][variant_of_intresst]


In [None]:
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)}) #Just pour l'affichage
resultat = algo_EM(X.shape[0], len(variant_of_intresst), len(M.index), X, np.array(M), max_iter=100)
print(variant_of_intresst)
resultat

In [17]:
import plotly.express as px

fig = px.pie(values=resultat[0], names=variant_of_intresst)
fig.show()