In [2]:

import anndata as ad
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd

from typing import List, Optional, Callable

In [3]:
# load the fuckin data

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

train_latent_path = "/dtu/blackhole/06/213542/paperdata/pbmc3k_train_with_latent.h5ad"
flow_model_save_path = "/dtu/blackhole/06/213542/flow_model.pt"

Using device: cpu


In [4]:
# ---- Load AnnData with latent space ----
adata = ad.read_h5ad(train_latent_path)

assert "X_latent" in adata.obsm, "Latent representation X_latent not found in .obsm"
z1 = adata.obsm["X_latent"]              # shape [N, 50]

# cell type labels
assert "cell_type" in adata.obs.columns, "cell_type column not found in adata.obs"
celltypes = pd.Categorical(adata.obs["cell_type"])
y_idx = torch.tensor(celltypes.codes, dtype=torch.long)   # integer labels
num_celltypes = len(celltypes.categories)

# library size (n_counts)
assert "n_counts" in adata.obs.columns, "n_counts column not found in adata.obs"
libsize = adata.obs["n_counts"].values.astype(np.float32)
log_l = np.log(libsize + 1.0)[:, None]    # shape [N, 1]

# ---- Convert to torch tensors ----
z1 = torch.tensor(z1, dtype=torch.float32)
log_l = torch.tensor(log_l, dtype=torch.float32)

# one-hot for cell types
y_onehot = F.one_hot(y_idx, num_classes=num_celltypes).float()

print("z1 shape:", z1.shape)
print("y_onehot shape:", y_onehot.shape)
print("log_l shape:", log_l.shape)
print("Num cell types:", num_celltypes)

z1 shape: torch.Size([2110, 50])
y_onehot shape: torch.Size([2110, 8])
log_l shape: torch.Size([2110, 1])
Num cell types: 8


In [5]:
class LatentFlowDataset(Dataset):
    def __init__(self, z, y_onehot, log_l):
        assert len(z) == len(y_onehot) == len(log_l)
        self.z = z
        self.y = y_onehot
        self.log_l = log_l

    def __len__(self):
        return self.z.shape[0]

    def __getitem__(self, idx):
        return self.z[idx], self.y[idx], self.log_l[idx]

batch_size = 512

dataset = LatentFlowDataset(z1, y_onehot, log_l)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

len(dataset), len(dataloader)

(2110, 4)

In [6]:
class LatentFlowDataset(Dataset):
    def __init__(self, z, y_onehot, log_l):
        assert len(z) == len(y_onehot) == len(log_l)
        self.z = z
        self.y = y_onehot
        self.log_l = log_l

    def __len__(self):
        return self.z.shape[0]

    def __getitem__(self, idx):
        return self.z[idx], self.y[idx], self.log_l[idx]

batch_size = 512

dataset = LatentFlowDataset(z1, y_onehot, log_l)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

len(dataset), len(dataloader)

(2110, 4)

In [7]:
# -------------------------------
# MLP (copied from autoencoder)
# -------------------------------
class MLP(nn.Module):
    def __init__(self, 
                 dims: List[int],
                 batch_norm: bool = True, 
                 dropout: bool = True, 
                 dropout_p: float = 0.1, 
                 activation: Optional[Callable] = nn.ELU, 
                 final_activation: Optional[str] = None):
        super().__init__()
        self.dims = dims
        layers = []
        for i in range(len(dims[:-2])):
            block = [nn.Linear(dims[i], dims[i+1])]
            if batch_norm:
                block.append(nn.BatchNorm1d(dims[i+1]))
            block.append(activation())
            if dropout:
                block.append(nn.Dropout(dropout_p))
            layers.append(nn.Sequential(*block))
        layers.append(nn.Linear(dims[-2], dims[-1]))
        self.net = nn.Sequential(*layers)
        if final_activation == "tanh":
            self.final_activation = nn.Tanh()
        elif final_activation == "sigmoid":
            self.final_activation = nn.Sigmoid()
        else:
            self.final_activation = None

    def forward(self, x):
        x = self.net(x)
        return x if self.final_activation is None else self.final_activation(x)

