Isoflow

In [None]:
import sys, os
import scanpy as sc
import anndata
import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from typing import List, Optional, Callable
import torch.nn as nn
import torch.nn.functional as F
from autoencoder_utils import NB_Autoencoder
from abc import ABC, abstractmethod
from typing import Optional, List, Type, Tuple, Dict
import math
import anndata as ad
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.cm as cm
from matplotlib.axes._axes import Axes
import torch
import torch.distributions as D
from torch.func import vmap, jacrev
from tqdm import tqdm
import seaborn as sns
from sklearn.datasets import make_moons, make_circles
from pathlib import Path
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from flow_utils import *

In [3]:
# import the dataset
TRAIN_DATA_PATH = Path("/dtu/blackhole/06/213542/paperdata/pbmc3k_train.h5ad")
TEST_DATA_PATH = Path("/dtu/blackhole/06/213542/paperdata/pbmc3k_test.h5ad")
RESULTS_DATA_PATH = Path("/dtu/blackhole/06/213542/paperdata/")
RESULTS_DATA_PATH.mkdir(parents = True, exist_ok = True)

adata = sc.read_h5ad(TRAIN_DATA_PATH)
adata_test = sc.read_h5ad(TEST_DATA_PATH)

#remove genes with low count
#sc.pp.filter_genes(adata, min_cells=20)
#sc.pp.filter_genes(adata_test, min_cells=20)

In [4]:
# -------------------------------
# Main training + encoding
# -------------------------------
    # --- Hyperparameters ---
input_file = adata
latent_dim = 50
hidden_dims = [512, 256]
batch_size = 512
epochs = 5            # short run to check
learning_rate = 1e-3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

X = adata.X
if hasattr(X, "toarray"):
    X = X.toarray()
X = torch.tensor(X, dtype=torch.float32)

dataset = TensorDataset(X)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # --- Initialize model ---
num_genes = adata.n_vars
model = NB_Autoencoder(num_features=num_genes, latent_dim=latent_dim, hidden_dims=hidden_dims)
model = model.to(device)
model.train()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # --- Training loop ---
for epoch in range(epochs):
    epoch_loss = 0
    for batch in dataloader:
        x_batch = batch[0].to(device)
        outputs = model(x_batch)
        loss_dict = model.loss_function(x_batch, outputs)
        loss = loss_dict["loss"]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item() * x_batch.size(0)

    epoch_loss /= len(dataset)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.3f}, NLL: {loss_dict['nll'].item():.3f}")

    # --- Save trained model ---
model_file = TRAIN_DATA_PATH.with_name(TRAIN_DATA_PATH.stem + "_nb_autoencoder.pt")
torch.save(model.state_dict(), model_file)
print(f"Trained model saved to {model_file}")

    # --- Encode all cells into latent space ---
model.eval()
all_z = []
with torch.no_grad():
    for batch in DataLoader(dataset, batch_size=batch_size):
        x_batch = batch[0].to(device)
        z = model(x_batch)["z"].cpu().numpy()
        all_z.append(z)
latent = np.concatenate(all_z, axis=0)

    # --- Save latent space to AnnData ---
adata.obsm["X_latent"] = latent
output_file = TRAIN_DATA_PATH.with_name(TRAIN_DATA_PATH.stem + "_with_latent.h5ad")
adata.write(output_file)
print(f"Latent space saved to {output_file}")



Epoch 1/5, Loss: 5837.239, NLL: 5400.990
Epoch 2/5, Loss: 4941.993, NLL: 4514.593
Epoch 3/5, Loss: 4193.137, NLL: 3856.201
Epoch 4/5, Loss: 3700.656, NLL: 3548.700
Epoch 5/5, Loss: 3452.399, NLL: 3385.501
Trained model saved to /dtu/blackhole/06/213542/paperdata/pbmc3k_train_nb_autoencoder.pt
Latent space saved to /dtu/blackhole/06/213542/paperdata/pbmc3k_train_with_latent.h5ad


In [None]:
# --- Load model ---
num_genes = adata_test.n_vars
model = NB_Autoencoder(num_features=num_genes, latent_dim=50, hidden_dims=[512, 256])
model.load_state_dict(torch.load(model_file))
model = model.to(device)
model.eval()

# Encode new cells
X_new = adata_test.X
if hasattr(X_new, "toarray"):
    X_new = X_new.toarray()
