In [None]:
import torch
from torch import nn
from tqdm import tqdm
import scanpy as sc
from torch.utils.data import DataLoader, random_split, TensorDataset
from sklearn.model_selection import train_test_split
import plotly.express as px
import plotly.graph_objects as go
from sklearn.decomposition import PCA
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from ae_model import Autoencoder
from neural_flow_model import FlowModel, ZINBSampler
from datasets import PerturbPairData
from scipy.stats import wasserstein_distance

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
NUM_EPOCHS = 200
BATCH_SIZE = 2048
HIDDEN_DIM = 128
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
COND_EMB_DIMS = 16
PERT_EMB_DIMS = 16
SIGMA = 0

In [None]:
def load_jiang(path: str):
    adata = sc.read_h5ad(path)
    # adata = adata[adata.obs["cell_type"] == 'ht29']
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    sc.tl.pca(adata, svd_solver="arpack")
    return adata

def load_pbmc(path: str):
    adata = sc.read_h5ad(path)
    adata.layers["counts"] = adata.X.copy()
    adata.obs['perturbation'] = adata.obs['cytokine']
    adata.obs['control'] = (adata.obs['cytokine'] == 'PBS').astype(int)
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    sc.tl.pca(adata, svd_solver="arpack")
    return adata


print("loading dataset")
adata = load_pbmc(
    path="data/pmbc_c14_mono_all_donors_20k_subset.h5ad"
)

print(adata)

In [None]:
ckpt = torch.load("models/model_geodesic_c14_mono_all_donors_20k_dim32.pt", map_location=device)

model_ae = Autoencoder(
    input_dim=ckpt["input_dim"],
    latent_dim=ckpt["latent_dim"],
    hidden_dim=ckpt["hidden_dim"],
).to(device)
model_ae.load_state_dict(ckpt["model_state"])
model_ae.eval()

In [None]:
# X_pca = torch.from_numpy(adata.obsm["X_pca"].astype("float32")).to(device)
with torch.no_grad():
    X = model_ae.encode(torch.from_numpy(adata.X.toarray().astype("float32")).to(device))

# X = torch.from_numpy(adata.X.toarray().astype("float32")).to(device) # use raw gene
# X = X_pca # use PCA
X = X # use latent space of AE

### Create mapping of perturbation and condition to index

In [None]:
PERTURBATION_COL = 'perturbation'
CONDITION_COL = 'cell_type'
CONTROL_COL = 'control'

In [None]:
all_conditions = sorted(adata.obs[CONDITION_COL].unique())
all_perturbations = sorted(adata.obs[PERTURBATION_COL].unique())

global_condition_to_idx = {cond: i for i, cond in enumerate(all_conditions)}
global_perturb_to_idx = {pert: i for i, pert in enumerate(all_perturbations)}

print(f"total # conditions: {len(global_condition_to_idx)}")
print(f"total # perturbs: {len(global_perturb_to_idx)}")

In [None]:
def prepare_tensors(dataset):
    c_idxs = [item[0] for item in dataset.pairs]
    q_idxs = [item[1] for item in dataset.pairs]
    p_idxs = [dataset.perturb_to_idx[item[2]] for item in dataset.pairs]
    cond_idxs = [dataset.condition_to_idx[item[3]] for item in dataset.pairs]
    return (
        dataset.X[c_idxs],
        dataset.X[q_idxs],
        torch.tensor(p_idxs, device=device),
        torch.tensor(cond_idxs, device=device)
    )

In [None]:
from sklearn.metrics.pairwise import rbf_kernel

def compute_mmd(x_true, x_pred, gamma=None):
    if x_true.shape[0] > 2000:
        idx = np.random.choice(x_true.shape[0], 2000, replace=False)
        x_true = x_true[idx]

    if x_pred.shape[0] > 2000:
        idx = np.random.choice(x_pred.shape[0], 2000, replace=False)
        x_pred = x_pred[idx]

    K_xx = rbf_kernel(x_true, x_true, gamma=gamma)
    K_yy = rbf_kernel(x_pred, x_pred, gamma=gamma)
    K_xy = rbf_kernel(x_true, x_pred, gamma=gamma)

    return K_xx.mean() + K_yy.mean() - 2 * K_xy.mean()