# -------------------------------
# Vector Field Network
# -------------------------------
class VectorField(nn.Module):
    def __init__(self, latent_dim: int, cond_dim: int, hidden_dim: int = 256):
        super().__init__()
        # time embedding
        self.time_mlp = nn.Sequential(
            nn.Linear(1, 64),
            nn.SiLU(),
            nn.Linear(64, 64),
            nn.SiLU(),
        )
        # main network
        self.net = MLP(
            dims=[latent_dim + 64 + cond_dim, hidden_dim, hidden_dim, latent_dim],
            batch_norm=True,
            dropout=True,
            dropout_p=0.1,
            activation=nn.ELU,
        )

    def forward(self, zt, t, cond):
        """
        zt: [B, latent_dim]
        t: [B, 1]
        cond: [B, cond_dim]
        """
        t_emb = self.time_mlp(t)           # [B, 64]
        x = torch.cat([zt, t_emb, cond], dim=1)
        return self.net(x)

In [8]:
latent_dim = z1.shape[1]          # 50
cond_dim = num_celltypes + 1      # cell types + log size factor

vf = VectorField(latent_dim=latent_dim, cond_dim=cond_dim).to(device)
optimizer = torch.optim.AdamW(vf.parameters(), lr=1e-4)

print(vf)

VectorField(
  (time_mlp): Sequential(
    (0): Linear(in_features=1, out_features=64, bias=True)
    (1): SiLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): SiLU()
  )
  (net): MLP(
    (net): Sequential(
      (0): Sequential(
        (0): Linear(in_features=123, out_features=256, bias=True)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ELU(alpha=1.0)
        (3): Dropout(p=0.1, inplace=False)
      )
      (1): Sequential(
        (0): Linear(in_features=256, out_features=256, bias=True)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ELU(alpha=1.0)
        (3): Dropout(p=0.1, inplace=False)
      )
      (2): Linear(in_features=256, out_features=50, bias=True)
    )
  )
)


In [9]:
import math

num_epochs = 50             # adjust as you like
drop_prob = 0.15            # probability to drop conditioning (for unconditional training)

vf.train()

for epoch in range(1, num_epochs + 1):
    epoch_loss = 0.0
    n_batches = 0

    for z1_batch, y_batch, logl_batch in dataloader:
        z1_batch = z1_batch.to(device)          # [B, d]
        y_batch = y_batch.to(device)            # [B, num_celltypes]
        logl_batch = logl_batch.to(device)      # [B, 1]

        B, d = z1_batch.shape

        # 1) sample noise z0
        z0 = torch.randn_like(z1_batch)

        # 2) sample time t ~ U(0,1)
        t = torch.rand(B, 1, device=device)

        # 3) straight-line interpolation
        zt = (1.0 - t) * z0 + t * z1_batch

        # 4) target velocity u = z1 - z0 (independent of t)
        u = z1_batch - z0

        # 5) build conditioning vector [onehot, log_l]
        cond_full = torch.cat([y_batch, logl_batch], dim=1)   # [B, num_celltypes+1]

        # 6) classifier-free guidance training: randomly drop conditioning
        mask = (torch.rand(B, 1, device=device) > drop_prob).float()
        cond_train = cond_full * mask    # some rows become all zeros (unconditional)

        # 7) predicted velocity
        pred = vf(zt, t, cond_train)

        loss = ((pred - u) ** 2).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        n_batches += 1

    avg_loss = epoch_loss / max(1, n_batches)
    print(f"Epoch {epoch}/{num_epochs} - Flow Matching Loss: {avg_loss:.6f}")

# save trained flow model
torch.save(vf.state_dict(), flow_model_save_path)
print(f"Flow model saved to {flow_model_save_path}")

