In [1]:
from Bio import SeqIO, Align, Seq
import time
import numpy as np
import multiprocessing as mp
import argparse
from argparse import RawTextHelpFormatter
import os
import pdb
import pandas as pd
from contextlib import nullcontext
from gzip import open as gzopen
import pickle

In [2]:
def align_barcodes(read_name, read_sequence, barcodes, aligner, threshhold):
    qualities = {}
    qualities[read_name] = []
    for BC_name, BC_fwd in barcodes.items():
        alignments = aligner.align(read_sequence, BC_fwd)
        for alignment in sorted(alignments):
            if alignment.score >= threshhold:
                qualities[read_name].append([alignment.score, alignment, BC_name, read_sequence])
    return(qualities)

def filter_alignments_score(x):
    """
    Keeps only the best alignment.
    Reads with multiple best alignments are put into the multibarcoded category (weird flag).
    Reads with no alignments into the no barcodes category.
    """
    weird_reads = []
    unbarcoded_reads = []
    new_dict = {}
    for read_name, aln_list in x.items():
        weird_flag = False
        aln_list = np.array(aln_list, dtype=object)
        if len(aln_list) == 0: # if there are no alignments
            unbarcoded_reads.append(read_name)
            continue
        score_list = np.array([x[0] for x in aln_list])
        max_score = int(np.max(score_list))
        # check if the maximum score appears more than once
        max_score_count = np.sum(score_list == max_score)
        if max_score_count > 1:
            # make a list of all the barcodes with the maximum score
            max_score_aln_list = [x[2] for x in aln_list[score_list == max_score]]
            # when the aligned, top-score barcodes are the same then do nothing
            # the first alignment is then chosen
            if len(np.unique(max_score_aln_list)) != 1:
                # the read has at least 2 different barcodes aligned with the same quality
                # set the weird flag
                weird_flag = True
        if weird_flag:
            weird_reads.append(read_name)
            continue
        new_dict[read_name] = [aln_list[np.argmax(score_list)]]
    return(new_dict, unbarcoded_reads, weird_reads)

def prepare_arguments(x, barcodes, aligner, threshold):
    # prepare arguments for mp
    async_readnames_list = list(x.keys())
    async_sequences_list = [x[j] for j in async_readnames_list]
    async_barcode_list = [barcodes for j in range(len(x))]
    async_aligner_list = [aligner for j in range(len(x))]
    async_thresh_list = [threshold for i in range(len(x))]

    starmap_args = []
    for i in range(len(x)):
        starmap_args.append((
            async_readnames_list[i],
            async_sequences_list[i],
            async_barcode_list[i],
            async_aligner_list[i],
            async_thresh_list[i]))
    return(starmap_args)

def combine_dict(d1, d2):
    return {
        k: tuple(d[k] for d in (d1, d2) if k in d)
        for k in list(set(d1.keys()) | set(d2.keys()))
    }

In [3]:
args = {
    "variants": "/run/user/1000/gvfs/smb-share:server=home.isilon.bioquant.uni-heidelberg.de,share=nwg-grimm/scripts_NGSanalysis/ctest_data/BC_variants.txt",
    "input": "/run/user/1000/gvfs/smb-share:server=home.isilon.bioquant.uni-heidelberg.de,share=nwg-grimm/scripts_NGSanalysis/ctest_data/fastqs/test/",
    "open_gap": -1,
    "extend_gap": -10,
    "mismatch": -1,
    "match": 1,
    "threshold": 4,
    "out_directory": "example_out_HAC",
    "threads": 8,
    "barcode_length": None,
    "left": "GGCTGG",
    "right": "TGGGCC",
    "left_thresh": None,
    "right_thresh": None,
    "verbose": True
    }

In [4]:
use_reverse = False

