# Classify reads covering close-by positions into four groups

**Part 1 of the "linked mutations" analyses.**

See the section of the paper on linked positions for details. This produces, for each genome, a mapping named
`pospair2groupcts`: a `defaultdict` mapping pairs of positions (tuples) to a list of `[0, 0, 0, 0]`, where each entry in the list indicates the number of reads of each of the four groups represented for this pair.

This section takes a while (roughly 1 hour per genome, as of writing). To allow for later steps to be rerun without rerunning this one, we write out each `pospair2groupcts` object to the `pospair2groupcts/` folder within the `notebooks` directory.

In [1]:
%run "Header.ipynb"
%run "LoadMutationJSONData.ipynb"

In [2]:
import time
import pickle
import pysam
import skbio
from collections import defaultdict
from itertools import combinations
from linked_mutations_utils import (
    MAX_DIST_BTWN_LINKED_POSITIONS_NONINCLUSIVE, emptyListOf4, find_mutated_positions
)

In [3]:
bf = pysam.AlignmentFile("../main-workflow/output/fully-filtered-and-sorted-aln.bam", "rb")

t1 = time.time()
for seq in SEQS:
    fasta = skbio.DNA.read("../seqs/{}.fasta".format(seq))
    # Maps tuple of (left integer pos, right integer pos) to a list of [0, 0, 0, 0].
    # (Since a pair (i, j) is equal to a pair (j, i), we just index this so that the leftmost position is the
    # first element in the tuple and the rightmost position is the second element. This seems like a more intuitive
    # way of structuring this then as a nested dict of leftpos2rightpos2groupcts.)
    # https://stackoverflow.com/a/13065439
    #
    # Each entry in the list indicates counts of types of reads connecting these two positions we've seen thus
    # far. In 0-indexed coordinates:
    #
    # 0. Reads(i, j): reads that support mutations at both positions
    # 1. Reads(i, -): reads that only support mutations at i
    # 2. Reads(-, j): reads that only support mutations at j
    # 3. Reads(-, -): reads that do not support mutations at either position
    #
    # This matches the definitions in the paper (currently that is section 3.6.2, but that number may change as
    # the paper is edited and restructured).
    
    pospair2groupcts = defaultdict(emptyListOf4)
    
    # Identify all mutated positions in this genome up front to save time.
    print(f"Identifying mutated positions in genome {seq2name[seq]}...")
    mutated_positions = find_mutated_positions(seq, seq2pos2matchct, seq2pos2mismatchct)
    print(f"Found {len(mutated_positions)} mutated positions in {seq2name[seq]}.")
    
    # This should already be implicitly sorted, I think, but the code below relies on mutated_positions being
    # in the exact same order as expected. So we may as well be paranoid.
    mutated_positions = sorted(mutated_positions)
    
    # Maps read.query_name to mutated positions to a bool of True (this read is mutated at this position, compared
    # to the reference) or False (this read is not mutated at this position, compared to the reference).
    # The absence of a mutated position from the inner dict implies that this position is not seen in this read
    # (either due to indels/skips or this read just not being aligned to cover it).
    # Updated as we go through bf.fetch(), since we want to count supplementary alignments of a given read
    # together.
    readname2mutpos2ismutated = defaultdict(dict)
    
    # Go through all aligned segments for this genome...
    # (NOTE that "read" in the below loop really means "aligned segment", since a read can have multiple
    # aligned segments in the case of supplementary alignments)
    ts1 = time.time()
    for ri, read in enumerate(bf.fetch(seq), 1):
        if ri % 1000 == 0:
            print(
                f"On read {ri} in seq {seq2name[seq]}."
                f"Time spent on this seq so far: {time.time() - ts1:.2f} sec."
            )
        ap = read.get_aligned_pairs(matches_only=True)
        
        # Iterating through the aligned pairs is expensive. Since read lengths are generally in the thousands
        # to tens of thousands of bp (which is much less than the > 1 million bp length of any bacterial genome),
        # we set things up so that we only iterate through the aligned pairs once. We maintain an integer, mpi,
        # that is a poor man's "pointer" to an index in mutated_positions.
        
        mpi = 0
        
        # Go through this read's aligned pairs. As we see each pair, compare the pair's reference position
        # (refpos) to the mpi-th mutated position (herein referred to as "mutpos").
        #
        # If refpos >  mutpos, increment mpi until refpos <= mutpos (stopping as early as possible).
        # If refpos == mutpos, we have a match! Update readname2mutpos2ismutated[mutpos] based on
        #                      comparing the read to the reference at the aligned positions.
        # If refpos <  mutpos, continue to the next pair.
        
        readname = read.query_name
        for pair in ap:
            
            refpos = pair[1]
            mutpos = mutated_positions[mpi]
            
            no_mutations_to_right_of_here = False
            
            # Increment mpi until we get to the next mutated position at or after the reference pos for this
            # aligned pair (or until we run out of mutated positions).
            while refpos > mutpos:
                mpi += 1
                if mpi < len(mutated_positions):
                    mutpos = mutated_positions[mpi]
                else:
                    no_mutations_to_right_of_here = True
                    break
            
            # I expect this should happen only for reads aligned near the right end of the genome.
            if no_mutations_to_right_of_here:
                break
            
            # If the next mutation occurs after this aligned pair, continue on to a later pair.
            if refpos < mutpos:
                continue
                
            # If we've made it here, refpos == mutpos!
            # (...unless I messed something up in how I designed this code.)
            if refpos != mutpos:
                raise ValueError("This should never happen!")
                
            # Finally, compare the read to the reference at this position, and update mutpos2ismutated.
            readpos = pair[0]
            
            # WE NEED TO CONVERT TO A STRING because slicing a skbio.DNA object returns another DNA object.
            # May or may not have spent an hour staring at the screen due to that ._.
            refval = str(fasta[refpos])
            readval = read.query_sequence[readpos]

            mutated = (readval != refval)

            # (For debugging -- this is the first "highly mutated" position in G1217, the binary CAMP gene)
            # if mutpos == 1209000:
            #    text += (f"{read.query_name} @ {refpos}. ref = {refval}, read = {readval}, mutated = {mutated}")
            
            # This means that if a given mutated position is covered by multiple alignments of a single read,
            # the _last_ alignment we see will trump previous alignments. It's arbitrary, but we could modify this
            # if desired (e.g. respecting the primary alignment if possible -- but then what to do when there are
            # 2 supplementary alignments that overlap with each other?)
            readname2mutpos2ismutated[readname][mutpos] = mutated
        
    # Now that we've seen all alignments of each read, 
    for readname in readname2mutpos2ismutated:
        mutated_positions_covered_in_read = readname2mutpos2ismutated[readname].keys()
        # Now that we've seen all mutated positions covered by this read, update pair information.
        
        for (ii, jj) in combinations(mutated_positions_covered_in_read, 2):
            
            # To make life easier, just sort the pair and save that as i and j.
            # I *think* we could sort mutated_positions_covered_in_read and then combinations() should
            # automatically generate sorted combinations, but I'm not sure if that is guaranteed -- so to
            # reduce the probability of weird bugs we can just sort things here.
            i, j = sorted([ii, jj])
            
            # See if i and j are close enough to each other. There are two ways this can happen (these aren't
            # necessarily mutually exclusive but in practice probs will be, depending on genome size and
            # MAX_DIST_BTWN_LINKED_POSITIONS_NONINCLUSIVE)
            #
            # 1. i and j are close to each other without looping around the genome
            #    (e.g. i = 15,000; j = 15,001)
            #
            # 2. i and j are close to each other when you loop around the genome
            #    (e.g. genome length = 1,000,000; i = 0; j = 999,999)
            #    This case is only allowed when seq2iscircular[seq] is True. (For edges that aren't circular --
            #    e.g. edge 6104 [CAMP], as of writing, which is a linear edge within a circular component --
            #    we don't allow this case to ever be True.)
            
            # Case 1
            close_enough_nolooping = (j - i) < MAX_DIST_BTWN_LINKED_POSITIONS_NONINCLUSIVE
            
            # Case 2
            if seq2iscircular[seq]:
                close_enough_looping = (seq2len[seq] + i - j) < MAX_DIST_BTWN_LINKED_POSITIONS_NONINCLUSIVE
            else:
                close_enough_looping = False
                
            if close_enough_nolooping or close_enough_looping:
                im = readname2mutpos2ismutated[readname][i]
                jm = readname2mutpos2ismutated[readname][j]
                if im:
                    if jm:
                        # Read supports mutations at both i and j
                        pospair2groupcts[(i, j)][0] += 1
                    else:
                        # Read supports a mutation at i but not j
                        pospair2groupcts[(i, j)][1] += 1
                else:
                    if jm:
                        # Read supports a mutation at j but not i
                        pospair2groupcts[(i, j)][2] += 1
                    else:
                        # Read doesn't support mutations at either i or j
                        pospair2groupcts[(i, j)][3] += 1

    print(f"Finished going through reads in {seq2name[seq]}.")
    
    # Write out pospair2json to a safe location, just because this is probably going to take a while
    # and I don't want to risk losing this work.
    #
    # We use pickle instead of JSON because JSON can't handle tuples as the index of pospair2json:
    # see https://stackoverflow.com/a/16439720.
    # 
    # We use the file suffix ".pickle" and "wb" based on the conventions described in
    # https://stackoverflow.com/a/40433504 (...which in turn just reference the python docs).
    
    with open(f"pospair2groupcts/{seq}_pospair2groupcts.pickle", "wb") as dumpster:
        dumpster.write(pickle.dumps(pospair2groupcts))
        
