## Notebook to identify potential doublets using Scrublet

- [Github repo](https://github.com/AllonKleinLab/scrublet)
- [repo example notebook](https://github.com/AllonKleinLab/scrublet/blob/master/examples/scrublet_basics.ipynb)
- [Cell Systems paper](https://www.sciencedirect.com/science/article/pii/S2405471218304745)

In [None]:
!date

#### import libraries

In [None]:
import scanpy as sc
import scrublet as scr
from pandas import read_csv, DataFrame
import matplotlib.pyplot as plt
from matplotlib.pyplot import rc_context
from seaborn import barplot

%matplotlib inline
# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

#### set notebook variables

In [None]:
# naming
proj_name = 'aging'

# directories
wrk_dir = '/home/jupyter/brain_aging_phase1'
quants_dir = f'{wrk_dir}/demux'

# in files
anndata_file = f'{quants_dir}/{proj_name}.h5ad'
final_file = f'{quants_dir}/{proj_name}.pegasus.leiden_085.subclustered.h5ad'

# out files
scores_file = f'{quants_dir}/{proj_name}.scrublet_scores.csv'

# variables
DEBUG = True
dpi_value = 50
use_gene_only = False
expected_rate = 0.08

### load the anndata files

In [None]:
%%time
adata = sc.read(anndata_file)
if DEBUG:
    print(adata)

### Initialize Scrublet object
The relevant parameters are:

- expected_doublet_rate: the expected fraction of transcriptomes that are doublets, typically 0.05-0.1. Results are not particularly sensitive to this parameter. For this example, the expected doublet rate comes from the Chromium User Guide: https://support.10xgenomics.com/permalink/3vzDu3zQjY0o2AqkkkI4CC
- sim_doublet_ratio: the number of doublets to simulate, relative to the number of observed transcriptomes. This should be high enough that all doublet states are well-represented by simulated doublets. Setting it too high is computationally expensive. The default value is 2, though values as low as 0.5 give very similar results for the datasets that have been tested.
- n_neighbors: Number of neighbors used to construct the KNN classifier of observed transcriptomes and simulated doublets. The default value of round(0.5*sqrt(n_cells)) generally works well.

In [None]:
%%time
scrub = scr.Scrublet(adata.X, expected_doublet_rate=expected_rate)

### Run the default pipeline, which includes:
1. Doublet simulation
2. Normalization, gene filtering, rescaling, PCA
3. Doublet score calculation
4. Doublet score threshold detection and doublet calling

In [None]:
%%time
doublet_scores, predicted_doublets = scrub.scrub_doublets(log_transform=True)

### Plot doublet score histograms for observed transcriptomes and simulated doublets
The simulated doublet histogram is typically bimodal. The left mode corresponds to "embedded" doublets generated by two cells with similar gene expression. The right mode corresponds to "neotypic" doublets, which are generated by cells with distinct gene expression (e.g., different cell types) and are expected to introduce more artifacts in downstream analyses. Scrublet can only detect neotypic doublets.

To call doublets vs. singlets, we must set a threshold doublet score, ideally at the minimum between the two modes of the simulated doublet histogram. scrub_doublets() attempts to identify this point automatically and has done a good job in this example. However, if automatic threshold detection doesn't work well, you can adjust the threshold with the call_doublets() function. For example:

scrub.call_doublets(threshold=0.25)

In [None]:
with rc_context({'figure.figsize': (9, 9), 'figure.dpi': dpi_value}):
    plt.style.use('seaborn-bright')
    scrub.plot_histogram()

### Get 2-D embedding to visualize the results

In [None]:
print('Running UMAP...')
scrub.set_embedding('UMAP', scr.get_umap(scrub.manifold_obs_, 10, min_dist=0.3))

print('Done.')

### Plot doublet predictions on 2-D embedding
Predicted doublets should co-localize in distinct states.

In [None]:
with rc_context({'figure.figsize': (9, 9), 'figure.dpi': dpi_value}):
    plt.style.use('seaborn-bright')
    scrub.plot_embedding('UMAP', order_points=True)

### add the scores the the cell observations

In [None]:
adata.obs['doublet_score'] = doublet_scores
adata.obs['predicted_doublet'] = predicted_doublets

In [None]:
display(adata.obs.predicted_doublet.value_counts())

### save the scores

In [None]:
adata.obs.to_csv(scores_file)

In [None]:
scrublet_data = read_csv(scores_file, index_col=0)
print(scrublet_data.shape)
display(scrublet_data.predicted_doublet.value_counts())
doublets = scrublet_data.loc[scrublet_data.predicted_doublet]
display(doublets.predicted_doublet.value_counts())
print(doublets.shape)
if DEBUG:
    display(scrublet_data.sample(5))
    display(doublets.sample(5))

### see which cell-type clusters are impacted by the possible doublets

#### load the fully processed anndata file

In [None]:
%%time
adata_done = sc.read(final_file)

if DEBUG:
    print(adata_done)

### add to scrublet predictions to the observation data

In [None]:
adata_done.obs['scrublet_doublet'] = 'no'
adata_done.obs.loc[adata_done.obs.index.isin(doublets.index), 'scrublet_doublet'] = 'yes'
print(adata_done.obs.shape)
display(adata_done.obs.scrublet_doublet.value_counts())
if DEBUG:
    display(adata_done.obs.sample(5))

### visualize to predicted doublets in the full data UMAP

In [None]:
with rc_context({'figure.figsize': (9, 9), 'figure.dpi': 50}):
    plt.style.use('seaborn-bright')
    sc.pl.umap(adata_done, color=['scrublet_doublet'], 
               frameon=False)

### take a look at what is impacted by the predicted doublets

In [None]:
adata_doubles = adata_done[adata_done.obs.scrublet_doublet == 'yes']
print(adata_doubles)
if DEBUG:
    display(adata_doubles.obs.sample(5))

In [None]:
with rc_context({'figure.figsize': (12, 12), 'figure.dpi': 50}):
    plt.style.use('seaborn-bright')
    sc.pl.umap(adata_done, color=['new_anno'], 
               frameon=False, legend_loc='on data')

In [None]:
with rc_context({'figure.figsize': (9, 9), 'figure.dpi': 50}):
    plt.style.use('seaborn-bright')
    sc.pl.umap(adata_doubles, color=['scrublet_doublet'], 
               frameon=False, legend_loc=None)

In [None]:
with rc_context({'figure.figsize': (9, 9), 'figure.dpi': 50}):
    plt.style.use('seaborn-bright')
    sc.pl.umap(adata_doubles, color=['new_anno', 'broad_celltype', 'Brain_region'], 
               frameon=False, legend_loc='on data')

### what cell-types are the predicted doublets being assigned to

In [None]:
celltype_doublet_fracs = {}
celltype_doublet_counts = adata_doubles.obs.new_anno.value_counts()
celltype_doublet_counts = celltype_doublet_counts.to_frame()
celltype_doublet_counts = celltype_doublet_counts.rename(columns={'new_anno':'counts'})
celltype_doublet_counts['percent'] = round(celltype_doublet_counts.counts/celltype_doublet_counts.counts.sum()*100)
if DEBUG:
    display(celltype_doublet_counts)

In [None]:
with rc_context({'figure.figsize': (9, 9)}):  
    barplot(x=celltype_doublet_counts.index, y='percent', data=celltype_doublet_counts, palette='Purples')
    plt.grid(axis='y')
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.ylabel('percent scrublet doublet')
    plt.title('Percentage of scrublet doublet assigned to cluster cell-type')
    plt.show()

### what percentage of each cluster specific cell-type is impacted by possible doublets

In [None]:
cluster_impact_fracs = {}
cluster_impact_counts = adata_done.obs.groupby('new_anno').scrublet_doublet.value_counts()
for cell_type in adata_done.obs.new_anno.unique():
    print(cell_type)
    print(cluster_impact_counts[cell_type]['yes'],cluster_impact_counts[cell_type]['no'])
    cluster_impact_fracs[cell_type] = round(cluster_impact_counts[cell_type]['yes']/cluster_impact_counts[cell_type].sum(), 3)
    print(cluster_impact_fracs[cell_type])
if DEBUG:
    display(cluster_impact_counts)

#### which cell-types have more than 5% impact

In [None]:
for cell_type, frac in cluster_impact_fracs.items():
    this_percent = round(frac*100, 3)
    if this_percent >= 5:
        print(cell_type, this_percent)

In [None]:
with rc_context({'figure.figsize': (9, 9)}):  
    df = DataFrame.from_dict(cluster_impact_fracs, orient='index', columns=['frac'])
    df['percent'] = df.frac * 100
    df = df.sort_values('percent')
    barplot(x=df.index, y='percent', data=df, palette='Purples')
    plt.grid(axis='y')
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.ylabel('percent scrublet doublet')
    plt.title('Percentage of cluster cell-types that are scrublet doublet')
    plt.show()

In [None]:
!date