In [57]:
from moscot.problems.generic._generic import ConditionalNeuralProblem
import jax.numpy as jnp
import scanpy as sc
import pickle as pkl
import jax

In [2]:
adata = sc.read_h5ad("/Users/gori/Desktop/thesis/ConditionalOT_Perturbations/Datasets/sciplex_complete_middle_subset.h5ad")
with open("/Users/gori/Desktop/thesis/ConditionalOT_Perturbations/Datasets/processed_data_25.pickle", "rb") as f:
    aux_data = pkl.load(f)


In [3]:
inv_map = {
    "cell_lines": {v:k for k,v in aux_data["mapping"]["cell_lines"].items()},
    "conditions": {v:k for k,v in aux_data["mapping"]["conditions"].items()},
}

In [4]:
embedding_data = {}
embedding_data.update(
    {
       inv_map["cell_lines"][k]:v
        for k,v in aux_data["embedding_train"]["cell_lines"].items() 
    }
)
embedding_data.update(
    {
        inv_map["conditions"][k]:v
        for k,v in aux_data["embedding_train"]["conditions"].items() 
    }
)

In [36]:
adata_train = adata[adata.obs["split_ood_finetuning"].isin(["train"])],
neural_problem = ConditionalNeuralProblem(
    adata_train,
    embedding_data=embedding_data,
)

In [37]:
subset = []
for cell_line in adata_train.obs["cell_type"].unique():
    for condition in adata_train[adata_train.obs["cell_type"] == cell_line].obs["cov_drug"].unique():
        subset.append(
            (f"{cell_line}_{'control'}", condition)
        )

In [38]:
neural_problem.prepare(
    key="cov_drug",
    joint_attr="X_pca",
    policy="explicit",
    subset=[subset[0]]
)

<moscot.problems.generic._generic.ConditionalNeuralProblem at 0x303be31d0>

In [39]:
# prepare() method returns the pairs in inverse order, we want to solve from control to perturbation
neural_problem._sample_pairs = [(c, d) for (d,c) in neural_problem._sample_pairs]

In [40]:
neural_problem.solve(
    cond_dim=494,
    embedding_data=embedding_data,
    iterations=100,
    best_model_metric="sinkhorn"
)

100%|██████████| 100/100 [00:14<00:00,  6.70it/s]


<moscot.problems.generic._generic.ConditionalNeuralProblem at 0x303be31d0>

In [41]:
neural_problem.solution

CondNeuralDualOutput[predicted_cost=0.289, best_loss=Array(0.09143162, dtype=float32), sinkhorn_dist=0.20153522491455078]

In [42]:
cell_line, condition = subset[0][1].split("_")
cell_line, condition = embedding_data[cell_line], embedding_data[condition]
neural_problem.solution.push(
    x=jnp.zeros(25),
    cond=jnp.hstack([
        cell_line, condition
    ]),
)

Array([[ 0.44209042, -0.16016953,  0.18922329, -0.1408967 , -0.2180253 ,
         0.3904039 , -0.11984493,  0.0663259 ,  0.10406263, -0.20578012,
         0.11186115,  0.26099184,  0.05799967, -0.02125153,  0.16450763,
         0.03003113,  0.05401461, -0.0679647 ,  0.11992903,  0.01917309,
        -0.13954584,  0.07577816, -0.0289299 , -0.08964204,  0.06985531]],      dtype=float32)

In [None]:
batch_mapper = push_results = jax.vmap(
    lambda x, cond: neural_problem.solution.push(x=x, cond=cond),
)

In [93]:
for _, target in subset:
    cell_line, condition = target.split("_")
    try:
        embedding = jnp.hstack([embedding_data[cell_line], embedding_data[condition]])
    except KeyError:
        print(f"Skipping {cell_line}_{condition}")
        continue

    all_gex = adata[
        (adata.obs['split_ood_finetuning'] == "test") &
        (adata.obs["cell_type"] == cell_line) &
        (adata.obs["condition"] == condition)
    ].obsm["X_pca"]

    all_embeddings = jnp.repeat(
        embedding[None, :],
        all_gex.shape[0],
        axis=0
    )

    batch_results = batch_mapper(
        all_gex,
        all_embeddings
    )





Skipping A549_(+)-JQ1
Skipping A549_Hesperadin
Skipping A549_CUDC-101
Skipping A549_Raltitrexed
Skipping A549_Trametinib
Skipping A549_Dacinostat
Skipping A549_CUDC-907
Skipping A549_Pirarubicin
Skipping A549_Tanespimycin
Skipping A549_Givinostat
Skipping MCF7_Trametinib
Skipping MCF7_(+)-JQ1
Skipping MCF7_Tanespimycin
Skipping MCF7_Givinostat
Skipping MCF7_Pirarubicin
Skipping MCF7_Raltitrexed
Skipping MCF7_Hesperadin


KeyboardInterrupt: 