# Generate the data for "mutation matrices"

Both for codons (e.g. ATG, TGC) and amino acids (e.g. M, C).

**NOTE 1:** Because this boils down to calling [`AlignedSegment.get_aligned_pairs()`](https://pysam.readthedocs.io/en/latest/api.html#pysam.AlignedSegment.get_aligned_pairs) once for every linear alignment to the selected MAGs, **this notebook is currently pretty slow**! I've optimized things to the point where this notebook takes around 20 hours on our cluster.

If necessary / desired it should be possible to speed this up even more, using stuff like parallelization / writing this in a faster language like C / etc. There may also be methods in pysam I've overlooked that will help do this faster (from what I can tell the only way to do this in pysam involves calling `get_aligned_pairs()`, but maybe I've missed something...)

**NOTE 2:** This doesn't actually generate the figures for the matrices -- it just outputs JSON files to a folder named `matrix-jsons/`, and another notebook will generate the figures based on these JSON files. This is intended to make it easier to regenerate the figures using different styles / etc. without having to wait hours for this stuff to finish.

**NOTE 3:** We're redoing things so that the overall matrices are only based on individual MAGs -- so it should be possible to either parallelize this on the level of each MAGs (could probably parallelize things even further), or just only run this for individual MAGs. HOWEVER: for now, I'm going to keep generating this data for all of the three selected MAGs, because (i) we need this data for the Syn/NonSyn and Non-nonsense / Nonsense barplots, and (ii) I don't really want to rewrite this entire notebook since this is currently a one-off analysis...

In [1]:
%run "Header.ipynb"
%run "GeneUtils.ipynb"

In [2]:
import copy
import time
import json
import pysam
import skbio
from collections import defaultdict, Counter
from statistics import mean
from parse_sco import parse_sco

## Go through all alignments to each genome and figure out which genes they intersect and which codons in these genes they fully cover

__This is the main bottleneck of this notebook__, at least as of writing.

Define a dict which we'll use to keep track of aligned codon frequencies for each codon, for each gene, for each genome.

- For each alignment, see which predicted genes (if any) this alignment intersects within the genome. Note that "intersects" doesn't mean "fully covers".

- For each of these genes, see which codons (if any) this alignment fully covers within the gene.

- Increment aligned codon frequencies for all codons accordingly.

The reason we do things this way, as opposed to iterating over just the alignments overlapping each codon in each gene, is that doing things that way is really slow! I'm pretty sure it's because "find out which alignments overlap this region" is a pretty slow operation when working with large datasets -- and also since these are long reads, doing this on the level of each codon means we're effectively doing a lot of redundant work (you can imagine that, for a given codon, the odds are pretty good that most alignments overlapping it will also overlap adjacent codon(s)).

