In [None]:
# Import modules
import pysam
import numpy as np
import matplotlib.pyplot as plt
import math
import itertools
from dataclasses import dataclass
from collections import defaultdict

In [None]:
def cigar_compute(read):
        # Input an primary alignment
        # Output 
        eq , X, I, D, align_qual, M , S,H = 0,0,0,0, 0, 0 , 0,0
        for op, length in read.cigartuples:
                if op == 7:
                        eq += length
                elif op == 8:
                        X += length
                elif op == 1:
                        I += length
                elif op == 2:
                        D += length
                elif op == 0:
                        M += length
                elif op == 4:
                        S += length
                elif op == 5:
                        H += length
                else:
                        continue
        qleng = M+ I + S + eq + X
        align_qual = round(float(eq)/(eq+X + I + D),2)
        return align_qual, qleng, S, H


# This part contains some plots of the original data. From these observations, we will decide to exclude or include some reads.

In [None]:
import pysam

file = pysam.AlignmentFile("new_map_sorted.bam","rb")
read_dict = {}
for read in file.fetch():
    if read.is_unmapped or read.is_supplementary or read.is_secondary:
        continue
    else:
        # It is primary alignment
        # Do analyze here
        CIGAR_qual, qleng, Soft,Hard = cigar_compute(read)
        bacteria = read.reference_name[0:2]
        if bacteria not in read_dict:
            read_dict[bacteria] = 1
        else:
            read_dict[bacteria] += 1

xpoint = []
ypoint = []
for bacteria in read_dict:
    xpoint += [bacteria]
    ypoint += [read_dict[bacteria]]
    print('The {} has {} alignments.'.format(bacteria , read_dict[bacteria]))

plt.title('Distribution of bacteria: primary alignment')
plt.plot(xpoint,ypoint,'o')
plt.ylabel('Number of alignments')
plt.xlabel('Bacteria names')
plt.savefig('Bacteria_distribution.png')
#plt.show()


In [None]:
count = 1
qleng_distribution = {}
len_min = float('inf')
len_max = float('-inf')
for read in file.fetch():
    if count <3900000:
        if read.is_unmapped or read.is_supplementary or read.is_secondary:
            continue
        else:
            # It is primary alignment
            # Do analyze here
            qleng = cigar_compute(read)[1]
            if qleng < len_min:
                len_min = qleng
            if qleng > len_max:
                len_max = qleng
            bacteria = read.reference_name[0:2]
            if bacteria not in qleng_distribution:
                qleng_distribution[bacteria] = [qleng]
            else:
                qleng_distribution[bacteria] += [qleng]
        if count%10000 == 0:
            print('Pass 100000K')
    else:
        break
    count+=1

plt.savefig('Bacteria_distribution.png')
#print('Done')

In [None]:
def find_median(sorted_list):
    indices = []

    list_size = len(sorted_list)
    median = 0

    if list_size % 2 == 0:
        indices.append(int(list_size / 2.0) - 1)  # -1 because index starts from 0
        indices.append(int(list_size / 2.0))

        median = (sorted_list[indices[0]] + sorted_list[indices[1]]) / 2.0
        pass
    else:
        indices.append(int(list_size / 2.0))

        median = sorted_list[indices[0]]
        pass

    return median, indices
    pass

In [None]:
bac_key = ['lf','lm','cn','ef','ec','pa','bs','sc','sa','se']
Q1_plot = []
med_plot = []
Q3_plot = []
for bac in bac_key:
    samples = qleng_distribution[bac]
    samples.sort()
    median, median_indices = find_median(samples)
    Q1, Q1_indices = find_median(samples[:median_indices[0]])
    Q2, Q2_indices = find_median(samples[median_indices[-1] + 1:])

    quartiles = [Q1, median, Q2]
    Q1_plot += [Q1]
    med_plot += [median]
    Q3_plot += [Q2]
    print("For bacteria {}, the quartiles are (Q1, median, Q3): {}".format(bac,quartiles))
    
plt.xlabel('Bacteria')
plt.ylabel('Alignment length')
plt.title('Quartiles')
plt.plot(bac_key,Q1_plot,'x', label = 'Q1')
plt.plot(bac_key,med_plot,'o', label = 'Med')
plt.plot(bac_key,Q3_plot,'x', label = 'Q3')
plt.legend()
plt.savefig('Alignment_quartiles.png')
plt.show()

