# Count viral barcodes from aligned Illumina 10X reads
This Python Jupyter notebook counts the viral barcodes from aligned Illumina 10X data and outputs the counts of each viral barcode for each cell barcode and barcoded gene into a CSV.

## Parameters for notebook
First, set the parameters for the notebook.
That should be done in the next cell, which is tagged as a `parameters` cell to enable [papermill parameterization](https://papermill.readthedocs.io/en/latest/usage-parameterize.html):

In [None]:
# parameters cell; in order for notebook to run this cell must define:
#  - input_fastq10x_bam: BAM file with aligned FASTQ 10X reads
#  - input_fastq10x_bai: BAM index file for `input_fastq10x_bai`
#  - input_viraltag_locs: CSV file giving the location of the viral tags
#  - input_viraltag_identities: YAML file giving expected identity of each tag for each tag variant
#  - input_cellbarcodes: TSV file giving valid cell barcodes
#  - output_viraltag_counts: created CSV file with the counts of each tag variant for each gene

In [None]:
# Parameters
input_fastq10x_bam = "results/aligned_fastq10x/wt_virus_pilot/Aligned.sortedByCoord.out.bam"
input_fastq10x_bai = "results/aligned_fastq10x/wt_virus_pilot/Aligned.sortedByCoord.out.bam.bai"
input_viralbc_locs = "results/viral_fastq10x/viralbc_locs.csv"
input_cellbarcodes = "results/aligned_fastq10x/wt_virus_pilot/Solo.out/Gene/filtered/barcodes.tsv"
output_viralbc_counts = "results/viral_fastq10x/wt_virus_pilot_viralbc_counts.csv"

## Import Python modules

In [None]:
import pandas as pd

from plotnine import *

from pymodules.tags_and_barcodes import extract_tags

import pysam

## Read viral barcode locations

Read the viral tag locations:

In [None]:
print(f"Reading viral barcode locations from {input_viralbc_locs}")
viralbc_locs_df = pd.read_csv(input_viralbc_locs)
viralbc_locs_df

Get names of the barcoded viral genes:

In [None]:
bc_viral_genes = viralbc_locs_df['gene'].unique()

assert len(bc_viral_genes) == len(viralbc_locs_df), 'currently on support on barcode per gene'

## Get set of valid cell barcodes

In [None]:
print(f"Reading valid cell barcodes from {input_cellbarcodes}")

cellbarcodes = set(pd.read_csv(input_cellbarcodes, header=None)[0])

print(f"Read {len(cellbarcodes)} valid barcodes.")

## Count viral barcodes
For each cell barcode and each viral gene, we count the number of unique reads for each viral tag variant.

For each viral barcode, we parse the barcode identity for all reads that cover that barcode.
The reads are grouped by UMI and cell barcode, and the viral barcode is labeled as `ambiguous` if there is not a majority consensus (>50%) nucleotide at any site in the viral barcode.
    
The output of this process is the tidy data frame `viralbc_counts`:

In [None]:
print(f"Parsing reads from {input_fastq10x_bam} (index {input_fastq10x_bai}):\n")

with pysam.AlignmentFile(input_fastq10x_bam, index_filename=input_fastq10x_bai) as bamfile:
    
    viralbc_counts = pd.DataFrame({},
                                  columns=['gene', 'cell_barcode', 'viral_barcode', 'count'])
    
    assert len(viralbc_locs_df) == viralbc_locs_df['gene'].nunique()
    for tup in viralbc_locs_df.itertuples():
        print(f"Processing viral barcodes for {tup.gene}...", end=' ')
        
        readiterator = bamfile.fetch(contig=tup.gene,
                                     start=tup.start - 1,  # convert 1- to 0-based indexing
                                     end=tup.end,
                                     )
        gene_counts_df = (
                extract_tags(readiterator, cellbarcodes, tup.start - 1, tup.end)
                .rename(columns={'tag': 'viral_barcode'})
                [['cell_barcode', 'UMI', 'viral_barcode']]
                )
        print(f"parsed viral barcodes for {len(gene_counts_df)} UMIs.")
            
        # aggregate viral barcode counts by gene and cell barcode
        viralbc_counts = viralbc_counts.append(
            gene_counts_df
            .groupby(['cell_barcode', 'viral_barcode'])
            .aggregate(count=pd.NamedAgg('UMI', 'count'))
            .reset_index()
            .assign(gene=tup.gene,
                    is_ambiguous=lambda x: x['viral_barcode'] == 'ambiguous')
            [['gene', 'cell_barcode', 'viral_barcode', 'count', 'is_ambiguous']],
            ignore_index=True, sort=False
            )     

The results are now in the data frame `viralbc_counts`:

In [None]:
viralbc_counts

## Number of called viral barcodes
Tabulate and plot the total number of UMIs with called viral barcodes for each gene, indicating the ones that are ambiguous as well.

First tabulate:

In [None]:
summary_df = (viralbc_counts
              .groupby(['gene', 'is_ambiguous'])
              .aggregate({'count': 'sum'})
              .reset_index()
              )

summary_df

Now plot:

In [None]:
p = (ggplot(summary_df, aes('gene', 'count', fill='is_ambiguous')) +
     geom_bar(stat='identity', position='dodge') +
     theme(figure_size=(1 * len(bc_viral_genes), 2.2),
           axis_text_x=element_text(angle=90)) +
     ylab('UMIs with viral barcode') +
     scale_fill_manual(values=('#E69F00', '#56B4E9'))
     )

_ = p.draw()

## Distribution of UMIs per viral barcode
Plot the number of UMIs per viral barcode (aggregated over all cell barcodes) in a [knee plot](https://liorpachter.wordpress.com/tag/knee-plot), excluding ambiguous barcodes:

In [None]:
n_umis = (
    viralbc_counts
    .query('is_ambiguous == False')
    .groupby(['gene', 'viral_barcode'])
    .aggregate({'count': 'sum'})
    .reset_index()
    .sort_values('count', ascending=False)
    .assign(rank=lambda x: x.groupby('gene').cumcount() + 1)
    )

p = (ggplot(n_umis, aes('rank', 'count')) +
     geom_path() +
     facet_wrap('~ gene', nrow=1) +
     theme(figure_size=(3 * len(bc_viral_genes), 2.5)) +
     scale_x_log10(name='viral barcode rank') +
     scale_y_log10(name='number UMIs (across all cells)')
     )

_ = p.draw()

## Write viral barcodes to file
We write a CSV file giving the count of each viral barcode for each gene and cell, excluding ambiguous viral barcodes:

In [None]:
print(f"Writing counts to {output_viralbc_counts}")

(viralbc_counts
 .query('is_ambiguous == False')
 [['gene', 'cell_barcode', 'viral_barcode', 'count']]
 .to_csv(output_viralbc_counts, index=False)
 )