**Note that we purposefully say "alignment" instead of "read"** since a single read can correspond to multiple distinct linear alignments (e.g. if this read is aligned to both the start and end of a genome, due to a chimeric / supplementary alignment). We assume that no two linear alignments from a read cover the _same_ region of the genome (the upstream filtering stuff should already guarantee this), but this is about as far as we go. (We could also probably filter out supplementary alignments entirely and I suspect that wouldn't change the results here much.)

In [3]:
# There's probably a fancier way of generating this list, but this is fine.
codons = []
# Also, we figure out the reverse complements of each of the 64 3-mers in advance -- this avoids
# us having to call str(skbio.DNA(c).reverse_complement()) every time we see a codon, and saves a tiny
# amount of time per alignment (the skbio approach took ~9e-5 seconds every time; the new approach takes ~9e-7
# seconds every time). Considering we're going through well over a million alignments, the time savings comes out
# to ... 130.977 seconds, aka 2 minutes 10 seconds, if I'm computing this correctly. So, not much, but it's
# something!
codon2revcomp = {}
nts = "ACGT"
for i in nts:
    for j in nts:
        for k in nts:
            c = "{}{}{}".format(i, j, k)
            codons.append(c)
            codon2revcomp[c] = str(skbio.DNA(c).reverse_complement())

In [4]:
# Maps sequence IDs to genes (keyed by their Index in the .sco file) to codons (keyed by (0-indexed!)
# left end, i.e. the lower of the two positional boundaries of the codon, regardless of if its gene
# is on the + or - strand) to observed aligned codon frequencies (keyed by just the triplet, e.g. "AAA").
#
# Example:
# {"edge_6104":                                Sequence
#     {1:                                      Gene index in the .sco file
#         {265:                                Left codon position
#             {"TTA": 1000, "TTT": 1, ... }    Aligned codon frequencies for this particular codon
#         }
#     }
# }
seq2gene2codon2alignedcodons = {}

In [5]:
bf = pysam.AlignmentFile("../main-workflow/output/fully-filtered-and-sorted-aln.bam", "rb")

tT1 = time.time()
for seq in SEQS:
    df = parse_sco("../seqs/genes/{}.sco".format(seq))
    
    # We don't actually store any results in this, but we do use it for a slight optimization
    gene2isrev = {}
    
    # Initialize some of the data structures
    # NOTE: this is kind of slow. However, it still finishes within a few seconds, so not the most
    # important thing to optimize
    seq2gene2codon2alignedcodons[seq] = {}
    for gene_data in df.itertuples():
        
        validate_gene_coords(gene_data)        
        seq2gene2codon2alignedcodons[seq][gene_data.Index] = {}
        gene2isrev[gene_data.Index] = (gene_data.Strand == "-")
        
        codon_positions = get_gene_left_codon_positions(gene_data)

        # For each codon in this gene, keep track of all the codons spanning it from the various
        # alignments to this genome.
        for cpleft in codon_positions:
            seq2gene2codon2alignedcodons[seq][gene_data.Index][cpleft] = defaultdict(int)
            
    print("Finished initialization for seq = {}".format(seq))
    alntimes = []
    
    for ri, aln in enumerate(bf.fetch(seq), 1):
        
        t1 = time.time()
        
        # Find all genes that this aln intersects in this genome
        
        # These are 0-indexed coordinates (and segright is offset to the right by one; see
        # https://pysam.readthedocs.io/en/latest/api.html#pysam.AlignedSegment.reference_end)
        segleft = aln.reference_start
        segright = aln.reference_end
        
        if segleft is None or segright is None:
            raise ValueError("Read {} is unmapped? This shouldn't happen!".format(seg.query_name))
        
        if segleft >= segright:
            raise ValueError("Read {}'s coordinates in pysam seem messed up: left = {}, right = {}".format(
                seg.query_name, segleft, segright
            ))
        
        # Convert aligned segment boundaries to 1-indexed coordinates to make comparing with gene
        # coordinates from the .sco file easier.
        # Since segright was already offset to the right by 1, we don't need to do anything for it
        # (the gene coordinates are exact: a gene from [266, 712] starts at base 266 and ends at base 712,
        # using 1-indexing. So in order to make the aln boundaries match, we'd add 1 for segright and then
        # subract 1 since segright was already 1 base off -- and n + 1 - 1 = n. (...math is hard)
        segleft += 1
        
        # Use vectorization to find genes overlapping this aln: see https://stackoverflow.com/a/17071908
        # for details on why parentheses, etc., and
        # https://engineering.upside.com/a-beginners-guide-to-optimizing-pandas-code-for-speed-c09ef2c6a4d6
        # for justification on why this is useful (tldr: makes code go fast)
        genes_overlapping_aln = list(
            df.loc[(df["RightEnd"] >= segleft) & (df["LeftEnd"] <= segright)].itertuples()
        )
        # Note about the above thing that just happened: you may be shaking your fist and saying "wait
        # itertuples is slow!" And yeah, kinda. But for whatever reason I've tried multiple times to keep
        # genes_overlapping_aln as a DataFrame (and then later vectorize stuff like checking that a given
        # aligned pair covers a codon within the genes, etc) and the overhead costs seem to slow things down.
        # I am sure it's possible to speed things up more, but right now things seem good enough.

        # (Debugging code)
        # print("{} genes overlap aln {}".format(len(genes_overlapping_aln), ri))
        # print("Read {}, which ranges from {} to {}, overlaps these genes:".format(ri, segleft, segright))
        # print(genes_overlapping_aln)
                
        # If no genes overlap this aln, we are free to move on to the next aln.
        if len(genes_overlapping_aln) > 0:
            
            # Computing this is relatively slow, which is why we jump through so many hoops before we do this.
            # Each entry in get_aligned_pairs() is a tuple with 2 elements:
            # the first is the query/read pos and the second is the reference pos.
            # TODO: would it be possible to only do this for certain positions we care about? get_aligned_pairs()
            # returns a lot of stuff we don't need, e.g. regions of the aln that don't intersect with any genes.
            ap = aln.get_aligned_pairs(matches_only=True)
            
            # Doesn't look like getting this in advance saves much time, but I don't think it hurts.
            read_seq = aln.query_sequence
            
            # We only consider the leftmost position of each codon, so we don't need to bother checking the last
            # two pairs of positions (since neither could be the leftmost position of a codon that this aln
            # fully covers).
            for api, pair1 in enumerate(ap[:-2]):

                # Convert to 1-indexed position for ease of comparison with gene coordinates
                pair1_refpos = pair1[1] + 1
                  
                havent_checked_next_pairs = True
                for gene_data in genes_overlapping_aln:
                    gl = gene_data.LeftEnd
                    
                    # Check that this pair is located within this gene and is the leftmost position of a
                    # codon in the gene. (Note that check works for both + or - strand genes.
                    # Whether the leftmost position is the "start" [i.e. CP 1] or "end" [i.e. CP 3] of
                    # the gene changes with the strand of the gene, but we'll account for that later on
                    # when we reverse-complement the codon if needed.)                  
                    if pair1_refpos >= gl and pair1_refpos <= gene_data.RightEnd - 2 and ((pair1_refpos - gl) % 3 == 0):
                        
                        # Nice! Looks like this aln fully covers this codon.
                        
                        # If we haven't yet, check that this aln doesn't skip over parts of the codon,
                        # or stuff like that. The reason this check is located *here* (and not
                        # before we loop over the genes) is that it seems like this is a faster strategy:
                        # only run these checks once we KNOW that this pair looks like it fully covers a
                        # codon, since many pairs might not meet that criteria.
                        #
                        # (And by recording that we've run this check once, in havent_checked_next_pairs,
                        # we can save the time cost of running the check multiple times.)
                        #
                        # I feel like an insane person trying to optimize this so much lmao.
                        if havent_checked_next_pairs:
                            # Check that the pairs are all consecutive (i.e. no "jumps" in the read,
                            # and no "jumps" in the reference)
                            # Since we don't consider the last two pairs in ap, pair2 and pair3 should
                            # always be available.
                            pair2 = ap[api + 1]
                            pair3 = ap[api + 2]

                            # For an aligned pair, [0] is the read pos and [1] is the reference pos.
                            # Result of caching this is probably negligible but... may as well
                            p10 = pair1[0]
                            p20 = pair2[0]
                            readpos_consecutive = p20 == (p10 + 1) and pair3[0] == (p20 + 1)
                            if not readpos_consecutive:
                                break

                            # (pair1_refpos is already off by 1 so no need to redo the addition operation)
                            p21 = pair2[1]
                            refpos_consecutive = p21 == pair1_refpos and pair3[1] == (p21 + 1)
                            if not refpos_consecutive:
                                break
                            havent_checked_next_pairs = False
                        
                        # Figure out what the read actually *says* in the alignment here.
                        # (It'll probably be a complete match most of the time, but there will
                        # be some occasional mismatches -- and seeing those is ... the whole point
                        # of this notebook.)

                        # We make sure to index the read by read coords, not reference coords!
                        aligned_codon = read_seq[p10: p10 + 3]

                        # Finally, update information about codon frequencies.
                        gi = gene_data.Index
                        if gene2isrev[gi]:
                            seq2gene2codon2alignedcodons[seq][gi][pair1_refpos][codon2revcomp[aligned_codon]] += 1
                        else:
                            seq2gene2codon2alignedcodons[seq][gi][pair1_refpos][aligned_codon] += 1

        t2 = time.time()
        alntimes.append(t2 - t1)
        if ri % 100 == 0:
            print(f"Seen {ri:,} alignments so far in {seq}.")
            print(f"Total time taken thus far:               {t2-tT1:,.4f} sec.")
            print(f"Average time per alignment for this seq: {mean(alntimes):,.4f} sec.")

    # At this point, we've seen all the alignments to all the codons in this genome.
    # We can now "call" mutations based on the frequencies we've counted.
tT2 = time.time()

print("Figuring all that out took a total of {} seconds.".format(tT2 - tT1))

bf.close()
# TODO: It'd probably be safer to output this after each sequence, rather than at the very end.
with open("matrix-jsons/seq2gene2codon2alignedcodons.json", "w") as dumpster:
    dumpster.write(json.dumps(seq2gene2codon2alignedcodons))

Finished initialization for seq = edge_6104
Seen 100 alignments so far in edge_6104.
Total time taken thus far:               4.6695 sec.
Average time per alignment for this seq: 0.0353 sec.
Seen 200 alignments so far in edge_6104.
Total time taken thus far:               8.4690 sec.
Average time per alignment for this seq: 0.0366 sec.
Seen 300 alignments so far in edge_6104.
Total time taken thus far:               12.1794 sec.
Average time per alignment for this seq: 0.0367 sec.
Seen 400 alignments so far in edge_6104.
Total time taken thus far:               16.9838 sec.
Average time per alignment for this seq: 0.0395 sec.
Seen 500 alignments so far in edge_6104.
Total time taken thus far:               20.1875 sec.
Average time per alignment for this seq: 0.0380 sec.
Seen 600 alignments so far in edge_6104.
Total time taken thus far:               22.7697 sec.
Average time per alignment for this seq: 0.0359 sec.
Seen 700 alignments so far in edge_6104.
Total time taken thus far:   

## Using the information we just computed for each genome, "call" mutations and store this information in the frequency data structures we set up earlier

This section is kind of slow, but on the order of "takes a few minutes on the cluster" and not on the order of "takes literally hours to run". We could make it more efficient if desired; most of my energy on optimization here thus far has been spent on the earlier step in this notebook.

Note that the main output of the above section (`seq2gene2codon2alignedcodons`) has already been written out to a JSON file -- in practice, it's useful to be able to start execution again from this bottom section after stopping the above section (e.g. if something goes wrong _here_, then we don't have to rerun the top part of this notebook again). Even if the entire notebook is run in a single shot, we still save and then load the `seq2gene2codon2alignedcodons` object anyway.