In [None]:
def compute_wd_per_gene(x_true, x_pred):
    n_genes = x_true.shape[1]
    wds = []
    for i in range(n_genes):
        wd = wasserstein_distance(x_true[:, i], x_pred[:, i])
        wds.append(wd)
    return np.array(wds)

In [None]:
all_donors = sorted(adata.obs['donor'].unique())
results = []

all_wds = []
all_deg_wds = []

for holdout_donor in all_donors:
    print(f"--- holding out {holdout_donor} ---")

    donor_mask = (adata.obs["donor"] == holdout_donor).to_numpy()
    ctrl_mask  = adata.obs[CONTROL_COL].to_numpy().astype(bool)

    donor_perts = list(adata.obs.loc[donor_mask & (~ctrl_mask), PERTURBATION_COL].unique())

    rng = np.random.default_rng(42)
    heldout_perts = rng.choice(
        donor_perts,
        size=max(1, int(np.ceil(0.30 * len(donor_perts)))),
        replace=False,
    )

    print(f'holding out perturbations: {heldout_perts}')

    test_mask = donor_mask & (ctrl_mask | adata.obs[PERTURBATION_COL].isin(heldout_perts).to_numpy())
    train_mask = ~ (donor_mask & adata.obs[PERTURBATION_COL].isin(heldout_perts).to_numpy())
    train_idx = np.where(train_mask)[0]
    test_idx  = np.where(test_mask)[0]

    X_train = X[train_idx]
    X_test  = X[test_idx]
    
    obs_train = adata.obs.iloc[train_idx].copy()
    obs_test  = adata.obs.iloc[test_idx].copy()

    train_dataset = PerturbPairData(
        X_train,
        obs=obs_train,
        perturb_col="perturbation",
        control_col="control",
        condition_col="cell_type",
        seed=42,
        device=device,
        perturb_to_idx=global_perturb_to_idx,
        condition_to_idx=global_condition_to_idx
    )
    test_dataset = PerturbPairData(
        X_test,
        obs=obs_test,
        perturb_col="perturbation",
        control_col="control",
        condition_col="cell_type",
        seed=43,
        device=device,
        perturb_to_idx=global_perturb_to_idx,
        condition_to_idx=global_condition_to_idx
    )

    train_x0, train_x1, train_p, train_c = prepare_tensors(train_dataset)
    test_x0, test_x1, test_p, test_c = prepare_tensors(test_dataset)

    flow_model = FlowModel(
        dim=X.shape[1],
        hidden_dim=HIDDEN_DIM,
        conditional_model=True,
        num_conditions=train_dataset.num_conditions(),
        num_perturbs=len(adata.obs['perturbation'].unique()),
        condition_embedding_dim=COND_EMB_DIMS,
        perturb_embedding_dim=PERT_EMB_DIMS,
    ).to(device)

    optimizer = torch.optim.AdamW(flow_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    loss_fn = nn.HuberLoss(delta=1.0)

    train_losses, test_losses = [], []
    n_train = train_x0.shape[0]
    n_test = test_x0.shape[0]

    for epoch in tqdm(range(NUM_EPOCHS), desc="Training"):
        perm = torch.randperm(n_train, device=device)
        flow_model.train()
        train_loss_sum_epoch = 0.0
        # train_loss_sum_pear_corr = 0.0
        train_n_epoch = 0

        for i in range(0, n_train, BATCH_SIZE):
            idx = perm[i : i + BATCH_SIZE]

            x_0 = train_x0[idx]
            x_1 = train_x1[idx]

            # add noise
            sigma = SIGMA
            x_0 = x_0 + torch.randn_like(x_0) * sigma
            x_1 = x_1 + torch.randn_like(x_1) * sigma

            perturb = train_p[idx]
            ct = train_c[idx]
            t = torch.rand(len(x_1), 1, device=device)

            x_t = (1 - t) * x_0 + t * x_1
            dx_t = x_1 - x_0

            pred = flow_model(x_t, t, perturb, ct)

            # pear_corr = torch.corrcoef(torch.stack([dx_t.ravel(), pred.ravel()]))[0, 1]

            optimizer.zero_grad()
            loss = loss_fn(pred, dx_t)
            loss.backward()
            optimizer.step()
            train_loss_sum_epoch += loss.item() * x_1.size(0)
            # train_loss_sum_pear_corr += pear_corr * x_1.size(0)
            train_n_epoch += x_1.size(0)

        avg_train_loss = train_loss_sum_epoch / max(train_n_epoch, 1)
        # avg_train_pear_corr = train_loss_sum_pear_corr / max(train_n_epoch, 1)

        flow_model.eval()

        test_loss_sum_epoch = 0.0
        test_loss_sum_pear_corr = 0.0
        test_n_epoch = 0

        with torch.no_grad():
            for i in range(0, n_test, BATCH_SIZE):
                x_0 = test_x0[i: i+BATCH_SIZE]
                x_1 = test_x1[i: i+BATCH_SIZE]
                perturb = test_p[i: i+BATCH_SIZE]
                ct = test_c[i: i+BATCH_SIZE]
                t = torch.rand(len(x_1), 1, device=device)

                x_t = (1 - t) * x_0 + t * x_1
                dx_t = x_1 - x_0

                pred = flow_model(x_t, t, perturb, ct)
                loss = loss_fn(pred, dx_t)

                # pear_corr = torch.corrcoef(torch.stack([dx_t.ravel(), pred.ravel()]))[0, 1]

                test_loss_sum_epoch += loss.item() * x_1.size(0)
                # test_loss_sum_pear_corr += pear_corr * x_1.size(0)
                test_n_epoch += x_1.size(0)

        avg_test_loss = test_loss_sum_epoch / max(test_n_epoch, 1)
        # avg_test_pear_corr = test_loss_sum_pear_corr / max(test_n_epoch, 1)

        train_losses.append(avg_train_loss)
        test_losses.append(avg_test_loss)

        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch + 1}/{NUM_EPOCHS}] | train loss: {avg_train_loss:.6f} | test loss: {avg_test_loss:.6f}")

    sampler = ZINBSampler(model_ae, device=device)

    wds = []
    deg_wds_list = []

    for pert in heldout_perts:

        try:
            treated_mask = (adata.obs['perturbation'] == pert) & (adata.obs['donor'] == holdout_donor)
            # treated_mask = (adata.obs['perturbation'] == pert) & (adata.obs['donor'] != holdout_donor)
            # print(treated_mask.sum())

            X_ctrl = torch.from_numpy(adata[ctrl_mask].X.toarray().astype("float32")).to(device)
            X_actual_treated_raw = torch.from_numpy(adata[treated_mask].layers['counts'].toarray().astype("float32")).to(device)
            library_sizes = torch.from_numpy(adata.layers['counts'][ctrl_mask].sum(1)).float().to(device)

            adata_deg = adata[ctrl_mask | treated_mask].copy()

            sc.tl.rank_genes_groups(
                adata_deg,
                groupby='cytokine',
                groups=[pert],
                reference='PBS',
                method='wilcoxon'
            )

            de_results_df = sc.get.rank_genes_groups_df(
                adata_deg,
                group=pert
            )

            top_degs = de_results_df[
                (de_results_df['pvals_adj'] < 0.05) & 
                (np.abs(de_results_df['logfoldchanges']) > 0.5)
            ]

            # print(f"len(top_degs): {len(top_degs)}")

            pert_degs = top_degs['names'].values.tolist()

            # print(pert_degs)

            with torch.no_grad():

                z_control = model_ae.encode(X_ctrl)

                pert_idx = torch.tensor([global_perturb_to_idx[pert]] * X_ctrl.shape[0], device=device)
                control_cell_types = adata.obs.loc[ctrl_mask, 'cell_type'].values
                cond_indices = [global_condition_to_idx[ct] for ct in control_cell_types]
                cond_tensor = torch.tensor(cond_indices, device=device)

                z_pred = flow_model.integrate(
                        x0=z_control, 
                        t0=0.0, 
                        t1=1.0, 
                        perturbations=pert_idx,
                        conditions=cond_tensor
                    )

                X_pred_treated_raw = sampler.sample(z_pred, library_sizes)

                wd_gene = compute_wd_per_gene(X_pred_treated_raw.cpu(), X_actual_treated_raw.cpu())

                # print(f'mean wd on test data: {wd_gene.mean()}')
            
                wds.append(wd_gene)

                deg_indices = [adata.var_names.get_loc(g) for g in pert_degs if g in adata.var_names]
                wd_degs = wd_gene[deg_indices]
                deg_wds_list.extend(wd_degs)
        except Exception as e:
            print(e)

    mean_wd = np.array(wds).mean()
        
    print(f'mean wd across all genes and perts: {mean_wd:.4f}')

    all_wds.append(mean_wd)

    mean_deg_wg = np.array(deg_wds_list).mean()

    print(f'mean deg wd across all genes and perts: {mean_deg_wg:.4f}')

    all_deg_wds.append(mean_deg_wg)

    break # remove this if you want to train on all splits

