In [1]:
#Author Kanishk Asthana kasthana@eng.ucsd.edu
import pysam
from datetime import datetime,timedelta
import numpy as np
from scipy.spatial.distance import pdist
import argparse
import os
import sys
from multiprocessing import shared_memory,Process,Manager
from multiprocessing.managers import SharedMemoryManager
#For writing output to the stdout in realtime
def write(*something):
    print(*something)
    sys.stdout.flush()

#Default Values
bamFileName="/stg3/data1/kanishk2/Temp/Sample_test.bam"
outBamFileName="/stg3/data1/kanishk2/Temp/SubCorrected_test.bam"
MIN_UMI_COUNT=20 #Minimum UMI count for Barcode to be considered for Collapse. Not counting UMIs for simplicity.Change this later because it breaks easily. #ToDO
SUBSTITUTION_ERROR_FREQ=0.85 #Minimum frequency of majority barcode in substitution pairs needed for collapse
num_cores=10 #Change this later to match same number of cores as STAR

In [2]:
%%time
'''
def parse_file(input_filename):
    if not os.path.isfile(input_filename):
        raise argparse.ArgumentTypeError("File does not exist. Please use a valid file path.")
    return(input_filename)
    
def check_umi(input_umi):
    value=int(input_umi)
    assert value>0,"Please enter positive values only for UMI Count Filter!"
    return value

parser=argparse.ArgumentParser(description="Script to detect and Correct Substitution errors in the Cell Barcodes. These errors are likely sequencing errors and not Bead Synthesis Errors")
parser.add_argument("INPUT_FILENAME",help="Aligned Merged and Gene labeled BAM file for Correcting Substitution Errors.", type=parse_file)
parser.add_argument("OUTPUT_FILENAME",help="Please enter a Valid Path for the Error Corrected output BAM file.",type=str)
parser.add_argument("-umi","--MIN_UMI", help="Minimum UMI count per barcode to be considered for Collapse. Default Value is 20.",type=check_umi)
args=parser.parse_args()
write(args)

if args.MIN_UMI is not None:
    MIN_UMI_COUNT=args.MIN_UMI

bamFileName=args.INPUT_FILENAME
outBamFileName=args.OUTPUT_FILENAME
'''
script_start_time=datetime.now()

class CellBarcode:
    CellBarcodesWithEnoughCounts=[]
    NumberofBarcodesCollapsed=0
    def __init__(self,barcode,umi):
        self.count=1 #Number of reads with that Barcode. When initializing you count the barcode you initialize with
        self.barcode=barcode
        self.barcodeToReturn=self #This will be updated with a different reference if this barcode is merged with another
        self.umi_dict={}
        self.umi_dict[umi]=1
        
    def increase_count(self,umi):
        self.count+=1 #Increase count if you see a barcode
        if umi in self.umi_dict:
            self.umi_dict[umi]+=1
        else:
            self.umi_dict[umi]=1
            
    def computeHasEnoughCounts(self):
        if len(self.umi_dict.keys())>MIN_UMI_COUNT:
            CellBarcode.CellBarcodesWithEnoughCounts.append(self.barcode)
            
    def combineIfSubstitution(self,CellBarcode2):
        largerBarcode=self.barcodeToReturn
        smallerBarcode=CellBarcode2.barcodeToReturn
        if smallerBarcode.count>largerBarcode.count:
            largerBarcode=CellBarcode2.barcodeToReturn
            smallerBarcode=self.barcodeToReturn
        freq=largerBarcode.count/(largerBarcode.count+smallerBarcode.count)
        if freq>SUBSTITUTION_ERROR_FREQ:
            smallerBarcode.barcodeToReturn=largerBarcode
            largerBarcode.count+=smallerBarcode.count
            CellBarcode.NumberofBarcodesCollapsed+=1

barcode_dict={}

#If you get an error here your file probably not correctly formated. Make sure you have a header.
bamFile=pysam.AlignmentFile(bamFileName,"rb")
BamRecords=bamFile.fetch(until_eof=True)

start_time=datetime.now()
prevMil=start_time
write("Started Processing BAM file at",start_time,". Getting Cell Barcodes to correct Illumina Sequencing Base Substitution Errors!")

