# Compute "link graph" for phasing

This'll be continued in the next notebook, `Phasing-02-VizGraph.ipynb`.

In [1]:
%run "Header.ipynb"

In [2]:
import time
import pickle
import pysam
import skbio
import networkx as nx
from collections import defaultdict
from itertools import combinations
from linked_mutations_utils import find_mutated_positions, gen_ddi, MINSPAN, MINLINK_EXCLUSIVE

In [3]:
# This probably won't save a noticeable amount of memory, but humor me
i2n = "ACGT"
n2i = {"A": 0, "C": 1, "G": 2, "T": 3}

## 1. For each read, identify all nucleotides aligned to mutated positions spanned by this read

This takes about 1.8 hours for the three selected genomes. (That said, these genomes have super high coverage, so for less-well-covered genomes this will probably go faster.)

In [4]:
bf = pysam.AlignmentFile("../main-workflow/output/fully-filtered-and-sorted-aln.bam", "rb")

t1 = time.time()
for seq in SEQS:
    fasta = skbio.DNA.read("../seqs/{}.fasta".format(seq))
    
    # Identify all mutated positions in this genome up front to save time.
    print(f"Identifying mutated positions in genome {seq2name[seq]}...")
    mutated_positions = find_mutated_positions(seq)
    print(f"Found {len(mutated_positions):,} mutated positions in {seq2name[seq]}.")
    print("Going through these positions...")
    
    # This should already be implicitly sorted, I think, but the code below relies on mutated_positions being
    # in the exact same order as expected. So we may as well be paranoid.
    mutated_positions = sorted(mutated_positions)
    
    # Maps read name to another dict of mutated position -> aligned nucleotide (in A, C, G, T).
    # We build this up all at once so that we can take supplementary alignments of the same read into account.
    readname2mutpos2nt = defaultdict(dict)
    
    # Go through all linear alignments of each read to this genome...
    ts1 = time.time()
    for ai, aln in enumerate(bf.fetch(seq), 1):
        if ai % 1000 == 0:
            print(
                f"\tOn aln {ai:,} in seq {seq2name[seq]}. "
                f"Time spent on {seq2name[seq]} so far: {time.time() - ts1:,.2f} sec."
            )
        ap = aln.get_aligned_pairs(matches_only=True)
        
        # Iterating through the aligned pairs is expensive. Since read lengths are generally in the thousands
        # to tens of thousands of bp (which is much less than the > 1 million bp length of any bacterial genome),
        # we set things up so that we only iterate through the aligned pairs once. We maintain an integer, mpi,
        # that is a poor man's "pointer" to an index in mutated_positions.
        
        mpi = 0
        
        # Go through this aln's aligned pairs. As we see each pair, compare the pair's reference position
        # (refpos) to the mpi-th mutated position (herein referred to as "mutpos").
        #
        # If refpos >  mutpos, increment mpi until refpos <= mutpos (stopping as early as possible).
        # If refpos == mutpos, we have a match! Update readname2mutpos2ismutated[mutpos] based on
        #                      comparing the read to the reference at the aligned positions.
        # If refpos <  mutpos, continue to the next pair.
        
        readname = aln.query_name
        for pair in ap:
            
            refpos = pair[1]
            mutpos = mutated_positions[mpi]
            
            no_mutations_to_right_of_here = False
            
            # Increment mpi until we get to the next mutated position at or after the reference pos for this
            # aligned pair (or until we run out of mutated positions).
            while refpos > mutpos:
                mpi += 1
                if mpi < len(mutated_positions):
                    mutpos = mutated_positions[mpi]
                else:
                    no_mutations_to_right_of_here = True
                    break
            
            # I expect this should happen only for reads aligned near the right end of the genome.
            if no_mutations_to_right_of_here:
                break
            
            # If the next mutation occurs after this aligned pair, continue on to a later pair.
            if refpos < mutpos:
                continue
                
            # If we've made it here, refpos == mutpos!
            # (...unless I messed something up in how I designed this code.)
            if refpos != mutpos:
                raise ValueError("This should never happen!")
                
            # Finally, get the nucleotide aligned to this mutated position from this read.
            readpos = pair[0]
            # (Convert the nucleotide to an integer in the range [0, 3] using n2i)
            readval = n2i[aln.query_sequence[readpos]]
            
            # Record this specific "allele" for this read. We can use this to link alleles that co-occur
            # on the same read.
            readname2mutpos2nt[readname][mutpos] = readval
            
    with open(f"phasing-data/{seq}_readname2mutpos2nt.pickle", "wb") as dumpster:
        dumpster.write(pickle.dumps(readname2mutpos2nt))
        
print(f"Time taken: {time.time() - t1:,} sec.")

## 2. Compute frequency information for individual nucleotides and pairs of nucleotides at mutated positions

We could use Hansel to store the pairs-of-nucleotides data as a matrix, but I opted to use a custom solution (for now, at least) for a few reasons:

1. Don't need anything fancy -- just need to store this, not use the probabilistic weighting stuff
2. I don't have time right now to learn Hansel's API (I've read through the docs and am still a bit confused)
3. I think we could probably use less storage (e.g. we only need to store one "triangle" of the matrix; as far as I can tell, Hansel treats H\[a, b, i, j\] as independent of H\[b, a, j, i\], which isn't necessary for haplotyping IMO

In [5]:
t1 = time.time()
for seq in SEQS:
    with open(f"phasing-data/{seq}_readname2mutpos2nt.pickle", "rb") as loadster:
        # NOTE: this won't necessarily include ALL reads aligned to a sequence -- for example, if a read
        # doesn't cover any mutated positions, it will be omitted from the top level of this dict. (This is
        # because these reads won't be useful for linking the mutated positions.)
        readname2mutpos2nt = pickle.load(loadster)
        print(f"{len(readname2mutpos2nt):,} unique reads described in the data for seq {seq2name[seq]}.")

    ts1 = time.time()
    
    # Now we've seen all alignments of each read, we can go through readname2mutpos2nt and compute
    # co-occurrence information (and create a graph, plot stuff, etc.)
    
    # Maps mutated position -> nucleotide seen at this position, summed across all reads included here -> freq.
    # This corresponds to Reads(i, N) as described in the paper.
    pos2nt2freq = defaultdict(gen_ddi)
    
    # This defaultdict has two levels:
    # OUTER: Keys are sorted (in ascending order) 0-indexed pairs (tuples) of mutated positions. The
    #        inclusion ofa pair of mutated positions in this defaultdict implies that these two mutated
    #        positions were spanned by at least one read. The value of each pair is another defaultdict:
    #
    # INNER: The keys of this inner defaultdict are pairs of integers, each in the range [0, 3].
    #        These represent the 4 nucleotides (0 -> A, 1 -> C, 2 -> G, 3 -> T): the first entry represents
    #        the nucleotide seen at the first position in the pair (aka the position "earlier" in the genome),
    #        and the second entry represents the nucleotide seen at the second position in the pair (aka
    #        the position "later" in the genome). Of course, many bacterial genomes are circular, so "earlier"
    #        and "later" are kinda arbitrary. Anyway, there are 16 possible pairs in one of these defaultdicts,
    #        since there are 4^2 = 16 different possible combinations of two nucleotides (ignoring deletions,
    #        degenerate nucleotides, etc.) That said, I expect in practice only a handful of nucleotide pairs
    #        will be present for a given position pair. The value of each pair in this defaultdict is
    #        an integer representing the frequency with which this pair of nucleotides was observed on a
    #        spanning read at this pair of positions.
    #
    # So, as an example, if we only have two mutated positions in a genome (at 0-indexed positions 100 and 500),
    # and we saw:
    #
    # - 30    reads with an A at both positions
    # - 1,000 reads with an A at position 100 and a T at position 500
    # - 5     reads with a T at position 100 and an A at position 500
    # - 100   reads with a T at both positions
    # - 3     reads with a C at position 100 and a T at position 500
    # - 1     read  with a G at position 100 and a T at position 500
    #
    # ... then pospair2ntpair2freq would look like
    # {
    #     (100, 500): {
    #         {
    #             (0, 0): 30,
    #             (0, 3): 1000,
    #             (3, 0): 5,
    #             (3, 3): 100,
    #             (1, 3): 3,
    #             (2, 3): 1
    #         }
    #     }
    # }
    pospair2ntpair2freq = defaultdict(gen_ddi)
    for ri, readname in enumerate(readname2mutpos2nt, 1):
        if ri % 100000 == 0:
            print(
                f"\tOn read {ri:,} in seq {seq2name[seq]}. "
                f"Time spent on {seq2name[seq]} so far: {time.time() - ts1:,.2f} sec."
            )
        # TODO: see if we can avoid sorting here -- inefficient when done once for every read, maybe?
        mutated_positions_covered_in_read = sorted(readname2mutpos2nt[readname].keys())
        
        # NOTE: it may be possible to include this in the combinations() loop below, but we'd need some
        # snazzy logic to prevent updating the same position multiple times. Easiest for my sanity to just
        # be a bit inefficient and make this two separate loops.
        for mutpos in mutated_positions_covered_in_read:
            pos2nt2freq[mutpos][readname2mutpos2nt[readname][mutpos]] += 1
            
        for (i, j) in combinations(mutated_positions_covered_in_read, 2):
            
            # We can assume that i and j are sorted because mutated_positions_covered_in_read is sorted:
            # see https://docs.python.org/3.10/library/itertools.html#itertools.combinations. This is
            # guaranteed, but let's be paranoid just in case:
            if j <= i:
                raise ValueError("Something went horribly wrong with combinations()")
                
            # these are integers in the range [0, 3]
            i_nt = readname2mutpos2nt[readname][i]
            j_nt = readname2mutpos2nt[readname][j]
            
            # We know these mutated positions were observed on the same read, and we know the exact nucleotides
            # this read had at both positions -- update this in pospair2ntpair2freq
            pospair2ntpair2freq[(i, j)][(i_nt, j_nt)] += 1
            
            # print(f"Read {readname} has {i2n[i_nt]} at pos {i} and {i2n[j_nt]} at pos {j}.")

    print(f"Finished going through reads in {seq2name[seq]}.")
    
    # We use the file suffix ".pickle" and "wb" based on the conventions described in
    # https://stackoverflow.com/a/40433504 (...which in turn just reference the python docs).
    with open(f"phasing-data/{seq}_pospair2ntpair2freq.pickle", "wb") as dumpster:
        dumpster.write(pickle.dumps(pospair2ntpair2freq))
        
    with open(f"phasing-data/{seq}_pos2nt2freq.pickle", "wb") as dumpster:
        dumpster.write(pickle.dumps(pos2nt2freq))
        
print(f"Time taken: {time.time() - t1:,} sec.")

348,612 unique reads described in the data for seq CAMP.
	On read 100,000 in seq CAMP. Time spent on CAMP so far: 0.40 sec.
	On read 200,000 in seq CAMP. Time spent on CAMP so far: 0.83 sec.
	On read 300,000 in seq CAMP. Time spent on CAMP so far: 3.24 sec.
Finished going through reads in CAMP.
257,428 unique reads described in the data for seq BACT1.
	On read 100,000 in seq BACT1. Time spent on BACT1 so far: 1,198.51 sec.
	On read 200,000 in seq BACT1. Time spent on BACT1 so far: 2,807.78 sec.
Finished going through reads in BACT1.
700,066 unique reads described in the data for seq BACT2.
	On read 100,000 in seq BACT2. Time spent on BACT2 so far: 6.45 sec.
	On read 200,000 in seq BACT2. Time spent on BACT2 so far: 8.70 sec.
	On read 300,000 in seq BACT2. Time spent on BACT2 so far: 10.09 sec.
	On read 400,000 in seq BACT2. Time spent on BACT2 so far: 11.34 sec.
	On read 500,000 in seq BACT2. Time spent on BACT2 so far: 12.47 sec.
	On read 600,000 in seq BACT2. Time spent on BACT2 so f

## 3. Convert position pair + nucleotide pair information to a graph structure

We now know, for every pair of positions spanned by at least one read, the frequencies of nucleotide pairs seen together at these positions.

We can now construct a graph where nodes represent _alleles_ (position + nucleotide seen at this position), and edges connect alleles seen together.

We only connect two allele nodes if (ignoring exact nucleotides) at least _minSpan_ reads cover both positions, and the __link__ between two allele nodes (defined in the paper) is at least _minLink_. These parameters' values are given in `linked_mutations_utils.py`.

In [11]:
t1 = time.time()
for seq in SEQS:
    print(f"Generating link graph for seq {seq}...")
    with open(f"phasing-data/{seq}_pos2nt2freq.pickle", "rb") as loadster:
        pos2nt2freq = pickle.load(loadster)
        
    g = nx.Graph()
    
    # Add nodes to the graph -- one per seen nucleotide at every mutated position
    for pos in pos2nt2freq.keys():
        # Since this data structure is a defaultdict, this will only iterate over the defined (i.e. seen)
        # nucleotide indices (integers in the range [0, 3]).
        for nt in pos2nt2freq[pos].keys():
            # Set the "freq" attribute of this allele node to the number of times this nucleotide was seen
            # at this position in the reads. This corresponds to Reads(i, N) for position i and nt N.
            g.add_node((pos, nt), freq=pos2nt2freq[pos][nt])
    
    # Next step: add edges to the graph based on co-occurrence information
    with open(f"phasing-data/{seq}_pospair2ntpair2freq.pickle", "rb") as loadster:
        pospair2ntpair2freq = pickle.load(loadster)
    
    for pospair in pospair2ntpair2freq:
        i = pospair[0]
        j = pospair[1]
        
        # NOTE: possible to speed this up by bundling this computation into the for loop below, maybe
        # also note that "num spanning reads" only includes reads that meet criteria about not having
        # skips/indels at either position, etc.
        num_spanning_reads = sum(pospair2ntpair2freq[pospair].values())
        
        if num_spanning_reads > MINSPAN:
            for ntpair in pospair2ntpair2freq[pospair]:
                # these are still ints in the range [0, 3]
                i_nt = ntpair[0]
                j_nt = ntpair[1]
                link = pospair2ntpair2freq[pospair][ntpair] / max(pos2nt2freq[i][i_nt], pos2nt2freq[j][j_nt])
                if link > MINLINK_EXCLUSIVE:
                    # Yay, add an edge between these alleles!
                    g.add_edge((i, i_nt), (j, j_nt), link=link)
                    
    with open(f"phasing-data/{seq}_linkgraph.pickle", "wb") as dumpster:
        dumpster.write(pickle.dumps(g))
        
print(f"Time taken: {time.time() - t1:,} sec.")