In [1]:
import pysam
import os
import sys
from sys import getsizeof
import time

directory_path = os.path.abspath(os.path.join('../src/'))
if directory_path not in sys.path:
    sys.path.append(directory_path)
    
from read_process import get_contig_lengths_dict,\
incorporate_replaced_pos_info,incorporate_insertions_and_deletions,\
get_positions_from_md_tag,reverse_complement,get_edit_information,get_edit_information_wrapper,\
has_edits,get_total_coverage_for_contig_at_position,\
print_read_info, update_coverage_array, get_read_information, get_hamming_distance

from utils import get_intervals, index_bam, write_rows_to_info_file, write_header_to_bam, \
write_read_to_bam_file, remove_file_if_exists, make_folder

import os, psutil


# ~~~~~~~~~~~~~~~~~~
# Multi-processing enabled
# ~~~~~~~~~~~~~~~~~~

# An example on a bam for 1500 cell barcodes (group0, group1, group2, group3, group4, group5, group6, group7 group8, group9, group10, group11 split from the original bam)

### Should be about 1500*30,000 = 45 million reads

#### in 10X's bam file, xf=25 means that read is uniquely mapped to a genome, and was used for counting UMI. So we should only look at reads with xf=25 from the 10X bam.

In [2]:
bampath = '/projects/ps-yeolab3/ekofman/sailor2/data/groups_0_1_2_3_4_5_6_7_8_9_10_11_merged.bam'
samfile = pysam.AlignmentFile(bampath, "rb")

In [3]:
samfile_header = str(samfile.header)

In [4]:
getsizeof(samfile_header)/1000

19338.323

# Helper functions

In [7]:
def find_edits(bampath, contig, split_index, start, end, output_folder, verbose=False):
    time_reporting = {}
    start_time = time.perf_counter()
    
    samfile = pysam.AlignmentFile(bampath, "rb")
        
    counts = defaultdict(lambda:defaultdict(lambda:0))
    total_reads = 0
    
    bam_handles_for_barcodes = {}
    read_lists_for_barcodes = defaultdict(lambda:[])
    
    reads_for_contig = samfile.fetch(contig, start, end, multiple_iterators=True)

    output_file = '{}/{}_{}_{}_{}_edit_info.tsv'.format(edit_info_subfolder, contig, split_index, start, end)
    remove_file_if_exists(output_file)

    with open(output_file, 'w') as f:        
        write_header_to_bam(f)

        for i, read in enumerate(reads_for_contig):
            total_reads += 1
            
            if total_reads % 1000 == 0:
                time_reporting[total_reads] = time.perf_counter() - start_time

            barcode = read.get_tag("CB")
            barcodes[contig][barcode] += 1

            verbose = False
            
            # This is to ID the weird read with Ns
            if 'AAACCCAAGAACTTCC-1' == barcode and split_index == '005' and contig == '17':
                if read.query_name == 'A01535:287:H3JJHDSX7:1:2176:29188:6324':
                    verbose = True
            
            error_code, list_of_rows, num_edits_of_each_type = get_read_information(read, contig, verbose=verbose)

                
            if error_code:
                counts[contig][error_code] += 1
            else:
                counts[contig][EDITED_CODE] += 1
                write_rows_to_info_file(list_of_rows, f)
            
            # Store each read using its string representation
            read_as_string = read.to_string()
            
            if 'AAACCCAAGAACTTCC-1' == barcode and split_index == '005':
                if read.query_name == 'A01535:287:H3JJHDSX7:1:2176:29188:6324':
                    print("\n\n~~~ Found read, is {}\nSplit index is {}; edit types is {}; \n{}~~~".format(read_as_string, 
                                                                                                     split_index,
                                                                                                     num_edits_of_each_type,
                                                                                                     list_of_rows
                                                                                                    ))
                    print("Read as string:\n{}".format(read_as_string))
                    
            read_lists_for_barcodes[barcode].append(read_as_string)
            
    
    # Add all reads to dictionary for contig and barcode, in their string representation
    num_barcodes = 0
    total_bams = len(read_lists_for_barcodes)
    
    
    barcode_to_concatted_reads = {}
    for barcode, read_list in read_lists_for_barcodes.items():
        num_barcodes += 1
        if num_barcodes % 100 == 0:
            #print('{}/{} processed'.format(num_barcodes, total_bams))
            pass
        # Concatenate the string representations of all reads for each bam-contig combination
        all_reads_concatted = '\n'.join(read_list)
            
        # Save this concatenated block of text to dictionary
        barcode_to_concatted_reads[barcode] = all_reads_concatted
        
    time_reporting[total_reads] = time.perf_counter() - start_time
    
    samfile.close()
    
    return barcode_to_concatted_reads, total_reads, barcodes, counts, time_reporting


