# Pacbio consensus UMI
This notebook calls consensus muations on PacBio data grouped by by cell_barcode, gene, and UMI. Then, it exports a processed CSV with the following columns:
* cell_barcode
* gene
* UMI
* mutation
* mutation_CCS

In [None]:
import Bio.Seq
import Bio.SeqIO

from IPython.display import display

from dms_variants.constants import CBPALETTE

import numpy as np

import pandas as pd

import plotnine as p9

In [None]:
consensus_UMI_csv = snakemake.input.consensus_UMI_csv
viral_genbank = snakemake.input.viral_genbank
expt = snakemake.wildcards.expt
consensus_gene_csv = snakemake.output.consensus_gene_csv

Style parameters:

In [None]:
p9.theme_set(p9.theme_classic())

## Load Data

In [None]:
mutations = pd.read_csv(consensus_UMI_csv)
mutations['mutation'] = mutations['mutation'].fillna('None')
display(mutations)

## Process Data

Generate list of all cell_barcode-gene:

In [None]:
cb_gene = (
    mutations
    [['cell_barcode',
      'gene']]
    .drop_duplicates()
)
display(cb_gene)

Count total UMI for cell_barcode-gene:

In [None]:
total_UMI_df = (
    mutations
    .groupby(['cell_barcode', 'gene'])
    ['UMI']
    .nunique()
    .reset_index()
    .rename(columns={'UMI': 'total_UMI'}))
display(total_UMI_df)

Exclude WT UMI:

In [None]:
mutations_noWT = mutations.query('mutation != "WT"')

display(mutations_noWT)

Count number of UMI supporting each mutation for cell_barcode-gene:

In [None]:
mutation_UMI_df = (
    mutations_noWT
    .groupby(['cell_barcode', 'gene', 'mutation'])
    ['UMI']
    .nunique()
    .reset_index()
    .rename(columns={'UMI': 'mutation_UMI'}))
display(mutation_UMI_df)

Bring in total UMI counts and calculate fraction of total UMI with each mutation:

In [None]:
mutation_frac_df = pd.merge(
    left=mutation_UMI_df,
    right=total_UMI_df,
    on=['cell_barcode', 'gene'],
    how='left',
    validate='many_to_one')
mutation_frac_df['frac_UMI'] = (
    mutation_frac_df['mutation_UMI'] /
    mutation_frac_df['total_UMI']
)
mutation_frac_df['support'] = (
    mutation_frac_df['mutation_UMI'].astype(str)+'/'+
    mutation_frac_df['total_UMI'].astype(str)
)
display(mutation_frac_df)

**Call Consensus Mutations**  
Label mutation as consensus if it is found in more than half of UMIs and at least 2 UMIs. These correspond to `frac_UMI > 0.5` and `mutation_UMI >= 2` in our dataframe.

In [None]:
mutation_frac_df['consensus'] = (
    (mutation_frac_df['frac_UMI'] > 0.5) & 
    (mutation_frac_df['mutation_UMI'] >= 2))
display(mutation_frac_df)

Plot outcome of consensus calling

In [None]:
mutation_frac_histo = (
    p9.ggplot(
        (mutation_frac_df
         [['cell_barcode', 'gene', 'mutation', 'frac_UMI', 'consensus']]
         .drop_duplicates()),
         p9.aes(x='frac_UMI',
                fill='consensus')) +
    p9.geom_histogram(bins=20) +
    p9.ggtitle('Mutation fractions\n'
               '(excludes WT UMI)\n'
               f'{expt}') +
    p9.labs(x='fraction of total UMI for cell_barcode-gene') +
    p9.theme(figure_size=(4, 3),
             plot_title=p9.element_text(size=9),
             axis_title=p9.element_text(size=9),
             legend_title=p9.element_text(size=9),
             legend_title_align='center') +
    p9.scale_fill_manual([CBPALETTE[1], CBPALETTE[0]]))

display(mutation_frac_histo)

## Merge mutatant and WT UMIs
Merge data into single dataframe, `gene_mutations`.  
Steps:  
1. Filter `mutation_frac_df` for only consensus mutations
2. Merge with `cb_gene` dataframe so every `cell_barcode-gene` has at least one row.
3. Fill `mutation` column with "WT" if a `cell_barcode-gene` does not have any consensus mutations. 

In [None]:
# Merge in consesnsus mutations
gene_mutations = pd.merge(
    left=cb_gene,
    right=mutation_frac_df.query('consensus == True'),
    on=['cell_barcode', 'gene'],
    how='left',
    validate='one_to_many'
)

# Fill mutation with WT if no consensus mutation
gene_mutations['mutation'] = (
    gene_mutations['mutation']
    .fillna('WT')
)

display(gene_mutations)

Check that every `cell_barcode-gene-UMI` is represented in final `UMI_mutations` dataframe.

In [None]:
assert len(gene_mutations[['cell_barcode', 'gene']].drop_duplicates()) == \
    len(cb_gene), "Missing cell_barcode-gene from df"

## Plot outcomes

Plot distribution of total UMI per cell_barcode-gene:

In [None]:
total_UMI_histo = (
    p9.ggplot(
        (total_UMI_df),
         p9.aes(x='total_UMI')) +
    p9.geom_histogram(bins=20) +
    p9.ggtitle('n UMI per cell_barcode-gene-UMI\n'
               f'{expt}') +
    p9.labs(x='n UMI',
            y='n cell_barcode-gene') +
    p9.theme(figure_size=(4, 3),
             plot_title=p9.element_text(size=9),
             axis_title=p9.element_text(size=9),
             legend_title=p9.element_text(size=9),
             legend_title_align='center'))

display(total_UMI_histo)

Plot distribution of UMI per mutation:

In [None]:
mutation_UMI_histo = (
    p9.ggplot(
        (mutation_UMI_df
         [['cell_barcode', 'gene', 'mutation', 'mutation_UMI']]
         .drop_duplicates()),
         p9.aes(x='mutation_UMI')) +
    p9.geom_histogram(bins=20) +
    p9.ggtitle('n UMI per mutation\n'
               '(excludes WT UMI)\n'
               f'{expt}') +
    p9.labs(x='n UMI',
            y='n cell_barcode-gene-mutation') +
    p9.theme(figure_size=(4, 3),
             plot_title=p9.element_text(size=9),
             axis_title=p9.element_text(size=9),
             legend_title=p9.element_text(size=9),
             legend_title_align='center'))

display(mutation_UMI_histo)

Plot frac of UMI for each consensus mutation:

In [None]:
consensus_mutations_histo = (
    p9.ggplot(
        (gene_mutations
         [['cell_barcode', 'gene', 'mutation', 'frac_UMI']]
         .drop_duplicates()),
         p9.aes(x='frac_UMI')) +
    p9.geom_histogram(bins=20) +
    p9.ggtitle('Consensus mutation fractions\n'
               '(excludes WT genes)\n'
               f'{expt}') +
    p9.labs(x='fraction of total UMI for cell_barcode-gene') +
    p9.theme(figure_size=(4, 3),
             plot_title=p9.element_text(size=9),
             axis_title=p9.element_text(size=9),
             legend_title=p9.element_text(size=9),
             legend_title_align='center') +
    p9.scale_color_manual([CBPALETTE[1], CBPALETTE[0]]))

display(consensus_mutations_histo)

## Organize output and annotate protein mutations 
Integrate mutations into genotype

In [None]:
genotypes = (
    gene_mutations
    .groupby(['cell_barcode', 'gene'])
    ['mutation', 'support']
    .agg(list)
    .reset_index()
    .rename(columns={'mutation': 'consensus_mutations'})
)
genotypes['consensus_mutations'] = [' '.join(map(str, l)) for l in genotypes['consensus_mutations']]
genotypes['support'] = [' '.join(map(str, l)) for l in genotypes['support']]

display(genotypes)

Bring in support information (total UMI per `cell_barcode-gene`):

In [None]:
output_df = pd.merge(
    left=genotypes,
    right=total_UMI_df,
    on=['cell_barcode', 'gene'],
    how='left',
    validate='one_to_one'
)

display(output_df)

Annotate mutations as to synonymous / nonsynonymous:

In [None]:
viral_cds = {
    s.id: [f for f in s.features if f.type == "CDS"][0].extract(s.seq)
    for s in Bio.SeqIO.parse(viral_genbank, "genbank")
}

def annotate_mutations(row):
    gene = row["gene"]
    cds = viral_cds[gene]
    prot = str(cds.translate())
    muts = row["consensus_mutations"]
    if muts == "WT":
        return muts
    annotated_muts = []
    for mut in muts.split():
        if mut.startswith("ins") or mut.startswith("del"):
            annotated_muts.append(mut)
        else:
            wt = mut[0]
            i = int(mut[1: -1])
            m = mut[-1]
            if 1 <= i <= len(cds):
                mut_cds = list(str(cds))
                assert mut_cds[i - 1] == wt
                mut_cds[i - 1] = m
                mut_cds = Bio.Seq.Seq("".join(mut_cds))
                mut_prot = str(mut_cds.translate())
                if prot == mut_prot:
                    annotated_muts.append(f"{mut}_synonymous")
                else:
                    aamut = [
                        f"{x}{r}{y}" for r, (x, y) in enumerate(zip(prot, mut_prot), start=1) if x != y
                    ]
                    assert len(aamut) == 1
                    aamut = aamut[0]
                    annotated_muts.append(f"{mut}_{aamut}")
            else:
                annotated_muts.append(f"{mut}_noncoding")
    return " ".join(annotated_muts)


output_df = output_df.assign(
    consensus_mutations_annotated=lambda x: x.apply(annotate_mutations, axis=1)
)

output_df

Double check that every cell_barcod-gene is represented in final `output_df`

In [None]:
assert len(output_df[['cell_barcode', 'gene']].drop_duplicates()) == \
    len(cb_gene), "Missing cell_barcode-gene from df"

Make sure no na values are included in final `output_df`. Everything should either have a consensus mutation or be annotated as `"WT"`

In [None]:
assert output_df.notnull().any().any(), \
    "Found null value in output_df"

In [None]:
print(f'Saving gene consensus mutations to {consensus_gene_csv}')
output_df.to_csv(consensus_gene_csv, index=False)