In [None]:
from typing import Dict, Optional
import scanpy as sc
import os
import anndata as ad
import anndata
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import matplotlib.ticker as ticker
import seaborn as sns
import math
from plotnine import *
#import scrublet as scr
#from scipy.stats import median_abs_deviation
import sctk as sk
import pandas as pd
import tables
import scipy.sparse as sp
sc.settings.verbosity = 0
 

# Utils functions

In [None]:
def calculate_expected_doublet_rate(adata_object):
    """
    for a given adata object, using the number of cells in the object to return 
    the expect\ed doublet rate and the number of cells in the object
    """
    #expected values from https://uofuhealth.utah.edu/huntsman/shared-resources/gba/htg/single-cell/genomics-10x
    expected_rates = {1000: 0.008, 
                      2000: 0.016,
                    3000: 0.023,
                     4000: 0.031,
                     5000: 0.039,
                     6000: 0.046,
                     7000: 0.054,
                     8000: 0.061,
                     9000: 0.069,
                     10_000: 0.076}
    #number of cells (rounded)
    recovered_cells = adata_object.shape[0]
    rounded_recovered_cells = math.ceil(recovered_cells / 1000) * 1000
    if rounded_recovered_cells > 10_000:
        rounded_recovered_cells = 10_000
        print(f"Rounded recovered cells > 10_000 so set to maximum ({recovered_cells} -> \
              {rounded_recovered_cells} -> 10_000 (rate {expected_rates[10_000]})")
    #set expected rate based on number of cells in object
    expected_rate = expected_rates[rounded_recovered_cells]
    print(f"Expected rate {expected_rate} ({expected_rate*100}%) for cells {recovered_cells}")
    return expected_rate, recovered_cells

Metadata input must be csv with columns including "Sample" and path_type, featuring path to input file 
\newline

(e.g. to cellbender output or cellranger/STARsolo output)


# Scrublet

In [None]:
import scanpy as sc
base_dir="/nfs/team298/ls34/disease_atlas/mrvi/"
adata_path = ### put path here
adata=sc.read_h5ad(base_dir+adata_path)
adata

In [None]:
error_record =[]
apply_scrublet=True

processed_adatas = []
adata.obs["scrublet_score"]=adata.obs_names
adata.obs["scrublet_score_z"]=adata.obs_names
adata.obs["cluster_scrublet_score"]=adata.obs_names
adata.obs["bh_pval"]=adata.obs_names

dataset_donors = []
donor_keys = []
len_samples = len(list(adata.obs["DonorID"].unique()))
for i, sample_id in enumerate(list(adata.obs["DonorID"].unique())):
    print(f"#### sctk scrublet on: {sample_id}. {i+1}/{len_samples}")
    if apply_scrublet:
        adata_donor = adata[adata.obs['DonorID'] == sample_id]
        sk.run_scrublet(adata_donor)
        
        mapping_dict=adata_donor.obs["scrublet_score"].to_dict()
        adata.obs["scrublet_score"]=adata.obs["scrublet_score"].map(mapping_dict).fillna(adata.obs["scrublet_score"])
        
        mapping_dict=adata_donor.obs["scrublet_score_z"].to_dict()
        adata.obs["scrublet_score_z"]=adata.obs["scrublet_score_z"].map(mapping_dict).fillna(adata.obs["scrublet_score_z"])
        
        mapping_dict=adata_donor.obs["cluster_scrublet_score"].to_dict()
        adata.obs["cluster_scrublet_score"]=adata.obs["cluster_scrublet_score"].map(mapping_dict).fillna(adata.obs["cluster_scrublet_score"])
        
        mapping_dict=adata_donor.obs["bh_pval"].to_dict()
        adata.obs["bh_pval"]=adata.obs["bh_pval"].map(mapping_dict).fillna(adata.obs["bh_pval"])

    


In [None]:
adata.obs["scrublet_score"].value_counts()

In [None]:
adata.obs["scrublet_score_binary04"] = ["Pass" if x < 0.4 else "Fail" for x in adata.obs["scrublet_score"]]
adata.obs["scrublet_score_binary04"].value_counts()

In [None]:
adata.obs["scrublet_score_binary03"] = ["Pass" if x < 0.3 else "Fail" for x in adata.obs["scrublet_score"]]
adata.obs["scrublet_score_binary03"].value_counts()