### Initialize data structures that we'll store frequency data in

In [6]:
# 64x63 dict: each key is a triplet of {A, C, G, T}, and each value is another dict with all the other codons
codon2codon2freq = {}

# 21x20 dict: each key is a proteinogenic amino acid (A, C, D, E, F, ...), limited to just
# stuff in the standard genetic code (i.e. ignoring selenocystine and pyrrolsine) but including
# "*", representing a stop codon.
aa2aa2freq = {}

# 64-key dict: maps each triplet to an integer indicating how frequently this triplet occurs in all genes
# in the genomes (i.e. not counting mutations into this triplet).
codon2freq = {}

# 21-key dict: maps amino acid/stop codon to integer indicating frequency across all genes.
aa2freq = {}

aas = set([])
for c in codons:
    aas.add(str(skbio.DNA(c).translate()))
    
# Initialize dicts to 0s
for c1 in codons:
    codon2codon2freq[c1] = {c2: 0 for c2 in set(codons) - set([c1])}
    codon2freq[c1] = 0
    
for aa1 in aas:
    aa2aa2freq[aa1] = {aa2: 0 for aa2 in set(aas) - set([aa1])}
    aa2freq[aa1] = 0

In [7]:
with open("matrix-jsons/seq2gene2codon2alignedcodons.json", "r") as loadster:
    seq2gene2codon2alignedcodons = json.load(loadster)

