# Viral tags by cell from 10x transcriptomics
This Python Jupyter notebook aggregates the viral tags by cell in the 10x transcriptomic data, and drops any viral tags that are ambiguous or invalid.

Import Python modules:

In [None]:
import matplotlib.pyplot as plt

import pandas as pd

Get `snakemake` variables [as described here](https://snakemake.readthedocs.io/en/stable/snakefiles/rules.html#jupyter-notebook-integration):

In [None]:
viral_tag_by_cell_umi_csv = snakemake.input.viral_tag_by_cell_umi_csv
viral_tag_by_cell_csv = snakemake.output.viral_tag_by_cell_csv
plot = snakemake.output.plot
expt = snakemake.wildcards.expt

Read data frame of viral tags by cell and UMI:

In [None]:
viral_tag_by_cell_umi = pd.read_csv(viral_tag_by_cell_umi_csv)

Create data frame where we count viral barcodes by cell (and gene), and then annotate as `not_valid` any tags that are invalid, ambiguous, or disagree within the UMI:

In [None]:
not_valid_tags = ['ambiguous', 'invalid', 'tags_disagree']
valid_tags = [tag for tag in viral_tag_by_cell_umi['tag_variant'].unique()
              if tag not in not_valid_tags]

viral_tag_by_cell = (
    viral_tag_by_cell_umi
    .groupby(['gene', 'cell_barcode', 'tag_variant'], as_index=False)
    .aggregate(count=pd.NamedAgg('UMI', 'count'))
    .assign(not_valid=lambda x: x['tag_variant'].isin(not_valid_tags),
            tag_variant=lambda x: pd.Categorical(x['tag_variant'],
                                                 valid_tags + not_valid_tags,
                                                 ordered=True)
            )
    )

Write output CSV file with the viral tag counts per cell, dropping invalid tags:

In [None]:
print(f"Writing per-cell viral tag counts to {viral_tag_by_cell_csv}")
(viral_tag_by_cell
 .query('not not_valid')
 .drop(columns='not_valid')
 .to_csv(viral_tag_by_cell_csv,
         compression='gzip',
         index=False)
 )

Make summary plots:

In [None]:
fig, axes = plt.subplots(nrows=2,
                         figsize=(15, 7))

fig.suptitle(f"viral tags in 10x transcriptomics for experiment {expt}")

# number of UMIs with viral the for each gene stratified by whether valid
tot_umi_counts = (
    viral_tag_by_cell
    .groupby(['gene', 'tag_variant'], as_index=False, observed=True)
    .aggregate(UMIs=pd.NamedAgg('count', 'sum'))
    .pivot_table(index='gene',
                 columns='tag_variant')
    .fillna(0)
    )
tot_umi_counts.columns = tot_umi_counts.columns.get_level_values(1)
tot_umi_counts.columns.name = None

for ax, ytype in zip(axes, ['number', 'fraction']):
    if ytype == 'number':
        df = tot_umi_counts
    elif ytype == 'fraction':
        df = tot_umi_counts.div(tot_umi_counts.sum(axis=1), axis=0)
    _ = df.plot(kind='bar',
                ax=ax,
                width=0.8,
                ).legend(loc='center left', bbox_to_anchor=(1, 0.5))
    ax.set_ylabel(f"{ytype} UMIs with viral tag")
    ymax = df.max().max()
    ax.set_ylim(0, 1.4 * ymax)
    for p in ax.patches:
        ax.annotate(f"{p.get_height():.2g}",
                    (p.get_x() + 0.02, p.get_height() + 0.05 * ymax),
                    rotation=90,
                    color=p.get_facecolor())

# save plot
fig.tight_layout()
print(f"Saving plot to {plot}")
fig.savefig(plot)