total_records=0
for record in BamRecords:
    

    #For printing progress
    total_records+=1
    if total_records%1000000==0:
        time_taken=datetime.now()-prevMil
        write("Finished processing ",total_records,"\trecords at",datetime.now(),". Previous 1000000 Records took ",time_taken.total_seconds(),"s")
        prevMil=datetime.now()
    

    #Main Logic
    cell_barcode=record.get_tag('XC')
    umi=record.get_tag('XM')
    
    if cell_barcode in barcode_dict:
        barcode_dict[cell_barcode].increase_count(umi)
    else:
        barcode_dict[cell_barcode]=CellBarcode(cell_barcode,umi)


total_time=datetime.now()-start_time
write("Finished processing BAM file at ",datetime.now(),". Total time taken ",total_time)

bamFile.close()

for barcode in barcode_dict.keys():
    barcode_dict[barcode].computeHasEnoughCounts()

write(len(CellBarcode.CellBarcodesWithEnoughCounts),"Cell barcodes have enough UMIs for further processing.")

Started Processing BAM file at 2022-07-24 16:47:32.334130 . Getting Cell Barcodes to correct Illumina Sequencing Base Substitution Errors!
Finished processing BAM file at  2022-07-24 16:47:35.113043 . Total time taken  0:00:02.778906
3384 Cell barcodes have enough UMIs for further processing.
CPU times: user 2.77 s, sys: 53.1 ms, total: 2.82 s
Wall time: 2.82 s


In [3]:
%%time
def hammingDistance(barcode1,barcode2):
    distance=0
    for i in range(0,len(barcode1)):
        if barcode1[i]!=barcode2[i]:
            distance+=1
    return distance

write("Collapsing Barcodes at a Hamming Distance of 1 and with the Major Barcode present above frequency",SUBSTITUTION_ERROR_FREQ)
indices=np.triu_indices(len(CellBarcode.CellBarcodesWithEnoughCounts),1)#Use diagnal offset=1. We don't want distance of barcodes with themself
for i in range(0,len(indices[0])):
    firstBarcode=CellBarcode.CellBarcodesWithEnoughCounts[indices[0][i]]
    secondBarcode=CellBarcode.CellBarcodesWithEnoughCounts[indices[1][i]]
    if hammingDistance(firstBarcode,secondBarcode)==1:
        barcode_dict[firstBarcode].combineIfSubstitution(barcode_dict[secondBarcode])

write(CellBarcode.NumberofBarcodesCollapsed,"barcodes pairs were collapsed!")

Collapsing Barcodes at a Hamming Distance of 1 and with the Major Barcode present above frequency 0.85
145 barcodes pairs were collapsed!
CPU times: user 7.51 s, sys: 17.8 ms, total: 7.53 s
Wall time: 7.53 s


In [14]:
indices=np.triu_indices(len(CellBarcode.CellBarcodesWithEnoughCounts),1)

In [15]:
lst1=[]
for i in range(0,len(indices[0])):
    lst1.append((indices[0][i],indices[1][i]))

In [16]:
vector_length=len(CellBarcode.CellBarcodesWithEnoughCounts)
lst2=[]
for i in range(0,vector_length-1):
    for j in range(i+1,vector_length):
        lst2.append((i,j))

In [17]:
len(lst1)

5724036

In [18]:
len(lst2)

5724036

In [19]:
lst1==lst2

True

In [3]:
%%time
def hammingDistance(barcode1,barcode2):
    distance=0
    for i in range(0,len(barcode1)):
        if barcode1[i]!=barcode2[i]:
            distance+=1
    return distance

vector_length=len(CellBarcode.CellBarcodesWithEnoughCounts)
write("Collapsing Barcodes at a Hamming Distance of 1 and with the Major Barcode present above frequency",SUBSTITUTION_ERROR_FREQ)
#Calculating Distances for the upper triangle of the distance matrix.
for i in range(0,vector_length-1):
    for j in range(i+1,vector_length):
        firstBarcode=CellBarcode.CellBarcodesWithEnoughCounts[i]
        secondBarcode=CellBarcode.CellBarcodesWithEnoughCounts[j]
        if hammingDistance(firstBarcode,secondBarcode)==1:
            barcode_dict[firstBarcode].combineIfSubstitution(barcode_dict[secondBarcode])

write(CellBarcode.NumberofBarcodesCollapsed,"barcodes pairs were collapsed!")

Collapsing Barcodes at a Hamming Distance of 1 and with the Major Barcode present above frequency 0.85
145 barcodes pairs were collapsed!
CPU times: user 6.43 s, sys: 2.87 ms, total: 6.43 s
Wall time: 6.43 s


