In [1]:
from apdpa._data import get_data_loaders, create_combo_split, get_all_single_perturbation_idx

In [2]:
# Import data
import numpy as np
import pandas as pd
import scanpy as sc

adata = sc.read_h5ad("data/proteomics/data/preprocessed_small.h5ad") #730 Ã— 5519

In [3]:
adata[adata.obs["protein_plate"] == "MCF7"].obs["type"].value_counts()

type
singleDrug         175
drugCombination    108
noDrug              65
Name: count, dtype: int64

In [4]:
naming_config = {
    "type_col": "type",                       # Column in adata.obs storing the sample type.
    "single_type": "singleDrug",      # Value for single perturbation samples.
    "no_type": "noDrug",              # Value for no perturbation samples.
    "combo_type": "drugCombination",        # Value for combination samples.
    "perturbation_a_col": "anchor_drug",     # Column for the first perturbation in a combo.
    "perturbation_b_col": "library_drug",     # Column for the second perturbation in a combo.
    "condition_col": "protein_plate",             # Column representing the condition (cell line/protein plate).
    # If you wish to force certain conditions to a specific split (e.g. protein plates),
    # set group_col equal to condition_col. Otherwise, leave as None.
    "group_col": None,                        
    "strength_col": "perturbation_strength",  # Column with the perturbation strength.
    "strength_a_col": "anchor_dose",   # Strength for perturbation A (used for single samples and combo A).
    "strength_b_col": "library_dose",   # Strength for perturbation B (used only for combos).
}


In [5]:

# 1. CHECK COMBO SPLIT

# Get data loaders
train_combo_idxs, val_combo_idxs, test_combo_idxs = create_combo_split(adata, 
                                                                           config=naming_config,
                                                                           test_frac=0.2,
                                                                           val_frac=0.1,
                                                                           random_state=42)

train_drugs = adata.obs[["anchor_drug","library_drug"]].iloc[train_combo_idxs]
train_drugs["combo_id"] = (
        train_drugs["anchor_drug"].astype(str) +
        "+" +
        train_drugs["library_drug"].astype(str)
    )
drug_combos = train_drugs["combo_id"].unique()

val_drugs = adata.obs[["anchor_drug","library_drug"]].iloc[val_combo_idxs]
val_drugs["combo_id"] = (
        val_drugs["anchor_drug"].astype(str) +
        "+" +
        val_drugs["library_drug"].astype(str)
    )
drug_combos = val_drugs["combo_id"].unique()

test_drugs = adata.obs[["anchor_drug","library_drug"]].iloc[test_combo_idxs]
test_drugs["combo_id"] = (
        test_drugs["anchor_drug"].astype(str) +
        "+" +
        test_drugs["library_drug"].astype(str)
    )
drug_combos = test_drugs["combo_id"].unique()

# Check for overlap between splits
train_drugs = set(train_drugs["combo_id"].unique())
val_drugs = set(val_drugs["combo_id"].unique())
test_drugs = set(test_drugs["combo_id"].unique())

print("Train and val overlap:", len(train_drugs.intersection(val_drugs)))
print("Train and test overlap:", len(train_drugs.intersection(test_drugs)))
print("Val and test overlap:", len(val_drugs.intersection(test_drugs)))

Train and val overlap: 0
Train and test overlap: 0
Val and test overlap: 0


In [6]:
# 2. CHECK SINGLE SPLIT

train_single_idxs = get_all_single_perturbation_idx(adata, config=naming_config)
adata.obs["type"].iloc[train_single_idxs].value_counts()

type
singleDrug         360
drugCombination      0
noDrug               0
Name: count, dtype: int64

In [7]:
# 3. CHECK DATA LOADERS

train_loader_single,train_loader_combo,_,_ = get_data_loaders(
    adata,
    naming_config,
    batch_size_combo=32,
    batch_size_single=32
)

In [8]:
iter_train_single = iter(train_loader_single)
batch = next(iter_train_single)
X_noPerturbation, X_single, perturbation_single, strength_single = batch[0], batch[1], batch[2], batch[3]

In [9]:
iter_train_combo = iter(train_loader_combo)
batch = next(iter_train_combo)

In [10]:
(X_combo, strength_combo_a, strength_combo_b,
X_a, strength_single_a, perturbation_a,
X_b, strength_single_b, perturbation_b,
X_no1, X_no2, cond) = next(iter_train_combo)