print(f"cv mean wd: {np.array(all_wds).mean():.4f}")
print(f"cv mean deg wd: {np.array(all_deg_wds).mean():.4f}")

In [None]:
go.Figure([
    go.Scatter(y=train_losses, name="train"),
    go.Scatter(y=test_losses, name="test"),
]).update_layout(xaxis_title="epoch", yaxis_title="loss", yaxis=dict(range=[0, None])).show()

In [None]:
# wds = []
# deg_wds_list = []

# for pert in heldout_perts:
#     # treated_mask = (adata.obs['perturbation'] == pert) & (adata.obs['donor'] == holdout_donor)
#     treated_mask = (adata.obs['perturbation'] == pert) & (adata.obs['donor'] != holdout_donor)
#     # print(treated_mask.sum())

#     X_ctrl = torch.from_numpy(adata[ctrl_mask].X.toarray().astype("float32")).to(device)
#     X_actual_treated_raw = torch.from_numpy(adata[treated_mask].layers['counts'].toarray().astype("float32")).to(device)
#     library_sizes = torch.from_numpy(adata.layers['counts'][ctrl_mask].sum(1)).float().to(device)

#     adata_deg = adata[ctrl_mask | treated_mask].copy()

#     sc.tl.rank_genes_groups(
#         adata_deg,
#         groupby='cytokine',
#         groups=[pert],
#         reference='PBS',
#         method='wilcoxon'
#     )