In [8]:
%%time
write("Collapsing Barcodes at a Hamming Distance of 1 and with the Major Barcode present above frequency",SUBSTITUTION_ERROR_FREQ)
indices=np.triu_indices(len(CellBarcode.CellBarcodesWithEnoughCounts),1)#Use diagnal offset=1. We don't want distance of barcodes with themself
indices=np.array(indices)
shm = shared_memory.SharedMemory(create=True, size=indices.nbytes)
indices_shared=np.ndarray(indices.shape, dtype=indices.dtype, buffer=shm.buf)
indices_shared[:]=indices[:]

smm = SharedMemoryManager()
smm.start()
CellBarcodesWithEnoughCounts_shared=smm.ShareableList(CellBarcode.CellBarcodesWithEnoughCounts)

#Generating Indices for Iterating over for Calculating Hamming Distance. This will help split the matrix into equal chunks to compute on.
chunk_size=int(len(indices[0])/num_cores)
chunk_lst=[]
for i in range(0,num_cores-1):
    if i!=0:
        chunk_lst.append((i*chunk_size+1,(i+1)*chunk_size))
    else:
        chunk_lst.append((0,chunk_size))
chunk_lst.append((chunk_size*(num_cores-1)+1,len(indices[0])-1))
print("Spliting Compressed Hamming Distance Matrix computation over multiple CPUs by dividing into chunks:")
print(chunk_lst)

def ParallelHamming(chunk_tuple,Barcodes,indices,results):
    start,end=chunk_tuple
    print(start,end)
    print(len(Barcodes))
    print(indices.shape)
    for i in range(start,end+1):
        distance=0
        firstBarcode=Barcodes[indices[0][i]]
        secondBarcode=Barcodes[indices[1][i]]
        for j in range(0,len(firstBarcode)):
            if firstBarcode[j]!=secondBarcode[j]:
                distance+=1
        if distance==1:
            results.append((indices[0][i],indices[1],[i]))

#Start Separate Processes on Chunks            
result_indices_from_cpus=[]
process_for_cpus=[]
for i in range(0,num_cores):
    results_for_cpu=Manager().list()
    result_indices_from_cpus.append(results_for_cpu)
    pr=Process(target=ParallelHamming,args=(chunk_lst[i],CellBarcodesWithEnoughCounts_shared,indices_shared,results_for_cpu))
    process_for_cpus.append(pr)
    pr.start()
    
for i in range(0,num_cores):
    process_for_cpus[i].join()

Collapsing Barcodes at a Hamming Distance of 1 and with the Major Barcode present above frequency 0.85
Spliting Compressed Hamming Distance Matrix computation over multiple CPUs by dividing into chunks:
[(0, 572403), (572404, 1144806), (1144807, 1717209), (1717210, 2289612), (2289613, 2862015), (2862016, 3434418), (3434419, 4006821), (4006822, 4579224), (4579225, 5151627), (5151628, 5724035)]
0 572403
3384
(2, 5724036)
572404 1144806
3384
(2, 5724036)
1144807 1717209
3384
(2, 5724036)
1717210 2289612
3384
(2, 5724036)
2289613 2862015
3384
(2, 5724036)
2862016 3434418
3384
(2, 5724036)
3434419 4006821
3384
(2, 5724036)
4006822 4579224
3384
(2, 5724036)
4579225 5151627
3384
(2, 5724036)
5151628 5724035
3384
(2, 5724036)


Process Process-87:
Traceback (most recent call last):
  File "/stg1/data2/kanishk/anaconda3/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/stg1/data2/kanishk/anaconda3/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "<timed exec>", line 31, in ParallelHamming
  File "/stg1/data2/kanishk/anaconda3/lib/python3.8/multiprocessing/shared_memory.py", line 421, in __getitem__
    back_transform = self._get_back_transform(position)
  File "/stg1/data2/kanishk/anaconda3/lib/python3.8/multiprocessing/shared_memory.py", line 380, in _get_back_transform
    self._offset_back_transform_codes + position
  File "/stg1/data2/kanishk/anaconda3/lib/python3.8/multiprocessing/shared_memory.py", line 497, in _offset_back_transform_codes
    return self._offset_packing_formats + self._list_len * 8
  File "/stg1/data2/kanishk/anaconda3/lib/python3.8/multiprocessing/shared_memory.py", line 493, in _of

