# Stage 1: Data preprocessing

In this tutorial, we'll first walk through how to prepare the datasets for use in
CASCADE, using the [Norman, et al. (2019)](https://doi.org/10.1126/science.aax4438)
dataset as an example. This dataset contains single- and double-gene CRISPRa
perturbations.

In [1]:
import networkx as nx
import numpy as np
import pandas as pd
import scanpy as sc
from sklearn.preprocessing import OneHotEncoder, StandardScaler

from cascade.data import (
    configure_dataset,
    encode_regime,
    filter_unobserved_targets,
    get_all_targets,
    get_configuration,
    neighbor_impute,
)
from cascade.graph import assemble_scaffolds

## Read data

First, we need to prepare the dataset into `AnnData` objects. See the
[documentation](https://anndata.readthedocs.io/) for more details if you are
unfamiliar, including how to construct `AnnData` objects from scratch, and how
to read data in other formats (csv, mtx, loom, etc.) into `AnnData` objects.

Here we just load existing `h5ad` files, which is the native file format for
`AnnData`. The `h5ad` file used in this tutorial can be downloaded from here:

- http://ftp.cbi.pku.edu.cn/pub/cascade-download/Norman-2019.h5ad

In [2]:
adata = sc.read_h5ad("Norman-2019.h5ad")
adata

AnnData object with n_obs × n_vars = 86744 × 22881
    obs: 'guide_id', 'gemgroup', 'ncounts', 'knockup'
    var: 'perturbed'

## Data format requirements

CASCADE requires the following data format:

- Raw counts in `adata.X`;
- Total count in `adata.obs`, which would be used when fitting data with
  the negative binomial distribution;
- HGNC gene symbols as `adata.var_names`;
- Perturbation label in `adata.obs` that specifies which genes are perturbed
  in each cell:
  - For control cells with no perturbation, the value **MUST** be an empty
    string `""`;
  - For cells with multiple perturbations, the perturbed genes should be
    comma-separated, e.g., `"CEBPB,KLF1"`;
  - Name of perturbed genes must match those in `adata.var_names`.

In this case, we can verify that the expression matrix contains raw counts:

In [3]:
adata.X, adata.X.data

(<Compressed Sparse Row sparse matrix of dtype 'float32'
 	with 268281595 stored elements and shape (86744, 22881)>,
 array([ 1.,  1.,  1., ..., 12.,  3., 16.], dtype=float32))

The total counts are stored as `"ncounts"` in `adata.obs`:

In [4]:
adata.obs["ncounts"]

TTGAACGAGACTCGGA      15097.0
CGTTGGGGTGTTTGTG       8551.0
GAACCTAAGTGTTAGA      10999.0
CCTTCCCTCCGTCATC      38454.0
TCCCGATGTCTCTTAT      21433.0
                       ...   
TTTCCTCGTACGCACC      11991.0
TTTCCTCTCTTGCCGT      16561.0
TTTGCGCAGTCATGCT       5192.0
TTTGCGCCAGGACCCT      15704.0
TTTGCGCTCTCGCATC-1     6825.0
Name: ncounts, Length: 86744, dtype: float32

And perturbation labels are stored as `"knockup"` in `adata.obs`:

In [5]:
adata.obs["knockup"]

TTGAACGAGACTCGGA            ARID1A
CGTTGGGGTGTTTGTG            BCORL1
GAACCTAAGTGTTAGA              FOSB
CCTTCCCTCCGTCATC          KLF1,SET
TCCCGATGTCTCTTAT         BAK1,KLF1
                          ...     
TTTCCTCGTACGCACC                  
TTTCCTCTCTTGCCGT               HK2
TTTGCGCAGTCATGCT            RHOXF2
TTTGCGCCAGGACCCT      BAK1,BCL2L11
TTTGCGCTCTCGCATC-1      CEBPB,OSR2
Name: knockup, Length: 86744, dtype: category
Categories (237, object): ['', 'AHR', 'AHR,FEV', 'AHR,KLF1', ..., 'ZBTB10', 'ZBTB25', 'ZC3HAV1', 'ZNF318']

Before any further processing, we back up the raw UMI counts in a layer called
`“counts”`, which will be used later during model training.

In [6]:
adata.layers["counts"] = adata.X.copy()

## Cell and gene selection

Since CASCADE can only model perturbations in measured genes, we first filter out
any perturbation that was missing from the readout. A utility function called
[filter_unobserved_targets](api/cascade.data.filter_unobserved_targets.rst) is
provided for this purpose.

In this case no cell was filtered:

In [7]:
filter_unobserved_targets(adata, "knockup")

View of AnnData object with n_obs × n_vars = 86744 × 22881
    obs: 'guide_id', 'gemgroup', 'ncounts', 'knockup'
    var: 'perturbed'
    layers: 'counts'

Next, we identify highly variable genes using the `"seurat_v3"` method,
to allow the model to focus on informative genes:

In [8]:
sc.pp.highly_variable_genes(adata, n_top_genes=1000, flavor="seurat_v3")

Again, as CASCADE can only model perturbations in measured genes, we expand this
highly variable gene set to incorporating all perturbed genes (via a utility
function [get_all_targets](api/cascade.data.get_all_targets.rst)) to avoid
discarding useful perturbation information:

In [9]:
all_targets = get_all_targets(adata, key="knockup")
all_targets

AHR,ARID1A,ARRDC3,ATL1,BAK1,BCL2L11,BCORL1,BPGM,CBARP,CBFA2T3,CBL,CDKN1A,CDKN1B,CDKN1C,CEBPA,CEBPB,CEBPE,CELF2,CITED1,CKS1B,CLDN6,CNN1,CNNM4,COL1A1,COL2A1,CSRNP1,DLX2,DUSP9,EGR1,ETS2,FEV,FOSB,FOXA1,FOXA3,FOXF1,FOXL2,FOXL2NB,FOXO4,GLB1L2,HES7,HK2,HNF4A,HOXA13,HOXB9,HOXC13,IER5L,IGDCC3,IKZF3,IRF1,ISL2,JUN,KIF18B,KIF2C,KLF1,KMT2A,LHX1,LYL1,MAML2,MAP2K3,MAP2K6,MAP3K21,MAP4K3,MAP4K5,MAP7D1,MAPK1,MEIS1,MIDEAS,MIDN,NCL,NIT1,OSR2,PLK4,POU3F2,PRDM1,PRTG,PTPN1,PTPN12,PTPN13,PTPN9,RHOXF2,RREB1,RUNX1T1,S1PR2,SAMD1,SET,SGK1,SLC38A2,SLC4A1,SLC6A9,SNAI1,SPI1,STIL,TBX2,TBX3,TGFBR2,TMSB4X,TP73,TSC22D1,UBASH3A,UBASH3B,ZBTB1,ZBTB10,ZBTB25,ZC3HAV1,ZNF318

In [10]:
adata.var["selected"] = adata.var["highly_variable"] | adata.var_names.isin(all_targets)
adata.var["selected"].sum()

1064

## Encode intervention regime

CASCADE represents genetic perturbations as a cell-by-gene binary matrix,
which can be encoded from the `adata.obs["knockup"]` column using the
[encode_regime](api/cascade.data.encode_regime.rst) function. The function
stores the encoded regime matrix in a layer with user-specified name,
here using the name `"interv"`.

In [11]:
encode_regime(adata, "interv", key="knockup")
adata.layers["interv"]

<Compressed Sparse Row sparse matrix of dtype 'bool'
	with 108992 stored elements and shape (86744, 22881)>

## Encode technical covariates

To minimize the effect of technical confounding on the causal discovery process,
it is recommended to add all possible confounding factors into a covariate matrix
in `adata.obsm`.

Here we will add the one-hot encoded batch label (`"gemgroup"`) and log-centered
total counts as the covariate:

In [12]:
batch = OneHotEncoder().fit_transform(adata.obs[["gemgroup"]]).toarray()
batch

array([[0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 1., 0., 0.],
       ...,
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       [0., 0., 0., ..., 1., 0., 0.]])

In [13]:
log_ncounts = StandardScaler().fit_transform(np.log10(adata.obs[["ncounts"]]))
log_ncounts

array([[ 0.46084866],
       [-0.8991263 ],
       [-0.29681763],
       ...,
       [-2.092782  ],
       [ 0.5551557 ],
       [-1.4385142 ]], dtype=float32)

In [14]:
adata.obsm["covariate"] = np.concatenate([batch, log_ncounts], axis=1)
adata.obsm["covariate"].shape

(86744, 9)

## Data normalization

Next, we follow the standard scRNA-seq preprocessing approach in `scanpy`
to normalize the expression matrix in `adata.X`. You may visit its
[documentation](https://scanpy.readthedocs.io/) if unfamiliar.

In [15]:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

Now we can subset the dataset to retain the selected genes only:

In [16]:
adata = adata[:, adata.var["selected"]].copy()
adata

AnnData object with n_obs × n_vars = 86744 × 1064
    obs: 'guide_id', 'gemgroup', 'ncounts', 'knockup'
    var: 'perturbed', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'selected'
    uns: 'hvg', 'log1p'
    obsm: 'covariate'
    layers: 'counts', 'interv'

## Neighbor-based imputation

Given that scRNA-seq data can be sparse, we recommend conducting a lightweight
neighbor-based data imputation before model training. This is done by aggregating
similar cells in the PCA space with the same perturbation. We provide a utility
function called [neighbor_impute](api/cascade.data.neighbor_impute.rst) for this
purpose:

In [17]:
sc.pp.pca(adata)

In [18]:
adata = neighbor_impute(
    adata,
    k=20,
    use_rep="X_pca",
    use_batch="knockup",
    X_agg="mean",
    obs_agg={"ncounts": "sum"},
    obsm_agg={"covariate": "mean"},
    layers_agg={"counts": "sum"},
)

Note that we used the `"sum"` aggregation for `adata.obs["ncounts"]` and
`adata.layers["counts"]`, which maintains their count-based nature.

## Configure dataset

Now we can use the function
[configure_dataset](api/cascade.data.configure_dataset.rst) to tell CASCADE
where the expression matrix, intervention regime, covariates and total counts
are stored:

In [19]:
configure_dataset(
    adata,
    use_regime="interv",
    use_covariate="covariate",
    use_size="ncounts",
    use_layer="counts",
)
get_configuration(adata)

{'regime': 'interv',
 'covariate': 'covariate',
 'size': 'ncounts',
 'weight': None,
 'layer': 'counts'}

## Construct scaffold graph

Next, we need to construct a scaffold graph to guide the causal discovery process.

The following 4 pre-built human gene scaffolds are available for download:

- KEGG pathways:
  http://ftp.cbi.pku.edu.cn/pub/cascade-download/inferred_kegg_gene_only.gml.gz
- TF-target predictions:
  http://ftp.cbi.pku.edu.cn/pub/cascade-download/TF-target.gml.gz
- BioGRID protein-protein interactions:
  http://ftp.cbi.pku.edu.cn/pub/cascade-download/biogrid.gml.gz
- GTEx correlated genes:
  http://ftp.cbi.pku.edu.cn/pub/cascade-download/corr.gml.gz

In [20]:
kegg = nx.read_gml("inferred_kegg_gene_only.gml.gz")
tf_target = nx.read_gml("TF-target.gml.gz")
biogrid = nx.read_gml("biogrid.gml.gz")
corr = nx.read_gml("corr.gml.gz")

The individual scaffold components can be assembled into a hybrid scaffold using
the [assemble_scaffolds](api/cascade.graph.assemble_scaffolds.rst) function,
which also marginalizes these components with regard to the genes being
modeled here:

In [21]:
scaffold = assemble_scaffolds(corr, biogrid, tf_target, kegg, nodes=adata.var_names)
scaffold.number_of_nodes(), scaffold.number_of_edges()



(1064, 32264)

## Prepare gene function embeddings

Lastly, we also fetch relevant entries from gene embeddings pre-computed using
LSI of their GO annotations, which can be downloaded here:

- http://ftp.cbi.pku.edu.cn/pub/cascade-download/gene2gos_lsi.csv.gz

This will serve as the input of the interventional latent variable in CASCADE:

In [22]:
latent_emb = pd.read_csv("gene2gos_lsi.csv.gz", index_col=0)
latent_emb = latent_emb.reindex(adata.var_names).dropna()
latent_emb.shape

(866, 32)

## Save processed data files

Finally, save the preprocessed data files for use in [stage 2](training.ipynb).

In [23]:
adata.write("adata.h5ad", compression="gzip")

In [24]:
nx.write_gml(scaffold, "scaffold.gml.gz")

In [25]:
latent_emb.to_csv("latent_emb.csv.gz")

## Afterwords

Described above is the minimal preprocessing for running CASCADE. Additional
steps such as filtering non-perturbed cells using
[mixscape](https://pertpy.readthedocs.io/en/latest/tutorials/notebooks/mixscape.html)
may also be useful depending on the data at hand.