# Cells per viral barcode
This notebook filters calculates and plots the number of cells each viral barcode appears in.

## Notebook setup
Import python modules:

In [None]:
from IPython.display import display

from dms_variants.constants import CBPALETTE

import pandas as pd

import plotnine as p9

Get `snakemake` variables [as described here](https://snakemake.readthedocs.io/en/stable/snakefiles/rules.html#jupyter-notebook-integration):

In [None]:
viral_bc_background_freq_csv = (snakemake
                                .output
                                .viral_bc_background_freq_csv)
cells_per_viral_bc_csv = (snakemake
                          .output
                          .cells_per_viral_bc_csv)
expt = snakemake.wildcards.expt
plot = snakemake.output.plot
barcoded_viral_genes = snakemake.params.barcoded_viral_genes

## Organize data
Read the viral barcode UMI counts data into a pandas dataframe:

In [None]:
viral_bc_frac = pd.read_csv(viral_bc_background_freq_csv)
assert set(viral_bc_frac['gene']) == set(barcoded_viral_genes), \
       "Barcoded genes in barcode counts do not match expectation."

display(viral_bc_frac)

Set base plot style:

In [None]:
p9.theme_set(p9.theme_classic())

## Cells per viral barcode
Next, I want to remove viral barcodes that are found above background in multiple cells. This should be very unlikely in a low MOI infection.

Calculate the number of cells each viral barcode is found in:

In [None]:
n_cell_bc = (viral_bc_frac
             .groupby(['infecting_viral_tag',
                       'gene',
                       'viral_barcode',
                       'reject_uninfected'])
             ['cell_barcode']
             .nunique()
             .reset_index()
             .rename(columns={'cell_barcode': 'n_cell_bc'}))

n_cell_bc

viral_bc_frac = pd.merge(
    viral_bc_frac,
    n_cell_bc,
    on=['infecting_viral_tag',
        'gene',
        'viral_barcode',
        'reject_uninfected'],
    how='left')

viral_bc_frac

Plot distribution of cell barcode per viral barcode:

In [None]:
fig = (p9.ggplot(viral_bc_frac
                 .query('reject_uninfected == True')) +
       p9.geom_histogram(p9.aes(x='n_cell_bc'),
                         binwidth=1,
                         position='dodge') +
       p9.facet_grid('~gene') +
       p9.ggtitle('Number of cells per viral barcode\n'
                  f'{expt}') +
       p9.xlab('n_cells') +
       p9.theme(figure_size=(5, 2),
                plot_title=p9.element_text(size=11),
                axis_title=p9.element_text(size=10),
                legend_title=p9.element_text(size=10),
                legend_title_align='center'))
display(fig)

### Annotate viral barcodes that are found in more than 1 cell.
Now I will label viral barcodes that are found in more than 1 cells.

In [None]:
viral_bc_frac['gt1_cell'] = (viral_bc_frac['n_cell_bc'] > 1)

viral_bc_frac

Plot number of viral barcodes that were found in too many cells:

In [None]:
output_fig = (p9.ggplot(viral_bc_frac,
                        p9.aes(x='infected',
                               fill='gt1_cell')) +
              p9.geom_bar(stat='count', position='dodge') +
              p9.ggtitle('more than 1 cell per viral barcode\n'
                         f'{expt}') +
              p9.facet_grid('~gene') +
              p9.theme(figure_size=(4, 2),
                       plot_title=p9.element_text(size=10),
                       axis_title=p9.element_text(size=10),
                       legend_title=p9.element_text(size=9),
                       legend_title_align='center') +
              p9.scale_fill_manual(CBPALETTE[0:]))
display(output_fig)

Export plot:

In [None]:
print(f"Saving plots to {plot}")
p9.ggsave(plot=output_fig, filename=plot, verbose=False)

Export CSV:

In [None]:
# Export all viral barcodes with new annotation.
viral_bc_frac.to_csv(cells_per_viral_bc_csv,
                     columns=['cell_barcode',
                              'infected',
                              'infecting_viral_tag',
                              'gene',
                              'viral_barcode',
                              'frac_viral_bc_UMIs',
                              'reject_uninfected',
                              'gt1_cell'],
                     index=False)