# Assign infection status
This Python Jupyter notebook uses the fraction of viral UMIs to determine which cells were infected.

Import Python modules:

In [None]:
from IPython.display import display

from dms_variants.constants import CBPALETTE

import kneed

import numpy

import pandas as pd

import plotnine as p9

import scanpy

Get `snakemake` variables [as described here](https://snakemake.readthedocs.io/en/stable/snakefiles/rules.html#jupyter-notebook-integration):

In [None]:
matrix = snakemake.input.matrix
features = snakemake.input.features
cell_barcodes = snakemake.input.cell_barcodes
infection_status_csv = snakemake.output.infection_status_csv
plot = snakemake.output.plot
viral_genes = snakemake.params.viral_genes
expt = snakemake.wildcards.expt

Style parameters:

In [None]:
p9.theme_set(p9.theme_classic())

Read the cell-gene matrix into an [AnnData](https://anndata.readthedocs.io/) object:

In [None]:
adata = scanpy.read_mtx(matrix)
adata.var = pd.read_csv(cell_barcodes,
                        names=['cell_barcode'])
adata.obs = pd.read_csv(features,
                        sep='\t',
                        names=['ensemble_id', 'gene', 'feature_type'])

print(f"Read cell-gene matrix of {adata.n_vars} cells and {adata.n_obs} genes")

assert set(viral_genes).issubset(set(adata.obs['gene'])), 'lack viral genes'

Now get for each cell:
 - total UMI corrected reads
 - UMI corrected viral and cellular
 - fraction of UMIs that are viral and cellular

In [None]:
umi_counts = (
    adata.var
    .assign(total_UMIs=numpy.sum(adata.X, axis=0).A1.astype(int),
            viral_UMIs=numpy.sum(adata[adata.obs['gene']
                                       .isin(viral_genes), ].X,
                                 axis=0).A1.astype(int),
            cellular_UMIs=numpy.sum(adata[~adata.obs['gene']
                                          .isin(viral_genes), ].X,
                                    axis=0).A1.astype(int),
            frac_viral_UMIs=lambda x: x['viral_UMIs'] / x['total_UMIs'],
            frac_cellular_UMIs=lambda x: x['cellular_UMIs'] / x['total_UMIs'],
            )
    )
display(umi_counts)

In [None]:
umi_counts['rank_viral_burden'] = umi_counts['frac_viral_UMIs'].rank(method='first', ascending=False)
display(umi_counts.sort_values(by='rank_viral_burden'))

## Assign infection status

We will use two thresholds to label cells' infection status. The first is a threshold for cell that are clearly infected. The second is a threshold for cells that are ambiguous. These cells could either have been infected and expressed low levels of viral transcripts or the droplet may have incorporated many viral transcripts from the ambient supernatant.

In [None]:
infected_threshold = 5e-3
ambiguous_threshold = infected_threshold / 10

def assign_infection_status(x):
    if x['frac_viral_UMIs'] > infected_threshold:
        return 'infected'
    elif (x['frac_viral_UMIs'] <= infected_threshold) & \
         (x['frac_viral_UMIs'] > ambiguous_threshold):
        return 'ambiguous'
    else:
        return 'uninfected'
    
umi_counts['infected'] = umi_counts.apply(assign_infection_status, axis=1)
umi_counts['infected'].value_counts()

In [None]:
viral_burden_knee_plot = (
    p9.ggplot(
        (umi_counts),
        p9.aes(x='rank_viral_burden',
               y='frac_viral_UMIs',
               color='factor(infected)')) +
    p9.geom_point(alpha=0.1) +
    p9.scale_y_log10() +
    p9.theme(figure_size=(4, 2),
             plot_title=p9.element_text(size=9),
             axis_title=p9.element_text(size=9),
             legend_title=p9.element_text(size=9),
             legend_title_align='center') +
    p9.scale_color_manual(CBPALETTE)
)
display(viral_burden_knee_plot)

Examine relationship between total UMIs and viral UMIs

In [None]:
capture_efficiency_plot = (
    p9.ggplot(
        (umi_counts),
        p9.aes(x='total_UMIs',
               y='viral_UMIs',
               color='factor(infected)')) +
    p9.geom_point(alpha=0.2) +
    p9.scale_x_log10() +
    p9.scale_y_log10() +
    p9.ggtitle('Viral UMIs as a function of total UMIs captured') +
    p9.theme(figure_size=(4, 3),
             plot_title=p9.element_text(size=9),
             axis_title=p9.element_text(size=9),
             legend_title=p9.element_text(size=9),
             legend_title_align='center') +
    p9.scale_color_manual(CBPALETTE)
)
display(capture_efficiency_plot)

Export knee plot

In [None]:
print(f"Saving knee plot to {plot}")
viral_burden_knee_plot.save(plot)

Export infection status and UMI counts in a CSV:

In [None]:
infection_status_curated = (
    umi_counts
    [['cell_barcode', 'total_UMIs', 'viral_UMIs',
      'cellular_UMIs', 'frac_viral_UMIs',
      'infected']]
    )

display(infection_status_curated)

assert (infection_status_curated['cell_barcode'] ==
        adata.var['cell_barcode']).all(), 'cell barcodes out of order'

print(f"Saving viral-tag annotated cell barcodes to {infection_status_csv}")
infection_status_curated.to_csv(infection_status_csv,
                          compression='gzip',
                          index=False)