X_new = torch.tensor(X_new, dtype=torch.float32).to(device)

with torch.no_grad():
    z_new = model(X_new)["z"].cpu().numpy()

# Save to AnnData
adata_test.obsm["X_latent"] = z_new
new_cells = RESULTS_DATA_PATH /"new_cells_with_latent.h5ad"
adata_test.write(new_cells)

print(f"new cells are generated in  >>{new_cells}")


Flow model below

In [None]:
adata = ad.read_h5ad(output_file)

# Access latent representation
latent = adata.obsm["X_latent"]
# make it to a tensor and save in GPU
latent_tensor = torch.tensor(latent, dtype=torch.float32, device = device)
print("Shape of latent space:", latent.shape)
print(latent[:5])

In [None]:

dist = EmpiricalDistribution(latent_tensor)
samples = dist.sample(3)
logp = dist.log_density(samples)
print(logp)


In [None]:
# We want to go with Gaussian probability path, therefore we need to load functions for alpha and beta
alpha = LinearAlpha()
beta = LinearBeta()
path = GaussianConditionalProbabilityPath(
    p_data=emp_dist,
    alpha=alpha,
    beta=beta
)
print(path)

In [None]:
# now that we were able to construct a Gaussian probability path, we have to be able to make a conditional vector field
cvf_ode = ConditionalVectorFieldODE(path, z)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 2110
num_epochs = 5000
learning_rate = 1e-3
latent_dim = latent_tensor.shape[1]  # e.g., 50

vf_model = NeuralVectorField(latent_dim=latent_dim).to(device)
optimizer = torch.optim.AdamW(vf_model.parameters(), lr=learning_rate)

# Initialize GaussianConditionalProbabilityPath and ConditionalVectorFieldODE
path = GaussianConditionalProbabilityPath(emp_dist, alpha, beta)  # define alpha, beta
cvf_ode = ConditionalVectorFieldODE(path, z=torch.zeros(1, latent_dim, device=device))

for epoch in range(num_epochs):
    # --- Sample conditioning variable z ---
    z = emp_dist.sample(batch_size).to(device)

    # --- Sample time ---
    t = torch.rand(batch_size, 1, device=device)

    # --- Sample x_t from conditional path ---
    with torch.no_grad():
        x = path.sample_conditional_path(z, t)
        u_target = path.conditional_vector_field(x, z, t)

    # --- Normalize target ---
    u_mean = u_target.mean(dim=0, keepdim=True)
    u_std = u_target.std(dim=0, keepdim=True) + 1e-6
    u_target_norm = (u_target - u_mean) / u_std

    # --- Forward pass ---
    v_pred = vf_model(x, z, t)

    # --- Loss ---
    loss = F.mse_loss(v_pred, u_target_norm)

    # --- Backprop ---
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(vf_model.parameters(), max_norm=1.0)
    optimizer.step()

    if epoch % 50 == 0:
        print(f"[{epoch}] Loss: {loss.item():.6f}")


In [None]:
# Wrap the trained neural network
learned_ode = LearnedVectorFieldODE(vf_model)

# Save the wrapper
torch.save(learned_ode, RESULTS_DATA_PATH / "learned_ode.pt")

In [None]:
# Number of samples and latent dimension
n_samples = 1000
latent_dim = latent_tensor.shape[1]

# Starting points (noise)
x = torch.randn(n_samples, latent_dim, device=device)

# Conditioning variable z
# Single vector, broadcast to all samples
z = torch.zeros(1, latent_dim, device=device)  # or z = emp_dist.sample(1)

# Wrap the trained neural network as an ODE
learned_ode = LearnedVectorFieldODE(vf_model)

# Create Euler simulator with the conditioning variable
simulator = EulerSimulator(learned_ode, z)

# Simulation parameters
t0, t1 = 0.0, 1.0
n_steps = 50
dt = (t1 - t0) / n_steps

# Store trajectory
trajectory = [x.clone()]
t = torch.full((n_samples, 1), t0, device=device)

# Euler integration
for _ in range(n_steps):
    x = simulator.step(x, t, dt)
    trajectory.append(x.clone())
    t = t + dt

# Final generated samples
generated_cells = trajectory[-1]
print(generated_cells.shape)  # (1000, latent_dim)
torch.save(generated_cells, RESULTS_DATA_PATH / "generated_latent.pt")
