# Flow Sampling & Visualization

**Description:** Sampling from the trained Flow model using Euler integration, decoding via VAE, and plotting UMAPs.

In [None]:
import os
import sys
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import anndata as ad
import scanpy as sc
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from scvi.distributions import NegativeBinomial

# Import VAE utils
sys.path.append("../../")
from utils.autoencoder_utils import NB_Autoencoder

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Configuration & Data Stats
We load the data here primarily to recover cell type mappings and library size statistics for sampling.

In [None]:
input_file_path = "/dtu/blackhole/06/213542/paperdata/pbmc3k_train_with_latent.h5ad"
vae_model_path = "/dtu/blackhole/06/213542/paperdata/pbmc3k_train_nb_autoencoder.pt"
flow_model_save_path = "/dtu/blackhole/06/213542/paperdata/lib_size_flow_model.pt"

# Sampling Hyperparameters
guidance_scale = 2.0
n_steps = 50 
latent_dim = 50

# Load Data for Reference
adata = ad.read_h5ad(input_file_path)
latent = adata.obsm["X_latent"]
latent_tensor = torch.tensor(latent, dtype=torch.float32, device=device)

# Library Stats
if "total_counts" in adata.obs:
    lib_sizes = adata.obs["total_counts"].values
else:
    lib_sizes = np.array(adata.X.sum(1)).flatten()

log_lib_sizes = np.log1p(lib_sizes)
log_lib_tensor = torch.tensor(log_lib_sizes, dtype=torch.float32, device=device).unsqueeze(1)
lib_mean, lib_std = log_lib_tensor.mean(), log_lib_tensor.std()

# Cell Type Mappings
cell_types = adata.obs["cell_type"].astype(str).values
unique_types, inverse_idx = np.unique(cell_types, return_inverse=True)
num_cell_types = len(unique_types)

print(f"Loaded metadata. {num_cell_types} cell types found.")

## Re-defining Model Classes
Required to load the model architecture.

In [None]:
class TimeEmbedder(nn.Module):
    def __init__(self, embed_dim=32):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim), nn.SiLU(),
            nn.Linear(embed_dim, embed_dim), nn.SiLU()
        )
        self.embed_dim = embed_dim
    def forward(self, t):
        half_dim = self.embed_dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t * emb[None, :]
        emb = torch.cat((torch.sin(emb), torch.cos(emb)), dim=1)
        return self.mlp(emb)
    
class ResNetBlock(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim), nn.SiLU(),
            nn.Linear(hidden_dim, dim)
        )
    def forward(self, x):
        return x + self.mlp(x)

class NeuralVectorField(nn.Module):
    def __init__(self, latent_dim, hidden_dim=256, n_resblocks=5, time_embed_dim=64):
        super().__init__()
        self.x_proj = nn.Linear(latent_dim, hidden_dim)
        self.c_proj = nn.Linear(latent_dim, hidden_dim)
        self.l_proj = nn.Linear(1, hidden_dim)
        self.time_embedder = TimeEmbedder(time_embed_dim)
        self.null_cond = nn.Parameter(torch.randn(1, latent_dim))

        input_dim = hidden_dim * 3 + time_embed_dim 
        self.resblocks = nn.ModuleList([
            ResNetBlock(input_dim, hidden_dim * 3) for _ in range(n_resblocks)
        ])
        self.output_layer = nn.Sequential(
            nn.Linear(hidden_dim * 3 + time_embed_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, latent_dim)
        )

    def forward(self, x, c, t, l):
        xh = self.x_proj(x)
        ch = self.c_proj(c) 
        th = self.time_embedder(t)
        lh = self.l_proj(l) 
        h = torch.cat([xh, ch, lh, th], dim=-1)
        for block in self.resblocks:
            h = block(h)
        return self.output_layer(h)

class CellTypeConditioner(nn.Module):
    def __init__(self, n_types, latent_dim):
        super().__init__()
        self.embed = nn.Embedding(n_types, latent_dim)
    def forward(self, idx):
        return self.embed(idx)

# Euler Integration Class
class LearnedVectorFieldODE:
    def __init__(self, vf_model, conditioner, z_target_idx, l_target, guidance_scale=2.0):
        self.vf = vf_model
        self.c = conditioner(z_target_idx) # Embed the cell type indices
        self.l = l_target
        self.scale = guidance_scale
        self.c_null = self.vf.null_cond.expand(self.c.shape[0], -1)
    
    def drift(self, x, t):
        # Duplicate inputs for [Conditional, Unconditional] batching
        x_in = torch.cat([x, x], dim=0)
        t_in = torch.cat([t, t], dim=0)
        l_in = torch.cat([self.l, self.l], dim=0)
        
        # Stack: [Conditioned, Null]
        c_in = torch.cat([self.c, self.c_null], dim=0)
        
        # Forward Pass
        v_out = self.vf(x_in, c_in, t_in, l_in)
        v_cond, v_uncond = v_out.chunk(2, dim=0)
        
        # CFG Formula: v = v_uncond + s * (v_cond - v_uncond)
        return v_uncond + self.scale * (v_cond - v_uncond)