In [None]:
bac_key = ['lf','lm','cn','ef','ec','pa','bs','sc','sa','se']
fig, axi = plt.subplots(nrows=10, ncols=1, figsize = (24,14), tight_layout = True, sharey = True)

for i in range(len(bac_key)):
    # Add subplots
    axi[i].hist(x=qleng_distribution[bac_key[i]], bins=np.linspace(len_min, len_max,50), color='r',alpha=0.9, rwidth=0.85)
    # Show texts
    axi[i].text(0.5, 0.5, bac_key[i], horizontalalignment='center', verticalalignment='center', transform=axi[i].transAxes)
    #axi[1].text(0.5, 0.5, 'lm', horizontalalignment='center', verticalalignment='center', transform=axi[1].transAxes)

plt.ylabel('Number of alignments')
plt.xlabel('Length of alignments')
plt.savefig('Bacteria_distribution_alignment_length.png')
plt.show()

In [None]:
n, bins, patches = plt.hist(x=qleng_distribution['lm'], bins=np.linspace(len_min, len_max,100), color='r',
                            alpha=0.9, rwidth=0.85)
plt.grid(axis='y', alpha=0.75)
plt.xlabel('Alignment Length')
plt.ylabel('Number of alignments')
plt.title('Histogram of {}'.format('lm'))
# Set a clean upper y-axis limit.
plt.ylim(ymax=500000)
plt.savefig('Bacteria_distribution_alignment_length.png')

In [None]:
b = [i/100.0 for i in range(90,100)]
plt.hist(x=align_qual_dist['lf'], bins=b, color='r',
                          alpha=0.9, rwidth=0.85)

plt.show()

In [None]:
bac_key = ['lf','lm','cn','ef','ec','pa','bs','sc','sa','se']
fig, axi = plt.subplots(nrows=10, ncols=1, figsize = (24,20), sharey = True)
b = [i/100.0 for i in range(90,100)]
for i in range(len(bac_key)):
    # Add subplots
    axi[i].hist(x=align_qual_dist[bac_key[i]], bins=b, color='r',alpha=0.9, rwidth=0.85)
    # Show texts
    axi[i].text(0.2, 0.5, bac_key[i], horizontalalignment='center', verticalalignment='center', transform=axi[i].transAxes)
    
plt.ylabel('Number of alignments')
plt.xlabel('Alignment CIGAR quality')
plt.savefig('Bacteria_distribution_CIGAR_quality.png')
plt.show()

In [None]:
n, bins, patches = plt.hist(x=qlength, bins='auto', color='r',
                            alpha=0.9, rwidth=0.85)
plt.grid(axis='y', alpha=0.75)
plt.xlabel('Alignment Length')
plt.ylabel('Number of alignments')
plt.title('Bacteria lf in range 1000')
maxfreq = n.max()
# Set a clean upper y-axis limit.
plt.ylim(ymax=np.ceil(maxfreq / 10) * 10 if maxfreq % 10 else maxfreq + 10)

In [None]:
file = pysam.AlignmentFile("new_map_sorted.bam","rb")
cq = 1
soft_dict = {}
for read in file.fetch():
    if cq < 4000000:
        if prim[read.qname] == AlignInfo(True,False):
            # Do analyze here
            Cigar_qual, qleng, soft, hard = cigar_compute(read)
            bacteria = read.reference_name[0:2]
            if bacteria not in soft_dict:
                soft_dict[bacteria] = [soft]
            else:
                soft_dict[bacteria] += [soft]

        if cq%100000 == 0:
            print('Pass 100K')
    else:
        break
    cq+=1

plt.savefig('Bacteria_distribution.png')

In [None]:
cq = 1
map_qual = {}
check = 0
for read in file.fetch():
    if cq < 4000000:
        if read.is_unmapped or read.is_supplementary or read.is_secondary:
            continue
        else:
            # It is primary alignment
            # Do analyze here
            mqual = read.mapping_quality
            if mqual != 60:
                check += 1
            bacteria = read.reference_name[0:2]
            if bacteria not in map_qual:
                map_qual[bacteria] = [read.mapping_quality]
            else:
                map_qual[bacteria] += [read.mapping_quality]

        if cq%100000 == 0:
            print('Pass 100K')
    else:
        break
    cq+=1

plt.savefig('Bacteria_distribution.png')

