In [1]:
import sys
import yaml
import torch
import logging
from pathlib import Path

# Add the path to the directory containing the omnicell package
# Assuming the omnicell package is in the parent directory of your notebook
sys.path.append('..')  # Adjust this path as needed

import yaml
import torch
import logging
from pathlib import Path
from omnicell.config.config import Config, ETLConfig, ModelConfig, DatasplitConfig, EvalConfig, EmbeddingConfig
from omnicell.data.loader import DataLoader
from omnicell.constants import PERT_KEY, GENE_EMBEDDING_KEY, CONTROL_PERT
from train import get_model

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Configure paths
MODEL_CONFIG = ModelConfig.from_yaml("/orcd/data/omarabu/001/njwfish/omnicell/configs/models/cinemaot.yaml")
ETL_CONFIG = ETLConfig(name = "no_preprocessing", log1p = False, drop_unmatched_perts = True)
EMBEDDING_CONFIG = EmbeddingConfig(pert_embedding='GenePT')

SPLIT_CONFIG = DatasplitConfig.from_yaml("/orcd/data/omarabu/001/njwfish/omnicell/configs/splits/repogle_k562_essential_raw/random_splits/rs_accP_k562_ood_ss:ns_20_2_most_pert_0.1/split_0/split_config.yaml")
EVAL_CONFIG = EvalConfig.from_yaml("/orcd/data/omarabu/001/njwfish/omnicell/configs/splits/repogle_k562_essential_raw/random_splits/rs_accP_k562_ood_ss:ns_20_2_most_pert_0.1/split_0/eval_config.yaml")  # Set this if you want to run evaluations

# Load configurations
config = Config(model_config=MODEL_CONFIG,
                 etl_config=ETL_CONFIG, 
                 datasplit_config=SPLIT_CONFIG, 
                 eval_config=EVAL_CONFIG)




#Alternatively you can initialize the config objects manually as follows:
# etl_config = ETLConfig(name = XXX, log1p = False, drop_unmatched_perts = False, ...)
# model_config = ...
# embedding_config = ...
# datasplit_config = ...
# eval_config = ...
# config = Config(etl_config, model_config, datasplit_config, eval_config)


config.etl_config.pert_embedding = 'bioBERT'
config.etl_config.drop_unmatched_perts = True
# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize data loader and load training data
loader = DataLoader(config)
adata, pert_rep_map = loader.get_training_data()

# Get dimensions and perturbation IDs
input_dim = adata.obsm['embedding'].shape[1]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pert_ids = adata.obs[PERT_KEY].unique()
gene_emb_dim = adata.varm[GENE_EMBEDDING_KEY].shape[1] if GENE_EMBEDDING_KEY in adata.varm else None

print(f"Data loaded:")
print(f"- Number of cells: {adata.shape[0]}")
print(f"- Input dimension: {input_dim}")
print(f"- Number of perturbations: {len(pert_ids)}")



2025-01-31 09:19:43,867 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/repogle_k562_essential_raw.yaml
2025-01-31 09:19:43,869 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/satija_IFNB_raw.yaml
2025-01-31 09:19:43,871 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/adamson_INCOMPLETE.yaml
2025-01-31 09:19:43,872 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/satija_IFNB_HVG.yaml
2025-01-31 09:19:43,874 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/kang.yaml
2025-01-31 09:19:43,875 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/essential_gene_knockouts_raw.yaml
2025-01-31 09:19:43,877 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/satija_IFNG_raw_INCOMPLET

Using device: cuda


2025-01-31 09:19:47,859 - INFO - Loaded unpreprocessed data, # of data points: 310385, # of genes: 8563.
2025-01-31 09:19:47,859 - INFO - Preprocessing training data
2025-01-31 09:19:47,862 - INFO - Using identity features for perturbations
2025-01-31 09:19:47,983 - INFO - Removing observations with perturbations not in the dataset as a column
2025-01-31 09:19:48,169 - INFO - Removed 189 perturbations that were not in the dataset columns and 0 perturbations that did not have an embedding for a total of 189 perturbations removed out of an initial 2058 perturbations
  adata.obsm["embedding"] = adata.X.toarray().astype('float32')
2025-01-31 09:20:29,515 - INFO - Doing OOD split


Data loaded:
- Number of cells: 279630
- Input dimension: 8563
- Number of perturbations: 1850


In [2]:

model = get_model(config.model_config.name, config.model_config.parameters, loader, pert_rep_map, input_dim, device, pert_ids, gene_emb_dim)

2025-01-31 09:20:37,843 - INFO - Mean model selected


In [3]:
model.train(adata)

  pert_means = df.groupby('perturbation').mean()


In [6]:
import numpy as np

logger.info("Running evaluation")

