# NeurIPS Single-Cell Perturbation Response Prediction
This notebook demonstrates the data preprocessing pipeline for the NeurIPS single-cell perturbation response prediction task. It includes:
1. Data loading and AnnData object creation
2. Preprocessing of perturbation data
3. Train-test-OOD (out-of-distribution) splitting for evaluation

## Required Libraries
The following libraries are needed to run this notebook:

In [1]:
import pandas as pd
import scanpy as sc
from CRISP.utils import rank_genes_groups_by_cov
import numpy as np
import CRISP.scFM as scFM

### Required Input Files:

1. **adata_obs_meta.csv**: Contains metadata for each observation (cell), including:

    - `obs_id`: Unique identifier for each cell
    - `cell_type`: Cell type classification (e.g., 'T cells CD4+', 'B cells')
    - `donor_id`: Donor identifier 
    - `sm_name`: Small molecule (drug) name
    - `dose_uM`: Drug concentration in µM
    - `control`: Binary indicator for control samples
    - `SMILES`: SMILES notation representing drug molecular structure

2. **adata_train.parquet**: Contains gene expression data with columns:
    - `obs_id`: Identifier linking to metadata
    - `gene`: Gene identifier
    - `normalized_count`: Normalized expression value

3. **de_train.parquet**: Contains differential expression data, with genes as columns (column 5 onwards).

The final AnnData object will have the following structure:
- `adata.X`: Gene expression matrix (cells × genes)
- `adata.obs`: Cell metadata with additional computed fields
- `adata.var`: Gene metadata

Critical fields in the processed data include:
- `condition`: Drug name (alphanumeric only)
- `cell_type`: Type of cell
- `neg_control`: Binary indicator for negative control (DMSO treatment)
- `dose_val`: Normalized drug concentration

Let's start by loading the raw data:

In [None]:
# raw data can be downloaded in NeurIPS competition website: https://www.kaggle.com/competitions/open-problems-single-cell-perturbations/data?select=sample_submission.csv
obs_meta = pd.read_csv('raw/adata_obs_meta.csv')
adata_train = pd.read_parquet('raw/adata_train.parquet')
de_train = pd.read_parquet('raw/de_train.parquet')

## Creating AnnData Object
We'll now process the raw data to create an AnnData object, which is the standard format for single-cell analysis in the scanpy ecosystem. The process involves:
1. Creating indices for mapping cells and genes
2. Constructing a sparse matrix of gene expression values
3. Building a complete AnnData object with metadata

In [None]:
# Create dictionaries to map observation IDs and gene names to indices
obs_idx_dict = dict(zip(obs_meta['obs_id'].values,list(obs_meta.index)))
gene_idx_dict = dict(zip(list(de_train.columns)[5:],list(range(len(list(de_train.columns)[5:])))))

In [None]:
# Map observation IDs to indices
adata_train['obs_idx'] = adata_train['obs_id'].apply(lambda i: obs_idx_dict[i])
# Map gene names to indices, handling cases where genes might not be in the dictionary
def map_gene_idx(i):
    try:
        a = gene_idx_dict[i]
    except:
        a = None
    return a

In [None]:
# Apply gene index mapping and filter out genes not in the dictionary
adata_train['gene_idx'] = adata_train['gene'].apply(map_gene_idx)
adata_train_sub = adata_train[~adata_train['gene_idx'].isna()]
adata_train_sub['gene_idx'] = adata_train_sub['gene_idx'].astype(int)

In [None]:
# Create a sparse matrix from the normalized counts
from scipy.sparse import coo_matrix
from anndata import AnnData
sparse_matrix = coo_matrix((adata_train_sub['normalized_count'], (adata_train_sub['obs_idx'], adata_train_sub['gene_idx'])))
adata = AnnData(X=sparse_matrix)

In [None]:
# Add metadata to the AnnData object
adata.obs = obs_meta
adata.obs_names = obs_meta['obs_id'].values
adata.X = adata.X.tocsc()
adata.var_names = list(gene_idx_dict.keys())
adata.var['gene_id'] = list(gene_idx_dict.keys())

In [None]:
sc.write('adata_pp.h5ad',adata)

## Preprocessing the Data

Now we'll preprocess the data by:
1. Cleaning and standardizing drug names
2. Creating normalized dose values
3. Generating composite covariates for cell type, drug, and dose combinations
4. Setting control indicators

These preprocessing steps create several critical fields required for the drug response prediction task:

In [None]:
adata = sc.read('adata_pp.h5ad')

In [None]:
import re

# Function to clean drug names by removing non-alphanumeric characters
def remove_non_alphanumeric(input_string):
    return re.sub(r'[^a-zA-Z0-9]', '', input_string)

# Create and standardize required fields
# 1. Clean drug condition names
adata.obs['condition'] = adata.obs['sm_name']
adata.obs['condition'] = adata.obs['condition'].apply(remove_non_alphanumeric)
adata.obs['condition'] = adata.obs['condition'].replace('DimethylSulfoxide','DMSO')

# 2. Normalize dose values to [0,1] range
adata.obs['dose_val'] = adata.obs['dose_uM'].astype(float) / np.max(adata.obs['dose_uM'].astype(float))

