# Correct viral barcodes in progeny
This Python Jupyter notebook corrects viral barcodes in the progeny. Viral barcodes are input as read count data.

The notebook uses `UMI tools` to correct the viral barcodes. UMI tools outputs a set of corrected viral barcodes for each progeny sample. The notebook then associates each original viral barcode with its corrected viral barcodes and aggregates the counts for the corrected viral barcodes. Finally, the notebook outputs the corrected viral barcode count.

**Notes about UMI_tools**
* Using directional adjacency method. This has [been demonstrated on simulated data](https://cgatoxford.wordpress.com/2015/08/14/unique-molecular-identifiers-the-problem-the-solution-and-the-proof/) to produce a more accurate estimate of true UMI number than other heuristics.
* Sequence must be input as byte. See definition here: https://stackoverflow.com/questions/6269765/what-does-the-b-character-do-in-front-of-a-string-literal
* The corrected barcode is returned as the first barcode in the group list. See umi_tools API documentation: https://umi-tools.readthedocs.io/en/latest/API.html

## Notebook setup

Import Python modules:

In [None]:
import gzip

from IPython.display import display

from dms_variants.constants import CBPALETTE

import numpy as np

import pandas as pd

import plotnine as p9

from umi_tools import UMIClusterer

Get `snakemake` variables [as described here](https://snakemake.readthedocs.io/en/stable/snakefiles/rules.html#jupyter-notebook-integration):

In [None]:
viral_bc_in_progeny_csv = snakemake.input.viral_bc_in_progeny_csv
viral_bc_in_progeny_corrected_csv = snakemake.output.viral_bc_in_progeny_corrected_csv
plot = snakemake.output.plot
expt = snakemake.wildcards.expt

Set plot style

In [None]:
p9.theme_set(p9.theme_classic())

Import barcode frequency data

In [None]:
viral_bc_df = pd.read_csv(gzip.open(viral_bc_in_progeny_csv))
display(viral_bc_df)

## Data processing

Cluster barcodes within each cell

In [None]:
clusterer = UMIClusterer(cluster_method="directional")

lookup_list = []

for (source, tag, gene, replicate), df in (viral_bc_df
                                           .groupby(['source',
                                                     'tag',
                                                     'gene',
                                                     'replicate'])):
    # Convert dataframe to dictionary. Dict requried input for umi_tools.
    viral_bc_dict = df.set_index('barcode')['count'].to_dict()

    # Convert barcode strings to byte. Byte required dtype for umi_tools.
    byte_dict = {}
    for key, value in viral_bc_dict.items():
        byte_dict[key.encode("utf-8")] = float(value)

    # Cluster barcodes
    bc_groups = clusterer(byte_dict, threshold=1)
    groups_df = pd.DataFrame(bc_groups)
    groups_df = (groups_df
                 .stack()
                 .str
                 .decode('utf-8')
                 .unstack())  # Convert bytes back to string
    groups_df = groups_df.rename(columns={0: 'corrected_viral_bc'})
    groups_df = groups_df.set_index('corrected_viral_bc', drop=False)

    # Generate lookup table for this sample
    temp_lookup_df = (groups_df.melt(ignore_index=False,
                                     value_name='original_viral_bc')
                      ['original_viral_bc']
                      .dropna()
                      .reset_index())
    temp_lookup_df['source'] = source
    temp_lookup_df['tag'] = tag
    temp_lookup_df['gene'] = gene
    temp_lookup_df['replicate'] = replicate
    lookup_list.append(temp_lookup_df)

lookup_df = pd.concat(lookup_list)
display(lookup_df)

Merge corrected barcode data with barcode frequency data.

In [None]:
viral_bc_df = pd.merge(viral_bc_df,
                       lookup_df,
                       how='outer',
                       left_on=['source', 'tag', 'gene', 'replicate', 'barcode'],
                       right_on=['source', 'tag', 'gene', 'replicate', 'original_viral_bc'],
                       validate='one_to_one')
assert viral_bc_df.isnull().sum().sum() == 0, \
       "Mismatch between barcode frequencies and " \
       "corrected barcodes dataframes"

viral_bc_df = viral_bc_df.drop('barcode', axis=1)

Aggregate counts on corrected barcodes.

In [None]:
viral_bc_df = pd.merge(viral_bc_df,
                       (viral_bc_df
                        .groupby(['source',
                                  'tag',
                                  'gene',
                                  'replicate',
                                  'corrected_viral_bc'])
                        .sum()
                        .reset_index()),
                       on=['source', 'tag', 'gene', 'replicate', 'corrected_viral_bc'],
                       suffixes=['_original_viral_bc', '_corrected_viral_bc'],
                       validate='many_to_one')
assert viral_bc_df.isnull().sum().sum() == 0, \
       "Mismatch between original barcode frequencies and " \
       "corrected barcodes frequencies dataframes"

Annotate original barcodes as corrected or not.

In [None]:
viral_bc_df['corrected'] = (viral_bc_df['corrected_viral_bc']
                            != viral_bc_df['original_viral_bc'])

display(viral_bc_df)

## Visualize Results

Plot number of UMIs that underwent correction. Export plot for report.

In [None]:
fig = (p9.ggplot(viral_bc_df, p9.aes(x='corrected',
                                     y='count_original_viral_bc',
                                     fill='corrected')) +
       p9.geom_bar(stat='identity') +
       p9.facet_grid('source+tag~gene+replicate') +
       p9.ggtitle('Number of reads corrected\n'
                  f'for {expt}') +
       p9.ylab('n reads corrected') +
       p9.theme(figure_size=(2*(viral_bc_df['gene'].nunique()),
                             1.5*(viral_bc_df['tag'].nunique()*
                                  viral_bc_df['source'].nunique())),
                axis_text_x=p9.element_text(angle=90),
                plot_title=p9.element_text(size=10),
                axis_title=p9.element_text(size=10),
                legend_title=p9.element_text(size=9),
                legend_title_align='center') +
       p9.scale_fill_manual(CBPALETTE[0:]))

# save plot
print(f"Saving plot to {plot}")
p9.ggsave(plot=fig, filename=plot, verbose=False)
print("Plot saved.")

# show plot
fig.draw()

Plot fraction of UMIs corrected **within each progeny sample**.

In [None]:
correction_ratios_df = (viral_bc_df
                        .groupby(['source',
                                  'tag',
                                  'gene',
                                  'replicate',
                                  'corrected'])
                        .agg(func=np.sum)
                        .reset_index()
                        .pivot_table(index=['source', 'tag', 'gene', 'replicate'],
                                     columns='corrected',
                                     values='count_original_viral_bc')
                        .reset_index()
                        .fillna(0))

correction_ratios_df['ratio'] = (correction_ratios_df[True] /
                                 (correction_ratios_df[True] +
                                  correction_ratios_df[False]))

display(correction_ratios_df)

fig = (p9.ggplot(correction_ratios_df, p9.aes(x='replicate', y='ratio')) +
       p9.geom_point() +
       p9.facet_grid('source~tag+gene') +
       p9.ggtitle('Fraction of reads corrected\n'
                  f'for {expt}') +
       p9.labs(x='gene',
               y='frac reads corrected') +
       p9.theme(figure_size=(5,
                             1.25*viral_bc_df['source'].nunique()),
                plot_title=p9.element_text(size=10),
                axis_title=p9.element_text(size=10),
                legend_title=p9.element_text(size=9),
                legend_title_align='center',
                axis_text_x=p9.element_text(angle = 90, hjust = 1))
       )

# show plot
fig.draw()

## Export corrected barcode frequencies

In [None]:
export_df = (viral_bc_df
             [['source',
               'tag',
               'gene',
               'replicate',
               'corrected_viral_bc',
               'count_corrected_viral_bc']]
             .reset_index(drop=True)
             .drop_duplicates()
             .rename(columns={'corrected_viral_bc': 'barcode',
                              'count_corrected_viral_bc': 'count'})
             )

# Write to csv
export_df.to_csv(viral_bc_in_progeny_corrected_csv,
                 index=False)