# Processing and Filtering UMI Expression Matrices

Note that we needed 128 GB of RAM to run this notebook and that the notebook was run the environment created by scanpy-default-mamba.yml

**ALSO NOTE: If you change any samples in this process, it will change the barcodes selected due to the method for setting seeds (we set seeds prior to the whole sampling not prior to each sample, so any changes in the samples, will change downstream results). Similarly, you cannot only run a subsample of the data, you must run the whole thing to prevent similar issues**

Lastly, all paths have been scrubbed from this notebook, so please insert relevant paths to files if you wish to rerun (search for '#INSERT HERE')

In [1]:
##Import necessary packages
import numpy as np
import pandas as pd
import scanpy as sc
import scanpy.external as sce
import seaborn as sns
import anndata as ad
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.patches as mpatches
import re
import os
import sys
import time
import warnings
from tqdm import tqdm
from scipy.sparse import csr_matrix, issparse
from collections.abc import Iterable 
import gc

In [2]:
##Specify Analysis Parameters
GENE_PCT=0.1
SEED=42
N_PLATFORMS_DEVIANT = 1
DOUBLET_THRESH=0.25
FEATURE_SELECT_GENE_N = 2000
run_scrublet=False
filter_cells=False

In [3]:
parent_dir=##INSERT HERE
figure_dir=##INSERT HERE
supplement_dir=os.path.join(figure_dir, 'supplement')
supplement_fig_dir=os.path.join(supplement_dir, 'figures')
supplement_table_dir=os.path.join(supplement_dir, 'tables')

##make figure directories if they doesn't exist
os.makedirs(parent_dir, exist_ok=True)
os.makedirs(figure_dir, exist_ok=True)
os.makedirs(supplement_fig_dir, exist_ok=True)
os.makedirs(supplement_table_dir, exist_ok=True)

In [4]:
##LOAD CUSTOM FUNCTIONS FROM FUNCTIONS DIRECTORY IN REPO

##Note: These function files load their own necessary packages, so we don't need to load these dependencies here

##define function location
functions_dir = ##INSERT HERE

##if the directory is not in path, add
if functions_dir not in sys.path:
    sys.path.append(functions_dir)

##import functions
from get_statistics_functions import load_AnnData, load_raw_AnnData ##import all relevant processing functions
import highly_deviant_genes as hdg ##import all relevant highly deviant gene detection functions

In [5]:
## Define File Names
data_dir=##INSERT HERE

##filtered files (for barcode matching and labeling of called cells)
filtered_files = {'10X_3-rep1': ##INSERT HERE,
                  '10X_3-rep2': ##INSERT HERE,
                  '10X_5-rep1': ##INSERT HERE,
                  '10X_5-rep2': ##INSERT HERE,
                  '10X_FRP-rep1': ##INSERT HERE, 
                  '10X_FRP-rep2': ##INSERT HERE, 
                  'BD-rep1': ##INSERT HERE, 
                  'BD-rep2': ##INSERT HERE, 
                  'Fluent-rep1': ##INSERT HERE,
                  'Fluent-rep2': ##INSERT HERE,
                  'Fluent-rep3':  ##INSERT HERE,
                  'Honeycomb-rep1':  ##INSERT HERE,
                  'Honeycomb-rep2':  ##INSERT HERE,
                  'Singleron-rep1':  ##INSERT HERE,
                  'Singleron-rep2':  ##INSERT HERE,
                  'Scipio-rep1':  ##INSERT HERE,
                  'Scipio-rep2':  ##INSERT HERE,
                  'Parse-rep1':  ##INSERT HERE,
                  'Scale-rep1':  ##INSERT HERE,
                  'Broad-Reference':  ##INSERT HERE,
                 }

##get full paths
filtered_files = {method:os.path.join(data_dir, file) for method, file in filtered_files.items()}

##define methods
methods = filtered_files.keys()

##get tags
platform_tags, rep_tags = zip(*[mr.split("-") for mr in methods]) #split on the hyphen and separate into two lists

Get some summary stats for all platforms to make downsampling easier

#### <center> Process Data Separately (with Reference), so we can compare cell annotation across platforms