In [None]:
def apply_qc_thresholds(adata, MIN_N_GENES, MAX_TOTAL_COUNT, MAX_PCT_MT, label, MIN_TOTAL_COUNT=0,):
    """
    Apply thresholds to generate QC column which says if passed all
    """
    ## Cell cycle gene list
    cc_genes_csv=pd.read_csv("/lustre/scratch126/cellgen/team298/sko_expimap_2023/pan_fetal_cc_genes.csv", names=["ind", "gene_ids"], skiprows=1)
    cc_genes_csv = cc_genes_csv["gene_ids"]
    cc_genes_csv = list(cc_genes_csv)

    # Mark MT/ribo/Hb/cell cycle genes
    adata.var['mt'] = adata.var_names.str.startswith('MT-')  
    adata.var["ribo"] = adata.var_names.str.startswith(("RPS", "RPL"))
    adata.var["hb"] = adata.var_names.str.contains(("^HB[^(P)]")) 
    #adata.var["hb"] = adata.var_names.str.startswith(("HBA1", "HBA2", "HBB", "HBD","HBM", "HBZ", "HBG1", "HBG2", "HBQ1"))
    adata.var["cc_fetal"] = adata.var_names.isin(cc_genes_csv)

    # Calculate QC metrics
    sc.pp.calculate_qc_metrics(adata, qc_vars=["mt", "ribo"], inplace=True, log1p=False) #percent_top=[20],
    
    conditions = [
        (adata.obs['n_genes_by_counts'] < MIN_N_GENES),
        (adata.obs['pct_counts_mt'] > MAX_PCT_MT),
        (adata.obs['total_counts'] > MAX_TOTAL_COUNT),
        (adata.obs['total_counts'] < MIN_TOTAL_COUNT),
        (adata.obs['pct_counts_mt'] <= MAX_PCT_MT) & (adata.obs['n_genes_by_counts'] >= MIN_N_GENES) & 
        (adata.obs['total_counts'] <= MAX_TOTAL_COUNT)  & 
        (adata.obs['total_counts'] >= MIN_TOTAL_COUNT)
    ]
    label_suffix = label.split("_")[-1]
    print(label_suffix)
    pass_name = "Pass_" + label_suffix
    values = ['Low_nFeature', 'High_MT', 'High total count', 'Low total count', pass_name]

    adata.obs[label] = np.select(conditions, values)
    adata.obs[label] = adata.obs[label].astype('category')

    print(adata.obs[label].value_counts())
    



In [None]:
apply_qc_thresholds(adata, MIN_N_GENES=600, MAX_TOTAL_COUNT=30_000, MAX_PCT_MT=1,  MIN_TOTAL_COUNT=1000, label="QC_hi")


In [None]:
adata.write(base_dir + adata_path + ".scrublet")

In [None]:
leidenres_list = [2]
leiden_to_plot = []
neighbor_id = 'neighbor_30'
for leidenres in leidenres_list:
    print("###", leidenres)
    leiden_id = "leiden_res" + str(leidenres) # gayoso 1.2
    leiden_to_plot.append(leiden_id)
    sc.tl.leiden(adata, resolution=leidenres, key_added=leiden_id, neighbors_key=neighbor_id)


In [None]:
dict_scrublet_score = {}

grouped = adata.obs.groupby("leiden_res2")["scrublet_score"].mean()

dict_scrublet_score = grouped.to_dict()

adata.uns["dict_scrublet_score"] = dict_scrublet_score

print(dict_scrublet_score)


In [None]:
import scanpy as sc
adata.write(base_dir + adata_path + ".scrublet")

In [None]:
STOP

In [None]:
import scanpy as sc
adata=sc.read_h5ad('/nfs/team298/ls34/disease_atlas/mrvi/adata_inflamm_scanvi6.h5ad.filtered.scrublet')
adata

In [None]:
sc.settings.set_figure_params(dpi=100, facecolor="white", frameon=False, figsize=(18,18))
sc.pl.umap(adata, 
           color=["leiden_res2", "scrublet_score"],
           s=2,
           legend_loc="on data", vmax=0.4,
           legend_fontsize=9, )#fontsize=20)

In [None]:
dict_scrublet_score = {}
grouped = adata.obs.groupby("leiden_res2")["scrublet_score"].mean()
dict_scrublet_score = grouped.to_dict()
adata.uns["dict_scrublet_score"] = dict_scrublet_score



In [None]:
data=dict_scrublet_score 
keys = list(data.keys())
values = list(data.values())