In [8]:
def generate_mutmatrix_data(seq):
    fasta = skbio.DNA.read("../seqs/{}.fasta".format(seq))
    df = parse_sco("../seqs/genes/{}.sco".format(seq))
    for gene_data in df.itertuples():
        print("On gene {} in seq {}.".format(gene_data.Index, seq))
        for cpleft in range(gene_data.LeftEnd, gene_data.RightEnd + 1, 3):
            
            # Make note of the codon sequence and amino acid encoded by this codon in the "reference" genome.
            # (Keep in mind that the gene data in the .sco file uses 1-indexed coords, so we need to convert
            # accordingly.)
            codon_dna = fasta[cpleft - 1: cpleft + 2]
            if gene_data.Strand == "-":
                codon_dna = codon_dna.reverse_complement()

            codon_seq = str(codon_dna)
            aa = str(codon_dna.translate())
            
            # Update frequencies accordingly.
            codon2freq[codon_seq] += 1
            aa2freq[aa] += 1
            
            # We can finally compute stats re: number of mismatching and matching codons.
            aligned_codons = seq2gene2codon2alignedcodons[seq][str(gene_data.Index)][str(cpleft)]
            
            # Ignore weird, low-coverage cases. Copying from the SynAndNonsense barplots notebook, which
            # uses similar logic:
            #  If the reference codon sequence isn't even included in the aligned codons (but if
            #  there are still other codons included in the alignment???), we're likely
            #  at a weird low-coverage portion of the alignment. Ignore these cases (there are 116 such
            #  problematic codons as of writing this, mostly in edge_1671 -- something might be wrong with
            #  the alignment, which I had to recover from a backup recently -- will rerun and check).
            #
            #  This implicitly accounts for the case where 0 codons are present in aligned_codons, as well.
            if codon_seq in aligned_codons:
                
                num_aligned_codons = sum(aligned_codons.values())
                
                # Only call a mutation using the max-freq alt codon, not the sum of all alternate codon freqs.
                # The max(d, key=d.get) trick is from https://stackoverflow.com/a/280156 (I've used it a lot
                # here :) Notably, this breaks ties arbitrarily.
                alt_codons = {c: aligned_codons[c] for c in aligned_codons if c != codon_seq}
                if len(alt_codons) > 0:
                    max_freq_alt_codon = max(alt_codons, key=aligned_codons.get)
                    max_freq_alt_codon_frac = alt_codons[max_freq_alt_codon] / num_aligned_codons
                else:
                    max_freq_alt_codon_frac = 0
                    
                # print("sum of vals of ac is {}".format(sum(aligned_codons.values())))
                # print("Codon {} from {} to {} in gene {} in seq {} has mutations: {}".format(
                #     codon_seq, cpleft, cpleft + 2, gene_data.Index, seq, aligned_codons
                # ))

                # Just for reference: this will print about cases where aggregating alternate codons results in
                # "false positives," at least compared to the method of only considering the maximum frequency
                # alternate codon.
                #
                #if max_freq_alt_codon_frac <= 0.005 and sum(alt_codons.values()) / num_aligned_codons > 0.005:
                #    print("Found contrary case.")
                #    print(f"{codon_seq}, {aligned_codons}, only mutation in aggregate... max freq alt codon is {max_freq_alt_codon}")
                    
                # Using p = 0.5%
                if max_freq_alt_codon_frac > 0.005:

                    codon2codon2freq[codon_seq][max_freq_alt_codon] += 1

                    print(f"{codon_seq}, {aligned_codons}, is mutation! max freq alt codon is {max_freq_alt_codon}")

                    # NOTE: I guess you could argue that we should do this another way, where we actually compute
                    # the translations of all the alt codons and then pick the most common AA/stop codon from there?
                    #
                    # You could argue this either way: doing it based on just the mutated codon keeps the matrices
                    # consistent and lessens the impact of small errors, while taking into account all alt codon
                    # translations could help show weird things where multiple mutations have similar consequences.
                    # Hmm.
                    alt_codon_aa = str(skbio.DNA(max_freq_alt_codon).translate())
                    if alt_codon_aa != aa:
                        aa2aa2freq[aa][alt_codon_aa] += 1
                        # print("Is nonsyn mutation! Alt {} codes for {}; orig coded for {}".format(
                        #     max_freq_alt_codon, alt_codon_aa, aa
                        # ))

    # Write out stuff for further analysis / in case of crisis
    with open(f"matrix-jsons/{seq}-codon2codon2freq.json", "w") as dumpster:
        dumpster.write(json.dumps(codon2codon2freq))

    with open(f"matrix-jsons/{seq}-codon2freq.json", "w") as dumpster:
        dumpster.write(json.dumps(codon2freq))

    with open(f"matrix-jsons/{seq}-aa2aa2freq.json", "w") as dumpster:
        dumpster.write(json.dumps(aa2aa2freq))

    with open(f"matrix-jsons/{seq}-aa2freq.json", "w") as dumpster:
        dumpster.write(json.dumps(aa2freq))