Epoch 1/50 - Flow Matching Loss: 3.946148
Epoch 2/50 - Flow Matching Loss: 3.670063
Epoch 3/50 - Flow Matching Loss: 3.469714
Epoch 4/50 - Flow Matching Loss: 3.243708
Epoch 5/50 - Flow Matching Loss: 3.082470
Epoch 6/50 - Flow Matching Loss: 2.934929
Epoch 7/50 - Flow Matching Loss: 2.761802
Epoch 8/50 - Flow Matching Loss: 2.688230
Epoch 9/50 - Flow Matching Loss: 2.595561
Epoch 10/50 - Flow Matching Loss: 2.459846
Epoch 11/50 - Flow Matching Loss: 2.386738
Epoch 12/50 - Flow Matching Loss: 2.354287
Epoch 13/50 - Flow Matching Loss: 2.295142
Epoch 14/50 - Flow Matching Loss: 2.204769
Epoch 15/50 - Flow Matching Loss: 2.211609
Epoch 16/50 - Flow Matching Loss: 2.138831
Epoch 17/50 - Flow Matching Loss: 2.070782
Epoch 18/50 - Flow Matching Loss: 2.072066
Epoch 19/50 - Flow Matching Loss: 2.060832
Epoch 20/50 - Flow Matching Loss: 2.031931
Epoch 21/50 - Flow Matching Loss: 1.980508
Epoch 22/50 - Flow Matching Loss: 1.954944
Epoch 23/50 - Flow Matching Loss: 1.927908
Epoch 24/50 - Flow M

In [None]:
import torch
import torch.nn.functional as F
import numpy as np

# --- Utility: build conditioning batches ---

def build_cond_batch(celltype_idx: int,
                     log_l_value: float,
                     n_samples: int,
                     num_celltypes: int,
                     device: torch.device):
    """
    Build conditioning matrix [onehot(celltype), log_l] for n_samples.
    """
    y_idx = torch.full((n_samples,), celltype_idx, dtype=torch.long, device=device)
    y_onehot = F.one_hot(y_idx, num_classes=num_celltypes).float()        # [B, num_celltypes]
    log_l = torch.full((n_samples, 1), float(log_l_value), device=device) # [B, 1]
    return torch.cat([y_onehot, log_l], dim=1)                            # [B, num_celltypes+1]


@torch.no_grad()
def sample_latent_from_flow(
    vf: VectorField,
    cond: torch.Tensor,
    latent_dim: int,
    n_steps: int = 100,
    guidance_scale: float = 1.5,
    device: torch.device = torch.device("cpu"),
):
    """
    Integrate dz/dt = v_theta(z, t, cond) from t=0 to 1 with Euler steps.

    cond: [B, cond_dim] (same for all steps)
    returns: z1_samples [B, latent_dim]
    """
    vf.eval()

    B = cond.shape[0]
    cond = cond.to(device)

    # initial noise (t=0)
    z = torch.randn(B, latent_dim, device=device)

    dt = 1.0 / n_steps
    t = torch.zeros(B, 1, device=device)

    for k in range(n_steps):
        t.fill_(k * dt)

        # classifier-free guidance at sampling:
        # v = v_uncond + guidance_scale * (v_cond - v_uncond)
        v_uncond = vf(z, t, torch.zeros_like(cond))
        v_cond   = vf(z, t, cond)
        v = v_uncond + guidance_scale * (v_cond - v_uncond)

        z = z + v * dt

    return z.cpu()

In [10]:
# Example: generate 1000 cells of a chosen cell type
target_celltype_name = "CD4 T cells" 
n_samples = 1000

# Map name -> index in your existing 'celltypes' categorical
categories = list(celltypes.categories)
assert target_celltype_name in categories, f"{target_celltype_name} not in {categories}"
celltype_idx = categories.index(target_celltype_name)

# Get median log library size for that cell type from training data
mask_ct = (celltypes == target_celltype_name)
log_l_ct = np.log(adata.obs.loc[mask_ct, "n_counts"].values.astype(np.float32) + 1.0)
log_l_value = float(np.median(log_l_ct))

print(f"Using cell type '{target_celltype_name}' (idx={celltype_idx}) "
      f"with median log_l={log_l_value:.3f}")

# Build conditioning batch and sample
cond_batch = build_cond_batch(
    celltype_idx=celltype_idx,
    log_l_value=log_l_value,
    n_samples=n_samples,
    num_celltypes=num_celltypes,
    device=device,
)

z1_samples = sample_latent_from_flow(
    vf=vf,
    cond=cond_batch,
    latent_dim=latent_dim,
    n_steps=100,          # you can tune
    guidance_scale=1.5,   # you can tune (>=1)
    device=device,
)

print("Generated latent samples:", z1_samples.shape)  # [1000, latent_dim]

Using cell type 'CD4 T cells' (idx=0) with median log_l=7.747


NameError: name 'build_cond_batch' is not defined