In [5]:
# parameters
## aligner parameters
aligner = Align.PairwiseAligner()
aligner.mode = 'local'
aligner.open_gap_score = int(args["open_gap"])
aligner.extend_gap_score = int(args["extend_gap"])
aligner.mismatch_score = int(args["mismatch"])
aligner.match_score = int(args["match"])
## barcodes threshold
barcode_threshold = int(args["threshold"])

## peptide threshold
left_threshold = args["left_thresh"]
if left_threshold == None:
    left_threshold = len(args["left"])-1
right_threshold = args["right_thresh"]
if right_threshold == None:
    right_threshold = len(args["right"])-1

In [6]:
print(left_threshold)
print(right_threshold)

5
5


In [7]:
barcodes_FWD = pd.read_csv(args["variants"], names = ["sequence", "name"], sep="\t").set_index("name").to_dict()["sequence"]
if use_reverse:
    barcodes_RC = {name+"RC":str(Seq.Seq(bc).reverse_complement()) for name, bc in barcodes_FWD.items()}
    barcodes = barcodes_RC
else:
    barcodes = barcodes_FWD

if args["barcode_length"] == None:
    barcode_length = len(list(barcodes.values())[-1])
else:
    barcode_length = int(args["barcode_length"])

In [8]:
barcodes

{'AAV2': 'GCTCTGGATGTAGTA',
 'S0112coAAV2': 'TATCAAGCTAACGTT',
 'S0312coAAV2': 'GTCAACATCGTTACA',
 'S0412coAAV2': 'GGGCCCTAGCGCGTG',
 'S0512coAAV2': 'GATAGGCTGGTCCAA',
 'S0912coAAV2': 'TATTTGTGTCGTTCC',
 'S1112coAAV2': 'AGTTAGGGCGCTGCG',
 'S1312coAAV2': 'GCCCTTCAGTCAGCT',
 'S1612coAAV2': 'CGGTCGCGTGACGTG',
 'S1812coAAV2': 'GCCGGAGTCCCGGTA',
 'AAV9': 'TGTTGGAAGGTATCA',
 'S0112coAAV9': 'GACTTGGTTGTGACG',
 'S0312coAAV9': 'TTGTTGTATGAGCAG',
 'S0412coAAV9': 'CTACCTATTTACTCT',
 'S0512coAAV9': 'ACCGGGCGTTGAGGC',
 'S0912coAAV9': 'TGGTTTACAAATTAT',
 'S1112coAAV9': 'GTTGTGCCCTGAGTG',
 'S1312coAAV9': 'ACCGTATCTCTCCGG',
 'S1612coAAV9': 'TTGGAACGTGGGCTT',
 'S1812coAAV9': 'AGATTCAAAGCTGCG',
 'AAV5': 'AGCCTAATCTTTGAC',
 'AAV8': 'AAGCACTAAAGAACA',
 'AAV_DJ': 'GGTATGGCCTGCCGC'}

In [9]:
## out directory
out_directory = args["out_directory"]
if out_directory[-1] != "/":
    out_directory += "/"
if not os.path.exists(out_directory):
    os.mkdir(out_directory)
## multiprocessing
n_workers = args["threads"]

# BC

In [None]:
total_start_time = time.time()

