# Classify reads covering close-by positions into four groups

**Part 1 of the "linked mutations" analyses.**

In [1]:
%run "Header.ipynb"

In [2]:
import time
import pickle
import pysam
import skbio
import hansel
from collections import defaultdict
from itertools import combinations
from linked_mutations_utils import find_mutated_positions

In [7]:
bf = pysam.AlignmentFile("../main-workflow/output/fully-filtered-and-sorted-aln.bam", "rb")

# This probably won't save a noticeable amount of memory, but humor me
i2n = "ACGT"
n2i = {"A": 0, "C": 1, "G": 2, "T": 3}

t1 = time.time()
for seq in SEQS:
    fasta = skbio.DNA.read("../seqs/{}.fasta".format(seq))
    
    # Identify all mutated positions in this genome up front to save time.
    print(f"Identifying mutated positions in genome {seq2name[seq]}...")
    mutated_positions = find_mutated_positions(seq)
    print(f"Found {len(mutated_positions):,} mutated positions in {seq2name[seq]}.")
    print("Going through these positions...")
    
    # This should already be implicitly sorted, I think, but the code below relies on mutated_positions being
    # in the exact same order as expected. So we may as well be paranoid.
    mutated_positions = sorted(mutated_positions)
    
    # Count the occurrences of each nucleotide (A, C, G, T) at each mutated position.
    # These are stored at positions 0, 1, 2, 3 in these lists, as you might expect.
    # This is a somewhat inefficient way of storing this data since there will be a lot of zeroes, probably,
    # but speed is not our primary focus right now (translation: please I am so tired)
    mp2nts = {mp: [0, 0, 0, 0] for mp in mutated_positions}
    
    # Maps read name to another dict of mutated position -> aligned nucleotide (in A, C, G, T).
    # We build this up all at once so that we can take supplementary alignments of the same read into account.
    readname2mutpos2nt = defaultdict(dict)
    
    # Go through all linear alignments of each read to this genome...
    ts1 = time.time()
    for ai, aln in enumerate(bf.fetch(seq), 1):
        if ai % 1000 == 0:
            print(
                f"\tOn aln {ai} in seq {seq2name[seq]}. "
                f"Time spent on {seq2name[seq]} so far: {time.time() - ts1:,.2f} sec."
            )
        ap = aln.get_aligned_pairs(matches_only=True)
        
        # Iterating through the aligned pairs is expensive. Since read lengths are generally in the thousands
        # to tens of thousands of bp (which is much less than the > 1 million bp length of any bacterial genome),
        # we set things up so that we only iterate through the aligned pairs once. We maintain an integer, mpi,
        # that is a poor man's "pointer" to an index in mutated_positions.
        
        mpi = 0
        
        # Go through this aln's aligned pairs. As we see each pair, compare the pair's reference position
        # (refpos) to the mpi-th mutated position (herein referred to as "mutpos").
        #
        # If refpos >  mutpos, increment mpi until refpos <= mutpos (stopping as early as possible).
        # If refpos == mutpos, we have a match! Update readname2mutpos2ismutated[mutpos] based on
        #                      comparing the read to the reference at the aligned positions.
        # If refpos <  mutpos, continue to the next pair.
        
        readname = aln.query_name
        for pair in ap:
            
            refpos = pair[1]
            mutpos = mutated_positions[mpi]
            
            no_mutations_to_right_of_here = False
            
            # Increment mpi until we get to the next mutated position at or after the reference pos for this
            # aligned pair (or until we run out of mutated positions).
            while refpos > mutpos:
                mpi += 1
                if mpi < len(mutated_positions):
                    mutpos = mutated_positions[mpi]
                else:
                    no_mutations_to_right_of_here = True
                    break
            
            # I expect this should happen only for reads aligned near the right end of the genome.
            if no_mutations_to_right_of_here:
                break
            
            # If the next mutation occurs after this aligned pair, continue on to a later pair.
            if refpos < mutpos:
                continue
                
            # If we've made it here, refpos == mutpos!
            # (...unless I messed something up in how I designed this code.)
            if refpos != mutpos:
                raise ValueError("This should never happen!")
                
            # Finally, get the nucleotide aligned to this mutated position from this read.
            readpos = pair[0]
            # (Convert the nucleotide to an integer in the range [0, 23] using n2i)
            readval = n2i[aln.query_sequence[readpos]]
            
            # Record this specific "allele" for this read. We can use this to link alleles that co-occur
            # on the same read. 
            readname2mutpos2nt[readname][mutpos] = readval
            
            # Update numbers of nucleotides seen at this mutated position. This is reads(i, N), as described
            # in the paper.
            mp2nts[mutpos][readval] += 1
            
            # (For debugging -- this is the first "highly mutated" position in G1217, the binary CAMP gene)
            # if mutpos == 1209000:
            #    text += (f"{aln.query_name} @ {refpos}. ref = {refval}, read = {readval}")
        
        
