### K562 overexpression example

This notebook prepares a dataset from a tech-dev paper with a gazillion Perturb-seq experiments applied to mostly K562 ([Replogle et al 2020](https://www.nature.com/articles/s41587-020-0470-y)). We'll focus on just the CRISPRa multiplexing experiment, which does overexpression. 

Here we tidy the dataset and carry out a simple exploration in scanpy.

Note: there is a known issue in this notebook where the HVG ranking is conducted before filtering for minimum expression. Ranks may be sparse/skippy, e.g. 1 2 4 6 8 9. For backwards compatibility, we have not revised this yet.

In [None]:
import warnings
warnings.filterwarnings('ignore')
import regex as re
import os
import shutil
import importlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
from scipy.stats import spearmanr as spearmanr
from IPython.display import display, HTML
# local
import importlib
import sys
sys.path.append("setup")
import ingestion
importlib.reload(ingestion)

import anndata
import os, sys
import itertools as it
from scipy.stats import spearmanr, pearsonr, rankdata, f_oneway
from statsmodels.stats.multitest import multipletests
from sklearn.metrics import mutual_info_score
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter


#      visualization settings
%matplotlib inline
plt.rcParams['figure.figsize'] = [6, 4.5]
plt.rcParams["savefig.dpi"] = 300

# I prefer to specify the working directory explicitly.
os.chdir("/home/ekernf01/Desktop/jhu/research/projects/perturbation_prediction/cell_type_knowledge_transfer/perturbation_data")

# Universal
geneAnnotationPath = "../accessory_data/gencode.v35.annotation.gtf.gz"       # Downloaded from https://www.gencodegenes.org/human/release_35.html
humanTFPath =  "../accessory_data/humanTFs.csv"                              # Downloaded from http://humantfs.ccbr.utoronto.ca/download.php
humanEpiPath = "../accessory_data/epiList.csv"                               # Downloaded from https://epifactors.autosome.org/description 
cellcycleGenePath = "../accessory_data/regev_lab_cell_cycle_genes.txt"

# Replogle1 Specific
dataset_name = "replogle"
perturbEffectTFOnlyPath = "setup/replogle1TFOnly.csv"                         # a path to store temp file
perturbEffectFullTranscriptomePath = "setup/replogle1FullTranscriptome.csv"   # a path to store temp file

### How many TF's are perturbed?

In [None]:
human_tfs = pd.read_csv(humanTFPath)
EpiList   = pd.read_csv(humanEpiPath, index_col=0).iloc[:, [0,14]]
human_tfs = human_tfs.loc[human_tfs["Is TF?"]=="Yes",:]
replogle_perturbations = pd.read_csv(f"not_ready/{dataset_name}/perturbed_genes.csv")
replogle_perturbations.drop_duplicates(inplace = True)
replogle_perturbations["is_tf"] = replogle_perturbations["gene"].isin(human_tfs["HGNC symbol"]) # | replogle_perturbations["gene"].isin(EpiList["HGNC_symbol"]) 
display(replogle_perturbations.groupby("experiment").count()) #total
display(replogle_perturbations.groupby("experiment").sum()) #tf only
replogle_perturbations.query("experiment=='CRISPRa multiplex' & is_tf")["gene"].unique()

### Load expression data & set up cell metadata

In [None]:
# Reading from .mtx is slow, so we memoize to h5ad. 
if not os.path.exists(f"not_ready/{dataset_name}/GSM4367986_exp8/overall.h5ad.gzip"):
    expression_quantified = sc.read_10x_mtx(f"not_ready/{dataset_name}/GSM4367986_exp8/")
    expression_quantified.write_h5ad(f"not_ready/{dataset_name}/GSM4367986_exp8/overall.h5ad.gzip", compression="gzip")
else:
    expression_quantified = sc.read_h5ad(f"not_ready/{dataset_name}/GSM4367986_exp8/overall.h5ad.gzip")

In [None]:
sc.pp.calculate_qc_metrics(expression_quantified, inplace = True)
cell_metadata = pd.read_csv(f"not_ready/{dataset_name}/GSM4367986_exp8/cell_identities.csv.gz")
cell_metadata.index = cell_metadata["cell_barcode"]
cell_metadata["target_symbol"] = ingestion.convert_ens_to_symbol(
    cell_metadata["target"], 
    gtf=geneAnnotationPath, 
    strip_version = True)
cell_metadata["target_is_tf"] = cell_metadata["target_symbol"].isin(human_tfs["HGNC symbol"]).astype("int")
expression_quantified.obs = pd.merge(expression_quantified.obs,
                                     cell_metadata, 
                                     how = "left", 
                                     left_index = True, 
                                     right_index = True)

expression_quantified.obs["has_guide_annotations"] = pd.notnull(expression_quantified.obs["coverage"]).astype("int")
expression_quantified.obs["good_coverage"] = expression_quantified.obs["good_coverage"].astype("str")
expression_quantified.raw = expression_quantified.copy()

In [None]:
expression_quantified = expression_quantified[~expression_quantified.obs.target_symbol.isna(), :]

### How many cells do we have for each target?

In [None]:
n_cells_per_target = expression_quantified.obs.groupby("target_symbol")[["cell_barcode"]].count().sort_values("cell_barcode")
n_cells_per_target.columns = ["n_cells"]
n_cells_per_target["is_tf"] = n_cells_per_target.index.isin(human_tfs["HGNC symbol"])
n_cells_per_target.reset_index(inplace = True)
display(n_cells_per_target)
plt.rcParams['figure.figsize'] = [6, 9]
sns.barplot(data = n_cells_per_target, x = "n_cells", y = "target_symbol", hue = "is_tf").set_title("CRISPRa in K562")

### Convert ensembl gene id's to gene symbol

In [None]:
expression_quantified.var_names = ingestion.convert_ens_to_symbol(
    expression_quantified.var_names, 
    gtf=geneAnnotationPath, 
)
display(expression_quantified.var.head())
display(expression_quantified.var_names[0:5])

### Fill `perturbation` & `Is_control`

In [None]:
controls = ["Non-Targeting"]
expression_quantified.obs["perturbation"]     = expression_quantified.obs['target_symbol']
expression_quantified.obs["is_control"]       = expression_quantified.obs['target_symbol'].isin(controls)
expression_quantified.obs["is_control_int"]   = expression_quantified.obs['target_symbol'].isin(controls).astype(int)

In [None]:
expression_quantified

### Single-cell RNA standard filters

We prefer to err on the side of discarding real cells rather than risking inclusion of empty droplets, doublets, or other artifacts in our analysis. Out of an abundance of caution, we remove droplets with especially high or low total counts, and we remove droplets with high fractions of RNA from mitochondrial RNA's, ribosomal protein subunit RNA's, or high-expressed RNA's. With apologies, please RTFS below for exact thresholds and numbers. 

In [None]:
sc.pl.highest_expr_genes(expression_quantified, n_top=30, palette="Blues", width=.3)

In [None]:
expression_quantified.var['mt']   = expression_quantified.var_names.str.startswith(("MT-"))
expression_quantified.var['ribo'] = expression_quantified.var_names.str.startswith(("RPS","RPL"))
expression_quantified.var['mt'].sum(), expression_quantified.var['ribo'].sum(), 

In [None]:
sc.pp.calculate_qc_metrics(expression_quantified, qc_vars=['ribo', 'mt'], percent_top=None, log1p=False, inplace=True)

In [None]:
axs = sc.pl.violin(expression_quantified, ['n_genes_by_counts', 
                                           'total_counts', 
                                           'pct_counts_mt', 
                                           'pct_counts_ribo', 
                                           'pct_counts_in_top_50_genes'], 
                   jitter=0.5, multi_panel=True)

In [None]:
fig, ax = plt.subplots(1,1,figsize=(2,2))
sc.pl.scatter(expression_quantified, x='total_counts', y='n_genes_by_counts', ax=ax)

In [None]:
print("Number of cells: ", expression_quantified.n_obs)

# figure out the total counts == 95 percentile
thresh = np.percentile(expression_quantified.obs['total_counts'], 99)
print("99th percentile: ", thresh)

In [None]:
expression_quantified = expression_quantified[expression_quantified.obs['total_counts'] < thresh, :].copy()
print("Number of cells: ", expression_quantified.n_obs)

In [None]:
expression_quantified = expression_quantified[expression_quantified.obs["total_counts"] >= 2000, :].copy()
print("Number of cells: ", expression_quantified.n_obs)

In [None]:
expression_quantified = expression_quantified[expression_quantified.obs["pct_counts_in_top_50_genes"] <= 40, :].copy()
print("Number of cells: ", expression_quantified.n_obs)

In [None]:
# filter for % mt
expression_quantified = expression_quantified[expression_quantified.obs['pct_counts_mt'] < 20, :].copy()
print("Number of cells: ", expression_quantified.n_obs)

In [None]:
# filter for % ribo > 50%
expression_quantified = expression_quantified[expression_quantified.obs['pct_counts_ribo'] < 30, :].copy()
print("Number of cells: ", expression_quantified.n_obs)

In [None]:
""" To verify the outcome of filtering cells """
sc.pp.calculate_qc_metrics(expression_quantified, qc_vars=['ribo', 'mt'], percent_top=None, log1p=False, inplace=True)

In [None]:
axs = sc.pl.violin(expression_quantified, ['n_genes_by_counts', 
                                           'total_counts', 
                                           'pct_counts_mt', 
                                           'pct_counts_ribo', 
                                           'pct_counts_in_top_50_genes'], 
                   jitter=0.4, multi_panel=True)

In [None]:
fig, ax = plt.subplots(1,1,figsize=(2,2))
sc.pl.scatter(expression_quantified, x='total_counts', y='n_genes_by_counts', ax=ax)

In [None]:
print("Number of genes: ", expression_quantified.n_vars)
gThresh = 10

sc.pp.filter_genes(expression_quantified, min_cells=gThresh)
print("Number of genes: ", expression_quantified.n_vars)

In [None]:
""" Specifically rescuing the perturbed genes """
rows = [np.where(expression_quantified.var_names == p)[0] 
        for p in set(expression_quantified.obs.perturbation) 
        if p in expression_quantified.var_names]
perturbedKeep = np.full(expression_quantified.n_vars, False)
perturbedKeep[rows] = True
np.sum(perturbedKeep)

In [None]:
mito_genes = expression_quantified.var_names.str.startswith('MT-')
ribo_genes = expression_quantified.var_names.str.startswith(("RPL","RPS"))
malat_gene = expression_quantified.var_names.str.startswith("MALAT1")

In [None]:
remove = np.add(mito_genes, ribo_genes)
remove = np.add(remove, malat_gene)
keep = np.invert(remove)
keep = keep | perturbedKeep
expression_quantified = expression_quantified[:,keep].copy()
print("Number of genes: ", expression_quantified.n_vars)

In [None]:
[idx for idx, n in enumerate(expression_quantified.var.index) if n.upper() == 'GAPDH']

In [None]:
sc.pl.highest_expr_genes(expression_quantified, n_top=20, palette="Blues", width=.3)

In [None]:
""" To verify the outcome of filtering genes """
sc.pp.calculate_qc_metrics(expression_quantified, qc_vars=['ribo', 'mt'], percent_top=None, log1p=False, inplace=True)

In [None]:
axs = sc.pl.violin(expression_quantified, ['n_genes_by_counts', 
                                           'total_counts', 
                                           'pct_counts_mt', 
                                           'pct_counts_ribo', 
                                           'pct_counts_in_top_50_genes'], 
                   jitter=0.4, multi_panel=True)

In [None]:
fig, ax = plt.subplots(1,1,figsize=(2,2))
sc.pl.scatter(expression_quantified, x='total_counts', y='n_genes_by_counts', ax=ax)

### Basic EDA 

We supply some basic exploratory plots.

In [None]:
# When we do pseudo-bulk aggregation, we will want "raw" counts (not normalized), 
# but after applying the above filters. So we re-save the .raw attribute now.
expression_quantified.raw = expression_quantified.copy()
sc.pp.log1p(expression_quantified)
sc.pp.highly_variable_genes(expression_quantified, flavor = "seurat_v3", n_top_genes=expression_quantified.var.shape[0])
sc.pl.highly_variable_genes(expression_quantified)
with warnings.catch_warnings():
    sc.tl.pca(expression_quantified, n_comps=100)
sc.pp.neighbors(expression_quantified)
sc.tl.umap(expression_quantified)
clusterResolutions = []
sc.tl.leiden(expression_quantified)
cc_genes = pd.read_csv(cellcycleGenePath, header = None)[0]
sc.tl.score_genes_cell_cycle(expression_quantified, s_genes=cc_genes[:43], g2m_genes=cc_genes[43:])
plt.rcParams['figure.figsize'] = [6, 4.5]
sc.pl.umap(expression_quantified, color = [
    # "PTPRC",
    "leiden", 
    "is_control_int",
    "perturbation",
    'total_counts', 
    'log1p_total_counts',
    'pct_counts_in_top_50_genes', 
    'has_guide_annotations',
])
# Will ask CellOracle to use only one cluster.
# This requires setting certain other undocumented aspects of object state. :(
expression_quantified.obs["fake_cluster"]="all_one_cluster"
expression_quantified.obs.fake_cluster = expression_quantified.obs.fake_cluster.astype("category")
expression_quantified.uns["fake_cluster_colors"] = ['#1f77b4']

### Aggregate For Pseudo-Bulk

In [None]:
pseudobulk = ingestion.aggregate_by_perturbation(expression_quantified, group_by = ['target_symbol', 'type'])

In [None]:
rows = [np.where(expression_quantified.var_names == p)[0] 
        for p in set(expression_quantified.obs.perturbation) 
        if p in expression_quantified.var_names]
perturbedKeep = np.full(expression_quantified.n_vars, False)
perturbedKeep[rows] = True
print(f"{np.sum(perturbedKeep)} columns to keep (perturbed genes)")
retainColumn = np.max(pseudobulk.X[~pseudobulk.obs.is_control], axis=0) > 100
retainColumn = retainColumn | perturbedKeep
pseudobulk = pseudobulk[:, retainColumn].copy()
print(f"{pseudobulk.shape} is post-filtering shape.")

### Normalization on pseudobulk

In [None]:
pseudobulk.raw = pseudobulk

In [None]:
pseudobulk.X = ingestion.deseq2Normalization(pseudobulk.X.T).T

### Visualize Normalization Effort

In [None]:
z2 = pseudobulk.copy()

In [None]:
""" Sanity check: expression for house keeping genes are relatively stable """
fig, axes = plt.subplots(1, 2, figsize=(8,2))
axes[0].hist(z2.X[:, [idx for idx, n in enumerate(pseudobulk.var.index) if n.upper() == 'ACTB']], bins=100, label="ACTB")
axes[1].hist(z2.X[:, [idx for idx, n in enumerate(pseudobulk.var.index) if n.upper() == 'GAPDH']], bins=100, label="GAPDH")
axes[0].legend()
axes[1].legend()
plt.suptitle("Expression across pseudobulk samples")
plt.show()

In [None]:
""" The sum of gene expression before and after normalization """
fig, axes = plt.subplots(1, 2, figsize=(12,3))
axes[0].hist(pseudobulk.raw.X.sum(axis=1), bins=100, log=True, label="before DESeq2 norm")
axes[1].hist(pseudobulk    .X.sum(axis=1), bins=100, log=True, label= "after DESeq2 norm")
axes[0].legend()
axes[1].legend()
plt.show()

### Check Consistency between perturbation and measured

In [None]:
# If verbose is set to True, display disconcordant trials and their controls
status, logFC = ingestion.checkConsistency(pseudobulk, 
                                           perturbationType="overexpression", 
                                           group=None,
                                           verbose=False) 
pseudobulk.obs["consistentW/Perturbation"] = status
pseudobulk.obs["logFC"] = logFC
Counter(status)

### Final decision on QC filtering

Remove guide combinations that appear not to overexpress the targeted gene.

In [None]:
pseudobulk_filtered = pseudobulk[pseudobulk.obs['consistentW/Perturbation'] != 'No'].copy()

### Check Consistency between replications

In [None]:
correlations = ingestion.computeCorrelation(pseudobulk_filtered, verbose=True)
pseudobulk_filtered.obs["spearmanCorr"] = correlations[0]
pseudobulk_filtered.obs[" pearsonCorr"] = correlations[1]

pseudobulk = pseudobulk_filtered.copy()

### Compute the Magnitude of Perturbation Effect

In [None]:
"""
Downloaded from http://humantfs.ccbr.utoronto.ca/download.php """
TFList = pd.read_csv(humanTFPath, index_col=0).iloc[:, [1,3]]
TFDict = dict([tuple(i) for i in TFList.to_numpy().tolist() if i[1] == 'Yes'])

"""
Downloaded from https://epifactors.autosome.org/description """
EpiList = pd.read_csv(humanEpiPath, index_col=0).iloc[:, [0,14]]
EpiDict = dict([tuple(i) for i in EpiList.to_numpy().tolist()])

### How big are the effects?

We compute several different measures to get a sense for the overall strength of each effect. They are all well correlated. 

In [None]:
""" If want to look at bigness on TF only """
TFVar = [i for i,p in enumerate(pseudobulk.var.index) if p in TFDict or p in EpiDict]
pseudobulkTFOnly = pseudobulk[:, TFVar].copy()
ingestion.quantifyEffect(adata=pseudobulkTFOnly, 
                         fname=perturbEffectTFOnlyPath, 
                         group=None, 
                         diffExprFC=False, 
                         prefix="TFOnly")

ingestion.quantifyEffect(adata=pseudobulk, 
                         fname=perturbEffectFullTranscriptomePath, 
                         group=None,
                         diffExprFC=False, 
                         prefix="")

listOfMetrics = ["DEG", "MI", "logFCMean", "logFCNorm2", "logFCMedian"]
for m in listOfMetrics:
    pseudobulk.obs[f"TFOnly{m}"] = pseudobulkTFOnly.obs[f"TFOnly{m}"]

In [None]:
metricOfInterest = ["MI", "logFCMean", "logFCNorm2", "logFCMedian", 
                    "TFOnlyMI", "TFOnlylogFCMean", "TFOnlylogFCNorm2", "TFOnlylogFCMedian"]
ingestion.checkPerturbationEffectMetricCorrelation(pseudobulk, metrics=metricOfInterest)

In [None]:
ingestion.visualizePerturbationEffect(pseudobulk, metrics=metricOfInterest, TFDict=TFDict, EpiDict=EpiDict)

In [None]:
sorted(set([(i,j) for i,j in pseudobulk.obs[['perturbation', 'logFCNorm2']].to_numpy()]), key=lambda x: x[1])

In [None]:
temp = pseudobulk.copy()

""" If wish to see more clearer, by masking the ones with 
much higher logFC norm2 values """
# temp = pseudobulk[(pseudobulk.obs.perturbation != 'CCDC51') 
#                   & (pseudobulk.obs.perturbation != 'HSPD1') 
#                   & (pseudobulk.obs.perturbation != 'SPI1') 
#                   & (pseudobulk.obs.perturbation != 'CEBPB')
#                  ].copy()

""" If you wish to see the magnitude of perturbation effect more clearer,
    i.e. a smoother gradient of the color shift, feel free to uncomment
    the line below, which takes the log of the norm2 """
temp.obs['logFCNorm2 (log-scale)'] = np.log2(temp.obs['logFCNorm2'])

ingestion.visualizePerturbationMetadata(temp, 
                                        x="spearmanCorr", 
                                        y="logFC", 
                                        style="consistentW/Perturbation", 
                                        hue="logFCNorm2 (log-scale)", 
                                        markers=['o', '^'], 
                                        xlim=[-0.1, 0.8])

### Basic EDA

What does the final, fully filtered set of pseudo-bulk profiles look like?

In [None]:
sc.pp.log1p(pseudobulk)
with warnings.catch_warnings():
    sc.tl.pca(pseudobulk, n_comps=100)
sc.pp.neighbors(pseudobulk)
sc.tl.umap(pseudobulk)
clusterResolutions = []
sc.tl.leiden(pseudobulk)
cc_genes = pd.read_csv(cellcycleGenePath, header = None)[0]
sc.tl.score_genes_cell_cycle(pseudobulk, s_genes=cc_genes[:43], g2m_genes=cc_genes[43:])

In [None]:
plt.rcParams['figure.figsize'] = [6, 4.5]
sc.pl.umap(pseudobulk, color = [
    # "PTPRC",
    "leiden", 
    "is_control_int",
    "perturbation",
])

In [None]:
pseudobulk

In [None]:
perturbed_genes = set(list(pseudobulk.obs['perturbation'].unique())).difference(controls)
perturbed_and_measured_genes = perturbed_genes.intersection(pseudobulk.var.index)
perturbed_but_not_measured_genes = perturbed_genes.difference(pseudobulk.var.index)
genes_keep = pseudobulk.var.index[pseudobulk.var['highly_variable']]
genes_keep = set(genes_keep).union(perturbed_and_measured_genes)
print("These genes were perturbed:")
print(perturbed_genes)
print("These genes were perturbed but not measured:")
print(perturbed_but_not_measured_genes)


In [None]:
# final form, ready to save
pseudobulk.uns["perturbed_and_measured_genes"]     = list(perturbed_and_measured_genes)
pseudobulk.uns["perturbed_but_not_measured_genes"] = list(perturbed_but_not_measured_genes)
pseudobulk = ingestion.describe_perturbation_effect(pseudobulk, "overexpression")

In [None]:
os.makedirs(f"perturbations/{dataset_name}", exist_ok = True)
pseudobulk.write_h5ad(f"perturbations/{dataset_name}/test.h5ad")