## Tutorial notebook of training with CRISP

In this notebook, we take NeurIPS as example to show how to train CRISP with measured perturbation single cell RNA-seq dataset. \
In practice, considering the large scale amount of single cell training data and high dimensional gene features, we recommend user to train it with shell script, only 1 GPU node is enough.

In [1]:
from CRISP.trainer import Trainer
import scanpy as sc
import torch
import pandas as pd

In [3]:
# load anndata
adata = sc.read('adata_pp_filtered_scFM_resplit.h5ad')

In [4]:
dataset_params = {
    'perturbation_key':'condition', # key with drug name
    'dose_key': 'dose_val', # key with dosage info
    'smiles_key': 'SMILES', # key with drug SMILES
    'celltype_key': 'cell_type', # key with cell types
    'FM_key': 'X_scGPT', # key with scFM embeddings
    'control_key': 'neg_control', # key with is_control info (control: 1, treated: 0)
    'pc_cov': 'type_donor', # key with covariate to identify paired control group
    'degs_key': "rank_genes_groups_cov",  # DE genes dict 
    'pert_category': "cov_drug_name", # covariate combination used to grouped for evaluation
    'split_ood': True, # whether evaluate for ood subset or not
    'split_key': "split", # key with split label info
    'seed': 1327, # random seed
}

In [5]:
exp = Trainer()

In [None]:
# initialize dataset
exp.init_dataset(adata_obj=adata,**dataset_params)

In [7]:
# load smiles embeddings dataframe and init drug embeddings
chem_df = pd.read_parquet('../data/drug_embeddings/rdkit2D_embedding_lincs_nips.parquet')
exp.init_drug_embedding(chem_model='rdkit',chem_df=chem_df)

In [9]:
# init model
device = "cuda" if torch.cuda.is_available() else "cpu"
exp.init_model(
    hparams='',
    seed=1337,
)
exp.load_train()

In [10]:
train_params = {
    'checkpoint_freq': 51, # frequency to run evaluate
    'num_epochs': 51, 
    'max_minutes': 1000,
    'save_dir': '../experiments/results/nips_test',
}

In [11]:
# training
results = exp.train(**train_params)

100%|██████████| 51/51 [39:34<00:00, 46.56s/it]


In [12]:
# evaluation results for ood subset
results['ood']

[{'r2score': 0.9322444459834656,
  'r2score_de': 0.23656416248965573,
  'pearson': 0.9696602225764862,
  'pearson_de': 0.44427059194645313,
  'mse': 0.07321457,
  'mse_de': 0.45216295,
  'pearson_delta': 0.40612612374275175,
  'pearson_delta_de': 0.6822232767848168,
  'sinkhorn_de': 16.342594424625496}]