#     # Now we've seen all alignments of each read.
#     for readname in readname2mutpos2nt:
#         mutated_positions_covered_in_read = readname2mutpos2nt[readname].keys()
#         # Now that we've seen all mutated positions covered by this read, update pair information.
        
#         for (ii, jj) in combinations(mutated_positions_covered_in_read, 2):
            
#             # To make life easier, just sort the pair and save that as i and j.
#             # I *think* we could sort mutated_positions_covered_in_read and then combinations() should
#             # automatically generate sorted combinations, but I'm not sure if that is guaranteed -- so to
#             # reduce the probability of weird bugs we can just sort things here.
#             i, j = sorted([ii, jj])
            
#             # See if i and j are close enough to each other. There are two ways this can happen (these aren't
#             # necessarily mutually exclusive but in practice probs will be, depending on genome size and
#             # MAX_DIST_BTWN_LINKED_POSITIONS_NONINCLUSIVE)
#             #
#             # 1. i and j are close to each other without looping around the genome
#             #    (e.g. i = 15,000; j = 15,001)
#             #
#             # 2. i and j are close to each other when you loop around the genome
#             #    (e.g. genome length = 1,000,000; i = 0; j = 999,999)
#             #    This case is only allowed when seq2iscircular[seq] is True. (For edges that aren't circular --
#             #    e.g. edge 6104 [CAMP], as of writing, which is a linear edge within a circular component --
#             #    we don't allow this case to ever be True.)
            
#             # Case 1
#             close_enough_nolooping = (j - i) < MAX_DIST_BTWN_LINKED_POSITIONS_NONINCLUSIVE
            
#             # Case 2
#             if seq2iscircular[seq]:
#                 close_enough_looping = (seq2len[seq] + i - j) < MAX_DIST_BTWN_LINKED_POSITIONS_NONINCLUSIVE
#             else:
#                 close_enough_looping = False
                
#             if close_enough_nolooping or close_enough_looping:
#                 im = readname2mutpos2ismutated[readname][i]
#                 jm = readname2mutpos2ismutated[readname][j]
#                 if im:
#                     if jm:
#                         # Read supports mutations at both i and j
#                         pospair2groupcts[(i, j)][0] += 1
#                     else:
#                         # Read supports a mutation at i but not j
#                         pospair2groupcts[(i, j)][1] += 1
#                 else:
#                     if jm:
#                         # Read supports a mutation at j but not i
#                         pospair2groupcts[(i, j)][2] += 1
#                     else:
#                         # Read doesn't support mutations at either i or j
#                         pospair2groupcts[(i, j)][3] += 1

#     print(f"Finished going through reads in {seq2name[seq]}.")
    
#     # Write out pospair2json to a safe location, just because this is probably going to take a while
#     # and I don't want to risk losing this work.
#     #
#     # We use pickle instead of JSON because JSON can't handle tuples as the index of pospair2json:
#     # see https://stackoverflow.com/a/16439720.
#     # 
#     # We use the file suffix ".pickle" and "wb" based on the conventions described in
#     # https://stackoverflow.com/a/40433504 (...which in turn just reference the python docs).
    
#     with open(f"pospair2groupcts/{seq}_pospair2groupcts.pickle", "wb") as dumpster:
#         dumpster.write(pickle.dumps(pospair2groupcts))
        
print(f"Time taken: {time.time() - t1} sec.")

Identifying mutated positions in genome CAMP...
Found 283 mutated positions in CAMP.
Going through these positions...
	On aln 1000 in seq CAMP. Time spent on CAMP so far: 1.91 sec.
	On aln 2000 in seq CAMP. Time spent on CAMP so far: 3.55 sec.
	On aln 3000 in seq CAMP. Time spent on CAMP so far: 5.41 sec.
	On aln 4000 in seq CAMP. Time spent on CAMP so far: 7.08 sec.
	On aln 5000 in seq CAMP. Time spent on CAMP so far: 9.22 sec.
	On aln 6000 in seq CAMP. Time spent on CAMP so far: 12.31 sec.
	On aln 7000 in seq CAMP. Time spent on CAMP so far: 15.35 sec.
	On aln 8000 in seq CAMP. Time spent on CAMP so far: 18.45 sec.
	On aln 9000 in seq CAMP. Time spent on CAMP so far: 21.61 sec.
	On aln 10000 in seq CAMP. Time spent on CAMP so far: 24.67 sec.
	On aln 11000 in seq CAMP. Time spent on CAMP so far: 27.79 sec.


KeyboardInterrupt: 