# Viral gene presence
This notebook calls the presence or absence of each viral gene in each cell.

Python imports:

In [None]:
from IPython.display import display

from dms_variants.constants import CBPALETTE

import numpy

import pandas as pd

import plotnine as p9

import scanpy

import scipy

import statsmodels.stats.multitest

Input data paths from snakemake.

In [None]:
matrix = snakemake.input.matrix
cell_barcodes = snakemake.input.cell_barcodes
cell_barcodes_filtered = snakemake.input.cell_barcodes_filtered
features = snakemake.input.features
cell_annotations_csv = snakemake.input.cell_annotations
viral_genes_by_cell_csv = snakemake.output.viral_genes_by_cell_csv
viral_genes = snakemake.params.viral_genes
barcoded_viral_genes = snakemake.params.barcoded_viral_genes
plot = snakemake.output.plot
expt = snakemake.wildcards.expt

Global variables for this notebook:

Style parameters. *N.b.* `CBPALETTE` is defined in imports above.

In [None]:
p9.theme_set(p9.theme_classic())
CBPALETTE_rich = ['#D81B60', '#1E88E5', '#FFC107', '#004D40']

### Load data
Cell infection status and tag labels:

In [None]:
cell_annotations = pd.read_csv(cell_annotations_csv)
cell_annotations = cell_annotations[['cell_barcode',
                                       'infected',
                                       'infecting_viral_tag',
                                       'total_UMIs',
                                       'viral_UMIs',
                                       'frac_viral_UMIs']]
display(cell_annotations)

Cell-gene matrix:

In [None]:
adata = scanpy.read_mtx(matrix)
adata.var = pd.read_csv(cell_barcodes,
                        names=['cell_barcode'])
adata.obs = pd.read_csv(features,
                        sep='\t',
                        names=['ensemble_id', 'gene', 'feature_type'])

print(f"Read cell-gene matrix of {adata.n_vars} cells and {adata.n_obs} genes")

assert set(viral_genes).issubset(set(adata.obs['gene'])), 'lack viral genes'

#### Viral genes in each cell
Extract the UMI counts for each viral gene in each cell from `adata.obs`

In [None]:
#### Viral genes in each cell
viral_gene_expression = (
    adata.var
    .assign(fluPB2=numpy.sum(adata[adata.obs['gene'] == 'fluPB2', ]
                             .X, axis=0).A1.astype(int),
            fluPB1=numpy.sum(adata[adata.obs['gene'] == 'fluPB1', ]
                             .X, axis=0).A1.astype(int),
            fluPA=numpy.sum(adata[adata.obs['gene'] == 'fluPA', ]
                             .X, axis=0).A1.astype(int),
            fluHA=numpy.sum(adata[adata.obs['gene'] == 'fluHA', ]
                             .X, axis=0).A1.astype(int),
            fluNP=numpy.sum(adata[adata.obs['gene'] == 'fluNP', ]
                             .X, axis=0).A1.astype(int),
            fluNA=numpy.sum(adata[adata.obs['gene'] == 'fluNA', ]
                             .X, axis=0).A1.astype(int),
            fluM=numpy.sum(adata[adata.obs['gene'] == 'fluM', ]
                             .X, axis=0).A1.astype(int),
            fluNS=numpy.sum(adata[adata.obs['gene'] == 'fluNS', ]
                             .X, axis=0).A1.astype(int),
           )
)

viral_gene_expression

In [None]:
# **Restrict analysis to filtered cell barcodes**
filtered_cell_barcode_list = pd.read_csv(cell_barcodes_filtered)['cell_barcode'].to_list()
viral_gene_expression = viral_gene_expression.query('cell_barcode in @filtered_cell_barcode_list')
display(viral_gene_expression)

Merge in infection/infecting viral tag information:

In [None]:
viral_gene_expression = pd.merge(
    left=cell_annotations,
    right=viral_gene_expression,
    on=['cell_barcode'],
    validate='one_to_one'
)

display(viral_gene_expression)

Check that the individual viral genes sum to total viral UMIs for each cell:

In [None]:
assert (viral_gene_expression['viral_UMIs'] == (
    viral_gene_expression['fluPB2'] +
    viral_gene_expression['fluPB1'] +
    viral_gene_expression['fluPA'] +
    viral_gene_expression['fluHA'] +
    viral_gene_expression['fluNP'] +
    viral_gene_expression['fluNA'] +
    viral_gene_expression['fluM'] +
    viral_gene_expression['fluNS'])).all(), "Genes do not add to viral total"

Pivot the table so that gene is a column, and the UMI counts for each gene for each cell is a row

In [None]:
viral_gene_expression_long = viral_gene_expression.melt(
    id_vars=['cell_barcode',
             'infected',
             'infecting_viral_tag',
             'total_UMIs',
             'viral_UMIs',
             'frac_viral_UMIs'],
    var_name='gene',
    value_name='gene_UMIs'
)

viral_gene_expression_long['frac_gene_UMIs'] = (
    viral_gene_expression_long['gene_UMIs'] /
    viral_gene_expression_long['total_UMIs']
)

viral_gene_expression_long

Check that total number of rows is = total number of cells * total number of genes:

In [None]:
assert (len(viral_gene_expression_long) ==
        len(viral_gene_expression_long['cell_barcode'].drop_duplicates()) *
        len(viral_genes)), "not 8 genes listed for every cell"

Check that the extracted gene_UMIs sum to the total viral UMIs for each cell:

In [None]:
pd.testing.assert_frame_equal(
    (viral_gene_expression_long[['cell_barcode',
                                 'viral_UMIs']]
     .drop_duplicates()),
    (viral_gene_expression_long
     .groupby('cell_barcode')
     ['gene_UMIs']
     .sum()
     .to_frame()
     .reset_index()
     .rename(columns={'gene_UMIs': 'viral_UMIs'})),
    check_names=False,
    check_index_type=False
)

### Analyze distributions

Plot distribution of absolute UMI counts for each gene:

In [None]:
fig = (p9.ggplot(viral_gene_expression_long,
                 p9.aes(x='gene_UMIs')) +
       p9.geom_histogram(bins=20) +
       p9.facet_grid('gene~') +
       p9.ggtitle('viral gene expression\n'
                  'absolute UMI counts\n'
                  'per cell\n'
                  f'{expt}') +
       p9.labs(x='UMIs from gene in cell',
               y='cells') +
       p9.scale_x_log10() +
       p9.theme(figure_size=(4, 6),
                plot_title=p9.element_text(size=12),
                axis_title=p9.element_text(size=10),
                legend_title=p9.element_text(size=10),
                legend_title_align='center'))
display(fig)

Distribution of viral gene expression in infected and **uninfected** cells:

**N.b.** Plot 0 values on far left of axis, 100-fold lower than lowest real value:

In [None]:
zero_pseudocount = (min(viral_gene_expression_long
                        .query('frac_gene_UMIs > 0')
                        ['frac_gene_UMIs']) /
                    100)
zero_pseudocount

fig = (p9.ggplot((viral_gene_expression_long
                  .replace(to_replace={'frac_gene_UMIs':0},
                           value=zero_pseudocount)),
                 p9.aes(x='frac_gene_UMIs')) +
       p9.geom_histogram(bins=20) +
       p9.facet_grid('infected~gene', scales='free_y') +
       p9.ggtitle(f'viral gene expression\n'
                  f'per gene\n'
                  f'{expt}') +
       p9.labs(x='fraction of total UMIs in cell') +
       p9.scale_x_log10() +
       p9.theme(figure_size=(15, 5),
                plot_title=p9.element_text(size=12),
                axis_title=p9.element_text(size=10),
                legend_title=p9.element_text(size=10),
                legend_title_align='center'))
display(fig)

### Call presence/absence
Next I will extract a limit percentile value from the uninfected cells. We were very conservative in calling cells as infected (likely there are some infected cells in the population labelled "uninfected"). A priori, we expect most infected cells to express most viral genes. So we can be more lenient (lower percentile) in what we use as a cutoff in calling a gene present.