## Load Models

In [None]:
# 1. Initialize and Load Flow Model
conditioner = CellTypeConditioner(n_types=num_cell_types, latent_dim=latent_dim).to(device)
vf_model = NeuralVectorField(latent_dim=latent_dim).to(device)

checkpoint = torch.load(flow_model_save_path, map_location=device)
vf_model.load_state_dict(checkpoint['vf_state'])
conditioner.load_state_dict(checkpoint['cond_state'])
print("Flow model loaded.")

# 2. Load VAE for Decoding
vae = NB_Autoencoder(num_features=adata.n_vars, latent_dim=latent_dim)
try:
    vae.load_state_dict(torch.load(vae_model_path, map_location=device))
    vae.to(device)
    vae.eval()
    print("VAE loaded successfully.")
except FileNotFoundError:
    print(f"Warning: VAE model not found at {vae_model_path}. Decoding will fail.")

## Sampling & Decoding

In [None]:
def generate_samples(target_type_idx, num_samples, fix_library_size=True):
    vf_model.eval()
    conditioner.eval()
    
    x = torch.randn(num_samples, latent_dim, device=device)
    
    # Conditions
    type_tensor = torch.full((num_samples,), target_type_idx, dtype=torch.long, device=device)

    if fix_library_size:
        l_val = lib_mean.item()
        l_tensor = torch.full((num_samples, 1), l_val, device=device)
    else:
        l_tensor = torch.normal(lib_mean.item(), lib_std.item(), (num_samples, 1), device=device)

    # ODE Integration
    ode = LearnedVectorFieldODE(vf_model, conditioner, type_tensor, l_tensor, guidance_scale)
    dt = 1.0 / n_steps
    t = torch.zeros(num_samples, 1, device=device)
    
    print(f"Sampling {num_samples} cells (Type {target_type_idx})...")
    with torch.no_grad():
        for _ in range(n_steps):
            v = ode.drift(x, t)
            x = x + v * dt
            t = t + dt
            
    return x

# Example single sampling run
target_idx = 2
n_gen = 200
generated_latents = generate_samples(target_idx, n_gen)

# Rescaling Logic
std_orig = latent_tensor.std(dim=0)
mean_orig = latent_tensor.mean(dim=0)
std_gen = generated_latents.std(dim=0)
mean_gen = generated_latents.mean(dim=0)

generated_rescaled = (generated_latents - mean_gen) / std_gen * std_orig + mean_orig

# Decode with Fixed Library Size
target_lib_size = 1000 
print("Decoding to counts...")
with torch.no_grad():
    outputs = vae.decode(generated_rescaled, adata, target_lib_size)
    mu = outputs["mu"]
    theta = torch.exp(outputs["theta"])
    nb_dist = NegativeBinomial(mu=mu, theta=theta)
    X_gen_counts = nb_dist.sample().cpu().numpy()

# Save single batch
save_dir = "/dtu/blackhole/06/213542/paperdata/"
counts_save_path = os.path.join(save_dir, "new_generated_pbmc3k_counts.h5ad")
adata_gen = ad.AnnData(X=X_gen_counts)
if 'adata' in globals():
    adata_gen.var_names = adata.var_names
    adata_gen.obs['cell_type'] = unique_types[target_idx]
adata_gen.write(counts_save_path)
print(f"Saved generated counts to: {counts_save_path}")

## Full Dataset Generation & Visualization

In [None]:
def generate_full_dataset(n_per_type=200):
    """
    Generates samples for all cell types to create a full synthetic dataset.
    """
    vf_model.eval()
    conditioner.eval()
    
    all_latents = []
    all_lib_sizes = []
    all_types = []
    
    print(f"Generating {n_per_type} cells per type for {len(unique_types)} types...")
    for idx, ct in enumerate(unique_types):
        # Conditions
        type_tensor = torch.full((n_per_type,), idx, dtype=torch.long, device=device)
        l_tensor = torch.normal(lib_mean.item(), lib_std.item(), (n_per_type, 1), device=device)
        
        # Initial Noise
        x = torch.randn(n_per_type, latent_dim, device=device)
        
        # Integration
        ode = LearnedVectorFieldODE(vf_model, conditioner, type_tensor, l_tensor, guidance_scale=10.0)
        dt = 1.0 / n_steps
        t = torch.zeros(n_per_type, 1, device=device)
        
        with torch.no_grad():
            for _ in range(n_steps):
                v = ode.drift(x, t)
                x = x + v * dt
                t = t + dt
        
        all_latents.append(x)
        all_lib_sizes.append(l_tensor)
        all_types.extend([ct] * n_per_type)
        
    gen_latents_tensor = torch.cat(all_latents, dim=0)
    gen_libs_tensor = torch.cat(all_lib_sizes, dim=0)
    
    # Rescaling
    mean_gen = gen_latents_tensor.mean(dim=0)
    std_gen = gen_latents_tensor.std(dim=0)
    mean_orig = latent_tensor.mean(dim=0)
    std_orig = latent_tensor.std(dim=0)
    
    gen_rescaled = (gen_latents_tensor - mean_gen) / std_gen * std_orig + mean_orig
    
    # Decode to Counts
    lib_counts = torch.exp(gen_libs_tensor) - 1
    
    print("Decoding full dataset...")
    with torch.no_grad():
        outputs = vae.decode(gen_rescaled, adata, lib_counts)
        mu = outputs["mu"]
        theta = torch.exp(outputs["theta"])
        nb_dist = NegativeBinomial(mu=mu, theta=theta)
        X_counts = nb_dist.sample().cpu().numpy()
        
    return X_counts, np.array(all_types)

