# Filter viral barcodes in transcripts
This notebook filters viral barcodes in 10X transcriptome data to remove UMIs that are likely derived from leaked transcripts.

## Notebook setup
Import python modules:

In [None]:
from IPython.display import display

import pandas as pd

import plotnine as p9

Get `snakemake` variables [as described here](https://snakemake.readthedocs.io/en/stable/snakefiles/rules.html#jupyter-notebook-integration):

In [None]:
viral_tag_by_cell_csv = (snakemake
                         .input
                         .viral_tag_by_cell_csv)
viral_bc_by_cell_corrected_csv = (snakemake
                                  .input
                                  .viral_bc_by_cell_corrected_csv)
viral_bc_by_cell_filtered_csv = (snakemake
                                 .output
                                 .viral_bc_by_cell_filtered_csv)
plot = snakemake.output.plot
expt = snakemake.wildcards.expt
barcoded_viral_genes = snakemake.params.barcoded_viral_genes

## Organize data

Read the viral barcode UMI counts data into a pandas dataframe:

In [None]:
viral_bc_counts = pd.read_csv(viral_bc_by_cell_corrected_csv)
viral_bc_counts = (viral_bc_counts
                   .rename(columns={'count': 'viral_bc_UMIs'}))
assert set(viral_bc_counts['gene']) == set(barcoded_viral_genes), \
       "Barcoded genes in barcode counts do not match expectation."
display(viral_bc_counts)

Read the total number of UMIs per cell into a pandas dataframe. Only keep relevant columns.

In [None]:
all_cells = pd.read_csv(viral_tag_by_cell_csv)
all_cells = all_cells[['cell_barcode',
                       'infected',
                       'infecting_viral_tag',
                       'total_UMIs',
                       'viral_UMIs',
                       'cellular_UMIs',
                       'frac_viral_UMIs']]
display(all_cells)

Sanity check that `total_UMIs` is equal to `viral_UMIs + cellular_UMIs`:

In [None]:
assert bool((all_cells['total_UMIs'] ==
             all_cells['viral_UMIs'] +
             all_cells['cellular_UMIs'])
            .all()), "UMI counts do not add up"

Merge dataframes:

In [None]:
viral_bc_frac = pd.merge(
    left=pd.concat([all_cells.assign(gene=gene)
                    for gene in barcoded_viral_genes]),
    right=viral_bc_counts,
    how='outer',
    on=['cell_barcode', 'gene'],
    validate='one_to_many')
assert (viral_bc_frac['cell_barcode'].unique() ==
        all_cells['cell_barcode'].unique()).all(), \
       "Cell barcodes in merged dataframe don't " \
       "match barcodes in source data."
assert (viral_bc_frac['viral_barcode'].nunique() ==
        viral_bc_counts['viral_barcode'].nunique()), \
       "Number of viral barcodes in merged dataframe don't " \
       "match number of barcodes in source data."

# Make `infecting_viral_tag` column ordered category
viral_bc_frac['infecting_viral_tag'] = (viral_bc_frac['infecting_viral_tag']
                                        .astype('category')
                                        .cat
                                        .reorder_categories(['none',
                                                             'wt',
                                                             'syn',
                                                             'both']))

display(viral_bc_frac)

Calculate **each barcode's** fraction of all UMIs per cell:

In [None]:
viral_bc_frac = (
    viral_bc_frac
    .assign(viral_bc_UMIs=lambda x: (x['viral_bc_UMIs']
                                     .fillna(0)
                                     .astype(int, errors='raise'))))


viral_bc_frac['frac_viral_bc_UMIs'] = (
    viral_bc_frac['viral_bc_UMIs'] /
    viral_bc_frac['total_UMIs'])

display(viral_bc_frac)

**For each cell, for each gene,** sum the number of viral barcode UMIs.

In [None]:
viral_bc_frac = (
    viral_bc_frac
    .assign(sum_UMIs_with_viral_bc_for_cell_and_gene=(
        lambda x: ((x
                    .groupby(['cell_barcode', 'gene'])
                    ['viral_bc_UMIs'])
                   .transform('sum')))))

display(viral_bc_frac)

Calculate fraction of all UMIs and fraction of viral UMIs that have barcode for each cell-gene.

In [None]:
viral_bc_frac['frac_total_UMIs_with_viral_bc_for_cell_and_gene'] = (
    viral_bc_frac['sum_UMIs_with_viral_bc_for_cell_and_gene'] /
    viral_bc_frac['total_UMIs'])
viral_bc_frac['frac_viral_UMIs_with_viral_bc_for_cell_and_gene'] = (
    viral_bc_frac['sum_UMIs_with_viral_bc_for_cell_and_gene'] /
    viral_bc_frac['viral_UMIs'])
display(viral_bc_frac.groupby(['cell_barcode', 'gene']).head())

## Plots
Set base plot style:

In [None]:
p9.theme_set(p9.theme_classic())

### Per cell metrics
Summary figures that address number and fraction of barcoded viral UMIs **in aggregate for each cell and gene**.

Plot number of UMIs that contain a viral barcode per cell-gene combination and fraction of UMIs that contain a viral barcode per cell-gene combination.

In [None]:
fig = (p9.ggplot((viral_bc_frac
                  .drop_duplicates(subset=['cell_barcode', 'gene'])),
                 p9.aes(x='sum_UMIs_with_viral_bc_for_cell_and_gene')) +
       p9.geom_histogram(bins=20, fill="#3a3a3a") +
       p9.facet_grid('~gene') +
       p9.ggtitle(f'{expt}') +
       p9.labs(x='sum of UMIs with viral bc per cell-gene',) +
       p9.scale_y_log10() +
       p9.theme(figure_size=(4, 2),
                plot_title=p9.element_text(size=9),
                axis_title=p9.element_text(size=9),
                legend_title=p9.element_text(size=9),
                legend_title_align='center'))
display(fig)

In [None]:
fig = (p9.ggplot((viral_bc_frac
                  .drop_duplicates(subset=['cell_barcode', 'gene'])),
                 p9.aes(x='frac_total_UMIs_with_viral_bc_for_cell_and_gene')) +
       p9.geom_histogram(bins=20, fill="#3a3a3a") +
       p9.facet_grid('~gene') +
       p9.ggtitle(f'{expt}') +
       p9.labs(x='fraction of UMIs with viral bc per cell-gene',) +
       p9.scale_y_log10() +
       p9.theme(figure_size=(4, 2),
                plot_title=p9.element_text(size=9),
                axis_title=p9.element_text(size=9),
                legend_title=p9.element_text(size=9),
                legend_title_align='center'))
display(fig)

Plot fraction of barcoded UMIs per cell-gene as a function of total UMI count:

In [None]:
fig = (p9.ggplot(viral_bc_frac,
                 p9.aes(x='total_UMIs',
                        y='frac_total_UMIs_with_viral_bc_for_cell_and_gene')) +
       p9.geom_point(alpha=0.01) +
       p9.facet_grid('~gene') +
       p9.ggtitle(f'{expt}') +
       p9.labs(x='total UMIs per cell',
               y='fraction of UMIs with viral bc\nper cell-gene') +
       p9.theme(figure_size=(4, 2),
                plot_title=p9.element_text(size=9),
                axis_title=p9.element_text(size=9),
                legend_title=p9.element_text(size=9),
                legend_title_align='center'))
display(fig)

Plot fraction of barcoded UMIs per cell-gene as a function of viral UMI count and viral UMI fraction:

In [None]:
fig = (p9.ggplot(viral_bc_frac,
                 p9.aes(x='viral_UMIs',
                        y='frac_total_UMIs_with_viral_bc_for_cell_and_gene')) +
       p9.geom_point(alpha=0.01) +
       p9.facet_grid('~gene') +
       p9.ggtitle(f'{expt}') +
       p9.labs(x='viral UMIs per cell',
               y='fraction of UMIs with viral bc\nper cell-gene') +
       p9.theme(figure_size=(4, 2),
                plot_title=p9.element_text(size=9),
                axis_title=p9.element_text(size=9),
                legend_title=p9.element_text(size=9),
                legend_title_align='center'))
display(fig)

In [None]:
fig = (p9.ggplot(viral_bc_frac,
                 p9.aes(x='frac_viral_UMIs',
                        y='frac_total_UMIs_with_viral_bc_for_cell_and_gene')) +
       p9.geom_point(alpha=0.01) +
       p9.facet_grid('~gene') +
       p9.ggtitle(f'{expt}') +
       p9.labs(x='viral UMI fraction per cell',
               y='fraction of UMIs with viral bc\nper cell-gene') +
       p9.theme(figure_size=(4, 2),
                plot_title=p9.element_text(size=9),
                axis_title=p9.element_text(size=9),
                legend_title=p9.element_text(size=9),
                legend_title_align='center'))
display(fig)

Plot relationship between viral UMI fraction and fraction of viral UMIs with barcode:

In [None]:
fig = (p9.ggplot(viral_bc_frac,
                 p9.aes(x='frac_viral_UMIs',
                        y='frac_viral_UMIs_with_viral_bc_for_cell_and_gene')) +
       p9.geom_point(alpha=0.01) +
       p9.facet_grid('~gene') +
       p9.ggtitle(f'{expt}') +
       p9.labs(x='viral UMI fraction per cell',
               y='fraction of viral UMIs\nwith viral bcper cell-gene') +
       p9.theme(figure_size=(4, 2),
                plot_title=p9.element_text(size=9),
                axis_title=p9.element_text(size=9),
                legend_title=p9.element_text(size=9),
                legend_title_align='center'))
display(fig)

### Per barcode metrics
Figures that show the outcomes **for each barcode individually.**

Plot distribution of each barcode's fraction of a cell's total UMIs.

In [None]:
fig = (p9.ggplot(viral_bc_frac, p9.aes(x='frac_viral_bc_UMIs')) +
       p9.geom_histogram(bins=50, fill="#3a3a3a") +
       p9.facet_grid('infecting_viral_tag~gene') +
       p9.ggtitle('viral barcode distribution\n'
                  f'{expt}') +
       p9.labs(x='each viral barcode\'s fraction of total UMIs',) +
       p9.scale_y_log10() +
       p9.theme(figure_size=(6, 4),
                plot_title=p9.element_text(size=10),
                axis_title=p9.element_text(size=10),
                legend_title=p9.element_text(size=9),
                legend_title_align='center'))
display(fig)

Plot relationship between viral burden (frac_viral_UMIs) and each viral barcode's fraction of total UMIs:

In [None]:
fig = (p9.ggplot(viral_bc_frac, p9.aes(x='frac_viral_bc_UMIs',
                                       y='frac_viral_UMIs')) +
       p9.geom_point(alpha=0.05) +
       p9.facet_grid('infecting_viral_tag~gene') +
       p9.ggtitle('correlation between viral barcode fraction\n'
                  'and viral burden\n'
                  f'{expt}') +
       p9.labs(x='each viral barcode\'s fraction of total UMIs',
               y='viral UMI fraction per cell') +
       p9.scale_y_log10() +
       p9.scale_x_log10() +
       p9.theme(figure_size=(6, 4),
                plot_title=p9.element_text(size=10),
                axis_title=p9.element_text(size=10),
                legend_title=p9.element_text(size=9),
                legend_title_align='center'))
display(fig)

In [None]:
#  Temporary cell to create files expected by snakemake.
#  Not final version

print(f"Saving plots to {plot}")
p9.ggsave(plot=fig, filename=plot, verbose=False)

print(f"Saving filtered barcodes to {viral_bc_by_cell_filtered_csv}")

viral_bc_frac.to_csv(viral_bc_by_cell_filtered_csv,
                     columns=['cell_barcode',
                              'gene',
                              'frac_viral_bc_UMIs'],
                     index=False)