For low expressing genes (e.g. the polymerase complex), the limit value is 0, and this threshold does not work well. For these genes, instead, **we simply call genes without any UMIs as absent, and genes with at least 1 UMI as present.**  In some previous analyses (e.g. [Russel et al. 2018](https://elifesciences.org/articles/32303)) a single transcript was used to call the presence of a gene in infected cells.

Which genes are low expression?

In [None]:
viral_gene_expression_long.groupby('gene')['gene_UMIs'].mean()

In [None]:
low_expression_genes = ['fluPB2', 'fluPB1', 'fluPA', 'fluNA']
limit_percentile = 0.99
limit = dict()

for gene in viral_genes:
    if gene not in low_expression_genes:
        limit[gene] = (viral_gene_expression_long
                       .query('(infecting_viral_tag == "none") '
                              f'and (gene == "{gene}")')
                       ['frac_gene_UMIs']
                       .quantile(limit_percentile))

limit = pd.DataFrame.from_dict(limit, orient='index')
limit = limit.reset_index()
limit = limit.rename(columns={0: 'limit_frac',
                              'index': 'gene'})

display(limit)

In [None]:
fig = (p9.ggplot((viral_gene_expression_long
                  .replace(to_replace={'frac_gene_UMIs':0},
                           value=zero_pseudocount)),
                 p9.aes(x='frac_gene_UMIs')) +
       p9.geom_histogram(bins=20) +
       p9.geom_vline(limit,
                     p9.aes(xintercept='limit_frac'),
                     linetype='dashed',
                     color='#3A3B3C',
                     size=0.5) +
       p9.facet_grid('infected~gene', scales='free_y') +
       p9.ggtitle(f'viral gene expression\n'
                  f'per gene\n'
                  f'{expt}') +
       p9.labs(x='fraction of total UMIs in cell') +
       p9.theme(figure_size=(10, 3),
                plot_title=p9.element_text(size=12),
                axis_title=p9.element_text(size=10),
                legend_title=p9.element_text(size=10),
                legend_title_align='center') +
       p9.scale_x_log10())
display(fig)

Merge limit values into `viral_gene_expression_long` df.

In [None]:
viral_gene_expression_long = pd.merge(
    left=viral_gene_expression_long,
    right=limit,
    on='gene',
    how='left',
    validate='many_to_one'
)

display(viral_gene_expression_long)

Label genes as absent if they fall below this limit.

In [None]:
viral_gene_expression_long['gene_present'] = (
    (viral_gene_expression_long['gene_UMIs'] > 0)
    & ((viral_gene_expression_long['frac_gene_UMIs'] > 
        viral_gene_expression_long['limit_frac'])
       | (viral_gene_expression_long['limit_frac'].isnull())))

display(viral_gene_expression_long)

### Visualize results

Color the proportion histogram by present/absent call:

In [None]:
fig = (p9.ggplot((viral_gene_expression_long
                  .replace(to_replace={'frac_gene_UMIs':0},
                           value=zero_pseudocount)),
                 p9.aes(x='frac_gene_UMIs',
                        fill='gene_present')) +
       p9.geom_histogram(bins=20) +
       p9.geom_vline(limit,
                     p9.aes(xintercept='limit_frac'),
                     linetype='dashed',
                     color='#3A3B3C',
                     size=0.5) +
       p9.facet_grid('infected~gene', scales='free_y') +
       p9.ggtitle(f'viral gene expression\n'
                  f'per gene\n'
                  f'{expt}') +
       p9.labs(x='fraction of total UMIs in cell') +
       p9.theme(figure_size=(10, 3),
                plot_title=p9.element_text(size=12),
                axis_title=p9.element_text(size=10),
                legend_title=p9.element_text(size=10),
                legend_title_align='center') +
       p9.scale_x_log10() +
       p9.scale_fill_manual([CBPALETTE[1],CBPALETTE[0]])
       )
display(fig)

In general, expression of all viral genes correlates with one another in infected cells.  Let's check to see if we are correctly excluding viral genes that violate this correlation:

In [None]:
gene_expression_plot = (
    p9.ggplot((viral_gene_expression_long
               .query('infected == "infected" and '
                      'infecting_viral_tag != "both"')),
               p9.aes(x='frac_viral_UMIs',
                      y='frac_gene_UMIs',
                      color='gene_present')) +
    p9.geom_point(alpha=0.1) +
    p9.geom_smooth((viral_gene_expression_long
                    .query('infected == "infected" and '
                           'infecting_viral_tag != "both" and '
                           'gene_present == True')),
                    p9.aes(x='frac_viral_UMIs',
                           y='frac_gene_UMIs'),
                    method='lm',
                    color=f'{CBPALETTE[2]}',
                    se=False,
                    linetype='dashed') +
    p9.facet_grid('~gene') +
    p9.ggtitle(f'viral gene expression\n'
               f'vs viral burden\n'
               f'{expt}') +
    p9.labs(x='fraction viral UMIs in cell',
            y='fraction of UMIs from gene') +
    p9.scale_x_log10() +
    p9.scale_y_log10() +
    p9.theme(figure_size=(12, 2),
             plot_title=p9.element_text(size=12),
             axis_title=p9.element_text(size=10),
             legend_title=p9.element_text(size=10),
             legend_title_align='center') +
    p9.scale_color_manual([CBPALETTE[1],CBPALETTE[0]]))


display(gene_expression_plot)

Export gene expression plot:

In [None]:
print(f'Saving figure to {plot}"')
gene_expression_plot.save(plot)
print('Done.')

Annotate the fraction of cells that have each viral gene:

In [None]:
has_gene = pd.DataFrame(
    viral_gene_expression_long
    .query('infecting_viral_tag != "both"')
    .groupby(['infected',
              'gene'])
    ['gene_present']
    .value_counts(normalize=True)
    .reset_index(name='prop_cells'))

display(has_gene)

Plot this fraction for each gene and infecting viral tag:

In [None]:
fig = (p9.ggplot((has_gene),
                  p9.aes(x='gene',
                         y='prop_cells',
                         fill='gene_present')) +
              p9.geom_bar(stat='identity') +
              p9.ggtitle('gene is present above uninfected background\n'
                         f'{expt}') +
              p9.ylab('proportion of cells') +
              p9.facet_grid('infected~') +
              p9.theme(figure_size=(4, 4),
                       plot_title=p9.element_text(size=10),
                       axis_title=p9.element_text(size=10),
                       legend_title=p9.element_text(size=9),
                       legend_title_align='center') +
              p9.scale_fill_manual([CBPALETTE[1],CBPALETTE[0]]))
display(fig)

### Count genes per cell
Count the number of viral genes called as present in each infected cell.

In [None]:
n_viral_genes_by_cell = (
    viral_gene_expression_long
    .query('gene_present == True')
    .groupby('cell_barcode')
    ['gene']
    .nunique()
    .reset_index()
    .drop_duplicates()
    .rename(columns={'gene': 'n_viral_genes'}))

viral_gene_expression_long = pd.merge(
    left=viral_gene_expression_long,
    right=n_viral_genes_by_cell,
    on='cell_barcode',
    how='left'
)

viral_gene_expression_long['n_viral_genes'] = (
    viral_gene_expression_long['n_viral_genes'].fillna(0)
)
viral_gene_expression_long['n_viral_genes'] = (
    viral_gene_expression_long['n_viral_genes'].astype(int)
)

display(viral_gene_expression_long)

In [None]:
n_viral_genes_histogram = (
    p9.ggplot((viral_gene_expression_long
               [['cell_barcode', 'infected', 'n_viral_genes']]
               .drop_duplicates()),
              p9.aes(x='n_viral_genes',)) +
    p9.geom_bar(stat='count', position='dodge') +
    p9.facet_grid('infected~', scales='free_y') + 
    p9.ggtitle('Number of viral genes\n'
               'in each cell\n'
               f'{expt}') +
    p9.labs(x='n viral genes detected',
            y='n cells') +
    p9.theme(figure_size=(4, 4),
                plot_title=p9.element_text(size=9),
                axis_title=p9.element_text(size=9),
                legend_title=p9.element_text(size=9),
                legend_title_align='center') +
    p9.scale_fill_manual(CBPALETTE[0:])
)

display(n_viral_genes_histogram)

### Export annotations
Export a CSV with each cell barcode and whether each viral gene is called as present or absent for that cell.

In [None]:
gene_present_df = (
    viral_gene_expression_long[['cell_barcode',
                                'n_viral_genes',
                                'gene',
                                'frac_gene_UMIs',
                                'gene_present',]]
)
display(gene_present_df)
print(f'Writing gene presence data to {viral_genes_by_cell_csv}')
gene_present_df.to_csv(viral_genes_by_cell_csv, index=False)
print('Done.')