#     de_results_df = sc.get.rank_genes_groups_df(
#         adata_deg,
#         group=pert
#     )

#     top_degs = de_results_df[
#         (de_results_df['pvals_adj'] < 0.05) & 
#         (np.abs(de_results_df['logfoldchanges']) > 0.5)
#     ]

#     # print(f"len(top_degs): {len(top_degs)}")

#     pert_degs = top_degs['names'].values.tolist()

#     # print(pert_degs)

#     with torch.no_grad():

#         z_control = model_ae.encode(X_ctrl)

#         pert_idx = torch.tensor([global_perturb_to_idx[pert]] * X_ctrl.shape[0], device=device)
#         control_cell_types = adata.obs.loc[ctrl_mask, 'cell_type'].values
#         cond_indices = [global_condition_to_idx[ct] for ct in control_cell_types]
#         cond_tensor = torch.tensor(cond_indices, device=device)

#         z_pred = flow_model.integrate(
#                 x0=z_control, 
#                 t0=0.0, 
#                 t1=1.0, 
#                 perturbations=pert_idx,
#                 conditions=cond_tensor
#             )

#         X_pred_treated_raw = sampler.sample(z_pred, library_sizes)

#         wd_gene = compute_wd_per_gene(X_pred_treated_raw.cpu(), X_actual_treated_raw.cpu())

#         # print(f'mean wd on test data: {wd_gene.mean()}')
    
#         wds.append(wd_gene)

#         deg_indices = [adata.var_names.get_loc(g) for g in pert_degs if g in adata.var_names]
#         wd_degs = wd_gene[deg_indices]
#         deg_wds_list.extend(wd_degs)

#         # print(wd_degs)

# np.array(deg_wds_list).mean()

In [None]:
sampler = ZINBSampler(model_ae, device=device)

