# Generate the data for "mutation matrices"

Both for codons (e.g. ATG, TGC) and amino acids (e.g. M, C).

**NOTE 1:** Because this boils down to calling [`AlignedSegment.get_aligned_pairs()`](https://pysam.readthedocs.io/en/latest/api.html#pysam.AlignedSegment.get_aligned_pairs) once for every read aligned to the selected genomes, **this notebook is currently pretty slow**! I've optimized things to the point where (assuming
there are roughly 1,470,000 reads aligned to each genome, and that the runtime for other genomes is similar to that of the CAMP genome) this notebook should take around 8-12 hours to run on our cluster.

If necessary / desired it should be possible to speed this up even more, using stuff like parallelization / writing this in a faster language like C / etc. There may also be methods in pysam I've overlooked that will help do this faster.

**NOTE 2:** This doesn't actually generate the figures for the matrices -- it just outputs JSON files to a folder named `matrix-jsons/`, and another notebook will generate those. This is intended to make it easier to regenerate the figures using different styles / etc. without having to wait hours for this stuff to finish.

In [1]:
%run "Header.ipynb"

In [None]:
import copy
import time
import json
import pysam
import skbio
from collections import defaultdict, Counter
from statistics import mean
from parse_sco import parse_sco

## Initialize data structures that we'll store frequency data in

In [None]:
# 64x63 dict: each key is a triplet of {A, C, G, T}, and each value is another dict with all the other codons
codon2codon2freq = {}

# 21x20 dict: each key is a proteinogenic amino acid (A, C, D, E, F, ...), limited to just
# stuff in the standard genetic code (i.e. ignoring selenocystine and pyrrolsine) but including
# "*", representing a stop codon.
aa2aa2freq = {}

# 64-key dict: maps each triplet to an integer indicating how frequently this triplet occurs in all genes
# in the genomes (i.e. not counting mutations into this triplet).
codon2freq = {}

# 21-key dict: maps amino acid/stop codon to integer indicating frequency across all genes.
aa2freq = {}

# There's probably a fancier way of generating this list, but this is fine.
codons = []
# Also, we figure out the reverse complements of each of the 64 3-mers in advance -- this avoids
# us having to call str(skbio.DNA(c).reverse_complement()) every time we see a codon, and saves a tiny
# amount of time per read (the skbio approach took ~9e-5 seconds every time; the new approach takes ~9e-7
# seconds every time). Considering we're going through well over a million reads, the time savings comes out
# to ... 130.977 seconds, aka 2 minutes 10 seconds, if I'm computing this correctly. So, not much, but it's
# something!
codon2revcomp = {}
nts = "ACGT"
for i in nts:
    for j in nts:
        for k in nts:
            c = "{}{}{}".format(i, j, k)
            codons.append(c)
            codon2revcomp[c] = str(skbio.DNA(c).reverse_complement())

aas = set([])
for c in codons:
    aas.add(str(skbio.DNA(c).translate()))
    
# Initialize dicts to 0s
for c1 in codons:
    codon2codon2freq[c1] = {c2: 0 for c2 in set(codons) - set([c1])}
    codon2freq[c1] = 0
    
for aa1 in aas:
    aa2aa2freq[aa1] = {aa2: 0 for aa2 in set(aas) - set([aa1])}
    aa2freq[aa1] = 0

## Go through all reads aligned to each genome and figure out which genes they intersect and which codons in these genes they fully cover

Define a dict which we'll use to keep track of aligned codon frequencies for each codon, for each gene, for each genome.

- For each read, see which predicted genes (if any) this read intersects within the genome. Note that "intersects" doesn't mean "fully covers".

- For each of these genes, see which codons (if any) this read fully covers within the gene.

- Increment aligned codon frequencies for all codons accordingly.

The reason we do things this way, as opposed to iterating over just the reads overlapping each codon in each gene, is that doing things that way is really slow! I'm pretty sure it's because "find out which reads overlap this region" is a pretty slow operation when working with large datasets -- and also since these are long reads, doing this on the level of each codon means we're effectively doing a lot of redundant work (you can imagine that, for a given codon, the odds are pretty good that most reads overlapping it will also overlap adjacent codon(s)).