# Generate and Save
X_gen_all, types_gen_all = generate_full_dataset(n_per_type=250)

adata_gen = ad.AnnData(X=X_gen_all)
adata_gen.obs['cell_type'] = types_gen_all
adata_gen.obs['dataset'] = 'Generated'
adata_gen.var_names = adata.var_names

save_gen_path = os.path.join(os.path.dirname(flow_model_save_path), "generated_cells.h5ad")
adata_gen.write(save_gen_path)
print(f"Saved all generated cells to: {save_gen_path}")

## Plotting Results (Figure A2 Style)

In [None]:
# Prepare Merged Dataset
adata_real = adata.copy()
adata_real.obs['dataset'] = 'Real'
if hasattr(adata_real.X, "toarray"):
    adata_real.X = adata_real.X.toarray()

adata_merged = ad.concat([adata_real, adata_gen], join='outer', label='batch', keys=['Real', 'Generated'])
adata_merged.obs['cell_type'] = adata_merged.obs['cell_type'].astype('category')

# Preprocessing & Embedding
print("Running PCA and UMAP on merged dataset...")
sc.pp.normalize_total(adata_merged, target_sum=1e4)
sc.pp.log1p(adata_merged)
sc.pp.pca(adata_merged, n_comps=30)
sc.pp.neighbors(adata_merged, n_neighbors=15)
sc.tl.umap(adata_merged)

# Visualization
print("\n--- Starting Visualization Generation ---")
highlight_types = ['B cells', 'CD14+ Monocytes', 'CD4 T cells', 'CD8 T cells']

fig = plt.figure(figsize=(18, 10))
gs = gridspec.GridSpec(2, 3, figure=fig)

def plot_umap_scatter(ax, adata_subset, color_col, title, palette=None, s=10):
    sc.pl.umap(adata_subset, color=color_col, ax=ax, show=False, 
               title=title, frameon=True, s=s, palette=palette, legend_loc='on data')
    ax.set_xlabel("UMAP1")
    ax.set_ylabel("UMAP2")

# Panel 1: Real Data
ax1 = fig.add_subplot(gs[0, 0])
plot_umap_scatter(ax1, adata_merged[adata_merged.obs['dataset']=='Real'], 
                  'cell_type', 'Real Data')

# Panel 2: Generated Data
ax2 = fig.add_subplot(gs[0, 1])
plot_umap_scatter(ax2, adata_merged[adata_merged.obs['dataset']=='Generated'], 
                  'cell_type', 'Generated (All)')

def plot_overlap(ax, c_type):
    mask_r = (adata_merged.obs['dataset'] == 'Real') & (adata_merged.obs['cell_type'] == c_type)
    mask_g = (adata_merged.obs['dataset'] == 'Generated') & (adata_merged.obs['cell_type'] == c_type)
    
    umap_r = adata_merged[mask_r].obsm['X_umap']
    umap_g = adata_merged[mask_g].obsm['X_umap']
    
    ax.scatter(umap_r[:, 0], umap_r[:, 1], s=15, c='#377eb8', alpha=0.5, label='Real')
    ax.scatter(umap_g[:, 0], umap_g[:, 1], s=15, c='#e41a1c', alpha=0.6, label='Generated')
    
    ax.set_title(c_type)
    ax.set_xticks([])
    ax.set_yticks([])

# Overlap Panels
locs = [(0, 2), (1, 0), (1, 1), (1, 2)]
for i, c_type in enumerate(highlight_types):
    if i >= len(locs): break
    ax = fig.add_subplot(gs[locs[i]])
    plot_overlap(ax, c_type)
    if i == 0:
        ax.legend(frameon=False, loc='upper right', markerscale=2)

plt.tight_layout()
save_plot_path = os.path.join(os.path.dirname(flow_model_save_path), "cfgen_results_figure_a2.png")
plt.savefig(save_plot_path, dpi=300)
print(f"Figure saved to {save_plot_path}")
plt.show()