In [None]:
bac_key = ['lf','lm','cn','ef','ec','pa','bs','sc','sa','se']
fig, axi = plt.subplots(nrows=10, ncols=1, figsize = (24,20), sharey = True)
for i in range(len(bac_key)):
    # Add subplots
    axi[i].hist(x=map_qual[bac_key[i]], bins='auto', color='r',alpha=0.9, rwidth=0.85)
    # Show texts
    axi[i].text(0.2, 0.5, bac_key[i], horizontalalignment='center', verticalalignment='center', transform=axi[i].transAxes)
    
plt.ylabel('Number of alignments')
plt.xlabel('Mapping quality')
plt.savefig('Bacteria_distribution_mapqual.png')
plt.show()

In [None]:
file = pysam.AlignmentFile("new_map_sorted.bam","rb")
cq = 1
num_read = 0
truth_distri = {}

for read in file.fetch():
    if cq < 4000000:
        if read.is_unmapped or read.is_supplementary or read.is_secondary:
            # Discard unmapped, secondary or supplementary
            continue
        else:
            # It is primary alignment
            # Do analyze here
            # Discard lf, sc, and cn
            bacteria = read.reference_name[0:2]
            if bacteria == 'lf' or bacteria == 'sc' or bacteria=='cn':
                continue
            else:
                # Check alignment length
                qleng = cigar_compute(read)[1]
                if qleng < 3000 or qleng > 7000:
                    continue
                else:
                    # CIGAR quality
                    align_qual = cigar_compute(read)[0]
                    if align_qual < 0.94:
                        continue
                    else:
                        if read.mapping_quality != 60:
                            continue
                        else:
                            # Finish all filter
                            num_read += 1
                            if bacteria not in truth_distri:
                                truth_distri[bacteria] = 1
                            else:
                                truth_distri[bacteria] += 1
        if cq%100000 == 0:
            print('Pass 100K')
    else:
        break
    cq+=1

plt.savefig('Bacteria_distribution.png')
print(num_read)


# After restricting some parameters, we output the data to txt file, which will be used for the model.

In [None]:
@dataclass
class AlignInfo:
    """Class for keeping track of primary and other if available."""
    primary: bool
    other: bool
        
@dataclass
class Overlap:
    current: str
    after : str

In [None]:
file = pysam.AlignmentFile("new_map_sorted.bam","rb")

prim = defaultdict(lambda: AlignInfo(False,False))

for read in file.fetch():
    qname = read.qname
    if read.is_unmapped == False and read.is_supplementary == False and read.is_secondary == False:
        prim[qname].primary = True
    else:
        prim[qname].other = True

In [None]:
bac_key = ['lm','ef','ec','pa','bs','sa','se']
num_read = 0
check = 0
name_ref = []
for read in file.fetch():
    if check < 4000000:
        if prim[read.qname] == AlignInfo(True, False):
            bacteria = read.reference_name[0:2]
            if bacteria in bac_key:
                #print(qname, bacteria)
                align_qual, qleng, soft, hard = cigar_compute(read)
                if qleng >= 3000 and qleng <=7000:
                    if align_qual >= 0.94:
                        if read.mapping_quality == 60:
                            num_read += 1
                            name_ref += [[read.qname,read.rname, read.reference_start, read.reference_end,read.query_alignment_sequence,bacteria]]
    else:
        break
    check += 1
print(num_read)

In [None]:
name_ref_sorted = sorted(name_ref, key=lambda tup: (tup[1],tup[2],tup[3]))

In [None]:
def percen(l1,l2):
    #both of length 4, but we focus on the last 2 coordinates
    #Sorted, so same rname and l1[2] < l2[2]
    #Already check intersection. So no need to check here
    if l2[3] < l1[3]:
        # l1 is longer. Overlap is the whole l2
        return (l2[3]-l2[2])/(l1[3]-l1[2])
    else:
        # Overlap is l1[3]-l2[2]
        # Need to compare length
        if l2[3]-l2[2] < l1[3]-l1[2]: # list2 < list1
            return (l1[3]-l2[2])/(l1[3]-l1[2])
        else:
            return (l1[3]-l2[2])/(l2[3]-l2[2])

In [None]:
import random
ran_pick = random.sample([i for i in range(670951)], 60000)
with open('train60k.txt', 'w') as training:
    for pick in ran_pick:
        see_list = see[pick].split('\t')
        seq = see_list[1]
        bac = see_list[2]
        over = see_list[3]
        training.write(str(pick)+ '/t' + seq + '/t' + bac + '/t' + over)
        training.write('\n')