plt.figure(figsize=(20, 6))
bars = plt.bar(keys, values, color=['red' if value >= 0.3 else 'blue' for value in values])
plt.xlabel('Keys')
plt.ylabel('Values')
plt.title('Value for Each Key')
plt.xticks(rotation=90)  # Rotate x-axis labels if needed
plt.grid(True)

red_patch = plt.Line2D([0], [0], color='red', lw=4, label='Value >= 0.3')
blue_patch = plt.Line2D([0], [0], color='blue', lw=4, label='Value < 0.3')
plt.legend(handles=[red_patch, blue_patch])

plt.show()


In [None]:
keys = list(dict_scrublet_score.keys())
values = list(dict_scrublet_score.values())

plt.figure(figsize=(20, 6))
bars = plt.bar(keys, values, color=['red' if value >= 0.3 else 'blue' for value in values])
plt.xlabel('Keys')
plt.ylabel('Values')
plt.title('Value for Each Key')
plt.xticks(rotation=90)  # Rotate x-axis labels if needed
plt.grid(True)

red_patch = plt.Line2D([0], [0], color='red', lw=4, label='Value >= 0.3')
blue_patch = plt.Line2D([0], [0], color='blue', lw=4, label='Value < 0.3')
plt.legend(handles=[red_patch, blue_patch])

# Show the plot
plt.show()


In [None]:
binarized_scores = {k: '>=0.3' if v >= 0.3 else '<0.3' for k, v in dict_scrublet_score.items()}
adata.obs['highlight'] = adata.obs['leiden_res2'].map(binarized_scores)
print(adata.obs.head())
sc.settings.set_figure_params(dpi=100, facecolor="white", frameon=False, figsize=(18, 18))
sc.pl.umap(adata, 
           color=['highlight', "lvl3_annotation"], 
           s=2, 
           legend_loc='on data', 
           legend_fontsize=9)


In [None]:
adata.obs.highlight.value_counts()

In [None]:
sc.pl.umap(adata, 
           color=['scrublet_score_binary03'], 
           s=2, 
           legend_loc='on data', 
           legend_fontsize=9)


In [None]:
adata = adata[adata.obs["scrublet_score_binary03"]!="Fail"]


In [None]:
adata.obs["highlight"].value_counts()

In [None]:
adata = adata[adata.obs["leiden_res2"]!="48"] 
adata = adata[adata.obs["leiden_res2"]!="54"] 
adata = adata[adata.obs["leiden_res2"]!="56"] 
adata = adata[adata.obs["leiden_res2"]!="60"] 
 

In [None]:
adata.write('/nfs/team298/ls34/disease_atlas/mrvi/adata_inflamm_scanvi6.h5ad.filtered.scrubletfiltered')

In [None]:
sc.pl.umap(adata, 
           color=['scrublet_score_binary03'], 
           s=2, 
           legend_loc='on data', 
           legend_fontsize=9)


In [None]:
sc.pl.umap(adata, 
           color=['leiden_res2'], 
           s=2, 
           legend_loc='on data', 
           legend_fontsize=9)


In [None]:
base_dir + adata_path + ".scrubletfiltered"

In [None]:
import scanpy as sc
adata.write(base_dir + adata_path + ".scrublet")

In [None]:
"""
re-run with 0.35 threshold
"""

# Map gene names

In [None]:
adata

In [None]:
adata.var

In [None]:
import pickle
with open('/lustre/scratch126/cellgen/team298/ls34/gene_ensgids_dictionaries.pkl', 'rb') as file:
    dictionaries = pickle.load(file)
    gene_dict = dictionaries['gene_dict']
    del(dictionaries)

adata.var["ensg_id"] = adata.var.index
adata.var["gene_symbol"] = adata.var.index.map(gene_dict).
adata.var_names = adata.var["gene_symbol"] 
adata.var

In [None]:
adata.var["gene_symbol2"] = adata.var["ensg_id"].map(gene_dict).fillna(adata.var["ensg_id"])


In [None]:
adata.var_names = adata.var["gene_symbol2"] 
adata.var

# Save

In [None]:
adata.write("/lustre/scratch126/cellgen/team298/ls34/beacon/adata_files/adata_postscrublet_postqc", compression="gzip")

from datetime import datetime
now = datetime.now()
timestamp = now.strftime("%Y-%m-%d %H:%M:%S")
print(f"Saved! Time: {timestamp}")
