# Filter viral barcodes in transcripts
This notebook filters viral barcodes in 10X transcriptome data to remove UMIs that are likely derived from leaked transcripts.

## Notebook setup
Import python modules:

In [None]:
from IPython.display import display

from dms_variants.constants import CBPALETTE

import kneed

import pandas as pd

import plotnine as p9

Get `snakemake` variables [as described here](https://snakemake.readthedocs.io/en/stable/snakefiles/rules.html#jupyter-notebook-integration):

In [None]:
viral_tag_by_cell_csv = (snakemake
                         .input
                         .viral_tag_by_cell_csv
viral_bc_by_cell_corrected_csv = (snakemake
                                  .input
                                  .viral_bc_by_cell_corrected_csv)
viral_bc_by_cell_filtered_csv = (snakemake
                                 .output
                                 .viral_bc_by_cell_filtered_csv)
plot = snakemake.output.plot
expt = snakemake.wildcards.expt


## Organize data

Read the viral barcode UMI counts data into a pandas dataframe:

In [None]:
viral_bc_counts = pd.read_csv(viral_bc_by_cell_corrected_csv)
viral_bc_counts = (viral_bc_counts
                   .rename(columns={'count': 'bc_UMIs'}))
display(viral_bc_counts)

Read the total number of UMIs per cell into a pandas dataframe. Only keep relevant columns.

In [None]:
all_cells = pd.read_csv(viral_tag_by_cell_csv)
all_cells = all_cells[['cell_barcode',
                       'infected',
                       'infecting_viral_tag',
                       'total_UMIs',
                       'viral_UMIs',
                       'cellular_UMIs',
                       'frac_viral_UMIs']]
display(all_cells)

Sanity check that `total_UMIs` is equal to `viral_UMIs + cellular_UMIs`:

In [None]:
assert bool((all_cells['total_UMIs'] ==
             all_cells['viral_UMIs'] +
             all_cells['cellular_UMIs'])
            .all()), "UMI counts do not add up"

Merge dataframes:

In [None]:
viral_bc_frac = pd.merge(left=viral_bc_counts,
                         right=all_cells,
                         how='outer',
                         on=['cell_barcode'])

# Make `infecting_viral_tag` column ordered category
viral_bc_frac['infecting_viral_tag'] = (viral_bc_frac['infecting_viral_tag']
                                        .astype('category')
                                        .cat
                                        .reorder_categories(['none',
                                                             'wt',
                                                             'syn',
                                                             'both']))

# Fill NaN values with meaningful description
viral_bc_frac['gene'] = viral_bc_frac['gene'].fillna('none')
viral_bc_frac['viral_barcode'] = viral_bc_frac['viral_barcode'].fillna('none')
viral_bc_frac['bc_UMIs'] = viral_bc_frac['bc_UMIs'].fillna(0)

display(viral_bc_frac)

Sum total number of barcoded UMIs per cell:

In [None]:
viral_bc_frac = pd.merge(left=viral_bc_frac,
                         right=(viral_bc_frac
                                .groupby('cell_barcode')
                                ['bc_UMIs']
                                .sum()
                                .reset_index()
                                .rename(columns={'bc_UMIs':
                                                 'bc_UMIs_in_cell'})),
                         on=['cell_barcode'])

display(viral_bc_frac)

Calculate fraction of all UMIs and fraction of viral UMIs that have barcode.

In [None]:
viral_bc_frac['frac_total_UMIs_wBC'] = (viral_bc_frac['bc_UMIs_in_cell'] /
                                        viral_bc_frac['total_UMIs'])
viral_bc_frac['frac_viral_UMIs_wBC'] = (viral_bc_frac['bc_UMIs_in_cell'] /
                                        viral_bc_frac['viral_UMIs'])
display(viral_bc_frac)

Calculate **each barcode's** fraction of all UMIs per cell:

In [None]:
viral_bc_frac['frac_bc_UMIs'] = (viral_bc_frac['bc_UMIs'] /
                                 viral_bc_frac['total_UMIs'])

display(viral_bc_frac)

## Plots
Set base plot style:

In [None]:
p9.theme_set(p9.theme_classic())

### Per cell metrics
Summary figures that address number and fraction of barcoded viral UMIs **in aggregate for each cell**.

Plot number of viral barcode UMIs per cell and fraction of all UMIs with viral barcode per cell.

In [None]:
fig = (p9.ggplot((viral_bc_frac
                  .groupby('cell_barcode')
                  .max()
                  .reset_index()), p9.aes(x='bc_UMIs_in_cell')) +
       p9.geom_histogram(bins=20) +
       p9.ggtitle('barcode UMI counts per cell\n'
                  f'{expt}') +
       p9.theme(figure_size=(2, 2),
                plot_title=p9.element_text(size=9),
                axis_title=p9.element_text(size=9),
                legend_title=p9.element_text(size=9),
                legend_title_align='center'))
display(fig)

In [None]:
fig = (p9.ggplot((viral_bc_frac
                  .groupby('cell_barcode')
                  .max()
                  .reset_index()), p9.aes(x='frac_total_UMIs_wBC')) +
       p9.geom_histogram(bins=20) +
       p9.ggtitle('fraction of barcoded UMIs per cell\n'
                  f'{expt}') +
       p9.theme(figure_size=(2, 2),
                plot_title=p9.element_text(size=9),
                axis_title=p9.element_text(size=9),
                legend_title=p9.element_text(size=9),
                legend_title_align='center'))
display(fig)

Plot fraction of barcoded UMIs per cell as a function of viral UMI count and fraction:

In [None]:
fig = (p9.ggplot((viral_bc_frac
                  .groupby('cell_barcode')
                  .max()
                  .reset_index()), p9.aes(x='viral_UMIs',
                                          y='frac_total_UMIs_wBC')) +
       p9.geom_point(alpha=0.3) +
       p9.ggtitle('fraction of barcoded UMIs per cell\n'
                  'vs total viral UMI count\n'
                  f'{expt}') +
       p9.theme(figure_size=(2, 2),
                plot_title=p9.element_text(size=9),
                axis_title=p9.element_text(size=9),
                legend_title=p9.element_text(size=9),
                legend_title_align='center'))
display(fig)

In [None]:
fig = (p9.ggplot((viral_bc_frac
                  .groupby('cell_barcode')
                  .max()
                  .reset_index()), p9.aes(x='frac_viral_UMIs',
                                          y='frac_total_UMIs_wBC')) +
       p9.geom_point(alpha=0.3) +
       p9.ggtitle('fraction of barcoded UMIs per cell\n'
                  'vs fraction viral UMIs\n'
                  f'{expt}') +
       p9.theme(figure_size=(2, 2),
                plot_title=p9.element_text(size=9),
                axis_title=p9.element_text(size=9),
                legend_title=p9.element_text(size=9),
                legend_title_align='center'))
display(fig)

Plot relationship between viral UMI fraction and fraction of viral UMIs with barcode:

In [None]:
fig = (p9.ggplot((viral_bc_frac
                  .groupby('cell_barcode')
                  .max()
                  .reset_index()), p9.aes(x='frac_viral_UMIs',
                                          y='frac_viral_UMIs_wBC')) +
       p9.geom_point(alpha=0.3) +
       p9.ggtitle('fraction of barcoded UMIs per cell\n'
                  'vs fraction barcoded UMIs in viral UMIs\n'
                  f'{expt}') +
       p9.theme(figure_size=(2, 2),
                plot_title=p9.element_text(size=9),
                axis_title=p9.element_text(size=9),
                legend_title=p9.element_text(size=9),
                legend_title_align='center'))
display(fig)

### Per barcode metrics
Figures that show the outcomes **for each barcode individually.**

Plot distribution of each barcode's fraction of a cell's total UMIs.

In [None]:
fig = (p9.ggplot(viral_bc_frac, p9.aes(x='frac_bc_UMIs',
                                       fill='infecting_viral_tag')) +
       p9.geom_histogram(bins=200, position='stack') +
       p9.facet_grid('infecting_viral_tag~') +
       p9.ggtitle('fraction of all UMIs in cell\n'
                  'from each viral barcode\n'
                  f'{expt}') +
       p9.scale_y_log10() +
       p9.theme(figure_size=(5,
                             1*viral_bc_frac['infecting_viral_tag'].nunique()),
                plot_title=p9.element_text(size=10),
                axis_title=p9.element_text(size=10),
                legend_title=p9.element_text(size=9),
                legend_title_align='center') +
       p9.scale_fill_manual(CBPALETTE[0:]))
display(fig)

In [None]:
fig = (p9.ggplot(viral_bc_frac, p9.aes(x='frac_bc_UMIs',
                                       y='frac_viral_UMIs',
                                       color='infecting_viral_tag')) +
       p9.geom_point(alpha=0.3) +
       p9.facet_grid('infecting_viral_tag~') +
       p9.ggtitle('fraction of all UMIs in cell\n'
                  'from each viral barcode\n'
                  'vs fraction of viral UMIs in that cell\n'
                  f'{expt}') +
       p9.scale_y_log10() +
       p9.scale_x_log10() +
       p9.theme(figure_size=(4,
                             2*viral_bc_frac['infected'].nunique()),
                plot_title=p9.element_text(size=10),
                axis_title=p9.element_text(size=10),
                legend_title=p9.element_text(size=9),
                legend_title_align='center') +
       p9.scale_color_manual(CBPALETTE[0:]))
display(fig)

### Knee plots

Rank order barcodes by fraction of total UMIs in their cell:

In [None]:
viral_bc_frac = (viral_bc_frac
                 .sort_values('frac_bc_UMIs', ascending=True)
                 .reset_index(drop=True))

viral_bc_frac = (viral_bc_frac
                 .assign(bc_rank=lambda x: (x['frac_bc_UMIs']
                                            .rank(method='first',
                                                  ascending=True))))

display(viral_bc_frac)

Calculate knee using [kneed](https://pypi.org/project/kneed/) package:

In [None]:
kl = kneed.KneeLocator(x=viral_bc_frac['bc_rank'].tolist(),
                       y=viral_bc_frac['frac_bc_UMIs'].tolist(),
                       curve='convex',
                       direction='increasing',
                       S=10
                       )

viral_bc_frac['below_knee'] = viral_bc_frac['frac_bc_UMIs'] <= kl.knee_y

display(viral_bc_frac)

print(f'knee rank: {kl.knee}')
print(f'fraction of all UMIs at knee: {kl.knee_y}')
print(f'barcodes retained: {len(viral_bc_frac.query("below_knee == False"))}')

Plot knee plot of each barcode's fraction of all UMIs in cell:

In [None]:
fig = (p9.ggplot(viral_bc_frac, p9.aes(x='bc_rank',
                                       y='frac_bc_UMIs',
                                       color='infecting_viral_tag')) +
       p9.geom_point() +
       p9.geom_vline(xintercept=kl.knee,
                     linetype='dashed',
                     color='#3A3B3C',
                     size=0.5) +
       p9.facet_grid('~infected') +
       p9.ggtitle('knee plot of viral barcodes\n'
                  'fraction of all UMIs in cell\n'
                  f'{expt}') +
       p9.xlab('barcode fraction rank') +
       p9.ylab('fraction of all UMIs in cell\n'
               'assigned to barcode') +
       p9.theme(figure_size=(6, 2),
                plot_title=p9.element_text(size=10),
                axis_title=p9.element_text(size=10),
                legend_title=p9.element_text(size=9),
                legend_title_align='center') +
       p9.scale_color_manual(CBPALETTE[0:]))

display(fig)

Break the plot out by infecting viral tag and viral gene:

In [None]:
fig = (p9.ggplot(viral_bc_frac, p9.aes(x='bc_rank',
                                       y='frac_bc_UMIs',
                                       color='below_knee')) +
       p9.geom_point() +
       p9.geom_vline(xintercept=kl.knee,
                     linetype='dashed',
                     color='#3A3B3C',
                     size=0.5) +
       p9.facet_grid('infecting_viral_tag~gene') +
       p9.ggtitle('knee plot of viral barcodes\n'
                  'fraction of all UMIs in cell\n'
                  f'{expt}') +
       p9.xlab('barcode fraction rank') +
       p9.ylab('fraction of all UMIs in cell\n'
               'assigned to barcode') +
       p9.theme(figure_size=(6, 5),
                plot_title=p9.element_text(size=10),
                axis_title=p9.element_text(size=10),
                legend_title=p9.element_text(size=9),
                legend_title_align='center') +
       p9.scale_color_manual(CBPALETTE[0:]))

display(fig)

In [None]:
#  Temporary cell to create files expected by snakemake.
#  Not final version

print(f"Saving plots to {plot}")
p9.ggsave(plot=fig, filename=plot, verbose=False)

print(f"Saving filtered barcodes to {viral_bc_by_cell_filtered_csv}")
viral_bc_frac = (viral_bc_frac
                 .query('viral_barcode != "none"')
                 [['cell_barcode',
                   'gene',
                   'frac_bc_UMIs']]
                 .drop_duplicates()
                 .rename(columns={'mean_freq_corrected_bc':
                                  'freq'}))


viral_bc_frac.to_csv(viral_bc_by_cell_filtered_csv,
                     columns=['cell_barcode',
                              'gene',
                              'frac_bc_UMIs'],
                     index=False)

### 