# Count viral tags on aligned Illumina 10X reads
This Python Jupyter notebook counts the viral tags on aligned Illumina 10X data and outputs the counts of each viral tag variant for each cell barcode and gene into a CSV.

## Parameters for notebook
First, set the parameters for the notebook.
That should be done in the next cell, which is tagged as a `parameters` cell to enable [papermill parameterization](https://papermill.readthedocs.io/en/latest/usage-parameterize.html):

In [None]:
# parameters cell; in order for notebook to run this cell must define:
#  - input_fastq10x_bam: BAM file with aligned FASTQ 10X reads
#  - input_fastq10x_bai: BAM index file for `input_fastq10x_bai`
#  - input_viraltag_locs: CSV file giving the location of the viral tags
#  - input_viraltag_identities: YAML file giving expected identity of each tag for each tag variant
#  - input_cellbarcodes: TSV file giving valid cell barcodes
#  - output_viraltag_counts: created CSV file with the counts of each tag variant for each gene

## Import Python modules

In [None]:
import pandas as pd

from plotnine import *

from pymodules.tags_and_barcodes import extract_tags

import pysam

import yaml

## Read tags and expected identities

In [None]:
print(f"Reading expected tag identities from {input_viraltag_identities}")
with open(input_viraltag_identities) as f:
    viraltag_identities_dict = yaml.safe_load(f)

viraltag_identities_df = (
    pd.DataFrame.from_records(
        [(gene, tag, tagvariant, nt)
         for gene, genetags in viraltag_identities_dict.items()
         for tag, tagvals in genetags.items()
         for tagvariant, nt in tagvals.items()],
        columns=['gene', 'tag', 'tag_variant', 'nucleotide'])
    .pivot_table(index=['gene', 'tag'],
                  columns='tag_variant',
                  values='nucleotide',
                  aggfunc='sum')
    )

print('Here are the tag identities:')
viraltag_identities_df

Get tag variants and make sure an identity is defined for each tag variant at each tag:

In [None]:
tag_variants = viraltag_identities_df.columns.tolist()
print(f"Here are the {len(tag_variants)} tag variants: {', '.join(tag_variants)}")

assert viraltag_identities_df.notnull().any(axis=None), 'identities missing for some tags'

Read the viral tag locations:

In [None]:
print(f"Reading viral tag locations from {input_viraltag_locs}")
viraltag_locs_df = pd.read_csv(input_viraltag_locs)
viraltag_locs_df

Make sure we have locations and tag identities for same genes:

In [None]:
tagged_viral_genes = viraltag_locs_df['gene'].unique()
assert set(tagged_viral_genes) == set(viraltag_identities_dict)

And that for each viral gene we have the same set of tags:

In [None]:
for gene in tagged_viral_genes:
    if (set(viraltag_locs_df.query('gene == @gene')['tag_name']) != set(viraltag_identities_dict[gene])):
        raise ValueError(f"inconsistent tags for {gene}")

## Get set of valid cell barcodes

In [None]:
print(f"Reading vaid cell barcodes from {input_cellbarcodes}")

cellbarcodes = set(pd.read_csv(input_cellbarcodes, header=None)[0])

print(f"Read {len(cellbarcodes)} valid barcodes.")

## Count viral tags
For each cell barcode and each viral gene, we count the number of unique reads for each viral tag variant.

The basic process is as follows:
 1. For each viral tag, parse the tag identity for all reads that cover that tag.
    The reads are grouped by UMI and cell barcode, and the tag is labeled as `ambiguous` if no tag identities compose more than 50% of the reads for a UMI in a cell.
 2. The tags are then assigned to their tag variants, labeling the tag variant as `invalid` if the identity at the tag doesn't belong to an expected tag variant.
 3. Unify the tag assigments across tags by looking for UMIs that have several tags covered.
    If all of the tags agree on the tag variant, then that's what we assign that UMI.
    If the tags disagree, label as `tags_disagree`.
    
The output of this process is the tidy data frame `tag_variant_counts`:

In [None]:
print(f"Parsing reads from {input_fastq10x_bam} (index {input_fastq10x_bai}):\n")

with pysam.AlignmentFile(input_fastq10x_bam, index_filename=input_fastq10x_bai) as bamfile:
    
    tag_variant_counts = pd.DataFrame({},
                                      columns=['gene', 'cell_barcode', 'tag_variant', 'count'])
    
    for gene in tagged_viral_genes:
        print(f"Processing tags for {gene}")
        
        # parse each tag for the gene
        counts_df = pd.DataFrame({}, columns=['cell_barcode', 'UMI', 'tag_variant'])
        for tup in viraltag_locs_df.query('gene == @gene').itertuples():
            tag_to_variant = {var: tag for tag, var in
                                  viraltag_identities_dict[gene][tup.tag_name].items()}
            tag_to_variant['ambiguous'] = 'ambiguous'
            readiterator = bamfile.fetch(contig=tup.gene,
                                         start=tup.start - 1,  # convert 1- to 0-based indexing
                                         end=tup.end,
                                         )
            tag_counts_df = (
                        extract_tags(readiterator, cellbarcodes, tup.start - 1, tup.end)
                        .assign(tag_variant=lambda x: (x['tag']
                                                       .map(tag_to_variant)
                                                       .fillna('invalid'))
                                 )
                        [['cell_barcode', 'UMI', 'tag_variant']]
                        )
            print(f"\tParsed tag {tup.tag_name} for {len(tag_counts_df)} UMIs.")
            counts_df = counts_df.append(tag_counts_df)
            
        # unify tag variant assignments across tags for gene
        tag_variant_counts = tag_variant_counts.append(
            counts_df
            .groupby(['cell_barcode', 'UMI'])
            .aggregate(tag_variant=pd.NamedAgg('tag_variant', 'first'),
                       n_tag_variants=pd.NamedAgg('tag_variant', 'nunique'))
            .reset_index()
            .groupby(['cell_barcode', 'tag_variant', 'n_tag_variants'])
            .aggregate(count=pd.NamedAgg('UMI', 'count'))
            .reset_index()
            .assign(tag_variant=lambda x: x['tag_variant'].where(x['n_tag_variants'] == 1,
                                                                 'tags disagree'),
                    gene=gene)
            [['gene', 'cell_barcode', 'tag_variant', 'count']],
            ignore_index=True, sort=False
            )     

The results are now in the data frame `tag_variant_counts`:

In [None]:
tag_variant_counts

## Summarize the viral tag counts
Get the total fraction of each tag variant for each gene across all cell barcodes:

In [None]:
tag_order = tag_variants + sorted(t for t in tag_variant_counts['tag_variant'].unique()
                                  if t not in tag_variants)

total_tag_variant_counts = (
    tag_variant_counts
    .groupby(['gene', 'tag_variant'])
    .aggregate({'count': 'sum'})
    .reset_index()
    .assign(gene=lambda x: pd.Categorical(x['gene'],
                                          tagged_viral_genes,
                                          ordered=True),
            tag_variant=lambda x: pd.Categorical(x['tag_variant'],
                                                 tag_order,
                                                 ordered=True,
                                                 ),
            total=lambda x: x.groupby('gene')['count'].transform('sum'),
            frac=lambda x: x['count'] / x['total'],
            valid_tag=lambda x: x['tag_variant'].isin(tag_variants)
            )
    )

(total_tag_variant_counts
 .pivot_table(index=['gene', 'total'],
              columns='tag_variant',
              values='frac',
              fill_value=0)
 .round(3)
 )

Plot number of UMIs with called viral tags for each gene:

In [None]:
p = (ggplot(total_tag_variant_counts
            .groupby(['gene', 'valid_tag'])
            .aggregate({'count': 'sum'})
            .reset_index(),
            aes('gene', 'count', fill='valid_tag')) +
     geom_bar(stat='identity', position='dodge') +
     theme(figure_size=(0.4 *  len(tagged_viral_genes), 2),
           axis_text_x=element_text(angle=90)) +
     ylab('UMIs with viral tag') +
     scale_fill_manual(values=('#E69F00', '#56B4E9'))
     )
_ = p.draw()

Plot the fraction of called tags that go with each variant for each gene:

In [None]:
p = (ggplot(total_tag_variant_counts,
            aes('tag_variant', 'frac', fill='valid_tag')) +
     geom_bar(stat='identity') +
     facet_wrap('~ gene', nrow=1) +
     theme(figure_size=(1.25 * len(tagged_viral_genes), 2),
           axis_text_x=element_text(angle=90)) +
     ylab('fraction with tag') +
     scale_fill_manual(values=('#E69F00', '#56B4E9'))
     )
_ = p.draw()

## Write counts of valid viral tag variants
Finally, we write the counts of the valid viral tag variants to an output CSV file for later use.
Note that we **only** write non-zero counts for valid tags and cell barcodes:

In [None]:
print(f"Writing counts to {output_viraltag_counts}")

(tag_variant_counts
 .query('tag_variant in @tag_variants')
 [['gene', 'cell_barcode', 'tag_variant', 'count']]
 .to_csv(output_viraltag_counts, index=False)
 )