In [1]:
# Maps sequence IDs to genes (keyed by their Index in the .sco file) to codons (keyed by (0-indexed!)
# left end, i.e. the lower of the two positional boundaries of the codon, regardless of if its gene
# is on the + or - strand) to observed aligned codon frequencies (keyed by just the triplet, e.g. "AAA").
#
# Example:
# {"edge_6104":                                Sequence
#     {1:                                      Gene index in the .sco file
#         {265:                                Left codon position
#             {"TTA": 1000, "TTT": 1, ... }    Aligned codon frequencies for this particular codon
#         }
#     }
# }
seq2gene2codon2alignedcodons = {}

In [4]:
bf = pysam.AlignmentFile("../main-workflow/output/fully-filtered-and-sorted-aln.bam", "rb")

tT1 = time.time()
for seq in SEQS:
    df = parse_sco("../seqs/genes/{}.sco".format(seq))
    
    # We don't actually store any results in this, but we do use it for a slight optimization
    gene2isrev = {}
    
    # Initialize some of the data structures
    # NOTE: this is kind of slow. However, it still finishes within a few seconds, so not the most
    # important thing to optimize
    seq2gene2codon2alignedcodons[seq] = {}
    for gene_data in df.itertuples():
        
        # Should never happen, but check this so that we can compute overlap easily and with peace of mind later
        if gene_data.LeftEnd >= gene_data.RightEnd:
            raise ValueError("Gene {}'s coordinates seem messed up: left = {}, right = {}".format(
                gene_data.Index, gene_data.LeftEnd, gene_data.RightEnd
            ))
        
        seq2gene2codon2alignedcodons[seq][gene_data.Index] = {}
        gene2isrev[gene_data.Index] = (gene_data.Strand == "-")
        
        codon_positions = [
            i for i in range(gene_data.LeftEnd, gene_data.RightEnd + 1, 3)
        ]

        # For each codon in this gene...
        for cpleft in codon_positions:
            seq2gene2codon2alignedcodons[seq][gene_data.Index][cpleft] = defaultdict(int)
            
    print("Finished initialization for seq = {}".format(seq))
    readtimes = []
    
    # Note that this isn't really a "read" so much as it is an aligned linear segment (and a read
    # can have multiple such segments derived from it, as discussed above).
    for ri, read in enumerate(bf.fetch(seq), 1):
        
        t1 = time.time()
        
        # Find all genes that this read intersects in this genome
        
        # These are 0-indexed coordinates (and segright is offset to the right by one; see
        # https://pysam.readthedocs.io/en/latest/api.html#pysam.AlignedSegment.reference_end)
        segleft = read.reference_start
        segright = read.reference_end
        
        if segleft is None or segright is None:
            raise ValueError("Read {} is unmapped? This shouldn't happen!".format(seg.query_name))
        
        if segleft >= segright:
            raise ValueError("Read {}'s coordinates in pysam seem messed up: left = {}, right = {}".format(
                seg.query_name, segleft, segright
            ))
        
        # Convert aligned segment boundaries to 1-indexed coordinates to make comparing with gene
        # coordinates from the .sco file easier.
        # Since segright was already offset to the right by 1, we don't need to do anything for it
        # (the gene coordinates are exact: a gene from [266, 712] starts at base 266 and ends at base 712,
        # using 1-indexing. So in order to make the read boundaries match, we'd add 1 for segright and then
        # subract 1 since segright was already 1 base off -- and n + 1 - 1 = n. (...math is hard)
        segleft += 1
        
        # Use vectorization to find genes overlapping this read: see https://stackoverflow.com/a/17071908
        # for details on why parentheses, etc., and
        # https://engineering.upside.com/a-beginners-guide-to-optimizing-pandas-code-for-speed-c09ef2c6a4d6
        # for justification on why this is useful (tldr: makes code go fast)
        genes_overlapping_read = list(
            df.loc[(df["RightEnd"] >= segleft) & (df["LeftEnd"] <= segright)].itertuples()
        )
        # Note about the above thing that just happened: you may be shaking your fist and saying "wait
        # itertuples is slow!" And yeah, kinda. But for whatever reason I've tried multiple times to keep
        # genes_overlapping_read as a DataFrame (and then later vectorize stuff like checking that a given
        # aligned pair covers a codon within the genes, etc) and the overhead costs seem to slow things down.
        # I am sure it's possible to speed things up more, but right now things seem good enough.

        # (Debugging code)
        # print("{} genes overlap read {}".format(len(genes_overlapping_read), ri))
        # print("Read {}, which ranges from {} to {}, overlaps these genes:".format(ri, segleft, segright))
        # print(genes_overlapping_read)
                
        # If no genes overlap this read, we are free to move on to the next read.
        if len(genes_overlapping_read) > 0:
            
            # Computing this is relatively slow, which is why we jump through so many hoops before we do this.
            # Each entry in get_aligned_pairs() is a tuple with 2 elements:
            # the first is the read pos and the second is the reference pos.
            # TODO: would it be possible to only do this for certain positions we care about? get_aligned_pairs()
            # returns a lot of stuff we don't need, e.g. regions of the read that don't intersect with any genes.
            ap = read.get_aligned_pairs(matches_only=True)
            
            # Doesn't look like getting this in advance saves much time, but I don't think it hurts.
            read_seq = read.query_sequence
            
            # We only consider the leftmost position of each codon, so we don't need to bother checking the last
            # two pairs of positions (since neither could be the leftmost position of a codon that this read
            # fully covers).
            for api, pair1 in enumerate(ap[:-2]):

                # Convert to 1-indexed position for ease of comparison with gene coordinates
                pair1_refpos = pair1[1] + 1
                  
                havent_checked_next_pairs = True
                for gene_data in genes_overlapping_read:
                    gl = gene_data.LeftEnd
                    
                    # Check that this pair is located within this gene and is the leftmost position of a
                    # codon in the gene. (Note that check works for both + or - strand genes.
                    # Whether the leftmost position is the "start" [i.e. CP 1] or "end" [i.e. CP 3] of
                    # the gene changes with the strand of the gene, but we'll account for that later on
                    # when we reverse-complement the codon if needed.)                  
                    if pair1_refpos >= gl and pair1_refpos <= gene_data.RightEnd - 2 and ((pair1_refpos - gl) % 3 == 0):
                        
                        # Nice! Looks like this read fully covers this codon.
                        
                        # If we haven't yet, check that this read doesn't skip over parts of the codon,
                        # or stuff like that. The reason this check is located *here* (and not
                        # before we loop over the genes) is that it seems like this is a faster strategy:
                        # only run these checks once we KNOW that this pair looks like it fully covers a
                        # codon, since many pairs might not meet that criteria.
                        #
                        # (And by recording that we've run this check once, in havent_checked_next_pairs,
                        # we can save the time cost of running the check multiple times.)
                        #
                        # I feel like an insane person trying to optimize this so much lmao.
                        if havent_checked_next_pairs:
                            # Check that the pairs are all consecutive (i.e. no "jumps" in the read,
                            # and no "jumps" in the reference)
                            # Since we don't consider the last two pairs in ap, pair2 and pair3 should
                            # always be available.
                            pair2 = ap[api + 1]
                            pair3 = ap[api + 2]

                            # For an aligned pair, [0] is the read pos and [1] is the reference pos.
                            # Result of caching this is probably negligible but... may as well
                            p10 = pair1[0]
                            p20 = pair2[0]
                            readpos_consecutive = p20 == (p10 + 1) and pair3[0] == (p20 + 1)
                            if not readpos_consecutive:
                                break

                            # (pair1_refpos is already off by 1 so no need to redo the addition operation)
                            p21 = pair2[1]
                            refpos_consecutive = p21 == pair1_refpos and pair3[1] == (p21 + 1)
                            if not refpos_consecutive:
                                break
                            havent_checked_next_pairs = False
                        
                        # Figure out what the read actually *says* in the alignment here.
                        # (It'll probably be a complete match most of the time, but there will
                        # be some occasional mismatches -- and seeing those is ... the whole point
                        # of this notebook.)

                        # We make sure to index the read by read coords, not reference coords!
                        aligned_codon = read_seq[p10: p10 + 3]

                        # Finally, update information about codon frequencies.
                        gi = gene_data.Index
                        if gene2isrev[gi]:
                            seq2gene2codon2alignedcodons[seq][gi][pair1_refpos][codon2revcomp[aligned_codon]] += 1
                        else:
                            seq2gene2codon2alignedcodons[seq][gi][pair1_refpos][aligned_codon] += 1

        t2 = time.time()
        readtimes.append(t2 - t1)
        if ri % 100 == 0:
            print("Seen {} reads so far in {}.".format(ri, seq))
            print("Total time taken thus far:          {} sec.".format(t2 - tT1))
            print("Average time per read for this seq: {} sec.".format(mean(readtimes)))

    # At this point, we've seen all the reads aligned to all the codons in this genome.
    # We can now "call" mutations based on the frequencies we've counted.