KeyboardInterrupt: 

In [4]:
#Combine the Chunks into a single list that will be iterated on to combined the barcodes
final_result=[]
for result in result_indices_from_cpus:
    for index_tuple in result:
        final_result.append(index_tuple)

#Iterate on Indices Tuples to combine barcodes.
for indices_tuple in final_result:
    firstIndex,secondIndex=indices_tuple
    firstBarcode=CellBarcode.CellBarcodesWithEnoughCounts[firstIndex]
    secondBarcode=CellBarcode.CellBarcodesWithEnoughCounts[secondIndex]
    barcode_dict[firstBarcode].combineIfSubstitution(barcode_dict[secondBarcode])

write(CellBarcode.NumberofBarcodesCollapsed,"barcodes pairs were collapsed!")

ValueError: too many values to unpack (expected 2)

In [4]:
smm.shutdown()
shm.unlink()

In [3]:
%%time
def hammingDistance(barcode1,barcode2):
    distance=np.uint8(0)
    barcode1=barcode1[0]
    barcode2=barcode2[0]
    for i in range(0,len(barcode1)):
        if barcode1[i]!=barcode2[i]:
            distance+=np.uint8(1)
    return distance

write("Calculating Pairwise Hamming Distances between Barcodes! This may take some time. O(n^2) operation")
barcodes_array=np.array(CellBarcode.CellBarcodesWithEnoughCounts)
barcodes_array=barcodes_array.reshape(-1,1)
compressed_distances=pdist(barcodes_array,hammingDistance)

write("Collapsing Barcodes at a Hamming Distance of 1 and with the Major Barcode present above frequency",SUBSTITUTION_ERROR_FREQ)
#Getting Truth vector for all Barcode pairs with Hamming Distance of 1
truth_vector=compressed_distances==1
#Getting Indices for Barcodes pairs in the Truth Vector
indices=np.triu_indices(len(CellBarcode.CellBarcodesWithEnoughCounts),1)#Use diagnal offset=1. We don't want distance of barcodes with themself
for i in range(0,len(truth_vector)):
    if truth_vector[i]:
        firstBarcode=CellBarcode.CellBarcodesWithEnoughCounts[indices[0][i]]
        secondBarcode=CellBarcode.CellBarcodesWithEnoughCounts[indices[1][i]]
        barcode_dict[firstBarcode].combineIfSubstitution(barcode_dict[secondBarcode])

write(CellBarcode.NumberofBarcodesCollapsed,"barcodes pairs were collapsed!")



Calculating Pairwise Hamming Distances between Barcodes! This may take some time. O(n^2) operation
Collapsing Barcodes at a Hamming Distance of 1 and with the Major Barcode present above frequency 0.85
145 barcodes pairs were collapsed!
CPU times: user 28.6 s, sys: 16.9 ms, total: 28.7 s
Wall time: 28.7 s


In [3]:
#Writing to new File with Updated Barcodes
bamFile=pysam.AlignmentFile(bamFileName,"rb")
BamRecords=bamFile.fetch(until_eof=True)
outBamFile=pysam.AlignmentFile(outBamFileName,"wb",template=bamFile)

start_time=datetime.now()
prevMil=start_time
write("Started Writing Updated BAM file at",start_time,".")

total_records=0
for record in BamRecords:    

    #For printing progress
    total_records+=1
    if total_records%1000000==0:
        time_taken=datetime.now()-prevMil
        write("Finished processing ",total_records,"\trecords at",datetime.now(),". Previous 1000000 Records took ",time_taken.total_seconds(),"s")
        prevMil=datetime.now()
    
    #Main Logic
    cell_barcode=record.get_tag('XC')
    updated_barcode=barcode_dict[cell_barcode].barcodeToReturn.barcode #Getting Updated Barcode
    record.set_tag('XC',updated_barcode)
    outBamFile.write(record)

total_time=datetime.now()-start_time
write("Finished processing BAM file at ",datetime.now(),". Total time taken ",total_time)

bamFile.close()
outBamFile.close()

write("Total Execution Time:",datetime.now()-script_start_time)

Started Writing Updated BAM file at 2022-07-23 20:03:41.490016 .
Finished processing BAM file at  2022-07-23 20:03:52.275324 . Total time taken  0:00:10.785258
Total Execution Time: 0:01:25.270978
