# Extract viral tags from aligned 10x transcriptomic data
This Python Jupyter notebook uses the aligned 10x transcriptomic data to tally viral tags for each 10x cell barcode and UMI.
It does this only for the **valid** cell barcodes, and uses the error-corrected cell barcodes and UMIs reported in the BAM file created by `STARsolo`.

Import Python modules:

In [None]:
import pandas as pd

from pymodules.tags_and_barcodes import extract_tags

import pysam

import yaml

Get `snakemake` variables [as described here](https://snakemake.readthedocs.io/en/stable/snakefiles/rules.html#jupyter-notebook-integration):

In [None]:
bam = snakemake.input.bam
bai = snakemake.input.bai
cell_barcodes = snakemake.input.cell_barcodes
viral_tag_locs = snakemake.input.viral_tag_locs
viral_tag_identities = snakemake.input.viral_tag_identities
viral_tag_by_cell_umi_csv = snakemake.output.viral_tag_by_cell_umi_csv
expt = snakemake.wildcards.expt

Read expected identities for viral tags:

In [None]:
print(f"Reading expected tag identities from {viral_tag_identities}")
with open(viral_tag_identities) as f:
    tag_identities_dict = yaml.safe_load(f)

tag_identities_df = (
    pd.DataFrame.from_records(
        [(gene, tag, tagvariant, nt)
         for gene, genetags in tag_identities_dict.items()
         for tag, tagvals in genetags.items()
         for tagvariant, nt in tagvals.items()],
        columns=['gene', 'tag', 'tag_variant', 'nucleotide'])
    .pivot_table(index=['gene', 'tag'],
                  columns='tag_variant',
                  values='nucleotide',
                  aggfunc='sum')
    )

print('Here are the tag identities:')
display(tag_identities_df)

assert tag_identities_df.notnull().any(axis=None), 'identities missing for some tags'

tag_variants = tag_identities_df.columns.tolist()
print(f"There are {len(tag_variants)} tag variants:\n  " + '\n  '.join(tag_variants))

assert 'ambiguous' not in {tag.lower() for tag in tag_variants}, 'cannot have tag variant "ambiguous"'

Read the viral tag locations:

In [None]:
print(f"Reading viral tag locations from {viral_tag_locs}")
tag_locs_df = pd.read_csv(viral_tag_locs)
tag_locs_df

Make sure we have locations and tag identities for same genes, and for each gene we have the same set of tags:

In [None]:
tagged_genes = tag_locs_df['gene'].unique()
assert set(tagged_genes) == set(tag_identities_dict)

for gene in tagged_genes:
    if (set(tag_locs_df.query('gene == @gene')['tag_name']) != set(tag_identities_dict[gene])):
        raise ValueError(f"inconsistent tags for {gene}")

Get set of valid cell barcodes

In [None]:
print(f"Reading valid cell barcodes from {cell_barcodes}")
cell_barcode_set = set(pd.read_csv(cell_barcodes, header=None)[0])
print(f"Read {len(cell_barcode_set)} valid barcodes.")

Now we get the viral tags.
Specifically, parse the BAM file, and for each read mapping to a viral gene with a valid cell barcode and UMI, we see if we can determine the viral tag identity.
The basic process is as follows:
 1. For each viral tag location, parse the tag identity for all reads for valid cell barcodes that cover that tag.
    The reads are grouped by UMI and cell barcode, and the tag is labeled as `ambiguous` if no tag identities compose more than 50% of the reads for a UMI in a cell.
 2. The tags are then assigned to their tag variants, labeling the tag variant as `invalid` if the identity at the tag doesn't belong to an expected tag variant.
 3. Unify the tag assigments across tag locations within each gene for UMIs that have several tag locations covered.
    If all of the tag locations agree on the tag variant, then that's what we assign that UMI.
    If the tag locations disagree, label as `tags_disagree`.

In [None]:
print(f"Parsing viral tags from {bam} (index {bai}):\n")

tags_by_umi = pd.DataFrame({}, columns=['gene', 'cell_barcode', 'UMI', 'tag_name', 'tag_variant'])

with pysam.AlignmentFile(bam, index_filename=bai) as bamfile:
    for gene in tagged_genes:
        print(f"Processing viral tags for {gene}")
        # parse each tag for the gene
        gene_tags_by_umi = pd.DataFrame({}, columns=['cell_barcode', 'UMI', 'tag_variant'])
        for tup in tag_locs_df.query('gene == @gene').itertuples():
            tag_to_variant = {var: tag for tag, var in
                              tag_identities_dict[gene][tup.tag_name].items()}
            tag_to_variant['ambiguous'] = 'ambiguous'
            readiterator = bamfile.fetch(contig=tup.gene,
                                         start=tup.start - 1,  # convert 1- to 0-based indexing
                                         end=tup.end,
                                         )
            gene_loc_tags_by_umi = (
                        extract_tags(readiterator, cell_barcode_set, tup.start - 1, tup.end)
                        .assign(tag_variant=lambda x: (x['tag']
                                                       .map(tag_to_variant)
                                                       .fillna('invalid')),
                                tag_name=tup.tag_name,
                                )
                        [['cell_barcode', 'UMI', 'tag_name', 'tag_variant']]
                        )
            print(f"\tParsed tag {tup.tag_name} for {len(gene_loc_tags_by_umi)} UMIs.")
            gene_tags_by_umi = gene_tags_by_umi.append(gene_loc_tags_by_umi) 
        tags_by_umi = tags_by_umi.append(gene_tags_by_umi.assign(gene=gene))
    
if len(tags_by_umi) != len(tags_by_umi
                           [['gene', 'cell_barcode', 'UMI', 'tag_name']]
                           .drop_duplicates()
                           ):
    raise ValueError('not unique tag variant call for each gene / cell / UMI / tag')
    
# unify tag assignments across tag locations for each gene / cell / UMI
tags_by_umi = (
    tags_by_umi
    .groupby(['gene', 'cell_barcode', 'UMI'], as_index=False)
    .aggregate(tag_variant=pd.NamedAgg('tag_variant', 'first'),
               n_tag_locs_called=pd.NamedAgg('tag_variant', 'count'),
               n_tag_variants=pd.NamedAgg('tag_variant', 'nunique'),
               )
    .assign(tag_variant=lambda x: x['tag_variant'].where(x['n_tag_variants'] == 1,
                                                         'tags_disagree'),
            )
    .drop(columns='n_tag_variants')
    )

Write the viral tags to the output CSV file:

In [None]:
print(f"Writing viral tags to {viral_tag_by_cell_umi_csv}")

tags_by_umi.to_csv(viral_tag_by_cell_umi_csv,
                   index=False,
                   compression='gzip')