# Filter total UMIs
This Python Jupyter notebook uses the total number UMIs to exclude low quality cells and suspected doublets.

Import Python modules:

In [None]:
from IPython.display import display

from dms_variants.constants import CBPALETTE

import numpy

import pandas as pd

import plotnine as p9

import scanpy

Get `snakemake` variables [as described here](https://snakemake.readthedocs.io/en/stable/snakefiles/rules.html#jupyter-notebook-integration):

In [None]:
matrix = snakemake.input.matrix
features = snakemake.input.features
cell_barcodes = snakemake.input.cell_barcodes
total_UMI_deviations = float(snakemake.params.total_UMI_deviations)
expt = snakemake.wildcards.expt
plot = snakemake.output.plot
cell_barcodes_filtered = snakemake.output.cell_barcodes_filtered

Style parameters:

In [None]:
p9.theme_set(p9.theme_classic())

Read the cell-gene matrix into an [AnnData](https://anndata.readthedocs.io/) object:

In [None]:
adata = scanpy.read_mtx(matrix)
adata.var = pd.read_csv(cell_barcodes,
                        names=['cell_barcode'])
adata.obs = pd.read_csv(features,
                        sep='\t',
                        names=['ensemble_id', 'gene', 'feature_type'])

print(f"Read cell-gene matrix of {adata.n_vars} cells and {adata.n_obs} genes")

Now for each cell get the total UMIs.

In [None]:
umi_counts = (
    adata.var
    .assign(total_UMIs=numpy.sum(adata.X, axis=0).A1.astype(int)))
display(umi_counts)

Log transform the UMIs:

In [None]:
umi_counts['log_total_UMIs'] = numpy.log10(umi_counts['total_UMIs'])
display(umi_counts)

In [None]:
total_UMI_histogram = (
    p9.ggplot(
        (umi_counts),
        p9.aes(x='total_UMIs',)) +
    p9.geom_histogram(bins=25) +
    p9.ggtitle('total_UMIs per cell\n'
               f'{expt}') +
    p9.theme(figure_size=(4, 2),
             plot_title=p9.element_text(size=9),
             axis_title=p9.element_text(size=9),
             legend_title=p9.element_text(size=9),
             legend_title_align='center') +
    p9.scale_color_manual(CBPALETTE)
)
display(total_UMI_histogram)

In [None]:
log_total_UMI_histogram = (
    p9.ggplot(
        (umi_counts),
        p9.aes(x='log_total_UMIs',)) +
    p9.geom_histogram(bins=25) +
    p9.ggtitle('log total_UMIs per cell\n'
               f'{expt}') +
    p9.theme(figure_size=(4, 2),
             plot_title=p9.element_text(size=9),
             axis_title=p9.element_text(size=9),
             legend_title=p9.element_text(size=9),
             legend_title_align='center') +
    p9.scale_color_manual(CBPALETTE)
)
display(log_total_UMI_histogram)

Calculate the median and median absolute deviation of the log-transformed values:

In [None]:
def median_absolute_deviation(input_series):
    median = numpy.median(input_series)
    d = numpy.absolute(input_series - median)
    mad = numpy.median(d)
    output_dict = {'median': median,
                   'mad': mad}
    return(output_dict)

log_total_UMIs_stats = median_absolute_deviation(umi_counts['log_total_UMIs'])
print(f'The median value of the log transformed UMI counts is: {log_total_UMIs_stats["median"]}')
print(f'The median absolute deviation of the log transformed UMI counts is: {log_total_UMIs_stats["mad"]}')

Set the limits for filtering at 3 times the median absolute deviation:

In [None]:
log_total_UMIs_stats['lower_limit'] = (log_total_UMIs_stats['median'] - 
                                       (total_UMI_deviations * log_total_UMIs_stats['mad']))
log_total_UMIs_stats['upper_limit'] = (log_total_UMIs_stats['median'] + 
                                       (total_UMI_deviations * log_total_UMIs_stats['mad']))
print(f'The lower limit of the log transformed UMI counts is: {log_total_UMIs_stats["lower_limit"]}')
print(f'The upper imit of the log transformed UMI counts is: {log_total_UMIs_stats["upper_limit"]}')

Annotate excluded cells:

In [None]:
umi_counts['filtered'] = ((umi_counts['log_total_UMIs'] < log_total_UMIs_stats['lower_limit']) |
                          (umi_counts['log_total_UMIs'] > log_total_UMIs_stats['upper_limit']))
display(umi_counts)

Plot number of excluded cells:

In [None]:
outcome_plot = (
    p9.ggplot(
        (umi_counts),
        p9.aes(x='filtered',
               fill='filtered')) +
    p9.geom_bar(stat='count') +
    p9.ggtitle('Number of cells filtered on total UMIs\n'
               f'{expt}') +
    p9.theme(figure_size=(2, 2),
             plot_title=p9.element_text(size=9),
             axis_title=p9.element_text(size=9),
             legend_title=p9.element_text(size=9),
             legend_title_align='center') +
    p9.scale_fill_manual(CBPALETTE)
)
display(outcome_plot)

Plot histogram with limits:

In [None]:
filtered_log_histogram = (
    p9.ggplot(
        (umi_counts),
        p9.aes(x='log_total_UMIs',
               fill='filtered')) +
    p9.geom_histogram(bins=25) +
    p9.geom_vline(xintercept=[log_total_UMIs_stats['lower_limit'],
                              log_total_UMIs_stats['upper_limit']],
                  linetype='dashed',
                  color='green') +
    p9.ggtitle('log total_UMIs per cell\n'
               f'{expt}\n'
               f'limit set at {total_UMI_deviations} deviations from median') +
    p9.theme(figure_size=(4, 2),
             plot_title=p9.element_text(size=9),
             axis_title=p9.element_text(size=9),
             legend_title=p9.element_text(size=9),
             legend_title_align='center') +
    p9.scale_fill_manual(CBPALETTE)
)
display(filtered_log_histogram)

Plot histgrom on linear scale annotated with filter status:

In [None]:
filtered_total_UMI_histgoram = (
    p9.ggplot(
        (umi_counts),
        p9.aes(x='total_UMIs',
               fill='filtered')) +
    p9.geom_histogram(bins=25) +
    p9.ggtitle('total UMIs per cell\n'
               f'{expt}') +
    p9.scale_x_continuous(breaks=list(range(0,max(umi_counts['total_UMIs']),50000))) +
    p9.theme(figure_size=(6, 2),
             plot_title=p9.element_text(size=9),
             axis_title=p9.element_text(size=9),
             legend_title=p9.element_text(size=9),
             legend_title_align='center') +
    p9.scale_fill_manual(CBPALETTE)
)
display(filtered_total_UMI_histgoram)

Export log-scale histogram

In [None]:
print(f"Saving histogram to {plot}")
filtered_log_histogram.save(plot)

Export filtered list of cell barcodes:

In [None]:
cell_barcodes_filtered_list = (
    umi_counts.query('filtered == False')
    ['cell_barcode']
    .reset_index(drop=True)
    )

display(cell_barcodes_filtered_list)

print(f"Saving filtered cell barcodes to {cell_barcodes_filtered}")
cell_barcodes_filtered_list.to_csv(cell_barcodes_filtered,
                                   index=False)