def find_edits_and_split_bams(bampath, contig, split_index, start, end, output_folder, verbose=False):
    barcode_to_concatted_reads, total_reads, barcodes, counts, time_reporting = find_edits(bampath, contig, split_index,
                                                                         start, end, output_folder, verbose=verbose)    
    return barcode_to_concatted_reads, total_reads, barcodes, counts, time_reporting
    
def find_edits_and_split_bams_wrapper(parameters):
    try:
        start_time = time.perf_counter()
        bampath, contig, split_index, start, end, output_folder, verbose = parameters
        label = '{}({}):{}-{}'.format(contig, split_index, start, end)

        #print("{} ({}):{}-{}\tfind_edits_and_split_bams".format(contig, split_index, start, end))
        barcode_to_concatted_reads, total_reads, barcodes, counts, time_reporting = find_edits_and_split_bams(bampath, contig, split_index, start, end, output_folder, verbose=False)
        barcodes_df = pd.DataFrame.from_dict(barcodes)
        counts_df = pd.DataFrame.from_dict(counts)
        time_df = pd.DataFrame.from_dict(time_reporting, orient='index')
        barcode_to_concatted_reads_df = pd.DataFrame.from_dict(barcode_to_concatted_reads, orient='index')
        
        total_time = time.perf_counter() - start_time
        return barcode_to_concatted_reads_df, total_reads, barcodes_df, label, counts_df, time_df, total_time
    except Exception as e:
        print('Contig {}: {}'.format(label, e))
        return 0, pd.DataFrame(), label, pd.DataFrame()

# Go through every read and identify all edits

In [8]:
from collections import defaultdict
import pandas as pd
from matplotlib import pyplot as plt
import numpy as np
import time
from multiprocessing import Pool
import multiprocessing
from tqdm import tqdm

start_time = time.perf_counter()

print("CPU count: {}".format(multiprocessing.cpu_count()))

output_folder = '/projects/ps-yeolab3/ekofman/sailor2/scripts/check_against_pileup_all_cells_threaded_outs_bigger'

contig_lengths_dict = get_contig_lengths_dict(samfile)

# Print info?
verbose = False 
EDITED_CODE = 'edited'

# How many subcontigs to split each contig into to leverage multi-processing
num_intervals = 16

num_reads_to_coverage_dict_kb = {}
num_reads_to_seconds = {}


start_time = time.perf_counter()
total_seconds_for_reads = {0: 1}

barcodes = defaultdict(lambda:defaultdict(lambda:0))

jobs = []
for contig in contig_lengths_dict.keys():
    # Skip useless contigs
    if len(contig) > 5 or contig == 'Stamp' or contig != '17':
        continue
        
    print("Contig {}".format(contig))
    contig_length = contig_lengths_dict.get(contig)
    intervals_for_contig = get_intervals(contig, contig_lengths_dict, num_intervals)
    
    # Make subfolder in which to information about edits
    edit_info_subfolder = '{}/edit_info'.format(output_folder)
    make_folder(edit_info_subfolder)
        
    # Set up for pool
    for split_index, interval in enumerate(intervals_for_contig):
        split_index = str(split_index).zfill(3)
        parameters = [bampath, contig, split_index, interval[0], interval[1], output_folder, verbose]
        jobs.append(parameters)
    
print("{} total jobs".format(len(jobs)))

