# Viral barcode in each cell
This notebook filters viral barcodes in each cell in the 10X transcriptome data

## Notebook setup
Import python modules:

In [None]:
from IPython.display import display

from dms_variants.constants import CBPALETTE

import pandas as pd

import plotnine as p9

import scipy

import statsmodels.stats.multitest

Get `snakemake` variables [as described here](https://snakemake.readthedocs.io/en/stable/snakefiles/rules.html#jupyter-notebook-integration):

In [None]:
viral_tag_by_cell_csv = (snakemake
                         .input
                         .viral_tag_by_cell_csv)
viral_bc_by_cell_corrected_csv = (snakemake
                                  .input
                                  .viral_bc_by_cell_corrected_csv)
viral_genes_by_cell_csv = snakemake.input.viral_genes_by_cell_csv
valid_viral_barcodes_csv = (
    snakemake
    .output
    .valid_viral_barcodes_csv)
plot = snakemake.output.plot
expt = snakemake.wildcards.expt
barcoded_viral_genes = snakemake.params.barcoded_viral_genes

Set base plot style:

In [None]:
p9.theme_set(p9.theme_classic())

## Organize data

Read the viral barcode UMI counts data into a pandas dataframe:

In [None]:
viral_bc_counts = pd.read_csv(viral_bc_by_cell_corrected_csv)
viral_bc_counts = (viral_bc_counts
                   .rename(columns={'count': 'viral_bc_UMIs'}))
assert set(viral_bc_counts['gene']) == set(barcoded_viral_genes), \
       "Barcoded genes in barcode counts do not match expectation."
display(viral_bc_counts)

Read the total number of UMIs per cell into a pandas dataframe. Only keep relevant columns.

In [None]:
all_cells = pd.read_csv(viral_tag_by_cell_csv)
all_cells = all_cells[['cell_barcode',
                       'infected',
                       'infecting_viral_tag',
                       'total_UMIs',
                       'viral_UMIs',
                       'cellular_UMIs',
                       'frac_viral_UMIs']]
display(all_cells)

Sanity check that `total_UMIs` is equal to `viral_UMIs + cellular_UMIs`:

In [None]:
assert bool((all_cells['total_UMIs'] ==
             all_cells['viral_UMIs'] +
             all_cells['cellular_UMIs'])
            .all()), "UMI counts do not add up"

Read the gene presence/absence data

In [None]:
viral_genes_by_cell = pd.read_csv(viral_genes_by_cell_csv)
display(viral_genes_by_cell)

Merge dataframes:

In [None]:
viral_bc_frac = pd.merge(
    left=pd.concat([all_cells.assign(gene=gene)
                    for gene in barcoded_viral_genes]),
    right=viral_bc_counts,
    how='outer',
    on=['cell_barcode', 'gene'],
    validate='one_to_many')
assert (viral_bc_frac['cell_barcode'].unique() ==
        all_cells['cell_barcode'].unique()).all(), \
       "Cell barcodes in merged dataframe don't " \
       "match barcodes in source data."
assert (viral_bc_frac['viral_barcode'].nunique() ==
        viral_bc_counts['viral_barcode'].nunique()), \
       "Number of viral barcodes in merged dataframe don't " \
       "match number of barcodes in source data."

# Make `infecting_viral_tag` column ordered category
viral_bc_frac['infecting_viral_tag'] = (viral_bc_frac['infecting_viral_tag']
                                        .astype('category')
                                        .cat
                                        .reorder_categories(['none',
                                                             'wt',
                                                             'syn',
                                                             'both']))

viral_bc_frac = pd.merge(
    left=viral_bc_frac,
    right=viral_genes_by_cell,
    on=['cell_barcode', 'gene'],
    how='left',
    validate='many_to_one'
)

display(viral_bc_frac)

Calculate **each barcode's** fraction of all UMIs per cell:

In [None]:
viral_bc_frac = (
    viral_bc_frac
    .assign(viral_bc_UMIs=lambda x: (x['viral_bc_UMIs']
                                     .fillna(0)
                                     .astype(int, errors='raise'))))


viral_bc_frac['frac_viral_bc_UMIs'] = (
    viral_bc_frac['viral_bc_UMIs'] /
    viral_bc_frac['total_UMIs'])

display(viral_bc_frac)

## Filtering
Now, I will filter out viral barcodes in each infected cell. A valid viral barcode meets the following criteria:
1. The cell is infected
2. The cell expresses the barcoded gene
3. There are at least 2 UMIs with the viral barcode in that cell-gene
4. The viral barcode is present at a frequency greater than some cutoff percentile observed in uninfected cells.

In [None]:
filter_query_list = list()

### Infected cells
Valid viral barcodes must come from infected cells:

In [None]:
infected_rank = (viral_bc_frac
                 [['cell_barcode',
                   'gene',
                   'infected',
                   'frac_viral_bc_UMIs',]]
                 .copy()
                 .drop_duplicates())
infected_rank['rank'] = (infected_rank
                         .groupby('gene')
                         ['frac_viral_bc_UMIs']
                         .rank(ascending=False, method='first'))

fig = (p9.ggplot(infected_rank,
                 p9.aes(x='rank',
                        y='frac_viral_bc_UMIs',
                        color='infected')) +
       p9.geom_point(alpha=.3) +
       p9.facet_grid('gene~infected') +
       p9.scale_y_log10() +
       p9.ggtitle('viral bc UMIs in infected and uninfected cells\n'
                  f'{expt}') +
       p9.labs(x='viral bc frequency rank') +
       p9.theme(figure_size=(6, 4),
                plot_title=p9.element_text(size=9),
                axis_title=p9.element_text(size=9),
                legend_title=p9.element_text(size=9),
                legend_title_align='center') +
       p9.scale_color_manual(CBPALETTE[0:]))
display(fig)

Add infection status to filter query:

In [None]:
filter_query_list.append('(infected == "infected")')
filter_query = " and ".join(filter_query_list)
print("The filter query is:  "
      f"{filter_query}")

### Gene expression
Plot barcode UMIs per viral barcode in infected cells, annotated with gene expression:

In [None]:
gene_expressed_rank = (viral_bc_frac
                       .query(filter_query)
                       [['cell_barcode',
                         'gene',
                         'gene_present',
                         'frac_viral_bc_UMIs',]]
                       .copy()
                       .drop_duplicates())
gene_expressed_rank['rank'] = (gene_expressed_rank
                               .groupby('gene')
                               ['frac_viral_bc_UMIs']
                               .rank(ascending=False, method='first'))

fig = (p9.ggplot(gene_expressed_rank,
                 p9.aes(x='rank',
                        y='frac_viral_bc_UMIs',
                        color='gene_present')) +
       p9.geom_point(alpha=0.5) +
       p9.facet_grid('gene~') +
       p9.scale_y_log10() +
       p9.ggtitle('viral bc UMIs on expressed/missing genes\n'
                  'in infected cells\n'
                  f'{expt}') +
       p9.labs(x='viral bc frequency rank') +
       p9.theme(figure_size=(3, 4),
                plot_title=p9.element_text(size=9),
                axis_title=p9.element_text(size=9),
                legend_title=p9.element_text(size=9),
                legend_title_align='center') +
       p9.scale_color_manual(CBPALETTE[0:]))
display(fig)

Add gene expression to filter query:

In [None]:
filter_query_list.append('(gene_present == True)')
filter_query = " and ".join(filter_query_list)
print("The filter query is:  "
      f"{filter_query}")

### Absolute abundance
Label and remove viral barcodes with only 1 UMIs in a given cell:

In [None]:
UMI_limit = 1
viral_bc_frac['above_UMI_limit'] = viral_bc_frac['viral_bc_UMIs'] > UMI_limit

UMI_limit_rank = (viral_bc_frac
                  .query(filter_query)
                  [['cell_barcode',
                    'gene',
                    'above_UMI_limit',
                    'viral_bc_UMIs',]]
                  .copy()
                  .drop_duplicates())
UMI_limit_rank['rank'] = (UMI_limit_rank
                          .groupby('gene')
                          ['viral_bc_UMIs']
                          .rank(ascending=False, method='first'))

fig = (p9.ggplot(UMI_limit_rank,
                 p9.aes(x='rank',
                        y='viral_bc_UMIs',
                        color='above_UMI_limit')) +
       p9.geom_point(alpha=0.5) +
       p9.geom_hline(yintercept=UMI_limit, linetype='dashed', color=CBPALETTE[2]) +
       p9.facet_grid('gene~') +
       p9.scale_y_log10() +
       p9.ggtitle('viral barcode UMI threshold\n'
                  'in infected cells expressing barcoded gene\n'
                  f'{expt}') +
       p9.labs(x='viral bc UMI rank') +
       p9.theme(figure_size=(3, 4),
                plot_title=p9.element_text(size=9),
                axis_title=p9.element_text(size=9),
                legend_title=p9.element_text(size=9),
                legend_title_align='center') +
       p9.scale_color_manual(CBPALETTE[0:]))
display(fig)

Add absolute abundance to filter query:

In [None]:
filter_query_list.append('(above_UMI_limit == True)')
filter_query = " and ".join(filter_query_list)
print("The filter query is:  "
      f"{filter_query}")

### Relative abundance
Find the 99th percentile of individual viral barcode frequency in uninfected cells. Label and filter out viral barcodes that are less than this frequency.

In [None]:
limit_percentile = 0.99

uninfected_background_frac = {}
for gene in barcoded_viral_genes:
    limit = (viral_bc_frac
             .query('(infecting_viral_tag == "none") '
                    f'and (gene == "{gene}")')
             ['frac_viral_bc_UMIs']
             .quantile(limit_percentile))
    uninfected_background_frac[f'{gene}'] = limit
    print(f'The limit for a viral barcode on {gene} is {limit:.5f}.')
    
# Convert to DataFrame for plotting
uninfected_background_frac = (
    pd.DataFrame.from_dict(uninfected_background_frac,
                           orient='index',
                           columns=['uninf_background'])
    .reset_index()
    .rename(columns={'index': 'gene'}))

fig = (p9.ggplot((viral_bc_frac
                  .query('infecting_viral_tag == "none"')),
                 p9.aes(x='frac_viral_bc_UMIs')) +
       p9.geom_histogram(bins=50, fill="#3a3a3a") +
       p9.geom_vline(uninfected_background_frac,
                     p9.aes(xintercept='uninf_background'),
                     linetype='dashed',
                     color='#3A3B3C',
                     size=0.5) +
       p9.facet_grid('infecting_viral_tag~gene') +
       p9.ggtitle('viral barcode fraction in uninfected cells\n'
                  f'{expt}') +
       p9.labs(x='viral barcode frequency',) +
       p9.theme(figure_size=(5, 1),
                plot_title=p9.element_text(size=10),
                axis_title=p9.element_text(size=10),
                axis_text=p9.element_text(rotation=45),
                legend_title=p9.element_text(size=9),
                legend_title_align='center'))
display(fig)

In [None]:
viral_bc_frac = pd.merge(left=viral_bc_frac,
                         right=uninfected_background_frac,
                         on='gene',
                         how='outer',
                         validate='many_to_one')
viral_bc_frac['above_uninf_background'] = (viral_bc_frac['frac_viral_bc_UMIs'] >
                                           viral_bc_frac['uninf_background'])

In [None]:
background_rank = (viral_bc_frac
                   .query(filter_query)
                   [['cell_barcode',
                     'gene',
                     'above_uninf_background',
                     'frac_viral_bc_UMIs',]]
                   .copy()
                   .drop_duplicates())
background_rank['rank'] = (background_rank
                           .groupby('gene')
                           ['frac_viral_bc_UMIs']
                           .rank(ascending=False, method='first'))

fig = (p9.ggplot(background_rank,
                 p9.aes(x='rank',
                        y='frac_viral_bc_UMIs',
                        color='above_uninf_background')) +
       p9.geom_point(alpha=0.5) +
       p9.geom_hline(uninfected_background_frac,
                     p9.aes(yintercept='uninf_background'),
                     linetype='dashed',
                     color=CBPALETTE[2]) +
       p9.facet_grid('gene~') +
       p9.scale_y_log10() +
       p9.ggtitle(f'{expt}') +
       p9.labs(x='viral bc frequency rank') +
       p9.theme(figure_size=(3, 4),
                plot_title=p9.element_text(size=9),
                axis_title=p9.element_text(size=9),
                legend_title=p9.element_text(size=9),
                legend_title_align='center') +
       p9.scale_color_manual(CBPALETTE[0:]))
display(fig)

Add background frequency to filter query:

In [None]:
filter_query_list.append('(above_uninf_background == True)')
filter_query = " and ".join(filter_query_list)
print("The filter query is:  "
      f"{filter_query}")

## Assign filter status to each viral barcode

In [None]:
viral_bc_frac.loc[viral_bc_frac.eval(filter_query), 'valid_viral_bc'] = True
viral_bc_frac['valid_viral_bc'] = viral_bc_frac['valid_viral_bc'].fillna(False)
display(viral_bc_frac)

## Summary visualizations
### Valid viral barcodes in each infected cell

In [None]:
valid_rank = (viral_bc_frac
              .query('infected == "infected"')
              [['cell_barcode',
                'gene',
                'frac_viral_bc_UMIs',
                'valid_viral_bc']]
              .copy()
              .drop_duplicates())
valid_rank['rank'] = (valid_rank
                      .groupby('gene')
                      ['frac_viral_bc_UMIs']
                      .rank(ascending=False, method='first'))

fig = (p9.ggplot(valid_rank,
                 p9.aes(x='rank',
                        y='frac_viral_bc_UMIs',
                        color='valid_viral_bc')) +
       p9.geom_point(alpha=.3) +
       p9.facet_grid('gene~') +
       p9.scale_y_log10() +
       p9.ggtitle('valid and filterd viral barcodes\n'
                  'in infected cells\n'
                  f'{expt}') +
       p9.labs(x='viral bc frequency rank') +
       p9.theme(figure_size=(3, 4),
                plot_title=p9.element_text(size=9),
                axis_title=p9.element_text(size=9),
                legend_title=p9.element_text(size=9),
                legend_title_align='center') +
       p9.scale_color_manual(CBPALETTE[0:]))
display(fig)

### Valid viral barcodes in each cell
Visualize the valid viral barcodes in each cell

In [None]:
visualization_df = (viral_bc_frac
              .query('valid_viral_bc == True')
              [['cell_barcode',
                'gene',
                'viral_barcode',
                'frac_viral_bc_UMIs']]
              .copy()
              .drop_duplicates())
visualization_df['sum_frac_cell_gene'] = (visualization_df
                                          .groupby(['cell_barcode','gene'])
                                          ['frac_viral_bc_UMIs']
                                          .transform('sum'))
visualization_df['proportion_valid_viral_bcs_cell_gene'] = (
    visualization_df['frac_viral_bc_UMIs'] /
    visualization_df['sum_frac_cell_gene'])
visualization_df['cell_barcode'] = pd.Categorical(visualization_df['cell_barcode'])
visualization_df['cell_barcode_dummy'] = visualization_df.cell_barcode.cat.codes
visualization_df['viral_bc_rank'] = (visualization_df
                                     .groupby(['cell_barcode','gene'])
                                     ['proportion_valid_viral_bcs_cell_gene']
                                     .rank(ascending=False, method='first'))

display(visualization_df)

In [None]:
fig = (p9.ggplot(visualization_df,
                 p9.aes(x='cell_barcode_dummy',
                        y='proportion_valid_viral_bcs_cell_gene',
                        fill='factor(viral_bc_rank)')) +
       p9.geom_bar(stat='identity', width=1) +
       p9.facet_grid('gene~') +
       p9.ggtitle('valid and filterd viral barcodes\n'
                  'in infected cells\n'
                  f'{expt}') +
       p9.labs(x='infected cell arbitrary #') +
       p9.theme(figure_size=(6, 4),
                plot_title=p9.element_text(size=9),
                axis_title=p9.element_text(size=9),
                legend_title=p9.element_text(size=9),
                legend_title_align='center') +
       p9.scale_fill_manual(CBPALETTE[0:]))
display(fig)

### Number of valid viral barcodes per cell

In [None]:
n_viral_bc = (viral_bc_frac
             .query('(valid_viral_bc == True)',
                    engine='python')
             .groupby(['cell_barcode',
                       'gene',
                       'valid_viral_bc'])
             ['viral_barcode']
             .nunique()
             .reset_index()
             .rename(columns={'viral_barcode': 'n_viral_bc'}))

In [None]:
output_fig = (p9.ggplot(n_viral_bc) +
              p9.geom_histogram(p9.aes(x='n_viral_bc'),
                                       binwidth=1,
                                       position='dodge') +
              p9.facet_grid('~gene') +
              p9.ggtitle('Number of valid viral barcodes per cell\n'
                         'in infected cells\n'
                         f'{expt}') +
              p9.xlab('number of valid viral barcodes in cell') +
              p9.ylab('number of cells') +
              p9.theme(figure_size=(5, 2),
                       plot_title=p9.element_text(size=11),
                       axis_title=p9.element_text(size=10),
                       legend_title=p9.element_text(size=10),
                       legend_title_align='center'))
display(output_fig)

In [None]:
print(f"Saving plots to {plot}")
p9.ggsave(plot=output_fig, filename=plot, verbose=False)

print(f"Saving filtered barcodes to {valid_viral_barcodes_csv}")

# Export valid cell barcode-viral barcodes pairs
(viral_bc_frac
 .query('valid_viral_bc == True')
 .to_csv(valid_viral_barcodes_csv,
         columns=['cell_barcode',
                  'gene',
                  'viral_barcode',
                  'valid_viral_bc'],
         index=False))