In [2]:
import optuna

In [3]:
import cpa

Global seed set to 0


In [None]:

from rdkit import Chem
import scanpy as sc
from pandas import CategoricalDtype

import numpy as np
import pandas as pd
import anndata
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr
from scipy.sparse import issparse
from matplotlib.gridspec import GridSpec
from optuna.samplers import GridSampler
import itertools

import scipy.sparse as sparse
from scipy.spatial.distance import cdist
from tqdm import tqdm

In [None]:

adata = sc.read('./plate9_for_cpa_RDkit_training.h5ad') # loading the data 

adata.obs.rename(columns={"label": "condition_ID"}, inplace=True) # renaming the 'label' column to 'condition'


In [6]:
SEED = 42
adata.obs["cell_type"] = adata.obs["cell_line"]
cpa.CPA.setup_anndata(adata,
                      perturbation_key='condition_ID', #defines perturbation conditions
                      dosage_key='log_dose',           #for dose-dependent effects.
                      control_group='DMSO_TF_00uM',
                      batch_key=None,
                      smiles_key='smiles_rdkit',       #to compute RDKit embeddings, which capture chemical structure.
                      is_count_data=True,              #raw counts
                      categorical_covariate_keys=['cell_line'], #CPA can model how cell type affects drug response.
                      deg_uns_key='rank_genes_groups', #where CPA can find differential expression gene lists in .uns.
                      deg_uns_cat_key='cov_drug_dose',               #cov_drug_dose on the groupby of DE 
                      max_comb_len=1,
                     )


100%|██████████| 17087/17087 [00:00<00:00, 55191.91it/s]
100%|██████████| 17087/17087 [00:00<00:00, 1111012.33it/s]
100%|██████████| 86/86 [00:00<00:00, 16631.78it/s]