for file in os.listdir(args["input"]):

    start_time = time.time()
    
    file_path = os.path.join(args["input"], file)
    # conditional with clause
    with gzopen(file_path, "rt") if file_path.endswith(".gz") else nullcontext() as file_path:

        if file_path == None:
            file_path = os.path.join(args["input"], file)
    
        ## fastq input
        fastq = {entry.description: str(entry.seq) for entry in SeqIO.parse(file_path, "fastq")}

        
        ## info
        read_lengths = [len(x) for x in fastq.values()]
        mean_read_length = sum(read_lengths) / len(read_lengths)
        n_total = len(fastq)


        global fastq_input
        global results
        with open("tmp", "wb") as tmp:
            for item in prepare_arguments(fastq, barcodes, aligner, barcode_threshold):
                pickle.dump(item, tmp)
                tmp.write(os.linesep.encode("UTF-8"))
        del fastq

        # Main computation
        ## parallel workers
        p = mp.Pool(n_workers)
        start_time = time.time()
        ## do the calculation
        with open("tmp", "rb") as tmp:
            results = p.starmap(
                align_barcodes,
                (pickle.loads(line) for line in tmp)
            )
        
            ## combine the results
            qualities = {}
            for r in results:
                qualities.update(r)
    
            del results
        
            ## filter the alignments of every read, to only retain the best one
            filt_qualities, unbarcoded_reads, weird_reads = filter_alignments_score(qualities)
        
            # output
            ## make output dictionary
            out_dictionary = {}
            for barcode_name in list(barcodes.keys()):
                out_dictionary[barcode_name] = []
        
            ## populate output dict
            for read_name, aln_list in filt_qualities.items():
                out_dictionary[aln_list[0][2]].append([read_name, aln_list[0][3]])



    df_out = pd.DataFrame({key:[barcodes[key], len(x)] for key, x in out_dictionary.items()}, index = ["sequence", "count"]).transpose()
    # meta info
    n_recovered = df_out["count"].sum()
    

    time_taken = round(time.time() - start_time, 4)
    
    # Output
    df_out.to_csv(os.path.join(args["out_directory"], file).split(".")[0]+".BC.csv")

    
    # write log file
    log_output_file = os.path.join(args["out_directory"], file).split(".")[0]+".BC.log.txt"
    f=open(log_output_file,'w')
    f.write("File: " + file + "s\n")
    f.write("Time taken: " + str(time_taken) +  "\n")
    f.write("Total number of reads: " + str(n_total) + "\n")
    f.write("Mean sequence length: " + str(mean_read_length) + " bp\n")
    f.write("Reads recovered: " + str(n_recovered) + " (" + str(round(n_recovered/n_total*100, 2)) + "%)\n")
    #f.write("\nMean sequence quality: "+str(round(mean_quality, 2)))
    f.close()

    # terminal output
    if args["verbose"]:
        print("\nFile: " + file)
        print("Time taken: " + str(time_taken) + "s")
        print("Total number of reads: " + str(n_total))
        print("Mean sequence length: " + str(mean_read_length) + " bp")
        print("Reads recovered: " + str(n_recovered) + " (" + str(round(n_recovered/n_total*100, 2)) + "%)\n")

    del df_out, filt_qualities, unbarcoded_reads, weird_reads, qualities, out_dictionary


total_time_taken = round(time.time() - total_start_time, 4)
print("Total time taken: " + str(total_time_taken))

# PV

In [324]:
# prepare flanking regions
if use_reverse:
    Ldic = {"left": str(Seq.Seq(args["right"]).reverse_complement())}
    Cleft_threshold = right_threshold
    Rdic = {"right": str(Seq.Seq(args["left"]).reverse_complement())}
    Cright_threshold = left_threshold
else:
    Ldic = {"left": args["left"]}
    Cleft_threshold = left_threshold
    Rdic = {"right": args["right"]}
    Cright_threshold = right_threshold

