In [35]:
# this the cleaner version of the flow matching model
# import all packages and data
# the data comes from the encoder in 50 dimensional format
from abc import ABC, abstractmethod
from typing import Optional, List, Type, Tuple, Dict
import math
import anndata as ad
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.cm as cm
from matplotlib.axes._axes import Axes
import torch
import torch.distributions as D
from torch.func import vmap, jacrev
from tqdm import tqdm
import seaborn as sns
from sklearn.datasets import make_moons, make_circles
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [22]:
# Load the encoded data from the autoencoder
input_file_path = "/dtu/blackhole/06/213542/paperdata/pbmc3k_train_with_latent.h5ad"
adata = ad.read_h5ad(input_file_path)

# Access latent representation
latent = adata.obsm["X_latent"]
# make it to a tensor and save in GPU
latent_tensor = torch.tensor(latent, dtype=torch.float32, device = device)
print("Shape of latent space:", latent.shape)
print(latent[:5])

Shape of latent space: (2110, 50)
[[ 7.2119427e-01  4.5030427e+00  4.2068303e-02 -1.2502990e+00
  -4.2017937e+00 -2.4717948e+00 -1.5488725e+00 -1.8064263e+00
   8.0419064e-02  1.3857546e+00  1.5554056e+00 -3.6009327e-01
  -3.2110305e+00 -6.2240481e-01 -1.2413149e+00 -1.1609215e+00
   3.6916310e-01  1.0324364e+00  4.2429629e-01 -1.3808545e+00
  -2.6205373e+00  1.0236942e+00  2.3036060e+00  2.9912877e+00
   1.8578960e+00 -1.3467599e+00 -2.4267607e+00  4.2268958e+00
   6.7074037e-01 -3.4374635e+00  1.6940813e+00  1.2496483e+00
  -1.0571158e+00  2.6391823e+00 -3.2057946e+00  6.9281155e-01
  -1.6405296e+00 -4.2034798e+00  1.0298922e+00  3.6848754e-01
   1.9148359e+00  5.8723283e-01 -8.6443865e-01 -4.0662823e+00
  -1.6493044e+00  8.4716082e-04 -4.6969833e+00 -4.1647023e-01
  -1.7651422e+00  3.6352050e-01]
 [ 1.0745261e+00  4.1269851e+00  6.6678226e-02 -1.6734134e+00
  -5.4810619e+00 -2.8954308e+00 -1.4715413e+00 -1.3079782e+00
  -2.7325594e-01  1.6386645e+00  1.7486247e+00  8.0161452e-01
  -

In [23]:
# This is a way of encoding our data for empirical data
class EmpiricalDistribution(torch.nn.Module):
    def __init__(
        self,
        data: torch.Tensor,
        bandwidth: Optional[float] = None,
        compute_log_density: bool = True,
    ):
        super().__init__()
        assert data.dim() == 2, "data must be shape (N, D)"
        data = data.contiguous()
        
        self.register_buffer("data", data)   # (N, D)
        self.n = data.shape[0]
        self.data_dim = data.shape[1]        # <-- renamed attribute
        self.compute_log_density_flag = compute_log_density

        # Bandwidth estimation
        if bandwidth is None:
            std = torch.std(data, dim=0).mean().item()
            factor = (4.0 / (self.data_dim + 2.0)) ** (1.0 / (self.data_dim + 4.0))
            bw = factor * (self.n ** (-1.0 / (self.data_dim + 4.0))) * (std + 1e-6)
            self.bandwidth = torch.tensor(float(bw), device=self.data.device)
        else:
            self.bandwidth = torch.tensor(float(bandwidth), device=self.data.device)

        self._log_const = -0.5 * self.data_dim * math.log(2.0 * math.pi) - self.data_dim * torch.log(self.bandwidth).item()

    @property
    def dim(self):
        return self.data_dim
    def sample(self, num_samples: int) -> torch.Tensor:
        idx = torch.randint(0, self.n, (num_samples,), device=self.data.device)
        return self.data[idx]

    def log_density(self, x: torch.Tensor) -> torch.Tensor:
        if not self.compute_log_density_flag:
            raise RuntimeError("log_density disabled (compute_log_density=False).")

        assert x.dim() == 2 and x.shape[1] == self.data_dim

        x = x.to(self.data.device)
        x_norm2 = (x ** 2).sum(dim=1, keepdim=True)
        data_norm2 = (self.data ** 2).sum(dim=1).unsqueeze(0)
        cross = x @ self.data.t()
        d2 = x_norm2 + data_norm2 - 2.0 * cross

        sigma2 = (self.bandwidth ** 2).item()
        exponents = -0.5 * d2 / (sigma2 + 1e-12)
        lse = torch.logsumexp(exponents, dim=1, keepdim=True)

        log_prob = math.log(1.0 / self.n) + lse + self._log_const
        return log_prob

In [24]:
# lets test if the empirical distribution class actually works
# the data has to be a torch tensor

dist = EmpiricalDistribution(latent_tensor)
samples = dist.sample(3)
logp = dist.log_density(samples)
print(logp)

# it seems to work

tensor([[-59.5364],
        [-58.2875],
        [-57.4158]], device='cuda:0')


In [25]:
# we have to have a class that can draw from a Gaussian distribution

class Gaussian(torch.nn.Module):
    """
    Multivariate Gaussian distribution
    """
    def __init__(self, mean: torch.Tensor, cov: torch.Tensor):
        """
        mean: shape (dim,)
        cov: shape (dim,dim)
        """
        super().__init__()
        self.register_buffer("mean", mean)
        self.register_buffer("cov", cov)

    @property
    def dim(self) -> int:
        return self.mean.shape[0]

    @property
    def distribution(self):
        return D.MultivariateNormal(self.mean, self.cov, validate_args=False)

    def sample(self, num_samples) -> torch.Tensor:
        return self.distribution.sample((num_samples,))
        
    def log_density(self, x: torch.Tensor):
        return self.distribution.log_prob(x).view(-1, 1)

    @classmethod
    def isotropic(cls, dim: int, std: float) -> "Gaussian":
        mean = torch.zeros(dim)
        cov = torch.eye(dim) * std ** 2
        return cls(mean, cov)

In [26]:
# We want to go with Gaussian probability path, therefore we need to load functions for alpha and beta
class LinearAlpha():
    """Implements alpha_t = t"""
    
    def __call__(self, t: torch.Tensor) -> torch.Tensor:
        return t  # linear in time

    def dt(self, t: torch.Tensor) -> torch.Tensor:
        return torch.ones_like(t)  # derivative of t is 1


class LinearBeta():
    """Implements beta_t = 1 - t"""
    
    def __call__(self, t: torch.Tensor) -> torch.Tensor:
        return 1 - t

    def dt(self, t: torch.Tensor) -> torch.Tensor:
        return -torch.ones_like(t)  # derivative of 1 - t is -1



In [27]:
class GaussianConditionalProbabilityPath():
    def __init__(self, p_data, alpha, beta):
        self.p_data = p_data 
        p_simple = Gaussian.isotropic(p_data.dim, 1.0)
        self.alpha = alpha
        self.beta = beta

    def sample_conditioning_variable(self, num_samples: int) -> torch.Tensor:
        """
        Samples the conditioning variable z ~ p_data(x)
        Args:
            - num_samples: the number of samples
        Returns:
            - z: samples from p(z), (num_samples, dim)
        """
        return self.p_data.sample(num_samples)
    
    def sample_conditional_path(self, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Samples from the conditional distribution p_t(x|z) = N(alpha_t * z, beta_t**2 * I_d)
        Args:
            - z: conditioning variable (num_samples, dim)
            - t: time (num_samples, 1)
        Returns:
            - x: samples from p_t(x|z), (num_samples, dim)
        """
        return self.alpha(t) * z + self.beta(t) * torch.randn_like(z)
        
    def conditional_vector_field(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Evaluates the conditional vector field u_t(x|z)
        Note: Only defined on t in [0,1)
        Args:
            - x: position variable (num_samples, dim)
            - z: conditioning variable (num_samples, dim)
            - t: time (num_samples, 1)
        Returns:
            - conditional_vector_field: conditional vector field (num_samples, dim)
        """ 
        alpha_t = self.alpha(t) # (num_samples, 1)
        beta_t = self.beta(t) # (num_samples, 1)
        dt_alpha_t = self.alpha.dt(t) # (num_samples, 1)
        dt_beta_t = self.beta.dt(t) # (num_samples, 1)

        return (dt_alpha_t - dt_beta_t / beta_t * alpha_t) * z + dt_beta_t / beta_t * x

    def conditional_score(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Evaluates the conditional score of p_t(x|z) = N(alpha_t * z, beta_t**2 * I_d)
        Note: Only defined on t in [0,1)
        Args:
            - x: position variable (num_samples, dim)
            - z: conditioning variable (num_samples, dim)
            - t: time (num_samples, 1)
        Returns:
        - conditional_score: conditional score (num_samples, dim)
        """ 
        alpha_t = self.alpha(t)
        beta_t = self.beta(t)
        return (z * alpha_t - x) / beta_t ** 2

In [28]:
emp_dist = dist
alpha = LinearAlpha()
beta = LinearBeta()
path = GaussianConditionalProbabilityPath(
    p_data=emp_dist,
    alpha=alpha,
    beta=beta
)
print(path)

<__main__.GaussianConditionalProbabilityPath object at 0x7f5dbeed5f10>


In [29]:
# now that we were able to construct a Gaussian probability path, we have to be able to make a conditional vector field

class ConditionalVectorFieldODE():
    def __init__(self, path, z: torch.Tensor):
        """
        Args:
        - path: the ConditionalProbabilityPath object to which this vector field corresponds
        - z: the conditioning variable, (1, dim)
        """
        super().__init__()
        self.path = path
        self.z = z

    def drift_coefficient(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Returns the conditional vector field u_t(x|z)
        Args:
            - x: state at time t, shape (bs, dim)
            - t: time, shape (bs,.)
        Returns:
            - u_t(x|z): shape (batch_size, dim)
        """
        bs = x.shape[0]
        z = self.z.expand(bs, *self.z.shape[1:])
        return self.path.conditional_vector_field(x,z,t)

In [30]:
# now we somehow want to model the marginal vector field from the conditonal vector field
# for that we will use eulers:
class EulerSimulator():
    def __init__(self, ode, z: torch.Tensor):
        self.ode = ode
        self.z = z

    def step(self, xt: torch.Tensor, t: torch.Tensor, h: float):
        
        # Expand z to match batch size
        if self.z.shape[0] == 1:
            z_exp = self.z.expand(xt.shape[0], -1)
        else:
            z_exp = self.z
        dx = self.ode.drift_coefficient(xt, t, z_exp)
        return xt + dx * h



In [31]:
import math
import torch.nn as nn

class TimeEmbedder(nn.Module):
    def __init__(self, embed_dim=32, max_freq=1e4):
        super().__init__()
        self.embed_dim = embed_dim
        self.max_freq = max_freq
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.SiLU(),
            nn.Linear(embed_dim, embed_dim),
            nn.SiLU()
        )

    def forward(self, t):
        freqs = torch.exp(torch.linspace(0, math.log(self.max_freq), self.embed_dim // 2, device=t.device))
        args = t * freqs
        emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
        return self.mlp(emb)

class ResNetBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(dim, dim),
            nn.SiLU(),
            nn.Linear(dim, dim)
        )

    def forward(self, x):
        return x + self.block(x)

class NeuralVectorField(nn.Module):
    def __init__(self, latent_dim, hidden_dim=128, n_resblocks=3, time_embed_dim=32):
        super().__init__()
        self.x_proj = nn.Linear(latent_dim, hidden_dim)
        self.z_proj = nn.Linear(latent_dim, hidden_dim)
        self.time_embedder = TimeEmbedder(time_embed_dim)

        self.resblocks = nn.ModuleList([
            ResNetBlock(hidden_dim*2 + time_embed_dim) for _ in range(n_resblocks)
        ])
        self.output_layer = nn.Linear(hidden_dim*2 + time_embed_dim, latent_dim)

    def forward(self, x, z, t):
        xh = self.x_proj(x)
        zh = self.z_proj(z)
        th = self.time_embedder(t)
        h = torch.cat([xh, zh, th], dim=-1)
        for block in self.resblocks:
            h = block(h)
        return self.output_layer(h)


In [36]:
device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 2110
num_epochs = 5000
learning_rate = 1e-3
latent_dim = latent_tensor.shape[1]  # e.g., 50

vf_model = NeuralVectorField(latent_dim=latent_dim).to(device)
optimizer = torch.optim.AdamW(vf_model.parameters(), lr=learning_rate)

# Initialize GaussianConditionalProbabilityPath and ConditionalVectorFieldODE
path = GaussianConditionalProbabilityPath(emp_dist, alpha, beta)  # define alpha, beta
#cvf_ode = ConditionalVectorFieldODE(path, z=torch.zeros(1, latent_dim, device=device))

for epoch in range(num_epochs):
    # --- Sample conditioning variable z ---
    z = emp_dist.sample(batch_size).to(device)

    # --- Sample time ---
    t = torch.rand(batch_size, 1, device=device)

    # --- Sample x_t from conditional path ---
    with torch.no_grad():
        x = path.sample_conditional_path(z, t)
        u_target = path.conditional_vector_field(x, z, t)

    # --- Normalize target ---
    u_mean = u_target.mean(dim=0, keepdim=True)
    u_std = u_target.std(dim=0, keepdim=True) + 1e-6
    u_target_norm = (u_target - u_mean) / u_std

    # --- Forward pass ---
    v_pred = vf_model(x, z, t)

    # --- Loss ---
    loss = F.mse_loss(v_pred, u_target_norm)

    # --- Backprop ---
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(vf_model.parameters(), max_norm=1.0)
    optimizer.step()

    if epoch % 50 == 0:
        print(f"[{epoch}] Loss: {loss.item():.6f}")


[0] Loss: 1.204401
[50] Loss: 0.185784
[100] Loss: 0.120905
[150] Loss: 0.094456
[200] Loss: 0.071584
[250] Loss: 0.056728
[300] Loss: 0.049617
[350] Loss: 0.048656
[400] Loss: 0.045756
[450] Loss: 0.044495
[500] Loss: 0.038891
[550] Loss: 0.034101
[600] Loss: 0.031455
[650] Loss: 0.034569
[700] Loss: 0.031710
[750] Loss: 0.032153
[800] Loss: 0.027621
[850] Loss: 0.030345
[900] Loss: 0.026650
[950] Loss: 0.025186
[1000] Loss: 0.022602
[1050] Loss: 0.021637
[1100] Loss: 0.021811
[1150] Loss: 0.023561
[1200] Loss: 0.020222
[1250] Loss: 0.024248
[1300] Loss: 0.022474
[1350] Loss: 0.018282
[1400] Loss: 0.019463
[1450] Loss: 0.019736
[1500] Loss: 0.022325
[1550] Loss: 0.019251
[1600] Loss: 0.018110
[1650] Loss: 0.021988
[1700] Loss: 0.016521
[1750] Loss: 0.019133
[1800] Loss: 0.017260
[1850] Loss: 0.017991
[1900] Loss: 0.018593
[1950] Loss: 0.017139
[2000] Loss: 0.019657
[2050] Loss: 0.014752
[2100] Loss: 0.016337
[2150] Loss: 0.017055
[2200] Loss: 0.016753
[2250] Loss: 0.016394
[2300] Loss

In [37]:
# we want to save the best vector field:
class LearnedVectorFieldODE():
    def __init__(self, vf_model):
        self.vf_model = vf_model

    def drift_coefficient(self, x: torch.Tensor, t: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        # x, z: (batch_size, latent_dim)
        # t: (batch_size, 1)
        return self.vf_model(x, z, t)


In [38]:
# Wrap the trained neural network
learned_ode = LearnedVectorFieldODE(vf_model)

# Save the wrapper
torch.save(learned_ode, "learned_ode.pt")


In [39]:
# Number of samples and latent dimension
n_samples = 1000
latent_dim = latent_tensor.shape[1]

# Starting points (noise)
x = torch.randn(n_samples, latent_dim, device=device)

# Conditioning variable z
# Single vector, broadcast to all samples
z = torch.zeros(1, latent_dim, device=device)  # or z = emp_dist.sample(1)

# Wrap the trained neural network as an ODE
learned_ode = LearnedVectorFieldODE(vf_model)

# Create Euler simulator with the conditioning variable
simulator = EulerSimulator(learned_ode, z)

# Simulation parameters
t0, t1 = 0.0, 1.0
n_steps = 50
dt = (t1 - t0) / n_steps

# Store trajectory
trajectory = [x.clone()]
t = torch.full((n_samples, 1), t0, device=device)

# Euler integration
for _ in range(n_steps):
    x = simulator.step(x, t, dt)
    trajectory.append(x.clone())
    t = t + dt

# Final generated samples
generated_cells = trajectory[-1]
print(generated_cells.shape)  # (1000, latent_dim)
torch.save(generated_cells, "generated_latent.pt")


torch.Size([1000, 50])