# evaluate each pair of cells and perts
eval_dict = {}
for cell_id, pert_id, ctrl_data, gt_data in loader.get_eval_data():
    logger.debug(f"Making predictions for cell: {cell_id}, pert: {pert_id}")

    preds = model.make_predict(ctrl_data, pert_id, cell_id)
    eval_dict[(cell_id, pert_id)] = (ctrl_data.X.toarray(), gt_data.X.toarray(), preds)
    break
    
if not config.etl_config.log1p:
    for (cell, pert) in eval_dict:  
        ctrl_data, gt_data, pred_pert = eval_dict[(cell, pert)]
        eval_dict[(cell, pert)] =  (np.log1p(ctrl_data), np.log1p(gt_data), np.log1p(pred_pert))


2025-01-31 09:21:37,825 - INFO - Running evaluation
2025-01-31 09:21:37,827 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/repogle_k562_essential_raw.yaml
2025-01-31 09:21:37,829 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/satija_IFNB_raw.yaml
2025-01-31 09:21:37,831 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/adamson_INCOMPLETE.yaml
2025-01-31 09:21:37,832 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/satija_IFNB_HVG.yaml
2025-01-31 09:21:37,834 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/kang.yaml
2025-01-31 09:21:37,836 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/essential_gene_knockouts_raw.yaml
2025-01-31 09:21:37,837 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/

In [7]:
import scanpy as sc
from omnicell.evaluation.utils import get_DEGs, get_eval, get_DEG_Coverage_Recall, get_DEGs_overlaps
pval_threshold = 0.05
log_fold_change_threshold = 0.0

results_dict = {}
for (cell, pert) in eval_dict:  
    ctrl_data, gt_data, pred_pert = eval_dict[(cell, pert)]

    pred_pert = sc.AnnData(X=pred_pert)
    true_pert = sc.AnnData(X=gt_data)
    control = sc.AnnData(X=ctrl_data)

    logger.debug(f"Getting ground Truth DEGs for {pert} and {cell}")
    true_DEGs_df = get_DEGs(control, true_pert)
    signif_true_DEG = true_DEGs_df[true_DEGs_df['pvals_adj'] < pval_threshold]

    logger.debug(f"Number of significant DEGS from ground truth: {signif_true_DEG.shape[0]}")

    logger.debug(f"Getting predicted DEGs for {pert} and {cell}")
    pred_DEGs_df = get_DEGs(control, pred_pert)


    logger.debug(f"Getting evaluation metrics for {pert} and {cell}")
    r2_and_mse = get_eval(control, true_pert, pred_pert, true_DEGs_df, [100,50,20], pval_threshold, log_fold_change_threshold)

    logger.debug(f"Getting DEG overlaps for {pert} and {cell}")
    DEGs_overlaps = get_DEGs_overlaps(true_DEGs_df, pred_DEGs_df, [100,50,20], pval_threshold, log_fold_change_threshold)

    results_dict[(cell, pert)] = (r2_and_mse, DEGs_overlaps)



  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


In [8]:
for (cell, pert) in results_dict:
    r2_and_mse, DEGs_overlaps = results_dict[(cell, pert)]
    print(f"Cell: {cell}, Pert: {pert}")
    # print(f"R2 and MSE: {r2_and_mse}")
    print(f"DEGs Overlaps: {DEGs_overlaps}")
    print("-"*100)


Cell: k562, Pert: RPL15
DEGs Overlaps: {'Overlap_in_top_2159_DEGs': 875, 'Overlap_in_top_100_DEGs': 4, 'Overlap_in_top_50_DEGs': 0, 'Overlap_in_top_20_DEGs': 0, 'Jaccard': 0.17341105257258155}
----------------------------------------------------------------------------------------------------


In [9]:
r2_and_mse

{'all_genes_mean_sub_diff_R': np.float64(0.6966497841526466),
 'all_genes_mean_sub_diff_R2': np.float64(0.485320921759929),
 'all_genes_mean_sub_diff_MSE': np.float64(0.01054379975115533),
 'all_genes_mean_fold_diff_R': np.float64(0.7028902237471792),
 'all_genes_mean_fold_diff_R2': np.float64(0.4940546666393597),
 'all_genes_mean_fold_diff_MSE': np.float64(7.89763734207286),
 'all_genes_mean_R': np.float64(0.9880260038744728),
 'all_genes_mean_R2': np.float64(0.9761953843321598),
 'all_genes_mean_MSE': np.float64(0.010543799755630833),
 'all_genes_var_R': np.float64(0.8846038081846144),
 'all_genes_var_R2': np.float64(0.7825238974547221),
 'all_genes_var_MSE': np.float64(0.003680699246692525),
 'all_genes_corr_mtx_R': np.float64(0.5764657556881528),
 'all_genes_corr_mtx_R2': np.float64(0.3323127674811131),
 'all_genes_corr_mtx_MSE': np.float64(0.0074322404434879),
 'all_genes_cov_mtx_R': np.float64(0.7999105612729749),
 'all_genes_cov_mtx_R2': np.float64(0.6398569060360457),
 'all_gen