# Gene analysis utility functions

This relies on the `Header` and `LoadMutationJSONData` notebooks already having been run.

In [22]:
import skbio
import pandas as pd
from statistics import mean
from parse_sco import parse_sco

def histogram_of_123(vals, title, axes, xlabel="Codon Position", ylabel="Number of Mutated Positions",
                     format_yaxis="thousands", max_val=None, show_ylabel=True, show_xlabel=True):
    if len(vals) == 3:
        one, two, three = vals
        x = [1, 2, 3]
        labels = ["1", "2", "3"]
    elif len(vals) == 4:
        one, two, three, four = vals
        x = [1, 2, 3, 4]
        labels = ["1", "2", "3", "Non-Coding"]
    else:
        raise ValueError("Not 3 or 4 vals passed?")
    axes.bar(
        x=x,
        height=vals,
        color=[cp2color[i + 1] for i in range(len(vals))],
        edgecolor=BORDERCOLOR,
        tick_label=labels
    )
    
    if max_val is None:
        ypad = 0.04 * max(vals)
    else:
        ypad = 0.04 * max_val
    # to make interpreting the plots easier, show text for each bar
    # just above it: https://stackoverflow.com/a/30229062
    # and https://stackoverflow.com/questions/30228069/how-to-display-the-value-of-the-bar-on-each-bar-with-pyplot-barh/30229062#comment86813015_30229062
    for xi, yval in enumerate(vals, 1):
        # Used to make decision about normalized vs unnormalized values
        if format_yaxis == "percentages":
            text = f"{(yval * 100):.2f}%"
        else:
            text = f"{yval:,}"
        axes.text(xi, yval + ypad, text, ha="center")
        
    print(title, vals)
    if three > one and one > two:
        axes.set_title(title)
    else:
        axes.set_title(title, color="#cc3322", fontweight="semibold")
            
    if show_xlabel:
        axes.set_xlabel(xlabel)
    if show_ylabel:
        axes.set_ylabel(ylabel)
        
    if format_yaxis == "thousands":
        use_thousands_sep(axes.get_yaxis())
    elif format_yaxis == "percentages":
        # Make the y-axis show percentages: based on
        # https://old.reddit.com/r/learnpython/comments/7adhnk/matplotlib_setting_y_axis_labels_to_percent_yaxis/dp93fwq/
        # (This is duplicated within this notebook... TODO, make this a function like use_thousands_sep()?)
        axes.get_yaxis().set_major_formatter(matplotlib.ticker.PercentFormatter(xmax=1, decimals=2))
    else:
        raise ValueError("Unrecognized format_yaxis value: {}".format(format_yaxis))
    if max_val is not None:
        # Shoddy way to add padding -- see https://stackoverflow.com/a/42804403
        axes.set_ylim(0, max_val * 1.3)
    else:
        axes.set_ylim(0, max(vals) * 1.3)
    
def get_pos_interval_from_gene(gene):
    try:
        start = int(gene.LeftEnd)
    except ValueError:
        # Account for the "<2" left end coordinate in one of the genes predicted
        start = int(gene.LeftEnd[1:])
        
    try:
        end = int(gene.RightEnd)
    except ValueError:
        # Account for the ">N" right end coordinate in one of the genes predicted
        end = int(gene.RightEnd[1:])

    if gene.Strand == "-":
        return range(end, start - 1, -1)
    elif gene.Strand == "+":
        return range(start, end + 1)
    else:
        raise ValueError("Unknown strand: {}".format(gene.Strand))