In [9]:
generate_mutmatrix_data("edge_1671")
generate_mutmatrix_data("edge_6104")
generate_mutmatrix_data("edge_2358")

On gene 1 in seq edge_1671.
TCA, {'TCA': 1385, 'TCG': 11}, is mutation! max freq alt codon is TCG
ACT, {'ACT': 1398, 'ACG': 11, 'AAT': 1}, is mutation! max freq alt codon is ACG
CAT, {'CAT': 1400, 'CAC': 11}, is mutation! max freq alt codon is CAC
CAC, {'CAC': 1415, 'CAT': 12}, is mutation! max freq alt codon is CAT
GGG, {'GGG': 1321, 'GGA': 12}, is mutation! max freq alt codon is GGA
GGT, {'GGT': 1416, 'TGC': 11, 'GCC': 1}, is mutation! max freq alt codon is TGC
AGA, {'AGA': 1419, 'AAA': 12}, is mutation! max freq alt codon is AAA
GGG, {'GGG': 1410, 'GGC': 17}, is mutation! max freq alt codon is GGC
GGA, {'GGA': 1418, 'GGT': 12}, is mutation! max freq alt codon is GGT
AAC, {'AAC': 1422, 'AAT': 12}, is mutation! max freq alt codon is AAT
GCC, {'GCC': 1414, 'ACC': 12}, is mutation! max freq alt codon is ACC
GGC, {'GGC': 1422, 'GGT': 9, 'GCC': 1}, is mutation! max freq alt codon is GGT
GAA, {'GAA': 1415, 'GAG': 9}, is mutation! max freq alt codon is GAG
GAC, {'GAC': 1410, 'GAT': 11, 'GTC