# Call cellular infection status and viral tag assignment for infected cells
This Python Jupyter notebook uses the viral tags to determine which cells were infected, and the tag variant of the infecting virus for infected cells.

Import Python modules:

In [None]:
import anndata

from dms_variants.constants import CBPALETTE

import kneed

import matplotlib.pyplot as plt

import numpy

import pandas as pd

import scanpy

Get `snakemake` variables [as described here](https://snakemake.readthedocs.io/en/stable/snakefiles/rules.html#jupyter-notebook-integration):

In [None]:
viral_tag_by_cell_csv = snakemake.input.viral_tag_by_cell_csv
matrix = snakemake.input.matrix
features = snakemake.input.features
cell_barcodes = snakemake.input.cell_barcodes
cell_annotations = snakemake.output.cell_annotations
plot = snakemake.output.plot
viral_genes = snakemake.params.viral_genes
viral_tags = snakemake.params.viral_tags
expt = snakemake.wildcards.expt

This notebook assumes that there are exactly two viral tags, code does not work otherwise:

In [None]:
if len(viral_tags) != 2:
    raise ValueError('code assumes exactly two viral tags')

Read the cell-gene matrix into an [AnnData](https://anndata.readthedocs.io/) object:

In [None]:
adata = scanpy.read_mtx(matrix)
adata.var = pd.read_csv(cell_barcodes,
                        names=['cell_barcode'])
adata.obs = pd.read_csv(features,
                        sep='\t',
                        names=['ensemble_id', 'gene', 'feature_type'])

print(f"Read cell-gene matrix of {adata.n_vars} cells and {adata.n_obs} genes")

assert set(viral_genes).issubset(set(adata.obs['gene'])), 'cell-gene matrix missing viral genes'

Now get for each cell:
 - total UMI corrected reads
 - UMI corrected viral and cellular
 - fraction of UMIs that are viral and cellular

In [None]:
umi_counts = (
    adata.var
    .assign(total_UMIs=numpy.sum(adata.X, axis=0).A1.astype(int),
            viral_UMIs=numpy.sum(adata[adata.obs['gene'].isin(viral_genes),].X,
                                 axis=0).A1.astype(int),
            cellular_UMIs=numpy.sum(adata[~adata.obs['gene'].isin(viral_genes),].X,
                                    axis=0).A1.astype(int),
            frac_viral_UMIs=lambda x: x['viral_UMIs'] / x['total_UMIs'],
            frac_cellular_UMIs=lambda x: x['cellular_UMIs'] / x['total_UMIs'],
            )
    )

Now we make some plots showing the amount of total and viral UMIs per cell.
We are going to eventually make a merged plot here, so we create a multi-axis subplots figure and then just start adding plots to it as we go through this notebook:

In [None]:
fig, axes = plt.subplots(ncols=4,
                         nrows=2,
                         figsize=(14, 6)
                         )
fig.suptitle(f"Calling viral tags and infection status for {expt}", size=16)
i_axis = 0

# distribution of total and viral UMIs per cell
for umi_type, ax in zip(['total', 'viral'], axes.ravel()):
    _ = umi_counts.plot(y=f"{umi_type}_UMIs",
                        kind='hist',
                        bins=25,
                        ax=ax,
                        legend=False,
                        color=CBPALETTE[0],
                        )
    _ = ax.set_xlabel(f"{umi_type} UMIs per cell")
    _ = ax.set_ylabel('number of cells')
    _ = ax.set_title(f"{umi_type} UMIs per cell")
    i_axis += 1
    
# fraction of UMIs that are viral
ax = umi_counts.plot(y='frac_viral_UMIs',
                     kind='hist',
                     bins=25,
                     ax=axes.ravel()[i_axis],
                     legend=False,
                     color=CBPALETTE[0],
                     )
_ = ax.set_xlabel('fraction of UMIs from virus')
_ = ax.set_ylabel('number of cells')
_ = ax.set_title('fraction of UMIs from virus')
i_axis += 1

# correlation of viral and cellular UMIs
ax = umi_counts.plot(x='cellular_UMIs',
                     y='viral_UMIs',
                     kind='scatter',
                     ax=axes.ravel()[i_axis],
                     alpha=0.1,
                     legend=False,
                     color=CBPALETTE[0],
                     )
_ = ax.set_xlabel('cellular UMIs per cell')
_ = ax.set_ylabel('viral UMIs per cell')
_ = ax.set_title('viral vs cellular UMIs per cell')
i_axis += 1

# show figure
fig.tight_layout()

Combine UMI counts with counts of viral tags, indentifying the "major" (more abundant) and "minor" (less abundant) viral tag for each cell, and also computing their ratio to total UMIs in each cell.

In [None]:
# read tag counts in tidy form
tidy_tag_counts = pd.read_csv(viral_tag_by_cell_csv)

# check tag count values
assert set(tidy_tag_counts['tag_variant']).issubset(set(viral_tags)), 'unrecognized viral tag'
assert set(tidy_tag_counts['cell_barcode']).issubset(set(umi_counts['cell_barcode'])), 'unrecognized cell barcode'
assert set(tidy_tag_counts['gene']).issubset(set(viral_genes)), 'unrecognized viral gene'

# for each cell barcode, identify more abundant (major) and less abundant (minor) tag and get counts
tag_counts = (
    tidy_tag_counts
    # sum counts across viral genes
    .groupby(['cell_barcode', 'tag_variant'], as_index=False)
    .aggregate(count=pd.NamedAgg('count', 'sum'))
    # fill counts of 0 for cells missing any counts for a tag
    .merge(pd.concat([pd.DataFrame({'cell_barcode': umi_counts['cell_barcode'],
                         'tag_variant': viral_tag})
                      for viral_tag in viral_tags]),
           how='right', on=['cell_barcode', 'tag_variant'])
    .assign(count=lambda x: x['count'].fillna(0).astype(int))
    # get major and minor viral tag and their counts
    .sort_values('count', ascending=False)
    .groupby('cell_barcode', as_index=False)
    .aggregate(viral_tag_counts_total=pd.NamedAgg('count', 'sum'),
               viral_tag_major=pd.NamedAgg('tag_variant', 'first'),
               viral_tag_counts_major=pd.NamedAgg('count', 'max'),
               viral_tag_minor=pd.NamedAgg('tag_variant', 'last'),
               viral_tag_counts_minor=pd.NamedAgg('count', 'min')
               )
    # add the UMI counts
    .merge(umi_counts, on='cell_barcode', how='outer')
    # compute ratio of viral tag to total UMIs
    .assign(viral_tag_major_to_total_UMI=lambda x: x['viral_tag_counts_major'] / x['total_UMIs'],
            viral_tag_minor_to_total_UMI=lambda x: x['viral_tag_counts_minor'] / x['total_UMIs'])
    )

assert tag_counts.notnull().all().all()

Plot the ratio of minor and major tags to UMIs for each cell, and also the distribution of fraction of UMIs from virus for cells stratified by which viral tag variant is the major one:

In [None]:
# correlation of minor and major tag counts to total UMI ratios
ax = tag_counts.plot(x='viral_tag_major_to_total_UMI',
                     y='viral_tag_minor_to_total_UMI',
                     kind='scatter',
                     ax=axes.ravel()[i_axis],
                     alpha=0.1,
                     c=CBPALETTE[0],
                     legend=False,
                     )
_ = ax.set_xlabel('major tag counts / total UMIs')
_ = ax.set_ylabel('minor tag counts / total UMIs')
_ = ax.set_title('major and minor viral tags per cell')
i_axis += 1

# fraction of UMIs that are viral stratified by which is major tag variant
tag_to_color = dict(zip(viral_tags, CBPALETTE[1:]))
for tag in viral_tags:
    ax = tag_counts.query('viral_tag_major == @tag').plot(
                    y='frac_viral_UMIs',
                    kind='hist',
                    bins=25,
                    alpha=0.5,
                    color=tag_to_color[tag],
                    ax=axes.ravel()[i_axis],
                    legend=False,
                    )
ax.legend(labels=viral_tags, title='major viral tag')
_ = ax.set_xlabel('fraction of UMIs from virus')
_ = ax.set_ylabel('number of cells')
_ = ax.set_title('fraction UMIs from virus by major tag')
i_axis += 1

# display figure
fig.tight_layout()
display(fig)

Now we identify the cells that are clearly **not** infected with each viral tag variant.
The logic is as follows: the cells were independently infected with the two tag variants, and the only mixed right before sequencing.
So even in infected cells, the "minor" tag variant should always be at the background level expected of uninfected cells.
The only exception is for doublets of cells infected with the two tag variants.
Therefore, we divide cells by which tag variant is the "minor" tag, and then make a knee plot showing the cell rank versus the counts of that tag normalized by total UMIs.
This should be a near flat line except for an upward "knee" near the end for doublets of cells infected with the two tag variants.
We use the [kneed](https://kneed.readthedocs.io/) package to identify this knee, and  then plot the knee plots.

In [None]:
# identify the "knee" for each minor viral tag variant
knee = {}
knee_y = {}
for tag, ax in zip(viral_tags, axes.ravel()[i_axis: ]):
    
    # identify knee
    df = (tag_counts
          .query('viral_tag_minor == @tag')
          .sort_values('viral_tag_minor_to_total_UMI')
          .assign(cell_rank=lambda x: x['viral_tag_minor_to_total_UMI'].rank(method='first'))
          )
    kl = kneed.KneeLocator(x=df['cell_rank'].tolist(),
                           y=df['viral_tag_minor_to_total_UMI'].tolist(),
                           curve='convex',
                           direction='increasing',
                           S=3,
                           )
    knee[tag] = kl.knee
    knee_y[tag] = kl.knee_y
    
    # plot cells above / below knee
    ax = df.assign(color=lambda x: (x['viral_tag_minor_to_total_UMI'] < knee_y[tag])
                                    .map({True: CBPALETTE[3], False: CBPALETTE[4]})
                   ).plot(
                x='cell_rank',
                y='viral_tag_minor_to_total_UMI',
                kind='scatter',
                s=1.5,
                c='color',
                ax=ax)
    _ = ax.set_xlabel('cell rank')
    _ = ax.set_ylabel('minor tag / total UMIs')
    n_below = (df['viral_tag_minor_to_total_UMI'] < knee_y[tag]).sum()
    n_above = len(df) - n_below
    _ = ax.set_title(f"{tag} minor tag knee plot:\n{n_below} below & {n_above} above knee")
    i_axis += 1
    ax.axvline(x=knee[tag], color=CBPALETTE[0], alpha=0.5, linestyle='--')
    
# display figure
fig.tight_layout()
display(fig)
   
# annotate tag counts by whether below knee (uninfected with that tag)
tag_counts = (
    tag_counts
    .assign(viral_tag_minor_below_knee=lambda x: x.apply(lambda r: (r['viral_tag_minor_to_total_UMI'] <
                                                                    knee_y[r['viral_tag_minor']]),
                                                         axis=1))
    )

Save the output CSV and plot:

In [None]:
assert (tag_counts['cell_barcode'] == adata.var['cell_barcode']).all()

print(f"Saving viral-tag annotated cell barcodes to {cell_annotations}")
tag_counts.to_csv(cell_annotations,
                  compression='gzip',
                  index=False)

print(f"Saving plots to {plot}")
fig.savefig(plot)