# Pooling
results = []
overall_total_reads = 0
with Pool(processes=16) as p:
    max_ = len(jobs)
    with tqdm(total=max_) as pbar:
        for _ in p.imap_unordered(find_edits_and_split_bams_wrapper, jobs):
            pbar.update()
            results.append(_)
            
            total_reads = _[1]
            total_time = time.perf_counter() - start_time

            overall_total_reads += total_reads
            total_seconds_for_reads[overall_total_reads] = total_time

overall_time = time.perf_counter() - start_time 


CPU count: 36
Contig 17
16 total jobs


 56%|█████▋    | 9/16 [00:01<00:01,  5.71it/s]

~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Read ID: A01535:287:H3JJHDSX7:1:2176:29188:6324
----------------------------
MD tag: 

 62%|██████▎   | 10/16 [00:02<00:01,  4.99it/s]

89C11
CIGAR string 4M102N92M90N5M
Reference seq: 

 75%|███████▌  | 12/16 [00:02<00:00,  6.52it/s]

AAGGCCCCAAAAGTGGCGCAGCCCTCTATGGGCTCGAATTTTCTTCAGCCTCTCCAGGTCCTCACGCAGCTTGTTGTCTAGACCGTTGGCCAGAACCTGGC
Aligned seq: AAGGCCCCAAAAGTGGCGCAGCCCTCTATGGGCTCGAATTTTCTTCAGCCTCTCCAGGTCCTCACGCAGCTTGTTGTCTAGACCGTTGGTCAGAACCTGGC
['89', '11']
[89, 101]
CIGAR tuples [(0, 4), (3, 102), (0, 92), (3, 90), (0, 5)]
AAGGNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNCCCCAAAAGTGGCGCAGCCCTCTATGGGCTCGAATTTTCTTCAGCCTCTCCAGGTCCTCACGCAGCTTGTTGTCTAGACCGTTGGTCAGAACNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNCTGGC
aaggccccaaaagtggcgcagccctctatgggctcgaattttcttcagcctctccaggtcctcacgcagcttgttgtctagaccgttggCcagaacctggc
aaggnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnNnnnnnnnnnnnNnnnnccccaaaagtggcgcagccctctatgggctcgaattttcttcagcctctccaggtcctcacgcagcttgttgtctagaccgttggtcagaacnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnctggc
alt bases ['

100%|██████████| 16/16 [00:20<00:00,  1.27s/it]


In [12]:
read_as_string = 'A01535:287:H3JJHDSX7:1:2176:29188:6324	16	17	33952115	255	4M102N92M90N5M	*	0	0	AAGGCCCCAAAAGTGGCGCAGCCCTCTATGGGCTCGAATTTTCTTCAGCCTCTCCAGGTCCTCACGCAGCTTGTTGTCTAGACCGTTGGTCAGAACCTGGC	:FFFFFFF:FFFFFFFFF,FFFFFFFF,FFFF,FFF::FF:FF:FFFFFFFFFFFFFFFFFFF,FFFFFFFFF,FFFFFFFFFFFFFFFFFFFFFFFFFFF	NH:i:1	HI:i:1	AS:i:101	nM:i:1	TX:Z:ENSMUST00000008812,+383,101M;ENSMUST00000172799,+935,101M;ENSMUST00000174745,+319,101M;ENSMUST00000174758,+282,101M	GX:Z:ENSMUSG00000008668	GN:Z:Rps18	fx:Z:ENSMUSG00000008668	RE:A:E	xf:i:25	CR:Z:AAACCCAAGAACTTCC	CY:Z:FF:F,FFFFFFFF:FF	CB:Z:AAACCCAAGAACTTCC-1	UR:Z:CGCCCCTGCGCA	UY:Z:FFFFFFFF:,FF	UB:Z:CGCCCCTGCTCA	NM:i:1	MD:Z:89C11	RG:Z:ms_hippo_stamp_EIF4A_batch2:0:1:H3JJHDSX7:1-32F01FC'
read = pysam.AlignedSegment.fromstring(read_as_string, samfile.header)


In [39]:
reverse = False
md_tag = read.get_tag('MD')
cigarstring = read.cigarstring

cigar_tuples = read.cigartuples
aligned_seq = read.get_forward_sequence()
query_qualities = read.query_qualities

if not reverse:
    aligned_seq = reverse_complement(aligned_seq)

reference_seq = read.get_reference_sequence().lower()

if verbose:
    print("MD tag:", md_tag)
    print("CIGAR string", cigarstring)
    print("Reference seq:", reference_seq.upper())
    print("Aligned seq:", aligned_seq)


indicated_aligned_seq, alt_bases = incorporate_replaced_pos_info(aligned_seq, positions_replaced)
print(indicated_aligned_seq)
fixed_aligned_seq = incorporate_insertions_and_deletions(indicated_aligned_seq, cigar_tuples)
print('\n')
print(fixed_aligned_seq)

aaggccccaaaagtggcgcagccctctatgggctcgaattttcttcagcctctccaggtcctcacgcagcttgttgtctagaccgttggTcagaacctggc


aaggNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNccccaaaagtggcgcagccctctatgggctcgaattttcttcagcctctccaggtcctcacgcagcttgttgtctagaccgttggTcagaacNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNctggc


In [42]:

indicated_reference_seq, alt_bases = incorporate_replaced_pos_info(reference_seq, positions_replaced)
print(indicated_reference_seq)
fixed_reference_seq = incorporate_insertions_and_deletions(indicated_reference_seq, cigar_tuples)
print()
print(fixed_reference_seq)

aaggccccaaaagtggcgcagccctctatgggctcgaattttcttcagcctctccaggtcctcacgcagcttgttgtctagaccgttggCcagaacctggc

aaggNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNccccaaaagtggcgcagccctctatgggctcgaattttcttcagcctctccaggtcctcacgcagcttgttgtctagaccgttggCcagaacNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNctggc


In [32]:
fixed_aligned_seq

'aaggNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNccccaaaagtggcgcagccctctatgggctcgaattttcttcagcctctccaggtcctcacgcagcttgttgtctagaccgttggTcagaacNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNctggc'

In [25]:
fixed_aligned_seq

'AAGGNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNCCCCAAAAGTGGCGCAGCCCTCTATGGGCTCGAATTTTCTTCAGCCTCTCCAGGTCCTCACGCAGCTTGTTGTCTAGACCGTTGGTCAGAACNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNCTGGC'

In [27]:
positions_replaced = get_positions_from_md_tag(md_tag, verbose=verbose)
positions_replaced

[89, 101]

In [21]:
positions_replaced = get_positions_from_md_tag(md_tag, verbose=verbose)


indicated_aligned_seq, alt_bases = incorporate_replaced_pos_info(fixed_aligned_seq, positions_replaced)


'aaggnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnNnnnnnnnnnnnNnnnnccccaaaagtggcgcagccctctatgggctcgaattttcttcagcctctccaggtcctcacgcagcttgttgtctagaccgttggtcagaacnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnctggc'

aaggccccaaaagtggcgcagccctctatgggctcgaattttcttcagcctctccaggtcctcacgcagcttgttgtctagaccgttggTcagaacctggc
aaggccccaaaagtggcgcagccctctatgggctcgaattttcttcagcctctccaggtcctcacgcagcttgttgtctagaccgttggccagaacctggc


In [44]:
print("Total time: {} seconds".format(overall_time))
print("Total time: {} minutes".format(overall_time/60))

Total time: 26.106521040201187 seconds
Total time: 0.4351086840033531 minutes


# More helper functions

In [9]:
def sort_bam(bam_file_name):
    output_name = bam_file_name.split("bam")[0] + ".sorted.bam"
    pysam.sort("-o", output_name, bam_file_name)  
    return output_name

def write_reads_to_file(reads, bam_file_name, header_string):
    with pysam.AlignmentFile(bam_file_name, "wb", text=header_string) as bam_handle:
        for read_str in reads:
            read = pysam.AlignedSegment.fromstring(read_str, bam_handle.header)
            bam_handle.write(read) 
    bam_handle.close()
            
def write_reads_to_file_wrapper(parameters):
    reads, bam_file_name, header_string = parameters
    write_reads_to_file(reads, bam_file_name, header_string)
    
    try:
        index_bam(bam_file_name)
    except Exception as e:
        print("Failed at indexing {}".format(bam_file_name))
        
    

# Combine all of the reads (string representation) for each barcode
## Groups the results from each sub-contig segment above, for example the reads from the first half of chr1 and those from the second half.

In [10]:
from copy import deepcopy 

results_copy = deepcopy(results)

num_barcodes_to_time = {}
num_contigs_to_time = {}

overall_label_to_list_of_contents = defaultdict(lambda:{})

for barcode_to_concatted_reads_df, total_reads, barcodes_df, label, counts_df, time_df, total_time in results_copy:
    try:
        barcode_to_concatted_reads_df.columns = ['contents']
        barcode_to_concatted_reads_df['barcode'] = [b.split('/')[-1].split('.bam')[0] for b in barcode_to_concatted_reads_df.index]
        contig = label.split('(')[0]
        barcode_to_concatted_reads_df['barcode_contig'] = barcode_to_concatted_reads_df['barcode'] + '_' + contig
        overall_label_to_list_of_contents[contig][label] = barcode_to_concatted_reads_df
    except Exception as e:
        print(e, label)

Length mismatch: Expected axis has 0 elements, new values have 1 elements MT(000):0-1019
Length mismatch: Expected axis has 0 elements, new values have 1 elements MT(001):1019-2038
Length mismatch: Expected axis has 0 elements, new values have 1 elements Y(004):22936176-28670220
Length mismatch: Expected axis has 0 elements, new values have 1 elements Y(005):28670220-34404264
Length mismatch: Expected axis has 0 elements, new values have 1 elements Y(008):45872352-51606396
Length mismatch: Expected axis has 0 elements, new values have 1 elements Y(009):51606396-57340440
Length mismatch: Expected axis has 0 elements, new values have 1 elements Y(012):68808528-74542572
Length mismatch: Expected axis has 0 elements, new values have 1 elements Y(014):80276616-86010660


In [11]:
print("Overall contigs:\n\n\t",overall_label_to_list_of_contents.keys())
print("\nSubcontig regions for an example contig (17):\n\n\t",sorted(overall_label_to_list_of_contents.get('17').keys()))

Overall contigs:

	 dict_keys(['1', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '3', '4', '5', '6', '7', '8', '9', 'MT', 'X', 'Y'])

Subcontig regions for an example contig (17):

	 ['17(000):0-5936705', '17(001):5936705-11873410', '17(002):11873410-17810115', '17(003):17810115-23746820', '17(004):23746820-29683525', '17(005):29683525-35620230', '17(006):35620230-41556935', '17(007):41556935-47493640', '17(008):47493640-53430345', '17(009):53430345-59367050', '17(010):59367050-65303755', '17(011):65303755-71240460', '17(012):71240460-77177165', '17(013):77177165-83113870', '17(014):83113870-89050575', '17(015):89050575-94987280']


### Generate list of jobs to be multiprocessed

In [12]:
from collections import OrderedDict

start_time = time.perf_counter() 

# Get the bam header, which will be used for each of the split bams too
header_string = str(samfile.header)

# Make a subfolder into which the split bams will be placed
split_bams_folder = '{}/split_bams'.format(output_folder)
if not os.path.exists(split_bams_folder):
    os.mkdir(split_bams_folder)

    
num_contigs = 0
jobs_params = []

for contig, df_dict in overall_label_to_list_of_contents.items():
    num_contigs += 1
    print("Contig: {}".format(contig))
    
    # Make a sub-subfolder to put the bams for this specific contig
    contig_folder = '{}/{}/'.format(split_bams_folder, contig)
    if not os.path.exists(contig_folder):
        os.mkdir(contig_folder)
    
    # Sort the subcontig regions such that the reads are properly ordered 
    sorted_subcontig_names = sorted(df_dict.keys())
    sorted_subcontig_dfs = []
    for n in sorted_subcontig_names:
        sorted_subcontig_dfs.append(df_dict.get(n))
        
    # All of the reads for all of the barcodes are in this dataframe
    all_contents_df = pd.concat(sorted_subcontig_dfs)
    
    # Get all the unique barcodes
    all_barcodes = list(all_contents_df.barcode.unique())
    
    
    for i, barcode in enumerate(all_barcodes):            
        if i % 100 == 0:
            print('{}/{} barcodes'.format(i, len(all_barcodes)))

        # Combine the reads (in string representation) for all rows corresponding to a barcode
        contents_for_barcode = all_contents_df[all_contents_df.barcode == barcode]
                
        all_contents_text = '\n'.join(contents_for_barcode.contents)
        
        # Turn the newline-delimited block of text back into list of reads as strings
        reads = all_contents_text.split('\n')
            
        # Remove duplicates
        reads_deduped = list(OrderedDict.fromkeys(reads))
                
        # Establish the name of the split bam that will be generated
        bam_file_name = '{}/{}_{}.bam'.format(contig_folder, contig, barcode)
        
        # Add parameters to list of jobs
        jobs_params.append([reads_deduped, bam_file_name, header_string])
        
            
total_time = time.perf_counter() - start_time


Contig: 1
0/1500 barcodes
100/1500 barcodes
200/1500 barcodes
300/1500 barcodes
400/1500 barcodes
500/1500 barcodes
600/1500 barcodes
700/1500 barcodes
800/1500 barcodes
900/1500 barcodes
1000/1500 barcodes
1100/1500 barcodes
1200/1500 barcodes
1300/1500 barcodes
1400/1500 barcodes
Contig: 10
0/1500 barcodes
100/1500 barcodes
200/1500 barcodes
300/1500 barcodes
400/1500 barcodes
500/1500 barcodes
600/1500 barcodes
700/1500 barcodes
800/1500 barcodes
900/1500 barcodes
1000/1500 barcodes
1100/1500 barcodes
1200/1500 barcodes
1300/1500 barcodes
1400/1500 barcodes
Contig: 11
0/1500 barcodes
100/1500 barcodes
200/1500 barcodes
300/1500 barcodes
400/1500 barcodes
500/1500 barcodes
600/1500 barcodes
700/1500 barcodes
800/1500 barcodes
900/1500 barcodes
1000/1500 barcodes
1100/1500 barcodes
1200/1500 barcodes
1300/1500 barcodes
1400/1500 barcodes
Contig: 12
0/1500 barcodes
100/1500 barcodes
200/1500 barcodes
300/1500 barcodes
400/1500 barcodes
500/1500 barcodes
600/1500 barcodes
700/1500 barco

In [43]:
print("Total time to prepare list for multiprocess-writing bams: {} minutes".format(round(total_time/60)))

Total time to prepare list for multiprocess-writing bams: 25 minutes


# Generate bams

In [None]:
start_time = time.perf_counter()

with Pool(processes=16) as p:
    max_ = len(jobs_params)
    with tqdm(total=max_) as pbar:
        for _ in p.imap_unordered(write_reads_to_file_wrapper, jobs_params):
            pbar.update()

total_time = time.perf_counter() - start_time


  7%|▋         | 2268/32562 [00:33<06:59, 72.25it/s]

In [None]:
print("Total time to write bams: {} minutes".format(round(total_time/60)))

# Time profiling of the edit-counting step

In [None]:
total_contig_times = {}
all_read_info_dfs = []
all_time_dfs = []

total_times = {}
for result in results:
    
    label = result[3]

    try:
        total_time = result[6]
        total_times[label] = total_time
        
        time_df = result[5]
        all_time_dfs.append(time_df)
        total_time_for_contig = (float(time_df.max()))
        total_contig_times[label] = total_time_for_contig
        
        read_info_df = result[4]
        read_info_df.columns = [label]
        all_read_info_dfs.append(read_info_df)
    except Exception as e:
        print(e, label)

In [None]:
total_contig_times_df = pd.DataFrame.from_dict(total_contig_times, orient='index', columns=['seconds']).sort_values('seconds')

In [None]:
'Total time without threading: {} minutes'.format(round(total_contig_times_df.seconds.sum()/60, 2))

In [None]:
total_reads_df = pd.concat(all_read_info_dfs,axis=1).T

In [None]:
total_reads_and_times_df = total_reads_df.join(total_contig_times_df)

plt.scatter(total_reads_and_times_df.edited, total_reads_and_times_df.seconds, s=1)
plt.scatter(total_reads_and_times_df.no_edits, total_reads_and_times_df.seconds, s=1)

plt.title("Total processing time vs number of reads")
plt.ylabel("Time (seconds)")
plt.xlabel("Reads")
plt.legend(['Reads with edits', 'Read without edits'])

In [None]:
pd.DataFrame.from_dict(total_seconds_for_reads, orient='index').sort_index().plot(legend=False)
plt.xlabel("Reads")
plt.ylabel("Time (seconds)")
plt.title("Runtime vs number of reads processed")

In [None]:
rates = []
for reads, secs in zip(list(total_seconds_for_reads.keys()), list(total_seconds_for_reads.values())):
    rate = reads/secs
    rates.append(rate)

In [None]:
plt.plot(range(len(rates)), rates)
plt.title("Mean rate (reads per second)")
plt.ylabel("Reads/Second")
plt.xlabel("Number of reads (e10^6)")

average_rate = np.mean(rates)
plt.axhline(average_rate, color='r')
print("Average of {} reads/second".format(average_rate))

In [None]:
seconds_per_read = 1/average_rate

In [None]:
import math

reads_per_cell = 50000
total_cells = 10000
total_reads = reads_per_cell * total_cells
print(total_reads)

total_estimated_time = total_reads * seconds_per_read
print('Estimated total time in minutes for {} reads: {} minutes'.format(total_reads, math.ceil(total_estimated_time/60), 3))


# Second loop to get coverage at sites with edits

In [None]:
# Todo: get all edit info files for each contig and group them by contig, before processing.
# Todo: Multiprocess that too.

In [None]:
from glob import glob

splits = [i.split("/")[-1].split('_edit')[0] for i in glob('{}/edit_info/*'.format(output_folder))]
print("Accessing split bams: {}".format(', '.join(sorted(splits))))

### Gather the edit information generated for each subcontig, and group by contig so we only have 1 edit information dataframe to process per contig

In [None]:
all_edit_info_for_barcodes = []

edit_info_grouped_per_contig = defaultdict(lambda:[])
edit_info_grouped_per_contig_combined = defaultdict(lambda:[])

num_splits = len(splits)
print("Grouping edit information outputs by contig...")
for i, split in enumerate(splits):
    if i%10 == 0:
        print("\t{}/{}...".format(i, num_splits))
    contig = split.split("_")[0]
    
    barcode_to_coverage_dict = defaultdict()    
    
    barcode_to_coverage_dict = defaultdict()
    edit_info_file = '{}/edit_info/{}_edit_info.tsv'.format(output_folder, split)
    edit_info = pd.read_csv(edit_info_file, sep='\t')
    edit_info_grouped_per_contig[contig].append(edit_info)
print("Done grouping! Concatenating ...")

for contig, list_of_edit_info_dfs in edit_info_grouped_per_contig.items():
    edit_info_grouped_per_contig_combined[contig] = pd.concat(list_of_edit_info_dfs)

print("Done concatenating!")

### Get coverage at edit positions for each contig

In [None]:
import pandas as pd
pd.options.mode.chained_assignment = None 

def get_edit_info_for_barcode_in_contig(edit_info, contig, barcode, output_folder):
    
    bam_subfolder = "{}/split_bams/{}".format(output_folder, contig)
    barcode_bam = '{}/{}_{}.bam'.format(bam_subfolder, contig, barcode)

    samfile_for_barcode = pysam.AlignmentFile(barcode_bam, "rb")

    edit_info_for_barcode = edit_info[edit_info.barcode == barcode]
    positions_for_barcode = edit_info_for_barcode.position.tolist()

    coverage = []
    for pos in positions_for_barcode:
        coverage_at_pos = np.sum(samfile_for_barcode.count_coverage(contig, pos-1, pos, quality_threshold=0))
        coverage.append(coverage_at_pos)

    edit_info_for_barcode['coverage'] = coverage
    edit_info_for_barcode['contig'] = edit_info_for_barcode.contig.astype(str)

    return edit_info_for_barcode


def get_edit_info_for_barcode_in_contig_wrapper(parameters):
    edit_info, contig, barcode, output_folder = parameters
    edit_info_for_barcode = get_edit_info_for_barcode_in_contig(edit_info, contig, barcode, output_folder)
    return edit_info_for_barcode


def get_coverage_for_edits_in_contig(edit_info_grouped_per_contig_combined, output_folder):
    job_params = []
    
    for contig, edit_info in edit_info_grouped_per_contig_combined.items():
        unique_barcodes = sorted(edit_info.barcode.unique())

        for i, barcode in enumerate(unique_barcodes):                 
            job_params.append([edit_info, contig, barcode, output_folder])  
    return job_params
    
coverage_counting_job_params = get_coverage_for_edits_in_contig(edit_info_grouped_per_contig_combined, output_folder)

In [None]:
getsizeof(coverage_counting_job_params)

In [None]:
start_time = time.perf_counter()

results = []
with Pool(processes=16) as p:
    max_ = len(coverage_counting_job_params)
    with tqdm(total=max_) as pbar:
        for _ in p.imap_unordered(get_edit_info_for_barcode_in_contig_wrapper, coverage_counting_job_params):
            pbar.update()
            results.append(_)
            
total_time = time.perf_counter() - start_time

In [None]:
print(total_time)

In [None]:
all_edit_info = pd.concat(results)

# Group by site to get final total edit and coverage counts at each site

# Verify C>T ratios

In [None]:
all_edit_info.groupby(['ref', 'alt']).count().plot(kind='barh', legend=False)
plt.title("All edits")

base_quality_thresh = 15
all_edit_info[all_edit_info.base_quality > base_quality_thresh].groupby(['ref', 'alt']).count().plot(kind='barh', legend=False)
plt.title("Edits with base quality > {}".format(base_quality_thresh))

all_edit_info_filtered = all_edit_info[all_edit_info.base_quality > base_quality_thresh]


In [None]:
example_new_ct =  all_edit_info_filtered[(all_edit_info_filtered.ref == 'C') & (all_edit_info_filtered.alt == 'T')].sort_values('position')

In [None]:
len(example_new_ct)

In [None]:
example_new_ct

# Cells that do have STAMP expressed versus don't...?

In [None]:
stamp_expression_path = \
'/projects/ps-yeolab3/ekofman/Sammi/MouseBrainEF1A_SingleCell_EPR_combined/\
4.1_cells_with_middling_stamp/stamp_expression_for_all_used_cells.tsv'

In [None]:
stamp_expression_df = pd.read_csv(stamp_expression_path, sep='\t', index_col=0)

In [None]:
stamp_expression_df.Stamp.hist(bins=50)

In [None]:
for thresh in [0, 1, 2, 3, 4, 5, 6, 6.5]:
    print(thresh)
    barcodes_at_stamp_thresh = stamp_expression_df[stamp_expression_df.Stamp > thresh].index
    
    all_edit_info_filtered[
        all_edit_info_filtered.barcode.isin(barcodes_at_stamp_thresh)].groupby(['ref', 'alt']).count().plot(kind='barh', legend=False)
    plt.title("Edit Type Distribution for Cells with STAMP expression above {}".format(thresh))

In [None]:
for thresh in [1.5,2, 3, 4, 5, 6]:
    print(thresh)
    barcodes_at_stamp_thresh = stamp_expression_df[stamp_expression_df.Stamp < thresh].index
    
    all_edit_info_filtered[
        all_edit_info_filtered.barcode.isin(barcodes_at_stamp_thresh)].groupby(['ref', 'alt']).count().plot(kind='barh', legend=False)
    plt.title("Edit Type Distribution for Cells with STAMP expression below {}".format(thresh))