# Coverage of viral genes by Illumina 10X reads
This Python Jupyter notebook examines where the 10X aligned FASTQ reads fall on the viral genes (coverage), including with respect to viral barcodes and viral tags.
It also writes information on the locations of the viral barcodes and tags in the viral genes.

## Parameters
First, set the parameters for the notebook, such as to specify the input files and output plots.
This is done in the next cell, which is tagged as a `parameters` cell to enable [papermill parameterization](https://papermill.readthedocs.io/en/latest/usage-parameterize.html):

In [None]:
# parameters cell; in order for notebook to run this cell must define:
#  - samples_10x: list of 10X samples
#  - input_fastq10x_bams: list of BAM file with alignments of 10X reads for each sample
#  - input_fastq10x_bais: BAM indices for each file in `input_fastq10x_bam`
#  - input_viral_genbank: Genbank file with annotated viral genes
#  - output_viraltag_locs: output file with location of viral tags (1, ..., numbering)
#  - output_viralbc_locs: output file with location of viral barcodes (1, ..., numbering)

Check input lists all the right length, then make dicts that map sample names to BAM and BAM index files:

In [None]:
assert len(samples_10x) == len(input_fastq10x_bams) == len(input_fastq10x_bais)

fastq10x_bams = dict(zip(samples_10x, input_fastq10x_bams))
fastq10x_bais = dict(zip(samples_10x, input_fastq10x_bais))

## Import Python modules
We use [pysam](https://pysam.readthedocs.io/) to plot process the BAM files, [dna_features_viewer](https://edinburgh-genome-foundry.github.io/DnaFeaturesViewer/) for genes, and [plotnine](https://plotnine.readthedocs.io/) for some ggplot2-style plotting:

In [None]:
import collections
import itertools

import Bio.SeqIO

from IPython.display import display, HTML

import mizani

import pandas as pd

from plotnine import *

from pymodules.plot_viral_genes import plot_genes_and_coverage

import pysam

Color-blind palette:

In [None]:
cbpalette = ('#999999', '#E69F00', '#56B4E9', '#009E73',
             '#F0E442', '#0072B2', '#D55E00', '#CC79A7')

Set [plotnine theme](https://plotnine.readthedocs.io/en/stable/api.html#themes):

In [None]:
_ = theme_set(theme_classic)

## Read the viral genes
Get all the viral genes as [BioPython SeqRecords](https://biopython.org/wiki/SeqRecord):

In [None]:
print(f"Reading viral genes from {input_viral_genbank}")
viral_genes = list(Bio.SeqIO.parse(input_viral_genbank, 'genbank'))
viral_gene_names = [s.id for s in viral_genes]
print(f"Found {len(viral_gene_names)} viral genes:\n\t" +
      '\n\t'.join(viral_gene_names))
assert len(viral_genes) == len(set(viral_gene_names)), 'viral gene names not unique'

## Get viral tag and barcode locations
Get the locations of the viral tags and barcodes for each gene in 1-based indexing inclusive of the first and last number (so the type of numbering that Genbank files use):

In [None]:
viraltag_df = []
viralbc_df = []
for seqrecord in viral_genes:
    for f in seqrecord.features:
        if f.type == 'viral_barcode':
            viralbc_df.append([seqrecord.id, int(f.location.start) + 1, int(f.location.end)])
        elif 'tag' in f.type:
            viraltag_df.append((seqrecord.id, f.type, int(f.location.start) + 1, int(f.location.end)))
            
viraltag_df = pd.DataFrame.from_records(viraltag_df,
                                        columns=['gene', 'tag_name', 'start', 'end'])
assert len(viraltag_df) == len(viraltag_df.drop_duplicates())
print(f"\nViral tag locations; writing to {output_viraltag_locs}")
viraltag_df.to_csv(output_viraltag_locs, index=False)
display(HTML(viraltag_df.to_html(index=False)))

viralbc_df = pd.DataFrame.from_records(viralbc_df,
                                       columns=['gene', 'start', 'end'])
assert len(viralbc_df) == len(viralbc_df.drop_duplicates())
print(f"\nViral barcode locations; writing to {output_viralbc_locs}")
viralbc_df.to_csv(output_viralbc_locs, index=False)
display(HTML(viralbc_df.to_html(index=False)))

## Get alignment statistics for each gene and sample
Use [pysam](https://pysam.readthedocs.io/) to get the following dataframes:
 - `nreads_df`: total number of reads aligned to each gene in each sample, and total number of reads with gaps.
 - `coverage_df`: coverage at each site for each viral gene in each sample
 - `readlen_df`: distribution of lengths of the **aligned** portion of reads (not including soft-clipped bases) for each gene in each sample
 - `gapped_coverage_df`: like `coverage_df` but only includes reads with a gap.
 
In computing these statistics, a read is considered to have a gap if it has a substantial number of sites (set by the `gapped_sites_threshold` variable below) that are gapped in the alignment.
Gaps are considered anything annotated as either a `D` (deletion) or `N` (intron) in the cigar string.
The rational is that reads with reasonably long gaps might be spanning internal deletion junctions in viral genes.

In [None]:
# count a read as contributing to `gapped_coverage_list` if it has at
# least this many gapped sites (deletion or intron operations)
gapped_sites_threshold = 10

coverage_list = []
nreads_list = []
readlen_list = []
gapped_coverage_list = []

for viral_gene, sample10x in itertools.product(viral_genes, samples_10x):
    gene_name = viral_gene.id
    bam = fastq10x_bams[sample10x]
    bai = fastq10x_bais[sample10x]
    print(f"Getting statistics for {gene_name} in {sample10x} from {bam}")
    
    with pysam.AlignmentFile(bam, mode='rb', index_filename=bai) as bamfile:
        if len(viral_gene) != bamfile.get_reference_length(gene_name):
            raise ValueError(f"length of {gene_name} not as expected in {bam}")
            
        coverage_list.append(
                pd.DataFrame(dict(zip('ACGT',
                                      bamfile.count_coverage(contig=gene_name))))
                .assign(coverage=lambda x: x.sum(axis=1),
                        site=lambda x: x.index + 1,
                        gene=gene_name,
                        sample=sample10x)
                )
        
        readlens = collections.defaultdict(int)
        n_gapped_reads = 0
        gapped_coverage = collections.Counter({i:0 for i in range(len(viral_gene))})
        for read in bamfile.fetch(contig=gene_name):
            readlens[read.query_alignment_length] += 1
            cs = read.cigarstring
            if ((('D' in cs) or ('N' in cs)) and (sum(read.get_cigar_stats()[0][2: 4]) >=
                                                  gapped_sites_threshold)):
                gapped_coverage.update(read.get_reference_positions())
                n_gapped_reads += 1
                
        nreads_list.append(
                pd.DataFrame({'gene': [gene_name],
                              'sample': [sample10x],
                              'n_reads': [bamfile.count(contig=gene_name)],
                              'n_gapped_reads': [n_gapped_reads],
                              })
                ) 
            
        readlen_list.append(
                pd.DataFrame.from_records(list(readlens.items()),
                                          columns=['aligned_read_length', 'nreads'])
                .assign(gene=gene_name,
                        sample=sample10x)
                .sort_values('aligned_read_length')
                [['gene', 'sample', 'aligned_read_length', 'nreads']]
                )
        
        gapped_coverage_list.append(
                pd.DataFrame.from_records(list(gapped_coverage.items()),
                                          columns=['site', 'coverage'])
                .assign(site=lambda x: x['site'] + 1,
                        gene=gene_name,
                        sample=sample10x)
                [['gene', 'sample', 'site', 'coverage']]
                )
        
        
nreads_df = (pd.concat(nreads_list, sort=False, ignore_index=True)
             .assign(gene=lambda x: pd.Categorical(x['gene'],
                                                   viral_gene_names,
                                                   ordered=True))
             )
print('\n`nreads_df`:')
display(HTML(nreads_df.to_html(index=False)))
        
coverage_df = (pd.concat(coverage_list, sort=False, ignore_index=True)
               .assign(gene=lambda x: pd.Categorical(x['gene'],
                                                     viral_gene_names,
                                                     ordered=True))
               )
print('\nFirst few lines of `coverage_df`:')
display(HTML(coverage_df.head().to_html(index=False)))

readlen_df = (pd.concat(readlen_list, sort=False, ignore_index=True)
              .assign(gene=lambda x: pd.Categorical(x['gene'],
                                                    viral_gene_names,
                                                    ordered=True))
              )
print('\nFirst few lines of `readlen_df`:')
display(HTML(readlen_df.head().to_html(index=False)))

gapped_coverage_df = (pd.concat(gapped_coverage_list, sort=False,
                                ignore_index=True)
                      .assign(gene=lambda x: pd.Categorical(x['gene'],
                                                            viral_gene_names,
                                                            ordered=True))
                      )
print('\nFirst few lines of `gapped_coverage_df`:')
display(HTML(gapped_coverage_df.head().to_html(index=False)))

Sanity check to make sure we got the right number of sites for each gene in `coverage_df` and `gapped_coverage_df`:

In [None]:
for viral_gene in viral_genes:
    gene_name = viral_gene.id
    for df in [coverage_df, gapped_coverage_df]:
        lengths = (df
                   .query('gene == @gene_name')
                   .groupby('sample')
                   .size()
                   .values
                   )
        if not all(lengths == len(viral_gene)):
            raise ValueError(f"coverage not for expected number sites for {gene_name}")

Sanity check to make sure `nreads_df` and `readlen_df` have same number of reads:

In [None]:
nreads_n = nreads_df.sort_values(['gene', 'sample'])['n_reads'].values
readlen_n = (readlen_df
             .groupby(['gene', 'sample'])
             .aggregate({'nreads': 'sum'})
             ['nreads']
             .values
             )
if not all(nreads_n == readlen_n):
    raise ValueError('nreads differ for `nreads_df` and `readlen_df`')

## Total reads per viral gene
Plot the total number of aligned reads for each viral gene:

In [None]:
p = (ggplot(nreads_df, aes('gene', 'n_reads')) +
     geom_bar(stat='identity') +
     facet_wrap('~ sample', nrow=1) +
     theme(figure_size=(2.5 * len(samples_10x), 2),
           axis_text_x=element_text(angle=90)) +
     scale_y_continuous(labels=mizani.formatters.scientific_format(),
                        name='number aligned reads')
     )

_ = p.draw()

## Fraction of reads that are gapped
Plot the fraction of all reads that are gapped for each viral gene:

In [None]:
p = (ggplot(nreads_df.assign(frac=lambda x: x['n_gapped_reads'] / x['n_reads']),
            aes('gene', 'frac')) +
     geom_bar(stat='identity') +
     facet_wrap('~ sample', nrow=1) +
     theme(figure_size=(2.5 * len(samples_10x), 2),
           axis_text_x=element_text(angle=90)) +
     ylab('fraction gapped reads')
     )

_ = p.draw()

## Lengths of aligned reads for each viral gene
Plot the distribution of the lengths of the **aligned** portions of each read for each viral gene:

In [None]:
p = (ggplot(readlen_df, aes('aligned_read_length', 'nreads')) +
     geom_bar(stat='identity') +
     facet_grid('gene ~ sample', scales='free_y') +
     theme(figure_size=(2 * len(samples_10x), 1 * len(viral_gene_names)),
           axis_text_x=element_text(angle=90)) +
     scale_y_continuous(labels=mizani.formatters.scientific_format())
     )

_ = p.draw()

## Coverage per site, including over viral tags and barcodes
Now plot coverage per site alongside gene structure.
In this plot, we indicate the viral tags (blue) and viral barcodes (orange):

In [None]:
fig, _ = plot_genes_and_coverage(viral_genes, coverage_df)

In the above plot, reads that align not at the 3' end have two likely causes:
 - Premature polyadenylation or mis-priming by oligo-dT primers cause the 3' primer to anneal before the end of the gene.
 - An internal deletion in the viral gene.

We can get some sense of these by plotting the coverage indicating observed mutations (an enrichment of `A` mutations before a peak could indicate the polyA explanation) and coverage only from the gapped reads (which could indicate an internal deletion in the viral gene).

First make the plot showing mutant nucleotides at each site:

In [None]:
fig, _ = plot_genes_and_coverage(viral_genes, coverage_df, color_mutations=True)

Now plot coverage for **just** the gapped reads:

In [None]:
fig, _ = plot_genes_and_coverage(viral_genes, gapped_coverage_df)