# 3. Create composite covariates for analysis
# This combines cell type, drug, and dose information into single identifiers
adata.obs['cov_drug_dose_name'] = adata.obs.cell_type.astype(str) + '_' + adata.obs.condition.astype(str) + '_' + adata.obs.dose_val.astype(str)
adata.obs['cov_drug_name'] = adata.obs.cell_type.astype(str) + '_' + adata.obs.condition.astype(str)
adata.obs['eval_category'] = adata.obs['cov_drug_name']

# 4. Convert control indicators to integers
adata.obs['control'] = adata.obs['control'].astype(int)

# 5. Create additional identifiers for different analysis levels
adata.obs['drug_dose_name'] = adata.obs.condition.astype(str) + '_' + adata.obs.dose_val.astype(str)
adata.obs['neg_control'] = (adata.obs['condition']=='DMSO').astype(int)

### Filtering by Covariate Frequency

We'll filter out drug-cell type combinations with fewer than 5 samples to ensure robustness in the analysis:

In [None]:
# Count occurrences of each cell type-drug combination
a = pd.DataFrame(adata.obs.cov_drug_name.value_counts())
# Identify combinations with fewer than 5 samples
type_drug_less_index = a[a['cov_drug_name'] < 5].index
# Filter the AnnData object to keep only well-represented combinations
adata_filtered = adata[~adata.obs['cov_drug_name'].isin(type_drug_less_index)]

### Differential Expression Analysis
Now we'll perform differential expression analysis to identify genes responding to perturbations:

In [None]:
# Perform differential expression analysis using the rank_genes_groups_by_cov function
# This identifies genes differentially expressed between drug conditions while accounting for cell type
rank_genes_groups_by_cov(adata_filtered, groupby='cov_drug_name', covariate='cell_type', control_group='DMSO')

# The results are stored in adata_filtered.uns['rank_genes_groups']

### Drug Structure Canonicalization
For drug response prediction, we'll canonicalize SMILES strings to ensure consistent representation of drug structures:

In [None]:
# Canonicalize SMILES strings for consistent drug structure representation
from rdkit import Chem
smiles_list = adata_filtered.obs.SMILES.apply(lambda s: Chem.CanonSmiles(s))

### Calculate scGPT Embeddings
We'll use scGPT to generate embeddings that capture the gene expression patterns:

In [None]:
# Set the path to the pre-trained scGPT model (use 'blood' version for immune cells)
model_path = '/path/to/scGPT/model' # use blood
# Calculate scGPT embeddings and store them in the AnnData object
adata_filtered = scFM.calc_gpt(adata_filtered,model_path,gene_name='gene_name',return_key='X_scGPT')

## Train-test-ood split

We'll create multiple train-test-OOD (out-of-distribution) splits to evaluate model performance. This is a crucial step for assessing how well the model generalizes to:
1. New samples from known cell types and drugs (test set)
2. New cell types not seen during training (OOD set)

First, let's define a function for creating these splits:

In [None]:
def split_dataset(adata,cell_types_inood,split_key):
    # set all ood cell type samples as ood
    adata.obs[split_key] = 'train'
    setout_idx = adata[adata.obs.cell_type.isin(cell_types_inood)].obs.index
    adata.obs[split_key].loc[setout_idx] = 'ood'

    # set 20% left samples as test
    def stratified_sample(group):
        return group.sample(frac=0.2) 

    settest_idx = adata[adata.obs[split_key] != 'ood'].obs.groupby(['cell_type','donor_id','condition'], group_keys=False).apply(stratified_sample).index
    adata.obs[split_key].loc[settest_idx] = 'test'

    # set 75% unperturbed ood cell type samples as train
    def stratified_sample(group):
        return group.sample(frac=0.75)
    settrain_idx = adata[(adata.obs[split_key] == 'ood') & (adata.obs.neg_control == 1)].obs.groupby(['cell_type','donor_id','condition'], group_keys=False).apply(stratified_sample).index
    adata.obs[split_key].loc[settrain_idx] = 'train'


In [None]:
adata = split_dataset(adata,['Myeloid cells','T regulatory cells'],'split')
adata = split_dataset(adata,['T cells CD4+','B cells'],'split2')
adata = split_dataset(adata,['T cells CD8+','NK cells'],'split3')

In [6]:
pd.crosstab(adata[adata.obs['neg_control']==0].obs['split'],adata[adata.obs['neg_control']==0].obs['cell_type'])

cell_type,B cells,Myeloid cells,NK cells,T cells CD4+,T cells CD8+,T regulatory cells
split,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
ood,0,11264,0,0,0,7418
test,2217,0,10675,22849,2849,0
train,8863,0,42703,91344,11415,0


In [7]:
pd.crosstab(adata[adata.obs['neg_control']==0].obs['split2'],adata[adata.obs['neg_control']==0].obs['cell_type'])

cell_type,B cells,Myeloid cells,NK cells,T cells CD4+,T cells CD8+,T regulatory cells
split2,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
ood,11080,0,0,114193,0,0
test,0,2253,10684,0,2854,1491
train,0,9011,42694,0,11410,5927


In [8]:
pd.crosstab(adata[adata.obs['neg_control']==0].obs['split3'],adata[adata.obs['neg_control']==0].obs['cell_type'])

cell_type,B cells,Myeloid cells,NK cells,T cells CD4+,T cells CD8+,T regulatory cells
split3,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
ood,0,0,53378,0,14264,0
test,2214,2253,0,22846,0,1491
train,8866,9011,0,91347,0,5927
