# Flow Matching Training

**Description:** Training the OT-Flow Matching model with Classifier-Free Guidance. 
Conditions on both **Cell Type** and **Library Size**.

In [None]:
import os
import sys
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import anndata as ad
from tqdm import tqdm

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Configuration & Data Loading

In [None]:
# Paths
input_file_path = "/dtu/blackhole/06/213542/paperdata/pbmc3k_train_with_latent.h5ad"
flow_model_save_path = "/dtu/blackhole/06/213542/paperdata/lib_size_flow_model.pt"

os.makedirs(os.path.dirname(flow_model_save_path), exist_ok=True)

# Hyperparameters
batch_size = 256
num_epochs = 400
learning_rate = 5e-4
latent_dim = 50
p_uncond = 0.1   # Classifier-free guidance dropout probability

# Load Data
adata = ad.read_h5ad(input_file_path)
latent = adata.obsm["X_latent"]
latent_tensor = torch.tensor(latent, dtype=torch.float32, device=device)

# Library Sizes
if "total_counts" in adata.obs:
    lib_sizes = adata.obs["total_counts"].values
else:
    lib_sizes = np.array(adata.X.sum(1)).flatten()

log_lib_sizes = np.log1p(lib_sizes)
log_lib_tensor = torch.tensor(log_lib_sizes, dtype=torch.float32, device=device).unsqueeze(1)

# Stats for normalization (needed for training context if we normalize, but mostly for sampling later)
lib_min, lib_max = log_lib_tensor.min(), log_lib_tensor.max()
lib_mean, lib_std = log_lib_tensor.mean(), log_lib_tensor.std()

# Cell Types
cell_types = adata.obs["cell_type"].astype(str).values
unique_types, inverse_idx = np.unique(cell_types, return_inverse=True)
num_cell_types = len(unique_types)
cell_type_idx = torch.tensor(inverse_idx, dtype=torch.long, device=device)

print(f"Data Shape: {latent.shape}")
print(f"Library Size: Min={lib_min:.2f}, Max={lib_max:.2f}, Mean={lib_mean:.2f}")
print(f"Cell Types: {unique_types}")

## Model Architecture & Classes

In [None]:
class EmpiricalDistribution(nn.Module):
    def __init__(self, data):
        super().__init__()
        self.register_buffer("data", data)
    def sample(self, n):
        idx = torch.randint(0, len(self.data), (n,), device=self.data.device)
        return self.data[idx]

class GaussianConditionalProbabilityPath:
    def __init__(self, p_data):
        self.p_data = p_data
    def sample_conditional_path(self, z, t):
        # Linear interpolation: t * z + (1-t) * noise
        return t * z + (1 - t) * torch.randn_like(z)

    def conditional_vector_field(self, x_1, x_0):
        return x_1 - x_0

class TimeEmbedder(nn.Module):
    def __init__(self, embed_dim=32):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim), nn.SiLU(),
            nn.Linear(embed_dim, embed_dim), nn.SiLU()
        )
        self.embed_dim = embed_dim
    def forward(self, t):
        # Sinusoidal embedding
        half_dim = self.embed_dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t * emb[None, :]
        emb = torch.cat((torch.sin(emb), torch.cos(emb)), dim=1)
        return self.mlp(emb)
    
class ResNetBlock(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, dim)
        )

    def forward(self, x):
        return x + self.mlp(x)

class NeuralVectorField(nn.Module):
    def __init__(self, latent_dim, hidden_dim=256, n_resblocks=5, time_embed_dim=64):
        super().__init__()
        self.x_proj = nn.Linear(latent_dim, hidden_dim)
        self.c_proj = nn.Linear(latent_dim, hidden_dim) # Condition embedding
        self.l_proj = nn.Linear(1, hidden_dim)          # Library Size embedding
        self.time_embedder = TimeEmbedder(time_embed_dim)

        # Learnable null conditioning vector for CFG
        self.null_cond = nn.Parameter(torch.randn(1, latent_dim))

        # Input to ResBlocks includes hidden dim * 3 + time
        input_dim = hidden_dim * 3 + time_embed_dim 
        
        self.resblocks = nn.ModuleList([
            ResNetBlock(input_dim, hidden_dim * 3) for _ in range(n_resblocks)
        ])
        self.output_layer = nn.Sequential(
            nn.Linear(hidden_dim * 3 + time_embed_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, latent_dim)
        )

    def forward(self, x, c, t, l):
        xh = self.x_proj(x)
        ch = self.c_proj(c) 
        th = self.time_embedder(t)
        lh = self.l_proj(l) 
        
        # Concatenate all inputs
        h = torch.cat([xh, ch, lh, th], dim=-1)
        
        for block in self.resblocks:
            h = block(h)
        return self.output_layer(h)

class CellTypeConditioner(nn.Module):
    def __init__(self, n_types, latent_dim):
        super().__init__()
        self.embed = nn.Embedding(n_types, latent_dim)
    def forward(self, idx):
        return self.embed(idx)

## Training Loop

In [None]:
# Initialisation
emp_dist = EmpiricalDistribution(latent_tensor)
latent_dim = latent_tensor.shape[1]

# Initialize models
conditioner = CellTypeConditioner(n_types=num_cell_types, latent_dim=latent_dim).to(device)
vf_model = NeuralVectorField(latent_dim=latent_dim).to(device)

# Optimize both model and conditioner
optimizer = torch.optim.AdamW(list(vf_model.parameters()) + list(conditioner.parameters()), lr=learning_rate)

epochs_list = []
loss_list = []

print("Starting CFG Training...")

for epoch in range(num_epochs):
    # 1. Sample Indices and Data
    indices = torch.randint(0, latent_tensor.shape[0], (batch_size,))
    z = latent_tensor[indices].to(device)          # Target (Data)
    c_idx = cell_type_idx[indices].to(device)      # Condition (Cell Type)
    l = log_lib_tensor[indices].to(device)         # Condition (Library Size)

    # 2. Sample Noise and Time
    x = torch.randn(batch_size, latent_dim, device=device)
    t = torch.rand(batch_size, 1, device=device)
    
    # 3. Compute Flow Target (Noise -> Data)
    u_target = z - x 
    
    # Normalize target
    u_mean = u_target.mean(dim=0, keepdim=True)
    u_std = u_target.std(dim=0, keepdim=True) + 1e-6
    u_target_norm = (u_target - u_mean) / u_std
    
    # 4. CFG: Condition masking
    c_emb = conditioner(c_idx)
    
    # Create mask (1 = drop condition, 0 = keep condition)
    mask = (torch.rand(batch_size, 1, device=device) < p_uncond).float()
    
    # Apply Mask: Replace dropped conditions with null_cond
    c_input = mask * vf_model.null_cond.expand(batch_size, -1) + (1 - mask) * c_emb
    
    # 5. Forward pass
    v_pred = vf_model(x, c_input, t, l)
    
    # 6. Loss
    loss = F.mse_loss(v_pred, u_target_norm)
    
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(vf_model.parameters(), 1.0)
    optimizer.step()
    
    epochs_list.append(epoch)
    loss_list.append(loss.item())
    
    if epoch % 50 == 0:
        print(f"[{epoch}] Loss: {loss.item():.6f}")

# Save Models
torch.save({
    'vf_state': vf_model.state_dict(),
    'cond_state': conditioner.state_dict()
}, flow_model_save_path)
print(f"Saved models to {flow_model_save_path}")