# Identify and cluster "linked" mutated positions

In [1]:
%run "Header.ipynb"
%run "LoadMutationJSONData.ipynb"

In [3]:
import time
import pickle
import pysam
import skbio
from collections import defaultdict

#### Define various constants

In [None]:
# unless (pos j) - (pos i) < this, we do not consider i and j linked.
MAX_DIST_BTWN_LINKED_POSITIONS_NONINCLUSIVE = 3000

# unless at least this many reads have mutations at both pos i and pos j, we do not consider i and j linked.
MIN_COV_OF_MUTATIONS_AT_LINKED_POSITIONS = 1000

# unless |Reads(i, -)| + |Reads(-, j)| < this fraction * |Reads(i, j)|, we do not consider i and j linked.
MAX_NONLINKED_MUTATED_FRACTION_NONINCLUSIVE = 0.2

# How we call a mutation: only if (# mismatches) / (# mismatches + # matches) > MINFREQ.
# Defaults to 0.5%.
MINFREQ = 0.005

## Classify reads covering close-by positions into four groups

See the section of the paper on linked positions for details.

In [4]:
bf = pysam.AlignmentFile("../main-workflow/output/fully-filtered-and-sorted-aln.bam", "rb")

# Used to initialize entries in the pospair2groupcts defaultdict below.
# This was originally a lambda function, but that breaks pickle: https://stackoverflow.com/a/16439720
# ... so we need to use an ordinary function instead.
def emptyListOf4():
    return [0, 0, 0, 0]

