# Analyze cell-gene matrix
This Python Jupyter notebook processes the cell-gene matrix for basic purposes such as removing doublets and examining the amount of viral products per cell.


## Parameters
First, set the parameters for the notebook, such as to specify the input files and output plots.
This is done in the next cell, which is tagged as a `parameters` cell to enable [papermill parameterization](https://papermill.readthedocs.io/en/latest/usage-parameterize.html):

In [None]:
# parameters cell; in order for notebook to run this cell must define:
#  - input_matrix: filtered gene-barcode matrix from `STARsolo`
#  - input_features: list of features (genes) from `STARsolo`
#  - input_barcodes: list of cell barcodes from `STARsolo`
#  - input_viral_gtf: GTF file giving names of viral genes

## Import Python modules
We use [anndata](https://anndata.readthedocs.io/) and [scanpy](https://scanpy.readthedocs.io/) for most the data processing, and [plotnine](https://plotnine.readthedocs.io/) for ggplot2-style plotting:

In [None]:
import os

import anndata
import BCBio.GFF
from IPython.display import display, HTML
import numpy
import pandas as pd
from plotnine import *
import scanpy as sc

Set [scanpy](https://scanpy.readthedocs.io/) to provide lots of information including hints:

In [None]:
sc.settings.verbosity = 3  # verbosity: errors (0), warnings (1), info (2), hints (3)

Versions of key [scanpy](https://scanpy.readthedocs.io/) packages:

In [None]:
sc.logging.print_versions()

Color-blind palette:

In [None]:
cbpalette = ('#999999', '#E69F00', '#56B4E9', '#009E73',
             '#F0E442', '#0072B2', '#D55E00', '#CC79A7')

Set [plotnine theme](https://plotnine.readthedocs.io/en/stable/api.html#themes):

In [None]:
_ = theme_set(theme_classic)

## Read the cell-gene matrix
Read the cell-gene matrix into an [AnnData](https://anndata.readthedocs.io/) annotated data object.
We can't _quite_ use the [scanpy read_10x_mtx](https://icb-scanpy.readthedocs-hosted.com/en/stable/api/scanpy.api.read_10x_mtx.html) function as the `STARsolo` output isn't in quite the write format, so instead we write our own code that accomplishes the same:

In [None]:
print(f"Reading cell-gene matrix from {input_matrix}")
adata = anndata.read_mtx(input_matrix).T

print(f"Reading features (genes) from {input_features}")
genes = pd.read_csv(input_features, header=None, sep='\t')
adata.var_names = (anndata.utils.make_index_unique(pd.Index(genes[1]))
                   .rename('gene_symbols')
                   )
adata.var['gene_ids'] = genes[0].values

print(f"Reading barcodes (cells) from {input_barcodes}")
cells = pd.read_csv(input_barcodes, header=None, sep='\t')[0]
adata.obs_names = cells.rename('cell_barcodes')

print(f"\nInfo on created annotated data object:\n{adata}")

## Basic analysis of total and viral UMIs per cell

### Total UMIs per cell
Annotate and plot total UMIs per cell:

In [None]:
adata.obs['total_UMIs'] = numpy.sum(adata.X, axis=1).A1

p = (ggplot(adata.obs, aes('total_UMIs')) +
     geom_histogram(bins=50) +
     theme(figure_size=(3, 1.8)) +
     xlab('total UMIs per cell') +
     ylab('number of cells')
     )
_ = p.draw()

### Viral UMIs per cell
First get the viral transcripts:

In [None]:
print(f"Reading names of viral transcripts from {input_viral_gtf}")
with open(input_viral_gtf) as f:
    viral_genes = [seqrecord.id for seqrecord in BCBio.GFF.parse(f)]
    
print('The viral transcripts are as follows:\n\t' + '\n\t'.join(viral_genes))

assert set(viral_genes) <= set(adata.var_names), "missing some viral genes"

Now annotate and plot total viral UMIs per cell:

In [None]:
adata.obs['viral_UMIs'] = numpy.sum(adata[:, viral_genes].X, axis=1).A1

p = (ggplot(adata.obs, aes('viral_UMIs')) +
     geom_histogram(bins=50) +
     theme(figure_size=(3, 1.8)) +
     xlab('viral UMIs per cell') +
     ylab('number of cells')
     )
_ = p.draw()

Also annotate total cellular (non viral) UMIs per cell:

In [None]:
adata.obs['cellular_UMIs'] = adata.obs['total_UMIs'] - adata.obs['viral_UMIs']

And annotate the fraction of UMIs that are viral:

In [None]:
adata.obs['viral_UMI_frac'] = adata.obs['viral_UMIs'] / adata.obs['total_UMIs']

### Correlation between viral and total/cellular UMIs
Below calculate the correlation between the viral and total and cellular mRNAs:

In [None]:
for x in ['total_UMIs', 'cellular_UMIs']:
    
    corr = adata.obs[x].corr(adata.obs['viral_UMIs'], method='pearson')
    
    p = (ggplot(adata.obs, aes(x, 'viral_UMIs')) +
         geom_point(alpha=0.15) +
         theme(figure_size=(2.2, 2.2)) +
         xlab(x.replace('_', ' ')) +
         ylab('viral UMIs') +
         ggtitle(f"Correlation: {corr:.2f}")
         )
    _ = p.draw()

### Filter cells on total cellular UMIs
We filter cells that have unusually low or high number of UMIs.
We mark as `filtered` cells that fail the filter:

In [None]:
mean_cell_UMIs = adata.obs['cellular_UMIs'].mean()
limits = (mean_cell_UMIs / 2.5, mean_cell_UMIs * 2.5)

print(f"Average of {mean_cell_UMIs:.1f} cellular UMIs / cell.\nMarking as "
      f"filtered if <{limits[0]:.1f} or >{limits[1]:.1f} cellular UMIs.")

adata.obs = (
    adata.obs
    .assign(filtered=lambda x: ((x['cellular_UMIs'] < limits[0]) |
                                (x['cellular_UMIs'] > limits[1])),
            filtered_desc=lambda x: numpy.where(x['filtered'] != True, 'retained',
                                    numpy.where(x['cellular_UMIs'] < limits[0],
                                    'too few cellular UMIs', 'too many cellular UMIs')),
            )
    )

Plot cellular and viral mRNAs in filtered versus retained cells

In [None]:
p = (ggplot(
        adata.obs.assign(ncells=lambda x: x.groupby('filtered_desc')
                                           ['filtered']
                                           .transform('count'),
                         cell_group=lambda x: x['filtered_desc'] + ' (' + 
                                              x['ncells'].astype(str) + ' cells)'
                         ),
        aes('cellular_UMIs', 'viral_UMIs', color='cell_group')) +
     geom_point(alpha=0.15) +
     theme(figure_size=(2.2, 2.2),
           legend_title=element_blank()) +
     xlab('cellular UMIs') +
     ylab('viral UMIs') +
     scale_color_manual(values=cbpalette[1:]) +
     guides(color=guide_legend(override_aes={'alpha': 1}))
     )
_ = p.draw()

From here on out, we will restrict analyses to the non-filtered cells:

### Fraction of UMIs derived from virus
Make a basic plot of the fraction of UMIs derived from virus:

In [None]:
p = (ggplot(adata.obs.query('not filtered'), aes('viral_UMI_frac')) +
     geom_histogram(bins=50) +
     theme(figure_size=(3, 1.8)) +
     xlab('fraction of UMIs from virus') +
     ylab('number of cells')
     )
_ = p.draw()

Make the same plot filtering on cells with at least 1% of their UMIs from virus.
Note that we haven't yet established a rigorous cutoff for which cells are truly infected, but filtering for at least 1% of UMIs from virus probably gets rid of most uninfected cells:

In [None]:
p = (ggplot(adata.obs.query('(not filtered) & (viral_UMI_frac > 0.01)'), 
            aes('viral_UMI_frac')) +
     geom_histogram(bins=50) +
     theme(figure_size=(3, 1.8)) +
     xlab('fraction of UMIs from virus') +
     ylab('number of cells')
     )
_ = p.draw()

### 