for pert in heldout_perts:
    treated_mask = (adata.obs['perturbation'] == pert) & (adata.obs['donor'] != holdout_donor)
    # treated_mask = (adata.obs['donor'] == holdout_donor) & (adata.obs['perturbation'] == pert)

    X_ctrl = torch.from_numpy(adata[ctrl_mask].X.toarray().astype("float32")).to(device)
    X_actual_treated_raw = torch.from_numpy(adata[treated_mask].layers['counts'].toarray().astype("float32")).to(device)
    library_sizes = torch.from_numpy(adata.layers['counts'][ctrl_mask].sum(1)).float().to(device)

    with torch.no_grad():

        z_control = model_ae.encode(X_ctrl)

        pert_idx = torch.tensor([global_perturb_to_idx[pert]] * X_ctrl.shape[0], device=device)
        control_cell_types = adata.obs.loc[ctrl_mask, 'cell_type'].values
        cond_indices = [global_condition_to_idx[ct] for ct in control_cell_types]
        cond_tensor = torch.tensor(cond_indices, device=device)

        z_pred = flow_model.integrate(
                x0=z_control, 
                t0=0.0, 
                t1=1.0, 
                perturbations=pert_idx,
                conditions=cond_tensor
            )

        X_pred_treated_raw = sampler.sample(z_pred, library_sizes)

In [None]:
wd_gene.argmax()

In [None]:
p98_wd = np.percentile(wd_gene, 99.8)
print(f"98th Percentile WD: {p98_wd:.4f}")

In [None]:
sns.histplot(wd_gene, stat='density', discrete=True, color='blue', alpha=0.4, label='true treated', element='step')

In [None]:
mean_expression = X_actual_treated_raw.mean(dim=0)
top_values, top_indices = torch.topk(mean_expression, k=10)
top_indices

In [None]:
gene_idx = 849
x_pred = X_pred_treated_raw[:, gene_idx].cpu().numpy()
x_true = X_actual_treated_raw[:, gene_idx].cpu().numpy()

plt.figure(figsize=(6, 4))

sns.histplot(x_true, stat='density', discrete=True, color='blue', alpha=0.4, label='true treated', element='step')
sns.histplot(x_pred, stat='density', discrete=True, color='orange', alpha=0.4, label='pred treated', element='step')
plt.legend()
plt.show()

In [None]:
print(X_pred_treated_raw.shape)
print(X_actual_treated_raw.shape)

In [None]:
wd_gene.mean()

In [None]:
# flow_model.eval()

# x_0_list, x_1_list, x_pred_list, cond_list = [], [], [], []

# for split, loader in zip(["train", "test"], [train_loader, test_loader]):
# # for split, loader in zip(["train"], [train_loader]):

#     with torch.no_grad():
#         for batch in loader:
#             x_0 = batch["x_0"].to(device)
#             x_1 = batch["x_1"].to(device)
#             perturb = batch["perturb"].to(device)
#             conditions = batch["condition"].to(device)
#             x_pred = flow_model.integrate(x_0, 0.0, 1.0, perturbations=perturb, conditions=conditions)
#             x_0_list.append(x_0.detach().cpu())
#             x_1_list.append(x_1.detach().cpu())
#             x_pred_list.append(x_pred.detach().cpu())
#             cond_list.append(conditions.detach().cpu())

#     X_0 = torch.cat(x_0_list, dim=0).numpy()
#     X_1 = torch.cat(x_1_list, dim=0).numpy()
#     X_pred = torch.cat(x_pred_list)
#     cond_all = torch.cat(cond_list, dim=0).numpy()
#     cond_all = cond_all.squeeze().astype(str)

#     print(f"split: {split}, cond_all: {cond_all}")

#     X_all = np.vstack([X_0, X_1, X_pred])

#     pca = PCA(n_components=2, random_state=0)
#     pca.fit(np.vstack([X_0, X_1]))
#     Z_all = pca.transform(X_all)

#     Z0 = Z_all[:len(X_0)]
#     Z1 = Z_all[len(X_0):len(Z_all)-len(X_pred)]
#     X_pred_pca = Z_all[len(Z_all)-len(X_pred):]

