# demultiplex

> A python program for trimming and demultiplexing nanopore reads

In [1]:
#| default_exp demultiplex

In [2]:
#| hide
from nbdev.showdoc import *
from Bio import SeqIO
from Bio import Align
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
import numpy as np

In [3]:
#| export
from Bio import SeqIO
from Bio import Align
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
import numpy as np

In [4]:
#| export
def findAlingments(seq_record, primer_dict, inward_end, max_alignments):
    "Find alignments for each primer in a sequence record"
  
    primer_keys = list(primer_dict.keys())
    
    aligner = Align.PairwiseAligner()
    aligner.match_score = 1.0
    aligner.mismatch_score = 0
    aligner.gap_score = -2
    aligner.mode = "local"

    n_sequences = len(primer_keys)

    array_cols = max_alignments + 3
    al_array = np.zeros( (n_sequences, array_cols) )

    for i in list(range(0, n_sequences, 1)):
        print(primer_keys[i])
        al = []
        seq = primer_dict[primer_keys[i]].seq        
        alignments = aligner.align(seq_record[0:inward_end], seq)
        len_alignments = len(alignments)
        if(len_alignments <= max_alignments):
            score = alignments.score
            al = [j.aligned for j in alignments]
            len_al = len(al)
            for k in range(0, len_al):
                al[k] = (al[k][0][0][1])
            al_array[i, 0:len(al)] = al # ends of each alignment
            al_array[i, -3] = max(al) # maximum posistion of each alignment
            al_array[i, -2] = len_alignments # number of alingments
            al_array[i, -1] = np.around(alignments.score/len(seq)*100, 0) # normalized local alingnment score
            
    return(al_array)

In [5]:
#| hide
primer_dict = list(
    [SeqRecord(
        Seq("CGCTCAGTTC"),
        id="barcode_1",
        name="barcode_1"),
    SeqRecord(
        Seq("TATCTGACCT"),
        id="barcode_2",
        name="barcode_2"),
    SeqRecord(
        Seq("ATATGAGACG"),
        id="barcode_3",
        name="barcode_3")]
)
primer_dict = SeqIO.to_dict(primer_dict)

seq_record = Seq("TGATGTAAGTACGCTCAGTTCGATATCGATATGAGACGGATTAGGAGGGGGCGCGATGTTGTGTGGGAAAA")
ends = findAlingments(seq_record, primer_dict, 200, 3)

print(ends)

ends == [[ 21.,   0.,   0.,  21.,   1., 100.],
 [ 37.,   0.,   0.,  37.,   1.,  60.],
 [ 38.,   0.,   0.,  38.,   1., 100.]]


barcode_1
barcode_2
barcode_3
[[ 21.   0.   0.  21.   1. 100.]
 [ 37.   0.   0.  37.   1.  60.]
 [ 38.   0.   0.  38.   1. 100.]]


array([[ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True]])

In [6]:
#| export
def align_barcodes(primer_dict, record_dict, inward_end, max_alignments):
    "Aligne all barcodes in a list of seq records"

    record_keys = list(record_dict.keys())
    n_sequences = len(record_keys)
    
    alingments = list( range(0, n_sequences) )
    for i in list(range(0, n_sequences, 1)):
        print(record_keys[i])
        seq_i = record_dict[record_keys[i]].seq
    
        alingments[i] = findAlingments(seq_i, primer_dict, inward_end, max_alignments)
    return(alingments)

In [7]:
#| hide
primer_dict = list(
    [SeqRecord(
        Seq("CGCTCAGTTC"),
        id="barcode_1",
        name="barcode_1"),
    SeqRecord(
        Seq("TATCTGACCT"),
        id="barcode_2",
        name="barcode_2"),
    SeqRecord(
        Seq("ATATGAGACG"),
        id="barcode_3",
        name="barcode_3")]
)
primer_dict = SeqIO.to_dict(primer_dict)

record_dict = list(
    [SeqRecord(
        Seq("AGTGCCCCGCGCCACGCTCAGTTCCTCCCGCGCCGCCTGCCCTGCAGCCTGCCCGCGGCGCCTTTATACCCAGCGGGCTCGCGGGCTCGCGCGCTCACTAATGTTT"),
        id="seq_1",
        name="seq_1"),
    SeqRecord(
        Seq("ATGAACCGGGGAGTCCCTTTTTATCTGACCTTTCTGGTGCTGCAACTGGCGCTCCTCCCAGCAGCCACTCAGGGAAATAAAGTGGTGCTGGGCAAAAAAGGGGATAC"),
        id="seq_2",
        name="seq_2"),
    SeqRecord(
        Seq("GCCCAGGGACAGAGGAACAATATGAGACGCAGGTTCCTTAACAGGAACATGAAGCACCCCCAGGAGGGACAGCCGCTGGAGCTGGAGTGCCTGCCTTTCAACATCG"),
        id="seq_3",
        name="seq_3")]
)
record_dict = SeqIO.to_dict(record_dict)