print(f"Time taken: {time.time() - t1} sec.")

Identifying mutated positions in genome CAMP...
Found 470 mutated positions in CAMP.
On read 1000 in seq CAMP.Time spent on this seq so far: 1.90 sec.
On read 2000 in seq CAMP.Time spent on this seq so far: 3.64 sec.
On read 3000 in seq CAMP.Time spent on this seq so far: 5.60 sec.
On read 4000 in seq CAMP.Time spent on this seq so far: 7.35 sec.
On read 5000 in seq CAMP.Time spent on this seq so far: 9.53 sec.
On read 6000 in seq CAMP.Time spent on this seq so far: 12.82 sec.
On read 7000 in seq CAMP.Time spent on this seq so far: 16.02 sec.
On read 8000 in seq CAMP.Time spent on this seq so far: 19.27 sec.
On read 9000 in seq CAMP.Time spent on this seq so far: 22.52 sec.
On read 10000 in seq CAMP.Time spent on this seq so far: 26.29 sec.
On read 11000 in seq CAMP.Time spent on this seq so far: 29.61 sec.
On read 12000 in seq CAMP.Time spent on this seq so far: 32.93 sec.
On read 13000 in seq CAMP.Time spent on this seq so far: 36.18 sec.
On read 14000 in seq CAMP.Time spent on this 