#     df = pd.DataFrame({
#         "pc1": np.concatenate([Z0[:,0], Z1[:,0], X_pred_pca[:, 0]]),
#         "pc2": np.concatenate([Z0[:,1], Z1[:,1], X_pred_pca[:, 1]]),
#         "split": ["x_0 (control)"]*len(Z0) + ["x_1 (perturbed)"]*len(Z1) + ["x_1 (predicted)"]*len(X_pred_pca),
#         "cond": cond_all.tolist() + [None] * (len(Z1) + len(X_pred_pca)),
#     })

#     fig = px.scatter(df, x="pc1", y="pc2", color="split", opacity=0.6, title=f"{split} set in PCA space", width=600, height=500)
#     fig.show()

#     # fig = px.scatter(df, x="pc1", y="pc2", color="cond", opacity=0.6, title=f"{split} set in PCA space", width=600, height=500)
#     # fig.show()

In [None]:
test_dataset.condition_to_idx

In [None]:
set(train_idx).isdisjoint(set(test_idx))

In [None]:
adata.var.index[1011]

In [None]:
scores = np.asarray(adata.X.mean(axis=0)).flatten()
top_n_indices = np.argsort(scores)[::-1][:10]
top_n_indices

In [None]:
adata.obs['cytokine'].value_counts()

In [None]:
test_dataset.perturbs

In [None]:
sig_effects = {}

for pert in test_dataset.perturbs:

    if pert == 'PBS':
        continue

    print("----" * 10)
    print(f"pert: {pert}")

    test_adata = adata[adata.obs['cytokine'].isin(['PBS', pert])].copy()

    sc.tl.rank_genes_groups(
        test_adata,
        groupby='cytokine',
        groups=[pert],
        reference='PBS',
        method='wilcoxon'
    )

    de_results_df = sc.get.rank_genes_groups_df(
        test_adata,
        group=pert
    )

    top_degs = de_results_df[
        (de_results_df['pvals_adj'] < 0.05) & 
        (np.abs(de_results_df['logfoldchanges']) > 0.5)
    ].head(20)

    print(top_degs[['names', 'pvals_adj', 'logfoldchanges']])

    sig_effects[pert] = []
    for gene, fold_change in zip(top_degs['names'], top_degs['logfoldchanges']):
        sig_effects[pert].append((gene, fold_change))

In [None]:
sig_effects

In [None]:
PERTURBATION = 'IL-15' # 'IL-7'

In [None]:
np.where(adata.var.index == 'IFIT3')

In [None]:
from scvi.distributions import ZeroInflatedNegativeBinomial
import matplotlib.pyplot as plt
from scipy.stats import wasserstein_distance

mmds = []
wds = []