Note that we will add the Broad Institutue Multiplatform Dataset (converted from Seurat) as the 'gold standard' for comparison AND for integration and label transfer (https://singlecell.broadinstitute.org/single_cell/study/SCP424/single-cell-comparison-pbmc-data and https://github.com/satijalab/seurat-data/). We also split the Broad reference to get a validation set (same size as other platforms) that we can use to assess how good our cell annotation methods are later. 

To make sure we are being consistent with our analysis, we will first subset our validation dataset to 10000 cells to approximate the cell number target we used for each platform. We will then identify "doublets" with scanpy's external link to Scrublet which simulates doublets by randomly combining two observations into one and comparing the observed results with the simulated results. We manually set a threshold of 0.25 for the simulated doublets because the automatic caller gives poor results and 0.25 seems to be the most consistent across platforms. 

To be fair for downstream analysis, we will downsample again to the smallest number of non-doublet cells for all platforms (except the reference) and process from there.

Lastly, we will find deviant genes per platform (top 2000) that can be used for PCA and cell annotation later

Note that we perform batch normalization on the combined reference-platform anndata object so that we normalize according to each library size rather than overall

In [6]:
##Setup Directory to save anndata objects
anndata_object_dir = os.path.join(parent_dir, 'anndata_objects')
os.makedirs(anndata_object_dir, exist_ok=True)

##Load in all data and perform scrublet

##Set filtering params for pre-scrublet filtering
##Cells with less than 5 genes probably aren't cells and genes not found in at least 3 cells probably aren't useful
##Based on ATAC-seq guidelines
min_genes=5
min_cells=3
n_doublets={}
doublet_rates={}
doublet_dir = os.path.join(parent_dir, 'doublet_data')
os.makedirs(doublet_dir, exist_ok=True)

if run_scrublet:

    ##initialize data lists and dicts
    adata_all = {}

    ##load data and process for cell selection
    for method, file in filtered_files.items():
        
        print(f"Method: {method}")

        ## load adata file depending on filename
        if re.match(r'.*10X.*', file):
            adata = load_raw_AnnData(file)
        else:
            adata = load_AnnData(file)

        ##create stats dictionary in data to store n_doublet and doublet rate
        if 'stats' not in adata.uns:
            adata.uns['stats'] = {}

        ##update obs names if and only if it hasn't been tagged with method tag
        if method not in adata.obs_names[0]:
            adata.obs_names = [f'{bc}_{method}' for bc in adata.obs_names]

        # filter barcodes with few genes genes or counts so they don't affect scrublet
        sc.pp.filter_cells(adata, min_genes = min_genes) ##5 is default because used in scATAQ-seq paper

        ##if reference, split into reference and validation
        if method == 'Broad-Reference':
            
            ##for some reason the raw data contains '_index' as a column which gives an error when saving, so fix with this line
            adata.__dict__['_raw'].__dict__['_var'] = adata.__dict__['_raw'].__dict__['_var'].rename(columns={'_index': 'features'})
            
            
            cell_index = adata.obs_names ##all indices of ref
            N_test = 10000 ##initial target for all platforms
            np.random.seed(SEED) ##reset seed to make sure it will Always be the same split even if done outside of loop
            test_ind = np.random.choice(cell_index, N_test, replace = False) ##select test cell indices to be same size as other platforms
            train_ind = np.setdiff1d(cell_index, test_ind) ##get train indices as those not in test

            ##subset test data as adata_val
            adata_val = adata[test_ind]

            ##subset train data as adata
            adata = adata[train_ind]

            ##Filter genes so they don't affect scrblet
            sc.pp.filter_genes(adata_val, min_cells = min_cells)##3 is default because used in scATAQ-seq paper

            ##Identify and remove doublets with scrublet (Manually setting doublet threshold to 0.25 because it gives best results across platforms)
            sce.pp.scrublet(adata_val, threshold=DOUBLET_THRESH) ##identify doublet with threshold
            sce.pl.scrublet_score_distribution(adata_val)
            adata_val.uns['stats']['n_doublet'] = np.sum(adata_val.obs['predicted_doublet'] == True) ##get number of doublets identified
            adata_val.uns['stats']['doublet_rate'] = adata_val.uns['stats']['n_doublet']/adata_val.shape[0] ##get doublet rate
            
            ##Add platform info
            adata_val.obs['platform'] = 'Broad-Reference-Val'
            adata_val.obs['platform_type'] = 'reference_Broad_Val'
            
            ##save adata object with doublet scores prior to doublet removal
            output_file = os.path.join(anndata_object_dir, f'Broad-Reference-Val_filtered_withdoublets.h5ad')
            adata_val.write_h5ad(output_file)
            
            ##remove doublets
            adata_val = adata_val[adata_val.obs['predicted_doublet'] == False] ##keep only singlets
            
            n_doublets['Broad-Reference-Val'] = adata_val.uns['stats']['n_doublet']
            doublet_rates['Broad-Reference-Val'] = adata_val.uns['stats']['doublet_rate']


            ##Add "-Val" tag to the end of obs_names so that it is tagged with "Broad-Reference-Val" instead of "Broad-Reference"
            adata_val.obs_names = [f'{bc}-Val' for bc in adata_val.obs_names]

            

        #Filter genes so they don't affect scrblet
        sc.pp.filter_genes(adata, min_cells = min_cells) ##3 is default because used in scATAQ-seq paper

        ##Identify and remove doublets with scrublet (Manually setting doublet threshold to 0.25 because it gives best results across platforms)
        sce.pp.scrublet(adata, threshold=DOUBLET_THRESH) ##identify doublet with threshold
        sce.pl.scrublet_score_distribution(adata)
        adata.uns['stats']['n_doublet'] = np.sum(adata.obs['predicted_doublet'] == True) ##get number of doublets identified
        adata.uns['stats']['doublet_rate'] = adata.uns['stats']['n_doublet']/adata.shape[0] ##get doublet rate

        ## Add platform names to adata.obs
        adata.obs['platform'] = method

        # Add platform type
        if method in ['BD-rep1','BD-rep2', 'Honeycomb-rep1','Honeycomb-rep2','Singleron-rep1', 'Singleron-rep2']:
            adata.obs['platform_type'] = 'well-based'
        elif method in ['Scale-rep1', 'Scale-rep2', 'Parse-rep1', 'Parse-rep2']:
            adata.obs['platform_type'] = 'combinatorial'
        elif method in ['10X_3-rep1', '10X_3-rep2', '10X_5-rep1', '10X_5-rep2','10X_FRP-rep1','10X_FRP-rep2']:
            adata.obs['platform_type'] = 'droplet_GEMs'
        elif method in ['Fluent-rep1', 'Fluent-rep2', 'Fluent-rep3']:
            adata.obs['platform_type'] = 'droplet_PIPs'
        elif method in ['Scipio-rep1', 'Scipio-rep2']:
            adata.obs['platform_type'] = 'hydrogel'
        elif method == 'Broad-Reference':
            adata.obs['platform_type'] = "reference_Broad"
            
        ##save adata object with doublet scores prior to doublet removal
        output_file = os.path.join(anndata_object_dir, f'{method}_filtered_withdoublets.h5ad')
        adata.write_h5ad(output_file)
            
        ##remove doublets for downstream steps
        adata = adata[adata.obs['predicted_doublet'] == False] ##keep only singlets
        
        n_doublets[method] = adata.uns['stats']['n_doublet']
        doublet_rates[method] = adata.uns['stats']['doublet_rate']

        ##store all data and all genes
        adata_all[method] = adata

        del adata

    ##add adata_val to adata_all
    adata_all['Broad-Reference-Val'] = adata_val

    ##remove for memory
    del adata_val
    
    doublet_df = pd.DataFrame([n_doublets, doublet_rates]).T
    doublet_df = doublet_df.reset_index()
    doublet_df.columns = ['method', 'n_doublets', 'doublet_rate']
    doublet_df.to_csv(os.path.join(doublet_dir, 'doublet_summary_df.csv'), index=False)
    
    ##also save to supplement directory for paper
    doublet_df.to_csv(os.path.join(supplement_table_dir, 'Table19_Doublet_Rates.csv'), index=False)
    
else:
    ##Filtered and Processed Data objects
    filtered_doublet_files = {
     '10X_3-rep1':  ##INSERT HERE,
     '10X_3-rep2':  ##INSERT HERE,
     '10X_5-rep1':  ##INSERT HERE,
     '10X_5-rep2':  ##INSERT HERE,
     '10X_FRP-rep1':  ##INSERT HERE,
     '10X_FRP-rep2':  ##INSERT HERE,
     'BD-rep1':  ##INSERT HERE,
     'BD-rep2':  ##INSERT HERE,
     'Fluent-rep1':  ##INSERT HERE,
     'Fluent-rep2':  ##INSERT HERE,
     'Fluent-rep3':  ##INSERT HERE,
     'Honeycomb-rep1':  ##INSERT HERE,
     'Honeycomb-rep2':  ##INSERT HERE,
     'Parse-rep1':  ##INSERT HERE,
     'Scale-rep1':  ##INSERT HERE,
     'Scipio-rep1':  ##INSERT HERE,
     'Scipio-rep2':  ##INSERT HERE,
     'Singleron-rep1':  ##INSERT HERE,
     'Singleron-rep2':  ##INSERT HERE,
     'Broad-Reference-Val':  ##INSERT HERE,
     'Broad-Reference': ##INSERT HERE
    }
    
    filtered_doublet_files = {method:os.path.join(anndata_object_dir, file) for method, file in filtered_doublet_files.items()}
    
    ##load
    adata_all = {}
    for method, file in filtered_doublet_files.items():
        
        ## load adata files
        print(f"Loading {method} ...")
        adata_all[method] = sc.read_h5ad(file)
    
    
    
    doublet_df = pd.read_csv(os.path.join(doublet_dir, 'doublet_summary_df.csv'))

Loading 10X_3-rep1 ...
Loading 10X_3-rep2 ...
Loading 10X_5-rep1 ...
Loading 10X_5-rep2 ...
Loading 10X_FRP-rep1 ...
Loading 10X_FRP-rep2 ...
Loading BD-rep1 ...
Loading BD-rep2 ...
Loading Fluent-rep1 ...
Loading Fluent-rep2 ...
Loading Fluent-rep3 ...
Loading Honeycomb-rep1 ...
Loading Honeycomb-rep2 ...
Loading Parse-rep2 ...
Loading Scale-rep1 ...
Loading Scipio-rep1 ...
Loading Scipio-rep2 ...
Loading Singleron-rep1 ...
Loading Singleron-rep2 ...
Loading Broad-Reference-Val ...
Loading Broad-Reference ...


Now we can subsample to the smallest size from all of them (excluding scipio, since it can only do 5000 cells max), and refilter for genes and cells that don't meet our cutoff (genes need to be seen in at least 3 cells and cells need to have at least 5 genes)

In [7]:
if filter_cells:
    
    ##downsample size (min cells in experiments excluding scipio since it can only process about 5000 cells)
    filter1_sizes = {f'{method}':data.shape[0] for method, data in adata_all.items() if 'scipio' not in method.lower()}
    n_downsample = min(filter1_sizes.values())
    print(f'Lowest Sample Size ({min(filter1_sizes)}: {n_downsample})')

    ##set seed for downsampling
    np.random.seed(SEED)

    ##initialize data lists and dicts
    selected_barcodes={} ##downsampled barcodes that we can reuse (since we've tagged cell barcodes with method tag)

    ##Loop through adata dictionary to filter and downsample
    for method, adata in adata_all.items():

        ##copy adata to prevent in place modifications
        adata = adata.copy()

        ##filter barcodes with few genes genes or counts since they are not useful for this method
        ##min_genes and min_cells already set (5 and 4, respectively, but we will do this step again)
        sc.pp.filter_cells(adata, min_genes = min_genes) ##5 is default because used in scATAQ-seq paper

        ## downsample to lowest cell number (EXCEPT FOR REFERENCE AND SCIPIO; don't want to lose training data or excess cells due to scipio)
        if method not in ['Broad-Reference', 'Scipio-rep1', 'Scipio-rep2']:
            sel_barcodes = np.random.choice(adata.obs_names, n_downsample, replace=False)
            selected_barcodes[method]=sel_barcodes
            adata = adata[sel_barcodes]
        else: ##if scipio, use all barcodes since the platform can only handle about 5000 cells max 
            selected_barcodes[method]=adata.obs_names

        ## filter genes not in at least min_cells genes (filter AFTER downsample to make sure we check all available cells)
        sc.pp.filter_genes(adata, min_cells = min_cells) ##3 is default because used in scATAQ-seq paper

        ##subset reference and validation metadata for annotation later
        if method == 'Broad-Reference':
            reference_metadata = adata.obs
        elif method == 'Broad-Reference-Val':
            validation_metadata = adata.obs ##save metadata for evaluation metrics
        
    
        ##Add .obs column for the Broad-Reference batch_key
        if method not in ['Broad-Reference', 'Broad-Reference-Val']:
            adata.obs['platform_broadincluded'] = adata.obs['platform']
        else:
            adata.obs['platform_broadincluded'] = [f"{platform}_{broad}" for platform, broad in zip(adata.obs['platform'], adata.obs['Method'])]

        ##update raw with newly filtered info (still not normalized X) also helps overcome Broad-Reference data save
        adata.raw = adata.copy()

        print(f"Method: {method}")

        output_file = os.path.join(anndata_object_dir, f'{method}_filtered_nodoublets.h5ad')
        adata.write_h5ad(output_file)
        
        ##update adata with downsampled data
        adata_all[method] = adata
        
        ##delete for memory
        del adata
        
else:
    
    ##Filtered and Processed Data objects
    filtered_files =  {
     '10X_3-rep1':  ##INSERT HERE,
     '10X_3-rep2':  ##INSERT HERE,
     '10X_5-rep1':  ##INSERT HERE,
     '10X_5-rep2':  ##INSERT HERE,
     '10X_FRP-rep1':  ##INSERT HERE,
     '10X_FRP-rep2':  ##INSERT HERE,
     'BD-rep1':  ##INSERT HERE,
     'BD-rep2':  ##INSERT HERE,
     'Fluent-rep1':  ##INSERT HERE,
     'Fluent-rep2':  ##INSERT HERE,
     'Fluent-rep3':  ##INSERT HERE,
     'Honeycomb-rep1':  ##INSERT HERE,
     'Honeycomb-rep2':  ##INSERT HERE,
     'Parse-rep1':  ##INSERT HERE,
     'Scale-rep1':  ##INSERT HERE,
     'Scipio-rep1':  ##INSERT HERE,
     'Scipio-rep2':  ##INSERT HERE,
     'Singleron-rep1':  ##INSERT HERE,
     'Singleron-rep2':  ##INSERT HERE,
     'Broad-Reference-Val':  ##INSERT HERE,
     'Broad-Reference': ##INSERT HERE
    }
    
    filtered_files = {method:os.path.join(anndata_object_dir, file) for method, file in filtered_files.items()}
    
    ##initialize data lists and dicts
    adata_all = {}
    selected_barcodes={} ##downsampled barcodes that we can reuse (since we've tagged cell barcodes with method tag)
    
    for method, file in filtered_files.items():
        
        ## load adata files
        print(f"Loading {method} ...")
        adata_all[method] = sc.read_h5ad(file)
        
        ## extract filtered and downsampled barcodes
        selected_barcodes[method] = adata_all[method].obs_names
        
        ##subset reference and validation metadata for annotation later
        if method == 'Broad-Reference':
            reference_metadata = adata_all[method].obs
        elif method == 'Broad-Reference-Val':
            validation_metadata = adata_all[method].obs ##save metadata for evaluation metrics

Loading 10X_3-rep1 ...
Loading 10X_3-rep2 ...
Loading 10X_5-rep1 ...
Loading 10X_5-rep2 ...
Loading 10X_FRP-rep1 ...
Loading 10X_FRP-rep2 ...
Loading BD-rep1 ...
Loading BD-rep2 ...
Loading Fluent-rep1 ...
Loading Fluent-rep2 ...
Loading Fluent-rep3 ...
Loading Honeycomb-rep1 ...
Loading Honeycomb-rep2 ...
Loading Parse-rep2 ...
Loading Scale-rep1 ...
Loading Scipio-rep1 ...
Loading Scipio-rep2 ...
Loading Singleron-rep1 ...
Loading Singleron-rep2 ...
Loading Broad-Reference-Val ...
Loading Broad-Reference ...


Show doublet rate for each

In [8]:
for method, adata in adata_all.items():
    print(f"{method} Doublet Stats:")
    print(f"{adata.uns['stats']}")
    print("")
    print("____________________________________________________________")

10X_3-rep1 Doublet Stats:
{'doublet_rate': 0.03288773216609488, 'n_doublet': 278}

____________________________________________________________
10X_3-rep2 Doublet Stats:
{'doublet_rate': 0.027156549520766772, 'n_doublet': 306}

____________________________________________________________
10X_5-rep1 Doublet Stats:
{'doublet_rate': 0.03287995269071555, 'n_doublet': 278}

____________________________________________________________
10X_5-rep2 Doublet Stats:
{'doublet_rate': 0.02876459741856177, 'n_doublet': 234}

____________________________________________________________
10X_FRP-rep1 Doublet Stats:
{'doublet_rate': 0.03654517254749903, 'n_doublet': 377}

____________________________________________________________
10X_FRP-rep2 Doublet Stats:
{'doublet_rate': 0.03702783300198807, 'n_doublet': 298}

____________________________________________________________
BD-rep1 Doublet Stats:
{'doublet_rate': 0.02417718168181417, 'n_doublet': 274}

___________________________________________________

In [9]:
doublet_df

Unnamed: 0,method,n_doublets,doublet_rate
0,10X_3-rep1,278.0,0.032888
1,10X_3-rep2,306.0,0.027157
2,10X_5-rep1,278.0,0.03288
3,10X_5-rep2,234.0,0.028765
4,10X_FRP-rep1,377.0,0.036545
5,10X_FRP-rep2,298.0,0.037028
6,BD-rep1,274.0,0.024177
7,BD-rep2,559.0,0.033193
8,Fluent-rep1,94.0,0.008294
9,Fluent-rep2,79.0,0.007665


Check to make sure all downsampled sizes are the same except scipio and reference

In [9]:
for k, v in adata_all.items():
    print(f'{k}: {v[v.obs["platform"] == k].shape[0]}')

10X_3-rep1: 7750
10X_3-rep2: 7750
10X_5-rep1: 7750
10X_5-rep2: 7750
10X_FRP-rep1: 7750
10X_FRP-rep2: 7750
BD-rep1: 7750
BD-rep2: 7750
Fluent-rep1: 7750
Fluent-rep2: 7750
Fluent-rep3: 7750
Honeycomb-rep1: 7750
Honeycomb-rep2: 7750
Singleron-rep1: 7750
Singleron-rep2: 7750
Scipio-rep1: 3457
Scipio-rep2: 4425
Parse-rep2: 7750
Scale-rep1: 7750
Broad-Reference: 20840
Broad-Reference-Val: 7750


Add .obs column for the Broad-Reference batch_key

Save selected barcodes as CSV

In [1]:
##convert to selected barcodes dictionary to list of tuples (platform, barcode) for each barcode
##Then convert this to dataframe using the first tuple element as the platform and the second as the tagged barcode
selected_barcodes_df = pd.DataFrame([(platform, barcode) for platform, barcodes in selected_barcodes.items() for barcode in barcodes], 
                                    columns=['platform', 'selected_barcodes_tagged'])

# extract original barcodes using regex
selected_barcodes_df['original_barcodes'] =  selected_barcodes_df.apply(lambda row: row.selected_barcodes_tagged.rsplit('_'+row.platform, 1)[0], axis=1)

##save selected barcodes as CSV with original and new barcodes included
anndata_object_dir = os.path.join(parent_dir, 'anndata_objects')
os.makedirs(anndata_object_dir, exist_ok=True)
barcode_output_file = os.path.join(anndata_object_dir, 'selected_barcode_per_platform_30k.csv')
print(f"Saving selected barcodes to {barcode_output_file}")
selected_barcodes_df.to_csv(barcode_output_file, index = False)

In [12]:
def FindDeviantGenes_singlebatch(adata, layer='unnormalized', top_pct=0.15, top_n=None, verbose=True):
    
    """
    Input:
    adata: AnnData object 
    layer: Layer of adata object to use for finding deviant genes. Defaults to 'unnormalized' since deviance should only be calculated on unnormalized data 
    top_pct: percent of top deviant genes to select. Defaults to 15%
    top_n: number of top deviant genes to select (instead of top_pct; defaults to None, but if you set top_n, it overrides pct)
    
    Output:
    adata object with highly deviant genes identifed and set as highly variable genes so that we can use only this subset
    for pca
    """
    
    ##replace adata with a copy so we don't write in place
    adata = adata.copy()
    
    ##check if top_pct or top_n (if top_n provided, it overrides top_pct)
    if top_n is not None:
        top_pct = None
        if verbose is True: print(f'Detecting top {top_n} deviant genes deviant features ...\n')
    elif top_pct is not None:
        if verbose is True: print(f'Detecting top {top_pct*100}% deviant genes deviant features ...\n')
    else:
        raise ValueError("Either top_pct or top_n must be provided; both cannot be None.")
    
    if layer != 'unnormalized':
        if verbose is True: print("WARNING: Gene expression matrix used for deviant gene detection should be unnormalized counts")
        
    elif hasattr(adata, 'layers') and adata.layers.get('unnormalized') is None:
        raise ValueError("""
        Anndata object must either have an 'unnormalized' layer 
        OR you must specify the layer to use for deviant gene detection (specifying 'None' will default to adata.X).
        Note that the gene expression matrix used for deviant gene detection should be unnormalized/raw counts
        """)
        
    ##Subset the layer we will use (defaults to unnormalized)
    adata_X = adata.layers[layer] if layer is not None else adata.X
     
    ##detect deviance based on top_n
    if top_n is not None:
        ##identify highly deviant genes (top N) using a binomial deviance calculation
        deviance = hdg.highly_deviant_genes(X = adata_X, top_n = top_n, 
                                            family = 'binomial', gene_names =adata.var.index.to_list())

    ##detect deviance based on top_pct
    else:
        ##identify highly deviant genes (top %) using a binomial deviance calculation
        deviance = hdg.highly_deviant_genes(X = adata_X, top_pct = top_pct, 
                                            family = 'binomial', gene_names =adata.var.index.to_list())
        
    del adata_X

    ##Finish deviance processing
    adata.var['singlebatch_deviance'] = deviance['deviance'] ##get deviance scores
    adata.var['highly_deviant'] = deviance['highly_deviant'] ##get deviance labels

    ##Total unfiltered genes and selected genes
    n_deviant = adata.var['highly_deviant'].sum()
    if verbose is True: print(f"Total features in data: {len(adata.var)}")

    ##print for number of deviant genes
    if verbose is True:
        if top_n is not None:
            print(f'Selected top {top_n} deviant genes')
        else: 
            print(f"Top {top_pct*100}% deviant genes feature number: {n_deviant}")

    ## set highly_variable as highly_deviant so we can use in PCA
    adata.var["highly_variable"] = adata.var["highly_deviant"]

    ##subset the data to highly deviant genes, rank by deviance, and create results dictionary
    deviant_df = adata.var[adata.var['highly_deviant'] == True][['singlebatch_deviance']] ##subset highly deviant genes and scores
    deviant_df = deviant_df.reset_index().rename(columns={'index':'gene_names'}) ##turn index to a column for gene names
    deviant_df.sort_values(by='singlebatch_deviance', ascending=False, inplace=True) ##rank according to deviance scores
    deviant_df['rank'] = range(1, len(deviant_df) + 1) ##give a rank

    ##add stats and res back to adata
    adata.uns['deviant_genes'] = deviant_df
    return adata

def FindDeviantGenes(adata, layer='unnormalized', top_pct=0.15, top_n=None, batch_key=None, min_cells=3):
    
    """
    Input:
    adata: AnnData object 
    layer: Layer of adata object to use for finding deviant genes. Defaults to 'unnormalized' since deviance should only be calculated on unnormalized data 
    top_pct: percent of top deviant genes to select. Defaults to 15%
    top_n: number of top deviant genes to select (instead of top_pct; defaults to None, but if you set top_n, it overrides pct)
    batch_key: string indicating adata.obs column indicating sample batches 
               (if supplied, returns top highly deviant genes of intersection of genes across batches)
    min_cells: The minimum number of cells that a gene must be in for a subset (within batch) to count
    
    Output:
    adata object with highly deviant genes identifed and set as highly variable genes so that we can use only this subset
    for pca
    """
    
    ##replace adata with a copy so we don't write in place
    adata = adata.copy()
        
    ##Process per batch if batch_key is provided
    if batch_key is not None:
        
        ##give warniing if layer is not unnormalized
        if layer != 'unnormalized':
            print("WARNING: Gene expression matrix used for deviant gene detection should be unnormalized counts")
            
        batches = np.unique(adata.obs[batch_key])
        batch_values = adata.obs[batch_key].values
        
        adata_dict = {}
        for batch in batches:
            
            ##subset batch
            adata_batch = adata[batch_values == batch]
            
            ##since the data is joined, the subsets might have genes with no counts, so we must remove those for the deviance calc to work
            sc.pp.filter_genes(adata_batch, min_cells = min_cells) ##3 is default because used in scATAQ-seq paper
            
            ##get deviance per batch
            adata_dict[batch] = FindDeviantGenes_singlebatch(adata_batch, layer=layer, top_pct=top_pct, top_n=top_n, verbose=False)
            
            del adata_batch
            
        ##Rank genes by how many batches they were identified as deviant in and select top_perc or top_N
        ##Note that we only do this on those genes that are present in EACH batch (Intersection of Genes)
        ##Ties are broken by the median deviance across all batches
        
        ## Calculate the intersection of indices
        gene_intersect = set.intersection(*map(set, [a.var_names for a in adata_dict.values()]))
        
        ## Identify number of genes we need
        if top_n is not None:
            top_pct = None
            n_deviant = top_n
            print(f'Detecting top {top_n} deviant genes deviant features ...\n')
        elif top_pct is not None:
            print(f'Detecting top {top_pct*100}% deviant genes deviant features (from the intersection) ...\n')
            n_deviant = int(np.round(top_pct * len(gene_intersect))) ##n genes that will give top_pct of genes
        else:
            raise ValueError("Either top_pct or top_n must be provided; both cannot be None.")
            
        ##print some messages   
        print(f"Total features in data: {len(adata.var)}")
        print(f"Total intersecting features in data: {len(gene_intersect)}")

        ## Subset datasets to intersecting genes while preserving original order
        subset_datasets = []
        for batch_name, a in adata_dict.items():
            subset_indices = [index for index in a.var_names if index in gene_intersect]
            subset_data = a.var.loc[subset_indices]
            subset_datasets.append(subset_data[['singlebatch_deviance', 'highly_deviant']].to_numpy()) ##select only deviance (number) and highly_deviant (T/F) cols
            del subset_data
            
        del adata_dict

        ##Create a df containing median deviance and number of platforms where that gene was highly deviant across platforms
        batch_deviance_df = pd.DataFrame({'median_deviance': np.median(np.dstack(subset_datasets)[:,0,:], axis = 1),
                                          'n_platforms_deviant': np.dstack(subset_datasets)[:,1,:].sum(axis = 1)
                                         })
        batch_deviance_df.index = subset_indices ##set index as the gene names from the subset (intersection)
        
        ##Sort values on median deviance first (so we can break ties by median_deviance) and then rank on # platforms where it was highly_deviant
        batch_deviance_df = batch_deviance_df.sort_values(by = 'median_deviance', ascending = False)
        batch_deviance_df['deviant_gene_rank']=batch_deviance_df['n_platforms_deviant'].rank(method ='first', ascending = False)
        batch_deviance_df['highly_deviant'] = batch_deviance_df['deviant_gene_rank'] <= n_deviant
        batch_deviance_df = batch_deviance_df.loc[subset_indices]
        
        ##update adata object with identified deviant genes (set highly variable as highly deviant for PCA)
        ##Make sure we update the correct indives by subsetting batch_deviance.loc[subset_indices] again even though it is redundant
        adata.var['median_deviance'] = np.nan
        adata.var['deviant_gene_rank'] = np.nan
        adata.var['n_platforms_deviant'] = np.nan
        adata.var['highly_deviant'] = False
        adata.var.loc[subset_indices, 'median_deviance'] = batch_deviance_df.loc[subset_indices, 'median_deviance']
        adata.var.loc[subset_indices, 'deviant_gene_rank'] = batch_deviance_df.loc[subset_indices, 'deviant_gene_rank']
        adata.var.loc[subset_indices, 'n_platforms_deviant'] = batch_deviance_df.loc[subset_indices, 'n_platforms_deviant']
        adata.var.loc[subset_indices, 'highly_deviant'] = batch_deviance_df.loc[subset_indices, 'highly_deviant']
        adata.var['highly_variable'] = adata.var['highly_deviant']
        
        ##Explicitly convert median_deviance and n_platforms_deviant to float objects so that we can save to hda5 (otherwise error)
        adata.var['median_deviance'] = adata.var['median_deviance'].astype(np.float64)
        adata.var['n_platforms_deviant'] = adata.var['n_platforms_deviant'].astype(np.float64)
        
        ##add deviant gene df to adata.uns
        # adata.uns['deviant_genes'] = batch_deviance_df
        
        ##print for number of deviant genes
        if top_n is not None:
            print(f'Selected top {top_n} deviant genes')
        else: 
            print(f"Top {top_pct*100}% deviant genes feature number: {n_deviant}")
            
        del batch_deviance_df
            
        return adata
            
    else:
        adata = FindDeviantGenes_singlebatch(adata, layer=layer, top_pct=top_pct, top_n=top_n)

        return adata


def plot_PCA(adata, pca_basis='X_pca', hue=None, figsize = (8,6), title="", titlesize=12, markersize = 10, colors=None, bbox_to_anchor=(1.3,1.01)):
    
    """
    Input:
    adata: AnnData object (must have PCA already performed)
    pca_basis: Key for PCA data in adata.obsm
    hue: Variable to group by for plotting (Defaults to None)
    figsize: Desired size of figure (tuple; Defaults to (8,6))
    title: Title to give plot
    titlesize: Font to give title (defaults to 12)
    markersize: Size for points on scatter (defaults to 10)
    colors: List of colors to give to all groups from hue (Defaults to None)
    
    Output:
    PCA plot of provided data grouped by hue (if provided)
    """
    
    ##ONLY necessary for coloring by platform. scanpy gives an error for some reason
    
    ##Get obs variables from data
    scatter_df = adata.obs.copy()
    
    ##add PC data
    scatter_df = scatter_df.assign(PC1 = adata.obsm[pca_basis][:,0],
                                   PC2 = adata.obsm[pca_basis][:,1])
    
    fig, ax = plt.subplots(1,1, figsize=figsize)
    
    if hue is not None:
        if colors is not None:
            sns.scatterplot(data = scatter_df, x = "PC1", y = "PC2", hue = hue, palette = colors, ax = ax, s = markersize)
        else:
            sns.scatterplot(data = scatter_df, x = "PC1", y = "PC2", hue = hue, ax = ax, s = markersize)
            
        ax.legend(title = hue, loc='upper right', bbox_to_anchor=bbox_to_anchor, prop={'size': 9})
    else:
        sns.scatterplot(data = scatter_df, x = "PC1", y = "PC2", ax = ax)
    
    ## remove axis ticks
    ax.set(xticks=[], yticks=[])
    
    ##Set title
    ax.set_title(title, fontsize=titlesize)
        
    plt.show()
    
    del scatter_df
        
def log1p_normalize(adata):
    
    """
    Input: adata object with counts data stored in adata.X (should be unnormalized)
    
    Output: adata object with log1p(size corrected counts) stored in x and unnormalized counts 
            stored in adata.layers['unnormalized']
            
            
    Theory:
    
    The counts we have in our data are generated after cell capture, reverse transcription, amplification, and sequencing. 
    These steps inherently vary per cell, so the counts we see represent not only the biological variation per cell, but
    the technical variation as well. 
    
    Normalizing data is a valuable preprocessing step to adjust these counts in the dataset for technical variance by 
    scaling the observed variance into a specific range. There are many techniques (log shifted transformation, pearson r
    esiduals, etc.) used to make subsequent analysis and statistics applicable, but their usage depends on the situation at 
    hand.
    
    According to the Theis et. al's Single Cell Best Practices book, the shifted logarithm approach is useful for stabilizing
    variance and identifying differentially expressed genes while the pearson residual approach is useful for identifying biologically 
    relevant genes and rare cell types. 
    
    A recent benchmark by [Ahlmann-Eltze & Huber (2023)](https://doi.org/10.1038/s41592-023-01814-1)  
    revealed that the shifted logarithm approach,  demonstrates superior performance compared to other methods in
    uncovering underlying latent structures and stabilizing variance for identifying differentially expressed genes, 
    especially when followed by PCA. In this approach, the log-shifted counts are defined as:
    
   We will use the log-shifted normalization method because it performs well in the benchmark study and works to 
   identify differentially expressed genes. However, we will use the default size factor scaling
   provided in scanpy's `sc.pp.normalize_total()` function instead of the average counts as in the Ahlmann-Eltze & Huber
   paper. scanpy's default size factor scale is the median count across all cells, which should produce similar results 
   to the Ahlmann-Eltze & Huber paper (the only difference is that we use median instead of mean counts). This is performed
   below. 
   
   Note that we will not set a target_sum in `sc.pp.normalize_total` because we don't want to define a set scaling factor 
   for the data (ie $L=10^6$ would give us counts per million).
   """
    
    ## Normalizing Data

    ##recalculate metrics to get total_counts before normalizing
    sc.pp.calculate_qc_metrics(adata, log1p = True, percent_top=[20], inplace=True)

    ##check if data has already been normalized (will have an 'unnormalized' layer)
    if adata.layers.get('unnormalized') is None:

        warnings.warn("Normalizing (and log1p) total counts from adata.X. Make sure adata.X is not already normalized, or it may cause problems.")

        #save unnormalized count data to adata.layers['unnormalized'] for feature selection using deviance
        adata.layers['unnormalized'] = adata.X.copy()

    else: ##ie it has already been normalized with log1p_normalize

        warnings.warn(f"Data already normalized. Using adata.layers['unnormalized'] to normalize and log1p data. Storing at adata.X")
        adata.X = adata.layers['unnormalized'].copy()

    ##Either way, scale and log-shift transform data
    scales_counts = sc.pp.normalize_total(adata, target_sum=None, inplace=False)##does the y/s_c calculation
    adata.X = sc.pp.log1p(scales_counts["X"], copy=True) ##does log-shift ie log(y_scaled + 1)
    
    return adata

def preprocessData(adata, scale=True, normalize_data=True, batch_preprocess=False, **kwargs):
    
    """
    Input:
    adata: AnnData object
    normalize_data: Whether to log1p normalize data (Defaults to True because it is in Best Practices (See Below))
    scale: Whether to scale data to mean of zero and variance of 1 (Defaults to True)
    batch_preprocess: Whether to perform batch-specific normalization and scaling (Usually recommended except for integration with Harmony Based on our Results). Defaults to False
    **kwargs: any keyword argument that can be supplied to internal functions
    
    Output:
    adata object that has been log1p_normalized, had highly deviant genes identifed and set as highly variable genes, and been 
    scaled to a mean of zero and variance of 1 (if requested) so that we can use only this subset for pca
    
    Best Practices: https://www.sc-best-practices.org/preprocessing_visualization
    
    """
    
    ###Extract potential kwargs (Default after comma)
    layer = kwargs.get('layer', 'unnormalized') ##default is 'unnormalized'
    top_pct = kwargs.get('top_pct', 0.15) ##default is 2000 for top deviance pct
    top_n = kwargs.get('top_n', None) ##default is None for top deviance n since it defaults to using pct
    batch_key = kwargs.get('batch_key', None) ##default is none (no batch correction; single batch processing)
    min_cells = kwargs.get('min_cells', 3) ##default is 3 (only relevant for when batch_key is provided)
    max_value = kwargs.get('max_value', None) ##Default is none for max_value (value to clip to when scaling)
    
    ##Message warning users to use raw counts for Deviant Detection
    if (scale | normalize_data):
        print("Normalizing and/or Scaling Data; Make sure to provided raw count layer for Deviance Detection")
    
    ##Normalize (Get's metrics for total counts and then: sc.pp.normalize_total(adata, target_sum=None, inplace=False)##does the y/s_c calculation)
    
    ##check if batch preprocessing is desired to decide if we do per batch normalization, scaling
    if batch_preprocess is False:
        
        ##If batch_key is None and batch_scale or batch_normalize are set as True, give error
        if batch_key is not None:
            print("'batch_preprocess' set to False but batch_key provided.") 
            print("Only batch deviance detection will be performed not batch normalization and scaling.")
        
        if normalize_data is True: 
            print("Normalizing Data ...")
            adata = log1p_normalize(adata)
        
        ##Scale data (defaults to mean of zero and sd of 1)
        if scale is True:
            print("Scaling Data ...")
            adata = sc.pp.scale(adata, max_value=max_value, copy = True) ##Not done inplace
            
        # check if data is sparse, if not, make it sparse (scaling makes data not sparse)
        if not issparse(adata.X):
            adata.X = csr_matrix(adata.X)

        # check to make sure zeros are removed, if not, remove
        if (adata.X.data == 0).any():
            adata.X.eliminate_zeros()
            
      
    ##if batch_preprocessing is desired, perform batch-specific scaling and normalization (if requested for each)
    ##Note: User may choose to do batch_deviance detection but not batch scaling and normalization
    else:
        
        ##If batch_key is None and batch_scale or batch_normalize are set as True, give error
        if batch_key is None:
            raise ValueError("'batch_preprocess' requires a batch_key variable to group by")
            
        ##get batches
        batches = np.unique(adata.obs[batch_key])
        batch_values = adata.obs[batch_key].values
        
        
        ##Print update messages
        if normalize_data and scale:
            print("Normalizing and Scaling Data per Batch ...")
        elif normalize_data:
            print("Normalizing Data per Batch ...")
        elif scale:
            print("Scaling Data per Batch ...")
        else:
            print("Not Normalizing or Scaling Data ...")
                  

        adata_dict = {}
        for batch in batches:

            ##subset batch
            adata_batch = adata[batch_values == batch]

            if normalize_data is True:

                ##normalize per batch
                adata_batch = log1p_normalize(adata_batch)

            if scale is True:

                ##Batch Scale Data (defaults to mean of zero and sd of 1)
                adata_batch = sc.pp.scale(adata_batch, max_value=max_value, copy = True) ##Not done inplace
                
            # check if data is sparse, if not, make it sparse (scaling makes data not sparse)
            if not issparse(adata_batch.X):
                adata_batch.X = csr_matrix(adata_batch.X)

            # check to make sure zeros are removed, if not, remove
            if (adata_batch.X.data == 0).any():
                adata_batch.X.eliminate_zeros()

            ##add back to dict
            adata_dict[batch] = adata_batch

            del adata_batch

        ##concatenate adata batch back into adata so we have batch normalized and scaled data to correct for differences in library size
        adata = ad.concat(adata_dict, join = 'inner')

        ##concatenation removes adata.var, so we will select the first adata.var since normalization will not change any of these
        adata.var = adata_dict[batches[0]].var
        
        ##delete for memory
        del adata_dict
        gc.collect()
            

    ##Find Top Deviant Genes (Defaults to top 2000) correcting for batch_key if necessary 
    ##for deviance, only considers raw counts not normalized/scaled, so we can do after normalization and scaling
    adata = FindDeviantGenes(adata, layer=layer, top_pct=top_pct, top_n=top_n, batch_key=batch_key, min_cells=min_cells)
        
    return adata

Now we can concatenate each platform with the reference to generate datasets unique to each platform for cell annotations and find the deviant genes that can be used for pca, umap, and annotation later. Note that for deviant gene selection, we are selecting top 2000 genes for all platforms so it is consistent. We will do a batch corrected deviant gene selection to identify commonly deviant genes. We will also do a query centric approach where we identify the top 2000 deviant genes in the query alone. This will emphasize platform-specific gene expression. Also note that we normalize to total and scale to a mean of 0 and variance of 1 for processing (does NOT impact deviant gene detection because it works on Unnormalized data). For normalization, we normalize per platform AND per method used to generate the Broad-Reference set (For instance within Broad we have inDrops, 10X, etc., so we will normalize each of these as their own batch)

In [2]:
##concatenate each method with broad reference to generate a new adata object with reference for each

##Initialize dictionary for saved files
saved_files={}
selected_batchcor_genes={}
selected_common_genes={}
selected_query_genes={}

##Loop through adata to process
for method, adata in adata_all.items():
    
    ##copy adata so we don't update in place
    adata = adata.copy()
    
    ##don't combine Broad-Reference with itself
    if method == "Broad-Reference":
        continue 
    
    ##concat with broad-ref
    adata = ad.concat([adata, adata_all['Broad-Reference']], join='inner')
    
    ##set dummy variables for platforms (and convert to strings so plotting with scanpy works as binary (first bool then string))
    sel_methods = np.unique(adata.obs['platform'])
    adata.obs = pd.concat([adata.obs, pd.get_dummies(adata.obs['platform'], dtype=bool)], axis = 1) ##dummy variables for each platform for visualization
    adata.obs[sel_methods] = adata.obs[sel_methods].replace({True: 'True', False: 'False'})

    ## annotate mitochondrial (MT) and ribosomal (RPS and RPL) genes
    adata.var['mt'] = adata.var_names.str.startswith('MT-')  # annotate mitochondrial gene group as 'mt'
    adata.var['ribo'] = adata.var_names.str.startswith(("RPS","RPL"))  # annotate ribosomal gene group as 'ribo'

    ##Add relevant columns for cell-annotation later (only intersection)
    adata.obs = adata.obs.assign(Cluster_baseline = "Unassigned")
    adata.obs = adata.obs.assign(CellType_baseline = "Unassigned")
    adata.obs.loc[adata.obs['platform'] == 'Broad-Reference', ['Cluster_baseline', 'CellType_baseline']] = reference_metadata[['Cluster', 'CellType']].values

    if method == "Broad-Reference-Val":
        adata.obs.loc[adata.obs['platform'] == 'Broad-Reference-Val', ['Cluster_baseline', 'CellType_baseline']] = validation_metadata[['Cluster', 'CellType']].values

    ##Add a low resolution baseline (combine Monocytes, T Cells, and Dendritic Cells into one category, respectively)
    adata.obs['CellType_baseline_lowres'] = adata.obs['CellType_baseline'].copy()
    adata.obs['CellType_baseline_lowres'] = ['Monocyte (non-specific)' if 'monocyte' in cell.lower() else cell for cell in adata.obs['CellType_baseline_lowres']]
    adata.obs['CellType_baseline_lowres'] = ['T Cell (non-specific)' if 't cell' in cell.lower() else cell for cell in adata.obs['CellType_baseline_lowres']]
    adata.obs['CellType_baseline_lowres'] = ['Dendritic Cell (non-specific)' if 'dendritic' in cell.lower() else cell for cell in adata.obs['CellType_baseline_lowres']]

    ## Before any further preprocessing make sure to set a .raw attribute and subset query data for query-specific deviant gene detection
    adata.raw = adata
    adata.layers['unnormalized'] = adata.raw.X
    adata_query = adata[adata.obs['platform'] == method] ##subset query data so we can find query-specify deviation also
    
    ## Identify deviant genes for each platform (including Broad internal batches) separately so we can find commonly deviant genes (FEATURE_SELECT_GENE_N set at top of notebook)
    ## This will also normalize per Broad-Reference internal batch
    print(f"Processing Data to Get Deviant Genes per Batch for {method}...")
    adata = preprocessData(adata, scale=True, normalize_data=True, batch_preprocess=True,
                           layer='unnormalized', top_n=FEATURE_SELECT_GENE_N, batch_key='platform_broadincluded')
    
    ## Identify deviant genes for the query only so we can emphasize gene deviation in the query alone also
    ## Here we also use a batch_key for Broad-Reference-Val, otherwise we use no batch key
    print(f"Processing Query to Get Deviant Genes per Query for {method}...")
    if method == 'Broad-Reference-Val':
        print("(Processing Broad-Reference Batches Separately for Normalization and Deviant Gene Detection)")
        adata_query = preprocessData(adata_query, scale=True, normalize_data=True, batch_preprocess=True,
                                     layer='unnormalized', top_n=FEATURE_SELECT_GENE_N, batch_key='platform_broadincluded')
    else:
        adata_query = preprocessData(adata_query, scale=True, normalize_data=True, batch_preprocess=False,
                                     layer='unnormalized', top_n=FEATURE_SELECT_GENE_N) ##single batch method since only query
    
    ## Select those that are deviant in all platforms (including reference batch) as those to call deviant genes
    dev_genes_batch = adata.var['n_platforms_deviant'] == adata.obs['platform_broadincluded'].nunique()
    adata.var['highly_deviant_common'] = dev_genes_batch
    
    ## set those genes that are deviant in the query as those that are deviant/variable in the concatenated data
    ## adata.var and adata_query.var should have same genes since adata_query was a subset from adata
    if np.all(adata.var_names == adata_query.var_names):
        adata.var['highly_deviant_query'] = adata_query.var['highly_deviant']
    else:
        raise ValueError("adata.var_names does not match adata_query.var_names")
        
    ## delete for memory
    del adata_query
    
    ##save genes
    selected_common_genes[method] = adata.var_names[dev_genes_batch]
    selected_batchcor_genes[method] = adata.var_names[adata.var['highly_deviant']]
    selected_query_genes[method] = adata.var_names[adata.var['highly_deviant_query']]
    
    print(f"Common Highly Deviant Genes for {method}: {np.sum(dev_genes_batch)}")
    print(f"Highly Deviant Genes for {method}: {np.sum(adata.var['highly_deviant'])}")
    
    ## Save updated adata objects to their own files
    saved_files[method] = os.path.join(anndata_object_dir, f'{method}_intersection_w_reference_nodoublets.h5ad')
    adata.write_h5ad(saved_files[method])
    
    ##delete for memory
    del adata
    
    ##force garbage collection
    gc.collect()