alginment_arrays = align_barcodes(primer_dict, record_dict, 200, 5)
alginment_arrays

seq_1
barcode_1
barcode_2
barcode_3
seq_2
barcode_1
barcode_2
barcode_3
seq_3
barcode_1
barcode_2
barcode_3


[array([[ 24.,   0.,   0.,   0.,   0.,  24.,   1., 100.],
        [ 43.,   0.,   0.,   0.,   0.,  43.,   1.,  60.],
        [ 75.,   0.,   0.,   0.,   0.,  75.,   1.,  50.]]),
 array([[ 16.,  59.,  72.,   0.,   0.,  72.,   3.,  60.],
        [ 31.,   0.,   0.,   0.,   0.,  31.,   1., 100.],
        [  8.,  29.,   0.,   0.,   0.,  29.,   2.,  60.]]),
 array([[  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [ 28.,  97.,   0.,   0.,   0.,  97.,   2.,  60.],
        [ 29.,   0.,   0.,   0.,   0.,  29.,   1., 100.]])]

In [8]:
#| hide
x = np.zeros( (1, len(alginment_arrays)) )
for i in alginment_arrays:
    print(np.sum((i[:,-1] >= 85) & (i[:,-2] == 1)))
    print(i[np.where(i[:,-1] >= 85),:])

1
[[[ 24.   0.   0.   0.   0.  24.   1. 100.]]]
1
[[[ 31.   0.   0.   0.   0.  31.   1. 100.]]]
1
[[[ 29.   0.   0.   0.   0.  29.   1. 100.]]]


In [9]:
#| export
def decide_barcode_id(alginment_arrays):
    "Decide which barcode is best hit; remove if tie"
    
    id_array = np.zeros((np.shape(alginment_arrays)[0],3), dtype=np.int64)
    for i in range(0, np.shape(alginment_arrays)[0]):
        array_i = alginment_arrays[i]
        id_i = np.where(array_i[:,-1] == np.max(array_i[:,-1]))[0]
        if len(id_i) == 1:
            id_array[i,0] = id_i
            id_array[i,1] = array_i[id_i,-3]
            id_array[i,2] = array_i[id_i,-1]
        elif len(id_i) >= 1:
            id_array[i,0] =  -1
            id_array[i,1] = 0
            id_array[i,2] = 0
        
        
    return(id_array)

In [10]:
#| hide
seq_barcode_ids = decide_barcode_id(alginment_arrays)
seq_barcode_ids

array([[  0,  24, 100],
       [  1,  31, 100],
       [  2,  29, 100]])

In [11]:
#| export
def trim_record(seq_record, primer_end_position):
    "Trim barcodes"

    x = seq_record
    x =  x[primer_end_position:]
    return(x)

In [12]:
#| hide
record_keys = list(record_dict.keys())
record_x = record_dict[record_keys[0]]
len_old = len(record_x.seq)

record_x_new = trim_record(record_x, 3)

assert len(record_x_new.seq) + 3 == len_old


In [13]:
#| export
def sort_records_to_file(record_dict, primer_dict, output_folder, alginment_arrays, input_file_type):
    "Sort records into new files based on barcodes and name files after barcodes"
    
    seq_barcode_res = decide_barcode_id(alginment_arrays)
    seq_barcode_ids = seq_barcode_res[:,0]
    seq_barcode_end_pos = seq_barcode_res[:,1]
    seq_barcode_match = seq_barcode_res[:,2]
    primer_keys = list(primer_dict.keys())
    record_keys = list(record_dict.keys())
    record_numbers = range(0, len(record_keys))

    for k in range(0, len(primer_dict)):
        seq_iterator_k = (trim_record(record_dict[record_keys[i]], seq_barcode_end_pos[i]) for i in record_numbers if seq_barcode_ids[i] == k if seq_barcode_match[i] >= 85)
        SeqIO.write(seq_iterator_k, output_folder + "/" + primer_dict[primer_keys[k]].name + "_seqs." + input_file_type, input_file_type)

In [14]:
#| hide
sort_records_to_file(record_dict, primer_dict, "test_data/test_out", alginment_arrays, "fasta")

In [15]:
#| export
def demultiplex(input_file, input_file_type, primer_file, primer_file_type, output_folder, max_distance, max_alignments):
    "Trim and demultiplex sequencing reads"
    
    print("Create barcode dictionary")
    primer_dict = SeqIO.index(primer_file, primer_file_type)
    
    print("Create sequence dictionary")
    record_dict = SeqIO.index(input_file, input_file_type)
    
    print("Align barcodes")
    alginment_arrays = align_barcodes(primer_dict, record_dict, max_distance, max_alignments)
    
    print("Sort records to file")
    sort_records_to_file(record_dict, primer_dict, output_folder, alginment_arrays, input_file_type)

In [16]:
#| hide
# demultiplex("test_data/test.fasta", "fasta", "test_data/test_primer.fasta", "fasta", "test_data/test_out", 200, 5)

In [17]:
#| hide
import nbdev; nbdev.nbdev_export()