for pert, effects in sig_effects.items():
    
    if len(effects) == 0:
        continue

    for sig_gene, fold_change in effects:

        print(f'pert: {pert}, gene: {sig_gene}, fold_change: {fold_change}')

        gene_idx = np.where(adata.var.index == sig_gene)[0][0]

        # gene_idx = 329
        # donor_id = 'Donor1'

        # target_donors = adata.obs['donor'].unique()
        # split_mask = np.zeros(adata.n_obs, dtype=bool)
        # split_mask[train_idx] = True

        target_donors = ['Donor1']
        split_mask = np.zeros(adata.n_obs, dtype=bool)
        split_mask[test_idx] = True

        donor_mask = adata.obs['donor'].isin(target_donors)

        control_mask = (adata.obs['cytokine'] == 'PBS') & split_mask & donor_mask
        treated_mask = (adata.obs['cytokine'] == pert) & split_mask & donor_mask

        x_control_tensor = torch.from_numpy(adata[control_mask].X.toarray().astype("float32")).to(device)
        x_treated_tensor = torch.from_numpy(adata[treated_mask].X.toarray().astype("float32")).to(device)

        with torch.no_grad():
            z_control = model_ae.encode(x_control_tensor)
            z_treated = model_ae.encode(x_treated_tensor)

            pert_idx = train_dataset.perturb_to_idx[pert]

            control_cell_types = adata.obs.loc[control_mask, 'cell_type'].values
            cond_indices = [train_dataset.condition_to_idx[ct] for ct in control_cell_types]
            cond_tensor = torch.tensor(cond_indices, device=device)

            z_pred = flow_model.integrate(
                x0=z_control, 
                t0=0.0, 
                t1=1.0, 
                perturbations=pert_idx,
                conditions=cond_tensor
            )

        batch_size = z_control.shape[0]
        pert_tensor = torch.tensor([pert_idx], device=device).repeat(batch_size)

        h_decoded = model_ae.decoder(z_pred)


        pred_proportions = model_ae.decoder_mean(h_decoded)
        pred_dropout = model_ae.decoder_dropout(h_decoded)
        pred_dispersion = torch.exp(model_ae.decoder_dispersion)

        batch_proportions = pred_proportions[:, gene_idx]
        batch_dropout = pred_dropout[:, gene_idx]
        gene_theta = torch.exp(model_ae.decoder_dispersion)[gene_idx]
        batch_theta = gene_theta.repeat(batch_proportions.shape[0])
        batch_library_size = torch.from_numpy(
            adata.layers['counts'][control_mask.values].sum(axis=1)
        ).float().to(device).flatten()

        dist = ZeroInflatedNegativeBinomial(
            mu=batch_proportions * batch_library_size,           
            theta=gene_theta,     
            zi_logits=batch_dropout,
        )

        synthetic_treated_counts = dist.sample().cpu().numpy()
        real_control_counts = adata.layers['counts'][control_mask.values, gene_idx].toarray().flatten()
        real_treated_counts = adata.layers['counts'][treated_mask.values, gene_idx].toarray().flatten()

        plt.figure(figsize=(5, 3))

        sns.histplot(real_control_counts, label='observed control', fill=True, stat='density', discrete=True, element='step', color='gray')
        sns.histplot(real_treated_counts, label='observed treated', fill=True, stat='density', discrete=True, element='step', color='green')
        sns.histplot(synthetic_treated_counts, label='model prediction treated', fill=True, stat='density', discrete=True, element='step', color='blue', alpha=0.2)

        wd = wasserstein_distance(real_treated_counts, synthetic_treated_counts)
        print(f"Gene {sig_gene}: Wasserstein Distance = {wd:.4f}")
        wds.append(wd)

        plt.title(f"model fit for gene {adata.var_names[gene_idx]}")
        plt.xlabel("gene expression raw")
        plt.legend()
        plt.show()

    z_all = torch.cat([z_control, z_treated, z_pred], dim=0).cpu().numpy()
    diff_vector = z_treated.mean(0) - z_control.mean(0)
    diff_vector = diff_vector / diff_vector.norm()
    proj_control = (z_control @ diff_vector).cpu().numpy()
    proj_treated = (z_treated @ diff_vector).cpu().numpy()
    proj_pred = (z_pred @ diff_vector).cpu().numpy()

    mu_c = proj_control.mean()
    std_c = proj_control.std()

    norm_control = (proj_control - mu_c) / std_c
    norm_treated = (proj_treated - mu_c) / std_c
    norm_pred = (proj_pred - mu_c) / std_c

    plt.figure(figsize=(5, 3))
    sns.kdeplot(norm_control, fill=True, label='control', color='gray')
    sns.kdeplot(norm_treated, fill=True, label='treated', color='green')
    sns.kdeplot(norm_pred, fill=True, label='predicted', color='blue')

    x_true = norm_treated.reshape(-1, 1)
    x_pred = norm_pred.reshape(-1, 1)
    mmd = compute_mmd(x_true, x_pred)
    print(f"mmd: {mmd:.4f}")

    mmds.append(mmd)

    plt.title(f"projection onto perturbation axis for {pert}")
    plt.xlabel("latent score AU")
    plt.legend()
    plt.show()


# print(f"avg mmd: {np.array(mmds).mean():.4f}")

print(f"avg wd: {np.array(wds).mean():.4f}")

In [None]:
import umap

perturb_embedding = np.array(flow_model.perturb_emb.weight.detach().cpu().numpy())
pca_result = PCA(n_components=2).fit_transform(perturb_embedding)
reducer = umap.UMAP(n_components=2)
umap_result = reducer.fit_transform(perturb_embedding)

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].scatter(pca_result[:, 0], pca_result[:, 1])
ax[0].set_title("PCA")
ax[1].scatter(umap_result[:, 0], umap_result[:, 1])
ax[0].set_title("UMAP")
plt.show()