t1 = time.time()
for seq in SEQS:
    fasta = skbio.DNA.read("../seqs/{}.fasta".format(seq))
    # Maps tuple of (left integer pos, right integer pos) to a list of [0, 0, 0, 0].
    # (Since a pair (i, j) is equal to a pair (j, i), we just index this so that the leftmost position is the
    # first element in the tuple and the rightmost position is the second element. This seems like a more intuitive
    # way of structuring this then as a nested dict of leftpos2rightpos2groupcts.)
    # https://stackoverflow.com/a/13065439
    #
    # Each entry in the list indicates counts of types of reads connecting these two positions we've seen thus
    # far. In 0-indexed coordinates:
    #
    # 0. Reads(i, j): reads that support mutations at both positions
    # 1. Reads(i, -): reads that only support mutations at i
    # 2. Reads(-, j): reads that only support mutations at j
    # 3. Reads(-, -): reads that do not support mutations at either position
    #
    # This matches the definitions in the paper (currently that is section 3.6.2, but that number may change as
    # the paper is edited and restructured).
    
    pospair2groupcts = defaultdict(emptyListOf4)
    
    # Identify all mutated positions in this genome up front to save time.
    print(f"Identifying mutated positions in genome {seq2name[seq]}...")
    mutated_positions = []
    for pos in seq2pos2matchct[seq].keys():
        
        matchct = seq2pos2matchct[seq][pos]
        mismatchct = seq2pos2mismatchct[seq][pos]
        cov = mismatchct + matchct
        
        # We can be strict and filter out positions that don't pass the coverage filter for linked reads -- no
        # sense including these.
        if cov >= MIN_COV_OF_MUTATIONS_AT_LINKED_POSITIONS:
            
            # Actually "call" mutations, the same way we do elsewhere in these analyses (albeit maybe with
            # different values of MINFREQ). Of course, this isn't the only way to do this.
            if (mismatchct / cov) > MINFREQ:
                mutated_positions.append(int(pos))
    print(f"Found {len(mutated_positions)} mutated positions in {seq2name[seq]}.")
    
    # This should already be implicitly sorted, I think, but the code below relies on mutated_positions being
    # in the exact same order as expected. So we may as well be paranoid.
    mutated_positions = sorted(mutated_positions)
    
    # Go through all aligned segments for this genome...
    ts1 = time.time()
    for ri, read in enumerate(bf.fetch(seq), 1):
        if ri % 100 == 0:
            print(f"On read {ri} in seq {seq2name[seq]}. Time spent so far: {time.time() - ts1:.2f} sec.")
        ap = read.get_aligned_pairs(matches_only=True)
        
        # Maps mutated positions seen in this segment to True (this read is mutated at this position, compared
        # to the reference) or False (this read is not mutated at this position, compared to the reference).
        # The absence of a mutated position from this dict implies that this position is not seen in this aligned
        # segment (either due to indels/skips or this read just not being aligned to cover it).
        #
        # After we compute this we can increment pospair2groupcts accordingly for every pair of mutated positions
        # present in this dict.
        mutpos2ismutated = {}
        
        # Iterating through the aligned pairs is expensive. Since read lengths are generally in the thousands
        # to tens of thousands of bp (which is much less than the > 1 million bp length of any bacterial genome),
        # we set things up so that we only iterate through the aligned pairs once. We maintain an integer, mpi,
        # that is a poor man's "pointer" to an index in mutated_positions.
        
        mpi = 0
        
        # Go through this read's aligned pairs. As we see each pair, compare the pair's reference position
        # (refpos) to the mpi-th mutated position (herein referred to as "mutpos").
        #
        # If refpos >  mutpos, increment mpi until refpos <= mutpos (stopping as early as possible).
        # If refpos == mutpos, we have a match! Update mutpos2ismutated[mutpos] based on comparing this read's
        #                      aligned value at this position to the reference at this position.
        # If refpos <  mutpos, continue to the next pair.
        
        # After doing all that, we can just use mutpos2ismutated to update the group counts for each pair.
        
        mutated_positions_covered_in_read = []
        
        for pair in ap:
            
            refpos = pair[1]
            mutpos = mutated_positions[mpi]
            
            no_mutations_to_right_of_here = False
            
            # Increment mpi until we get to the next mutated position at or after the reference pos for this
            # aligned pair (or until we run out of mutated positions).
            while refpos > mutpos:
                mpi += 1
                if mpi < len(mutated_positions):
                    mutpos = mutated_positions[mpi]
                else:
                    no_mutations_to_right_of_here = True
                    break
            
            # I expect this should happen only for reads aligned near the right end of the genome.
            if no_mutations_to_right_of_here:
                break
            
            # If the next mutation occurs after this aligned pair, continue on to a later pair.
            if refpos < mutpos:
                continue
                
            # If we've made it here, refpos == mutpos!
            # (...unless I messed something up in how I designed this code.)
            if refpos != mutpos:
                raise ValueError("This should never happen!")
                
            # Finally, compare the read to the reference at this position, and update mutpos2ismutated.
            readpos = pair[0]
            
            # WE NEED TO CONVERT TO A STRING because slicing a skbio.DNA object returns another DNA object.
            # May or may not have spent an hour staring at the screen due to that ._.
            refval = str(fasta[refpos])
            readval = read.query_sequence[readpos]
            
            mutated = (readval != refval)
            
            # if mutated:
            #     print(f"Read {read.query_name} at pos {refpos}. ref = {refval}, read = {readval}, mutated = {mutated}")
            
            mutpos2ismutated[mutpos] = mutated
                
            mutated_positions_covered_in_read.append(mutpos)
        
        # Now that we've seen all mutated positions covered by this read, update pair information.
        # The naive way to do this is to iterate over
        # itertools.combinations(mutated_positions_covered_by_read, 2) -- however, that is super expensive
        # for even relatively small numbers of positions (5000 choose 2 is almost 12.5 million!)
        #
        # We can save time and effort by using a few tricks to only consider subsets of these pairs corresponding
        # to close-together mutated positions (within the max distance btwn linked positions).
        #
        # For each mutated position...
        for pindex, pi in enumerate(mutated_positions_covered_in_read, 0):

            # For each mutated position located to the right of this one:
            # (If pi is the last mutated position, then slicing by [pindex + 1:] will just return an empty list
            # and we will implicitly skip the entire body of the for loop below here.)
            for pj in mutated_positions_covered_in_read[pindex + 1:]:

                # No need to take abs(), since we know pi < pj.
                if pj - pi < MAX_DIST_BTWN_LINKED_POSITIONS_NONINCLUSIVE:
                    piv = mutpos2ismutated[pi]
                    pjv = mutpos2ismutated[pj]
                    if piv:
                        if pjv:
                            # Read supports mutations at both i and j
                            pospair2groupcts[(pi, pj)][0] += 1
                        else:
                            # Read supports a mutation at i but not j
                            pospair2groupcts[(pi, pj)][1] += 1
                    else:
                        if pjv:
                            # Read supports a mutation at j but not i
                            pospair2groupcts[(pi, pj)][2] += 1
                        else:
                            # Read doesn't support mutations at either i or j
                            pospair2groupcts[(pi, pj)][3] += 1
                    
                else:
                    # Since mutated_positions is monotonically increasing (...it's a sorted list of positions),
                    # we can break as soon as we get to the max distance away from p1.
                    break

    print(f"Finished going through reads in {seq2name[seq]}.")
    
    # Write out pospair2json to a safe location, just because this is probably going to take a while
    # and I don't want to risk losing this work.
    #
    # We use pickle instead of JSON because JSON can't handle tuples as the index of pospair2json:
    # see https://stackoverflow.com/a/16439720.
    # 
    # We use the file suffix ".pickle" and "wb" based on the conventions described in
    # https://stackoverflow.com/a/40433504 (...which in turn just reference the python docs).
    
    with open(f"tmp/{seq}_pospair2groupcts.pickle", "wb") as dumpster:
        dumpster.write(pickle.dumps(pospair2groupcts))
        
print(f"Time taken: {time.time() - t1} sec.")

NameError: name 'pysam' is not defined

## Using the computed groups, define mutated positions as "linked" or not