def get_val(seq, pos, p, fasta):
    """Calls simple p-mutations given a seq, position, value of p, and seq's DNA sequence ("fasta").
    
    Returns 1 if pos is a p-mutation, returns 0 otherwise.
    
    p should be in the range [0, 1] (i.e. if it's stored in the range [0, 100], you should divide it by
    100 first).
    
    pos should be 1-indexed, to be compatible with the JSONs.
    
    Note that the DNA sequence str, fasta, uses 0-indexing; that's fine, we'll account for that here.
    
    Note that this does not (currently) take the idea of minimum sufficient coverage into account. So you
    can use p = 0 or whatever without worrying about that causing math problems (although the ideas behind
    minimum sufficent coverage still apply, so when this function is used you should be careful to mention
    why we are not taking suff coverage into account here). If pos has a coverage of 0 (i.e. no aligned
    matching or mismatching reads), this'll always return 0 (i.e. "not a p-mutation").
    """ 
    mismatchct = seq2pos2mismatchct[seq][str(pos)]
    matchct = seq2pos2matchct[seq][str(pos)]
    
    # Note that, as mentioned above, this isn't "complete" coverage -- it can be zero
    # if e.g. all the reads covering a position are deletions
    cov = mismatchct + matchct
    if cov > 0:
        # skbio sequences use 0-indexed coords, so subtract 1 from pos
        ref = fasta[pos - 1]
        nonmatches = set("ACGT") - set(ref)
        # Instead of aggregating all mismatching nucleotides, only consider the individual mismatching
        # nucleotides
        alt_freqs = seq2pos2mismatches[seq][str(pos)]
        for alt in alt_freqs:
            if (alt_freqs[alt] / cov) > p:
                return 1
        # If we've made it here, none of the alt nucleotide freqs were a p-mutation.
        return 0
    else:
        # 0x coverage (at least, considering matching/mismatching reads).
        # We thus do not have the data to detect if this position is a p-mutation.
        return 0
        
def histogram_maker(
    p,
    title,
    axes,
    ylabel="Number of Mutated Positions",
    normalize=False,
    add_noncoding_col=True,
    make_yaxes_comparable=False,
    output_dict=False,
    show_xlabel=True,
    show_ylabel=True,
):
    """
    Produces histograms of pileup data for the 1st, 2nd, and 3rd positions of predicted genes.

    Parameters
    ----------
    p: float
        In the range [0, 1].
        
    title: str
        Will be included after the seq name in every histogram.
        
    ylabel: str
        Label for the y-axis of the histogram. Only will be used if normalize is False.
        
    normalize: bool
        If True, divides each 1/2/3 value by (number of positions considered). This makes it easier to compare
        histograms between different sequences.
        This also sets ylabel to "Number of Mutated Positions / Number of Positions", ignoring whatever
        ylabel's default was (or even ignoring the already-specified ylabel).
        
    add_noncoding_col: bool
        If True, adds a 4th column to each histogram representing all of the positions not contained within
        predicted protein-coding genes. If normalize is True, this column's value is divided by the total number
        of these positions, so it's also a percentage.
        
    make_yaxes_comparable: bool
        If True, sets the ylim max to just over the max y-value across all seqs' plots (so the y-range is the
        same for each row). Could be useful, could be bad if the values wildly differ btwn seqs (in which case
        the seqs with lower values could be hard to read)
    """
    if normalize:
        ylabel = r"$\dfrac{\mathrm{Number\ of\ Mutated\ Positions}}{\mathrm{Number\ of\ Positions}}$"
    seq2vals = {}
    max_val = 0
    for seq in SEQS:
        df = parse_sco(f"../seqs/genes/{seq}.sco")
        fasta = str(skbio.DNA.read(f"../seqs/{seq}.fasta"))

        # Mutation rates at modulo positions 1, 2, 3 in a gene within the genome (goes 1, 2, 3, 1, 2, 3, ...)
        m1 = []
        m2 = []
        m3 = []
        total_gene_length = 0
        bases_in_genes = set()
        for gene in df.itertuples():
            i = 1
            gm1 = []
            gm2 = []
            gm3 = []
            pos_interval = get_pos_interval_from_gene(gene)
            total_gene_length += len(pos_interval)
            bases_in_genes |= set(pos_interval)
            for pos in pos_interval:
                val = get_val(seq, pos, p, fasta)
                if i == 1:
                    gm1.append(val)
                    i = 2
                elif i == 2:
                    gm2.append(val)
                    i = 3
                elif i == 3:
                    gm3.append(val)
                    i = 1
                else:
                    raise ValueError("Marcus forgot how modulos work, go yell at him pls")

            if i != 1:
                raise ValueError("Gene length not divisible by 3.")
                
            m1 += gm1
            m2 += gm2
            m3 += gm3
        
        if not (len(m1) == len(m2) == len(m3)):
            raise ValueError("Imbalance in 1/2/3 positions.")
        
        mutation_vals = [m1, m2, m3]
        if add_noncoding_col:
            mn = []
            # Figure out which bases are not in any genes. Computing a set of a range is a pretty inefficient
            # way to do this, probably, but it works and is easy to reason about so i'll take it :P
            bases_not_in_genes = set(range(1, seq2len[seq] + 1)) - bases_in_genes
            for pos in bases_not_in_genes:
                val = get_val(seq, pos, p, fasta)
                mn.append(val)
            mutation_vals.append(mn)

        # Set vals to just the number of mutations in each pos in mutation_vals
        vals = [sum(m) for m in mutation_vals]
        
        if normalize:
            format_yaxis = "percentages"
            # Divide, to get (# mutated positions) / (# positions)
            # Note that we DON'T divide by just gene length / 3 (which would work for CP 1/2/3). Two reasons
            # for this:
            # 1. Overlapping genes can mess with this
            # 2. For non-coding positions (if we're adding a 4th col for these), this doesn't make sense!
            for i in range(len(vals)):
                num_total_positions = len(mutation_vals[i])
                vals[i] /= num_total_positions
        else:
            format_yaxis = "thousands"
            
        max_val = max(max_val, max(vals))
        seq2vals[seq] = vals

    # Delay creating a histogram for a given genome until we've computed values for all of the genomes --
    # this lets us be fancy and set all histograms in a row to the same max value on the y-axis if needed
    i = 0
    for seq in SEQS:
        if not make_yaxes_comparable:
            max_val = None
        histogram_of_123(seq2vals[seq], "{}: {}".format(seq2name[seq], title), axes[i],
                         ylabel=ylabel, format_yaxis=format_yaxis, max_val=max_val,
                         show_ylabel=(i == 0) and show_ylabel, show_xlabel=(i == 1) and show_xlabel)
        i += 1
        