tT2 = time.time()

print("Figuring all that out took a total of {} seconds.".format(tT2 - tT1))

bf.close()
with open("matrix-jsons/seq2gene2codon2alignedcodons.json", "w") as dumpster:
    dumpster.write(json.dumps(seq2gene2codon2alignedcodons))

NameError: name 'pysam' is not defined

## Using the information we just computed for each genome, "call" mutations and store this information in the frequency data structures we set up earlier

In [5]:
for seq in SEQS:
    fasta = skbio.DNA.read("../seqs/{}.fasta".format(seq))
    df = parse_sco("../seqs/genes/{}.sco".format(seq))
    for gene_data in df.itertuples():
        for cpleft in range(gene_data.LeftEnd, gene_data.RightEnd + 1, 3):
            
            # Make note of the codon sequence and amino acid encoded by this codon in the "reference" genome.
            # (Keep in mind that the gene data in the .sco file uses 1-indexed coords, so we need to convert
            # accordingly.)
            codon_dna = fasta[cpleft - 1: cpleft + 2]
            if gene_data.Strand == "-":
                codon_dna = codon_dna.reverse_complement()

            codon_seq = str(codon_dna)
            aa = str(codon_dna.translate())
            
            # Update frequencies accordingly.
            codon2freq[codon_seq] += 1
            aa2freq[aa] += 1
            
            # We can finally compute stats re: number of mismatching and matching codons.
            aligned_codons = seq2gene2codon2alignedcodons[seq][gene_data.Index][cpleft]
            num_aligned_codons = sum(aligned_codons.values())
            alt_codon_frac = (num_aligned_codons - aligned_codons[codon_seq]) / num_aligned_codons
            
            # print("sum of vals of ac is {}".format(sum(aligned_codons.values())))
            # print("Codon {} from {} to {} in gene {} in seq {} has mutations: {}".format(
            #     codon_seq, cpleft, cpleft + 2, gene_data.Index, seq, aligned_codons
            # ))
            
            # Using minfreq = 0.5%
            if alt_codon_frac > 0.005:
                
                # Subset aligned_codons to just the alternate codons. I guess we could also just use "del".
                alt_codons = {c: aligned_codons[c] for c in aligned_codons if c != codon_seq}
                
                # Retrieve max-freq alternate codon.
                # Based on https://stackoverflow.com/a/280156.
                # (Note that if there's a tie, the result is arbitrary. Shouldn't be a big deal. Making note
                # of in the paper.)
                max_freq_alt_codon = max(alt_codons, key=alt_codons.get)
                codon2codon2freq[codon_seq][max_freq_alt_codon] += 1
                
                # print("Is mutation! And max freq alt codon is {}".format(max_freq_alt_codon))
                
                # NOTE: I guess you could argue that we should do this another way, where we actually compute
                # the translations of all the alt codons and then pick the most common AA/stop codon from there?
                #
                # You could argue this either way: doing it based on just the mutated codon keeps the matrices
                # consistent and lessens the impact of small errors, while taking into account all alt codon
                # translations could help show weird things where multiple mutations have similar consequences.
                # Hmm.
                #
                # TODO: think about!
                alt_codon_aa = str(skbio.DNA(max_freq_alt_codon).translate())
                if alt_codon_aa != aa:
                    aa2aa2freq[aa][alt_codon_aa] += 1
                    # print("Is nonsyn mutation! Alt {} codes for {}; orig coded for {}".format(
                    #     max_freq_alt_codon, alt_codon_aa, aa
                    # ))

# Write out stuff for further analysis / in case of crisis
with open("matrix-jsons/codon2codon2freq.json", "w") as dumpster:
    dumpster.write(json.dumps(codon2codon2freq))
    
with open("matrix-jsons/codon2freq.json", "w") as dumpster:
    dumpster.write(json.dumps(codon2freq))
    
with open("matrix-jsons/aa2aa2freq.json", "w") as dumpster:
    dumpster.write(json.dumps(aa2aa2freq))

with open("matrix-jsons/aa2freq.json", "w") as dumpster:
    dumpster.write(json.dumps(aa2freq))

SyntaxError: unexpected EOF while parsing (<ipython-input-5-ed03929bc105>, line 1)