# QC filtering

This notebook is about filtering scATAC seq data within an adata object to archive a good quality peak matrix.
For this task various QC filters are implemented. The latter includes filters related to:

- doublet score
- number of features per barcode
- mean insert size
- promotor enrichment


## QC Settings

In [None]:
# Set the column in adata.obs containing the biological condition to evaluate
sample_column = "sample"

# Choose whether to binarize the X matrix
binarize_mtx = True  # True or False; convert matrix to binary

############################# Filters ##############################

# General filters
filter_chrM = True  # True or False; filtering out chrM
filter_xy = True    # True or False; filtering out chrX and chrY

# Doublet removal with scrublet
filter_doublets = True  # whether to remove doublets using scrublet (True) or to skip doublet calculation (False)
use_native_scrublet = True
threads = 2
doublet_threshold = 0.2
use_condition_column = False
condition_doublet_removal = None

# Decide whether to estimate thresholds individual per condition (False) or globally (True)
global_threshold = False

############################# Set default filter thresholds ###############################

# This will be applied to all samples - the thresholds can be changed manually when plotted 
use_default_thresholds = True  # set to False to ignore default_thresholds
default_thresholds = {
                      'n_features': {'min': 100, 'max': 5000},
                      'fld_score_cwt': {'min': 0.01, 'max': 0.6}
                      # add additional columns if needed
                     }

----------------------

## Loading packages and setup

In [None]:
# sctoolbox modules
import sctoolbox.utils as utils
import sctoolbox.tools as tools
import sctoolbox.plotting as pl
import sctoolbox.tools.qc_filter as qc

import scanpy as sc
import matplotlib.pyplot as plt
import episcanpy as epi
import numpy as np
import scrublet as scr

utils.settings_from_config("config.yaml", key="02")

## Load anndata 

In [None]:
adata = utils.load_h5ad("anndata_1.h5ad")
display(adata)

-------------

## Show STARsolo quality (optional)

If the data was mapped using STARsolo, use the parameter to set the path to the STARsolo runs and plot quality measures across runs. The path must be a folder, e.g. "path/to/starsolo_output", which contains folders per condition e.g. "cond1", "cond2", etc.

In [None]:
quant_folder = ""

In [None]:
if quant_folder != "":
    _ = pl.plot_starsolo_quality(quant_folder, save="starsolo_quality.pdf")

-----------

## Calculate QC metrics

### 1. Removing empty cells and features


In [None]:
adata.shape

In [None]:
adata = adata[adata.X.sum(axis=1) > 0]
adata = adata[:, adata.X.sum(axis=0) > 0]

In [None]:
adata.shape

In [None]:
adata.obs

### 2. Binarize matrix

In [None]:
adata.layers["raw"] = adata.X.copy()
if binarize_mtx:
    epi.pp.binarize(adata)

### 3. Filtering out chrX, chrY and chrM

In [None]:
if filter_chrM:
    non_m = [name for name in adata.var_names if not name.startswith('chrM')]  # remove chrM
    adata = adata[:, non_m]

In [None]:
if filter_xy:
    non_xy = [name for name in adata.var_names if not name.startswith('chrY') | name.startswith('chrX')]
    adata = adata[:, non_xy]

### 4. Calculate metrics

In [None]:
adata = tools.calculate_qc_metrics(adata, var_type='features')
utils.add_uns_info(adata, "obs_metrics", ["n_features", "log1p_total_counts"], how="append")

display(adata)

### 5. Doublet removal

In [None]:
 if filter_doublets:
        
    if use_condition_column:
        condition_doublet_removal = condition_column
    
    if use_native_scrublet:
        # TODO: Implement Wrapper function for sctoolbox
        adata.obs['doublet_score'] = float('NaN')
        adata.obs['predicted_doublet'] = None

        sample_dict = {}
        for sample in adata.obs[sample_column].unique():
            print('Run scrublet for condition: ' + sample)
            X = adata.X[adata.obs[sample_column] == sample]
            scrub = scr.Scrublet(X)
            doublet_scores, predicted_doublets = scrub.scrub_doublets()
            adata.obs.loc[adata.obs[sample_column]==sample, 'doublet_score'] = doublet_scores
            adata.obs.loc[adata.obs[sample_column]==sample, 'predicted_doublet'] = predicted_doublets
            
        adata.obs['predicted_doublet'] = adata.obs['predicted_doublet'].astype(bool)
        
    else:
        qc.estimate_doublets(adata, groupby=condition_doublet_removal, threads=threads, threshold=doublet_threshold)
    
    #Remove the duplicates from adata
    tools.filter_cells(adata, "predicted_doublet", remove_bool=True)


In [None]:
# remove empty features
adata = adata[adata.X.sum(axis=1) > 0]
adata = adata[:, adata.X.sum(axis=0) > 0]

### 6. Visualize global quality features

In [None]:
#plotting coverage with episcanpy sets the global sns style, which affects all other plots as well.
#we should provide plotting tools within sctoolbox.plotting for these

#epi.pp.coverage_cells(adata, binary=True, log=False, bins=50)
#epi.pp.coverage_cells(adata, binary=True, log=10, bins=50)

#epi.pp.coverage_features(adata, binary=True, log=False, bins=50)
#epi.pp.coverage_features(adata, binary=True, log=10, bins=50)

---------------------

## Cell filtering

### Get thresholds dict

In [None]:
groupby = sample_column if global_threshold is False else None
initial_thresholds = tools.get_thresholds_wrapper(adata, default_thresholds, 
                                                  only_automatic_thresholds=False, groupby=groupby)
obs_columns = list(initial_thresholds.keys())
tools.thresholds_as_table(initial_thresholds)

### Plot thresholds

In [None]:
%matplotlib widget
%bgcolor PowderBlue

#Plot violins and sliders
obs_figure, obs_slider_dict = pl.quality_violin(adata, columns=obs_columns,
                                                groupby=sample_column,
                                                which="obs",
                                                thresholds=initial_thresholds,
                                                global_threshold=global_threshold,
                                                title="Cell quality control (before)",
                                                save="cell_filtering.png")
obs_figure

In [None]:
# Get final thresholds
final_thresholds = pl.get_slider_thresholds(obs_slider_dict)
tools.thresholds_as_table(final_thresholds) # show thresholds

In [None]:
# Show pairwise comparisons of column values w/ thresholds (mean values in case thresholds are grouped)
%matplotlib inline
plt.close()  # close previous figure
if len(final_thresholds) > 1:
    mean_thresholds = qc.get_mean_thresholds(final_thresholds)
    _ = pl.pairwise_scatter(adata.obs, obs_columns, thresholds=mean_thresholds, save="cell_filtering_scatter.pdf")

### Filter adata

In [None]:
tools.apply_qc_thresholds(adata, final_thresholds, groupby=groupby)

# remove empty features after cell filtering
adata = adata[:, adata.X.sum(axis=0) > 0]

### Save plots

In [None]:
%matplotlib inline 

#Plot violins and sliders
figure, slider_dict = pl.quality_violin(adata, columns=obs_columns,
                                        groupby=sample_column,
                                        which="obs", ncols=3,
                                        global_threshold = global_threshold,
                                        title="Cell quality control (after)",
                                        save="cell_filtering_final.pdf")
figure 

-------------

## Save anndata

In [None]:
#Saving the data
adata_output = "anndata_2.h5ad"
utils.save_h5ad(adata, adata_output)