In [325]:
for file in os.listdir(args["input"]):
    file_path = os.path.join(args["input"], file)
    # conditional with clause
    with gzopen(file_path, "rt") if file_path.endswith(".gz") else nullcontext() as file_path:

        if file_path == None:
            file_path = os.path.join(args["input"], file)

        ## fastq input
        global fastq_input
        fastq = {entry.description: str(entry.seq) for entry in SeqIO.parse(file_path, "fastq")}
        fastq_input = prepare_arguments(fastq, Ldic, aligner, Cleft_threshold)

        ## info
        read_lengths = [len(x) for x in fastq.values()]
        mean_read_length = sum(read_lengths) / len(read_lengths)


        # Main computation
        ## parallel workers
        p = mp.Pool(n_workers)
        start_time = time.time()
        ## do the calculation
        Lresults = p.starmap(
            align_barcodes,
            fastq_input
        )

        Lqualities = {}
        for r in Lresults:
            Lqualities.update(r)
        Lfilt_qualities, Lunbarcoded_reads, Lweird_reads = filter_alignments_score(Lqualities)

        Rresults = p.starmap(
            align_barcodes,
            fastq_input
        )

        Rqualities = {}
        for r in Rresults:
            Rqualities.update(r)
        Rfilt_qualities, Runbarcoded_reads, Rweird_reads = filter_alignments_score(Rqualities)

        # if a read has both left and right flanks found combine them together 
        # the new dictionary contains the readname as key and left and right alignment in a tuple
        reads_with_flank = combine_dict(Lfilt_qualities, Rfilt_qualities)
        reads_with_peptide = {}
        n_reads_only_1_flanking = 0
        for readname in reads_with_flank:
            # check if both flanking sequences were found
            if len(reads_with_flank[readname]) == 2:
                # get coordinates of the insert between the flanking sequences
                start_peptide = reads_with_flank[readname][0][0][1].path[1][0]
                end_peptide = reads_with_flank[readname][1][0][1].path[0][0]
                # extract peptide and record
                peptide = fastq[readname][start_peptide:end_peptide]
                if len(peptide) > 0:
                    reads_with_peptide[readname] = fastq[readname][start_peptide:end_peptide]
                else:
                    n_reads_only_1_flanking += 1
            else:
                n_reads_only_1_flanking += 1


        n_recovered = len(reads_with_peptide)
        n_total = len(fastq)
        n_correct_length = sum(len(pep) == 21 for pep in reads_with_peptide.values())

        pep_length = [len(x) for x in reads_with_peptide.values()]
        mean_pep_length = round(sum(pep_length) / len(pep_length), 3)

    
    # add up variants
    df_out = pd.DataFrame(pd.Series(reads_with_peptide.values()).value_counts())

    # output
    df_out.to_csv(os.path.join(args["out_directory"], file).split(".")[0] + ".PV.csv")

    # write log file
    log_output_file = os.path.join(args["out_directory"], file).split(".")[0]+".PV.log.txt"
    f=open(log_output_file,'w')
    f.write("File: " + file + "\n")
    f.write("Total number of reads: " + str(n_total) + "\n")
    f.write("Mean sequence length: " + str(mean_read_length) + " bp\n")
    f.write("Reads recovered: " + str(n_recovered) + " (" + str(round(n_recovered/n_total*100, 2)) + "%)\n")
    f.write("Mean peptide length: " + str(mean_pep_length) + " bp\n")
    #f.write("\nMean sequence quality: "+str(round(mean_quality, 2)))
    f.close()

    # terminal output
    if args["verbose"]:
        print("\nFile: " + file)
        print("Total number of reads: " + str(n_total))
        print("Mean sequence length: " + str(mean_read_length) + " bp")
        print("Reads recovered: " + str(n_recovered) + " ("+str(round(n_recovered/n_total*100, 2))+"%)")
        print("Mean peptide length: " + str(mean_pep_length) + " bp")


File: m1_kidney.fastq.gz
Total number of reads: 47839
Mean sequence length: 122.0 bp
Reads recovered: 45078 (94.23%)
Mean peptide length: 21.028 bp

File: m1_input.fastq.gz
Total number of reads: 47839
Mean sequence length: 122.0 bp
Reads recovered: 45033 (94.13%)
Mean peptide length: 21.039 bp

File: m1_liver.fastq.gz
Total number of reads: 47839
Mean sequence length: 122.0 bp
Reads recovered: 45175 (94.43%)
Mean peptide length: 21.039 bp

File: m1_lung.fastq.gz
Total number of reads: 47839
Mean sequence length: 122.0 bp
Reads recovered: 45117 (94.31%)
Mean peptide length: 21.04 bp