def histogram_matrix_maker(percentage_thresholds, normalize=False, make_yaxes_comparable=False, figfilename=None):
    fig, axes = pyplot.subplots(
        len(percentage_thresholds),
        3,
        gridspec_kw={"hspace": 0.5, "wspace": 0.35}
    )
    if len(percentage_thresholds) % 2 == 0:
        middle_row = len(percentage_thresholds) / 2
    else:
        # this definitely works when you have 7 items (produces int(7/2) = int(3.5) = 3, which is the
        # correct "middle" number), but it's 5am and idk if it will hold for every odd number ever
        middle_row = int(len(percentage_thresholds) / 2)
    row = 0
    p2pct = get_p2pct(percentage_thresholds)
    for p in percentage_thresholds:
        
        axes_in_this_row = axes[row, :]
        if row == len(percentage_thresholds) - 1:
            show_xlabel = True
        else:
            show_xlabel = False
            
        show_ylabel = (row == middle_row)
        
        histogram_maker(
            p2pct[p],
            "$p$ = " + str(p) + "%",
            axes=axes_in_this_row,
            normalize=normalize,
            make_yaxes_comparable=make_yaxes_comparable,
            show_xlabel=show_xlabel,
            show_ylabel=show_ylabel,
        )
        row += 1
    
    title = "$p$-mutation frequencies across coding and non-coding positions"
    titley = 0.93
    if normalize:
        title += "\n(each bar normalized by total number of positions in that category)"
        # a bit extra y space to account for extra line
        titley = 0.94
    fig.suptitle(title, y=titley, x=0.5, fontsize=20)
    fig.set_size_inches(15, 20)
    if figfilename is not None:
        fig.savefig("figs/{}".format(figfilename), bbox_inches="tight")