[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        


In [None]:
# Define Data Splits (split_1ct_MEC)
np.random.seed(42)
conds = adata.obs["condition_ID"].unique().tolist()
n_ood = max(1, int(0.1 * len(conds)))
ood_conds = np.random.choice(conds, n_ood, replace=False)

# initialize all as train
adata.obs["split_1ct_MEC"] = "train"
# mark OOD drugs
mask = adata.obs["condition_ID"].isin(ood_conds)
adata.obs.loc[mask, "split_1ct_MEC"] = "ood"

# validation set
val_conds = np.random.choice(
    [c for c in conds if c not in ood_conds],
    max(1, int(0.05 * len(conds))),
    replace=False
)
mask_val = adata.obs["condition_ID"].isin(val_conds)
adata.obs.loc[mask_val, "split_1ct_MEC"] = "valid"

print(adata.obs['split_1ct_MEC'].value_counts())

# count of cells in each data split
ood_conds = list(adata[adata.obs['split_1ct_MEC'] == 'ood'].obs['condition_ID'].value_counts().index)

#all unique condition IDs from the OOD split,
print("-> OOD conditions:")
print(ood_conds) 

# Tags each cell's condition as either an OOD condition or 'other'. 
adata.obs['condition_split'] = adata.obs['condition_ID'].apply(lambda x: x if x in ood_conds else 'other')



#  rebuilt cov drug dose
def build_cov_drug_dose(row):
    cell = row["cell_line"]
    drug = row["drug_cleaned"] 
    dose = float(row["dose"])
    return f"{cell}_{drug}_{dose:.1f}"


adata.obs["cov_drug_dose"] = adata.obs.apply(build_cov_drug_dose, axis=1)

print("-> Done.")



split_1ct_MEC
train    14687
ood       1600
valid      800
Name: count, dtype: int64
-> OOD conditions:
['Almonertinib (mesylate)_50uM', 'Bisoprolol (hemifumarate)_50uM', 'Clofarabine_50uM', 'Cytarabine_50uM', 'Lucanthone_50uM', 'Rucaparib (phosphate)_50uM', 'Sivelestat (sodium tetrahydrate)_50uM', 'Tranilast_50uM']
-> Done.


In [None]:
# Define Search Space for Hyperparameter Optimization ------------------------------------
 
search_space = {
    "n_latent": [64, 128, 256],
    "n_hidden_encoder": [128, 256, 512],
    "n_layers_encoder": [2, 3],
    "dropout_rate_encoder": [0.05, 0.2, 0.3],
    "lr": [1e-4, 3e-4],
}
sampler = GridSampler(search_space)

#  Optuna Objective

def objective(trial: optuna.trial.Trial) -> float:
    # Sample hyperparameters
    n_latent = trial.suggest_categorical("n_latent", search_space["n_latent"])
    n_hidden_encoder = trial.suggest_categorical("n_hidden_encoder", search_space["n_hidden_encoder"])
    n_layers_encoder = trial.suggest_categorical("n_layers_encoder", search_space["n_layers_encoder"])
    dropout_rate_encoder = trial.suggest_categorical("dropout_rate_encoder", search_space["dropout_rate_encoder"])
    lr = trial.suggest_categorical("lr", search_space["lr"])

    label = f"lat{n_latent}_enc{n_hidden_encoder}x{n_layers_encoder}_do{dropout_rate_encoder}_lr{lr}"
    print(f"\n🔧 Trial {trial.number:02d} – {label}")

    # Model parameters
    model_params = dict(
        n_latent=n_latent,
        recon_loss="nb",
        doser_type="linear",
        n_hidden_encoder=n_hidden_encoder,
        n_layers_encoder=n_layers_encoder,
        n_hidden_decoder=128,
        n_layers_decoder=2,
        use_batch_norm_encoder=True,
        use_layer_norm_encoder=False,
        use_batch_norm_decoder=False,
        use_layer_norm_decoder=True,
        dropout_rate_encoder=dropout_rate_encoder,
        dropout_rate_decoder=0.05,
        variational=False,
        seed=6977,
    )

    trainer_params = {
        "n_epochs_kl_warmup": None,
        "n_epochs_pretrain_ae": 30,
        "n_epochs_adv_warmup": 50,
        "n_epochs_mixup_warmup": 0,
        "mixup_alpha": 0.0,
        "adv_steps": None,
        "n_hidden_adv": 64,
        "n_layers_adv": 3,
        "use_batch_norm_adv": True,
        "use_layer_norm_adv": False,
        "dropout_rate_adv": 0.3,
        "reg_adv": 20.0,
        "pen_adv": 5.0,
        "lr": lr,
        "wd": 4e-07,
        "adv_lr": 0.0003,
        "adv_wd": 4e-07,
        "adv_loss": "cce",
        "doser_lr": 0.0003,
        "doser_wd": 4e-07,
        "do_clip_grad": True,
        "gradient_clip_value": 1.0,
        "step_size_lr": 10,
    }

    try:
        # Train CPA model
        model = cpa.CPA(
            adata,
            split_key='split_1ct_MEC',
            train_split='train',
            valid_split='valid',
            test_split='ood',
            use_rdkit_embeddings=True,
            **model_params,
        )
        model.train(
            max_epochs=20,
            use_gpu=False,
            batch_size=32,
            plan_kwargs=trainer_params,
            early_stopping_patience=5,
            check_val_every_n_epoch=2,
        )

        # Evaluate on OOD conditions
        ood_conds = adata.obs.loc[adata.obs['split_1ct_MEC'] == 'ood', 'condition_ID'].unique()
        detailed_results = []
        for cond in ood_conds:
            mask_ood = (adata.obs['condition_ID'] == cond) & (adata.obs['split_1ct_MEC'] == 'ood')
            if mask_ood.sum() == 0:
                continue

            adata_ood = adata[mask_ood].copy()
            if "CPA_pred" not in adata_ood.obsm:
                model.predict(adata=adata_ood)

            x_pred = adata_ood.obsm["CPA_pred"]
            x_true = adata_ood.layers["counts"] if "counts" in adata_ood.layers else adata_ood.X
            if hasattr(x_true, "toarray"):
                x_true = x_true.toarray()
            x_pred = np.log1p(x_pred)
            x_true = np.log1p(x_true)

            # Control
            ctrl_mask = (adata.obs["cell_line"] == adata_ood.obs["cell_line"].unique()[0]) & \
                        (adata.obs["condition_ID"] == "DMSO_TF_00uM")
            adata_ctrl = adata[ctrl_mask].copy()
            x_ctrl = adata_ctrl.layers["counts"] if "counts" in adata_ctrl.layers else adata_ctrl.X
            if hasattr(x_ctrl, "toarray"):
                x_ctrl = x_ctrl.toarray()
            x_ctrl = np.log1p(x_ctrl)

            mean_true = x_true.mean(axis=0)
            mean_pred = x_pred.mean(axis=0)
            mean_ctrl = x_ctrl.mean(axis=0)

            r2 = r2_score(mean_true, mean_pred)
            rmse = np.sqrt(mean_squared_error(mean_true, mean_pred))

            lfc_true = mean_true - mean_ctrl
            lfc_pred = mean_pred - mean_ctrl
            r2_lfc = r2_score(lfc_true, lfc_pred)
            rmse_lfc = np.sqrt(mean_squared_error(lfc_true, lfc_pred))
            pearson_corr, _ = pearsonr(mean_true, mean_pred)


            detailed_results.append({
                "model": label,
                "OOD_condition": cond,
                "n_cells": mask_ood.sum(),
                "R2": r2,
                "RMSE": rmse,
                "R2_LFC": r2_lfc,
                "RMSE_LFC": rmse_lfc,
                "Pearson": pearson_corr,
            })

        # Aggregate metrics
        r2_mean = np.mean([r["R2"] for r in detailed_results])
        rmse_mean = np.mean([r["RMSE"] for r in detailed_results])
        r2_lfc_mean = np.mean([r["R2_LFC"] for r in detailed_results])
        rmse_lfc_mean = np.mean([r["RMSE_LFC"] for r in detailed_results])
        pearson_mean = np.mean([r["Pearson"] for r in detailed_results])


        trial.set_user_attr("metrics", {
            "model": label,
            "R2_mean": r2_mean,
            "RMSE_mean": rmse_mean,
            "R2_LFC_mean": r2_lfc_mean,
            "RMSE_LFC_mean": rmse_lfc_mean,
            "Pearson_mean": pearson_mean,

        })
        trial.set_user_attr("detailed", detailed_results)

        print(f"[Trial {trial.number}] R2={r2_mean:.3f}, RMSE={rmse_mean:.3f}, R2_LFC={r2_lfc_mean:.3f}")
        return r2_lfc_mean

    except Exception as e:
        print(f"⚠ Trial {trial.number} failed: {e}")
        return -np.inf

# Compute Total Combinations

n_combinations = len(list(itertools.product(
    search_space["n_latent"],
    search_space["n_hidden_encoder"],
    search_space["n_layers_encoder"],
    search_space["dropout_rate_encoder"],
    search_space["lr"]
)))
print(f"Total combinations: {n_combinations}")


# Run the Optuna Study (Grid Search over Hyperparameters)

study = optuna.create_study(direction="maximize", sampler=sampler)
study.optimize(objective, n_trials=n_combinations, show_progress_bar=True)


# Save and Display Summary of Trial Metrics

trial_metrics = [t.user_attrs["metrics"] for t in study.trials if "metrics" in t.user_attrs]
results_df = pd.DataFrame(trial_metrics)
results_df = results_df.set_index("model")
results_df.to_csv("./RDkit_optuna_optimized_full_grid.tsv", sep="\t", index=True)
display(results_df.style.format("{:.3f}"))
print("✅ Summary saved to RDkit_optuna_optimized_full_grid.tsv")

# Save Detailed Trial Results (Per-Condition Metrics)
detailed_all = [r for t in study.trials if "detailed" in t.user_attrs for r in t.user_attrs["detailed"]]
detailed_df = pd.DataFrame(detailed_all)
detailed_df.to_csv("./RDkit_optuna_optimized_full_grid_detailed.tsv", sep="\t", index=False)
print("✅ Detailed results saved to RDkit_optuna_optimized_full_grid_detailed.tsv")


# Best Trial and hyperparameters

print("\n🏆 Best Trial:")
print(f"  Number: {study.best_trial.number}")
print(f"  Value (R2_LFC_mean): {study.best_trial.value:.3f}")
print(f"  Params: {study.best_trial.params}")

[32m[I 2025-07-22 19:33:41,993][0m A new study created in memory with name: no-name-db24d2f9-8b52-423d-92eb-719421629f63[0m


Total combinations: 108


  0%|          | 0/108 [00:00<?, ?it/s]

Global seed set to 6977



🔧 Trial 00 – lat64_enc512x3_do0.2_lr0.0001
(86, 2048)


100%|██████████| 86/86 [00:00<00:00, 370.04it/s]
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Epoch 2/20:   5%|▌         | 1/20 [00:26<08:23, 26.52s/it, v_num=1, recon=373, r2_mean=0.21, adv_loss=4.49, acc_pert=0.0257]


Epoch 00001: cpa_metric reached. Module best state updated.


Epoch 4/20:  15%|█▌        | 3/20 [01:21<07:43, 27.26s/it, v_num=1, recon=294, r2_mean=0.342, adv_loss=4.32, acc_pert=0.0298, val_recon=295, disnt_basal=0.226, disnt_after=0.226, val_r2_mean=0.675, val_KL=nan]


Epoch 00003: cpa_metric reached. Module best state updated.


Epoch 6/20:  25%|██▌       | 5/20 [02:15<06:45, 27.06s/it, v_num=1, recon=286, r2_mean=0.363, adv_loss=4.24, acc_pert=0.0349, val_recon=285, disnt_basal=0.226, disnt_after=0.226, val_r2_mean=0.683, val_KL=nan]


Epoch 00005: cpa_metric reached. Module best state updated.


Epoch 8/20:  35%|███▌      | 7/20 [03:09<05:50, 26.92s/it, v_num=1, recon=282, r2_mean=0.375, adv_loss=4.19, acc_pert=0.0381, val_recon=281, disnt_basal=0.226, disnt_after=0.227, val_r2_mean=0.695, val_KL=nan]


Epoch 00007: cpa_metric reached. Module best state updated.


Epoch 10/20:  45%|████▌     | 9/20 [04:08<05:13, 28.53s/it, v_num=1, recon=279, r2_mean=0.385, adv_loss=4.15, acc_pert=0.0432, val_recon=277, disnt_basal=0.226, disnt_after=0.226, val_r2_mean=0.698, val_KL=nan]


Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 0.22527270507270508
disnt_after = 0.22568827283827284
val_r2_mean = 0.7028860199451447
val_r2_var = 0.22657282173633575
Epoch 12/20:  55%|█████▌    | 11/20 [05:04<04:14, 28.32s/it, v_num=1, recon=277, r2_mean=0.394, adv_loss=4.13, acc_pert=0.0445, val_recon=275, disnt_basal=0.225, disnt_after=0.226, val_r2_mean=0.703, val_KL=nan]


Epoch 00011: cpa_metric reached. Module best state updated.


Epoch 14/20:  65%|██████▌   | 13/20 [06:00<03:15, 27.94s/it, v_num=1, recon=274, r2_mean=0.399, adv_loss=4.11, acc_pert=0.0464, val_recon=273, disnt_basal=0.224, disnt_after=0.225, val_r2_mean=0.707, val_KL=nan]


Epoch 00013: cpa_metric reached. Module best state updated.


Epoch 16/20:  75%|███████▌  | 15/20 [06:58<02:23, 28.64s/it, v_num=1, recon=273, r2_mean=0.404, adv_loss=4.09, acc_pert=0.0509, val_recon=271, disnt_basal=0.226, disnt_after=0.226, val_r2_mean=0.711, val_KL=nan]

In [None]:
results_df.to_csv("./RDkit_optuna_optimized_full_grid.csv", sep="\t", index=True) # Saving the results DataFrame to a TSV file for further analysis

In [None]:
results_df.to_csv("./RDkit_optuna_optimized_full_grid_detailed.csv", sep="\t", index=True) # Saving the results DataFrame to a TSV file for further analysis

In [None]:
df = pd.read_csv("RDkit_optuna_optimized_full_grid.tsv", sep="\t")

## Vizualization of Model Performance Results

### Bar plot for R²


In [None]:
df_sorted = df.sort_values(by="R2", ascending=False)

plt.figure(figsize=(12, 6))
plt.barh(df_sorted["model"], df_sorted["R2"])
plt.xlabel("R²")
plt.title("Model Performance - R² Score")
plt.tight_layout()
plt.show()

### Bar plot for MSE

In [None]:
df_sorted = df.sort_values(by="MSE", ascending=True)

plt.figure(figsize=(12, 6))
plt.barh(df_sorted["model"], df_sorted["MSE"])
plt.xlabel("MSE")
plt.title("Model Performance - Mean Squared Error")
plt.tight_layout()
plt.show()

### Bar plot for Pearson correlation

In [None]:
df_sorted = df.sort_values(by="Pearson", ascending=False)

plt.figure(figsize=(12, 6))
plt.barh(df_sorted["model"], df_sorted["Pearson"])
plt.xlabel("Pearson Correlation")
plt.title("Model Performance - Pearson")
plt.tight_layout()
plt.show()

###  Bar plot for KDE divergence

In [None]:
df_sorted = df.sort_values(by="e_distance", ascending=True)

plt.figure(figsize=(12, 6))
plt.barh(df_sorted["model"], df_sorted["e_distance"])
plt.xlabel("E-distance")
plt.title("Model Performance - E-distance")
plt.tight_layout()
plt.show()

### Bar plot for Top-100 DEG Jaccard similarity

In [None]:
df_sorted = df.sort_values(by="mv_kde", ascending=True)

plt.figure(figsize=(12, 6))
plt.barh(df_sorted["model"], df_sorted["mv_kde"])
plt.xlabel("Multivariate KDE distance")
plt.title("Model Performance - mv_kde")
plt.tight_layout()
plt.show()

In [None]:
df_sorted = df.sort_values(by="jaccard_top100", ascending=False)

plt.figure(figsize=(12, 6))
plt.barh(df_sorted["model"], df_sorted["jaccard_top100"])
plt.xlabel("Jaccard Index (Top 100 DEGs)")
plt.title("Model Performance - DEG Overlap")
plt.tight_layout()
plt.show()