

---

# **ðŸ“„ Gaussian SchrÃ¶dinger Bridge â€” Cheat Sheet**

---

# **1. Optimal Transport (OT)**

### **Static OT**

Solve:
$$
\min_\pi \int |x - x'|^2 , d\pi
$$
subject to $\pi$ coupling $\rho_0, \rho_1$.

Only end-point movement; no time evolution.

### **Dynamic OT (Benamouâ€“Brenier)**

Solve:
$$
\min_{\rho_t,v_t} \int_0^1 \frac12 |v_t|^2 dt
$$
subject to:
$$
\partial_t \rho_t = -\nabla \cdot (\rho_t v_t)
$$

---

# **2. SchrÃ¶dinger Bridges (SBs)**

Given reference process $Y_t$ with law $Q_t$:

$$
\min_{P_0=\nu, P_1=\nu_1} \mathrm{KL}(P_t ,|, Q_t)
$$

Interpretation:

> The **most likely stochastic evolution** between two distributions, closest to a given SDE.

Connection:
$$
\text{SB} = \text{Dynamic OT} + \text{Entropy regularization}
$$

Ïƒ â†’ 0 â‡’ classical OT.

---

# **3. Reference SDE**

General linear reference:
$$
dY_t = (c_t Y_t + \alpha_t),dt + g_t , dW_t
$$

Covers:

* Brownian motion
* OU
* VP / VE / sub-VP
* Brownian dynamics (BDT)
* Many ML diffusions

---

# **4. Gaussian SchrÃ¶dinger Bridges (GSBs)**

Endpoints:
$$
X_0 \sim \mathcal{N}(\mu_0, \Sigma_0),\quad
X_1 \sim \mathcal{N}(\mu_1, \Sigma_1)
$$

Reference SDE linear â‡’ the optimal bridge is **Gaussian at all times**:
$$
X_t \sim \mathcal{N}(\mu_t, \Sigma_t)
$$

Goal: **explicit formulas** for $\mu_t$, $\Sigma_t$, and drift.

---

# **5. Geometry: Buresâ€“Wasserstein Metric**

Covariances live on manifold $S_{++}^d$.

Lyapunov operator:
Solve $A\Sigma + \Sigma A = U$ â‡’ $A = L_\Sigma[U]$.

Metric:
$$
\langle U, V\rangle_\Sigma
= \frac12 \mathrm{tr}(L_\Sigma[U],V)
$$

GSB = geodesic on this manifold + potential:
$$
U_\sigma(\Sigma) = -\frac{\sigma^2}{8}\mathrm{tr}(\Sigma^{-1})
$$

---

# **6. Closed-Form Covariance Evolution (Example: ÏƒWâ‚œ)**

Define:
$$
D_\sigma = \sqrt{4 \Sigma^{1/2} \Sigma_0 \Sigma^{1/2} + \sigma^4 I}
$$
$$
C_\sigma = \frac12 (\Sigma^{1/2} D_\sigma \Sigma^{-1/2} - \sigma^2 I)
$$

Then:
$$
\Sigma_t
= (1 - t)^2 \Sigma
$$
* $t^2 \Sigma_0$
* $t(1 - t)(C_\sigma + C_\sigma^\top + \sigma^2 I)$


Ïƒ â†’ 0 â‡’ Wasserstein geodesic.
Ïƒ â†’ âˆž â‡’ entropy-dominated smoothing.

---

# **7. Closed-Form Mean Evolution**

General formula:
$$
\mu_t = \bar r_t \mu_0 + r_t \mu_1 + \zeta(t) - r_t \zeta(1)
$$

Where:

* $r_t, \bar r_t$ depend on $c_t$ (reference SDE dynamics)
* $\zeta(t)$ depends on drift $\alpha_t$

If $\alpha_t = 0$:
$$
\mu_t = \bar r_t \mu_0 + r_t \mu_1
$$

---

# **8. Full Closed-Form Drift (Theorem 3)**

GSB satisfies SDE:
$$
dX_t = f_N(t,x), dt + g_t, dW_t
$$

### **Drift:**

$$
f_N(t,x)
= S_t^\top \Sigma_t^{-1}(x - \mu_t) + \dot{\mu}_t
$$

Where $S_t$ is explicitly:
$$
S_t = P_t - Q_t^\top + (c_t \kappa(t,t)(1-\rho_t) - g_t^2 \rho_t) I
$$
with all terms from reference SDE.

**Important:**
$$
S_t^\top \Sigma_t^{-1} \text{ is symmetric (proven)}
$$

---

# **9. Conditional Distributions (Corollary 1)**

$$
X_t \mid X_0=x_0 \sim \mathcal{N}(\mu_{t|0}, \Sigma_{t|0})
$$

Both mean and covariance have closed forms.
Similarly for $X_t \mid X_1$.

Used for:

* Forward & backward SB sampling
* Training GSBFLOW
* Exact data pair generation

---

# **10. Examples (Table 1)**

Plugging different (c_t,\alpha_t,g_t) yields GSBs for:

* **VE (variance-exploding)**
* **VP (variance-preserving)**
* **sub-VP**
* **OU**
* **BDT**
* **Brownian motion**

One unified formula â†’ many diffusion models.

---

# **11. GSBFLOW Algorithm**

GSBFLOW trains a neural drift model $f_\theta(t,x)$ by **matching**:

$$
f_\theta(t,X_t) \approx f_N(t,X_t)
$$

using **exact closed-form targets**, not noisy score estimates.

Training pairs:
$$
(X_t,;f_N(t,X_t))
$$

Benefits:

* No backward SDE needed
* No score matching / denoising objective
* Stable gradients
* Strong empirical performance

---



Trains a GSBFLOW-style model on MNIST

Uses a Brownian bridge SchrÃ¶dinger bridge between:

prior: Gaussian noise

data: MNIST images

Trains a CNN drift network by drift matching

---

# **ðŸ““ GSBFLOW-Style MNIST Training â€” Walkthrough**

---

# **0. Setup**

```python
!pip install torch torchvision --quiet
```

---

# **1. Imports & Configuration**

In [11]:
import math
import os
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as T

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

---

# **2. Load MNIST Dataset**

In [12]:
# Training hyperparameters
BATCH_SIZE = 128       # number of images per batch
NUM_EPOCHS = 5         # number of passes over the training data (increase for better results)
LR = 2e-4              # learning rate for optimizer
SIGMA = 0.5            # Brownian noise scale for the bridge
NUM_TRAIN_STEPS = None # if not None, you can limit total training steps (debugging)

# Sampling hyperparameters (for generating images after training)
NUM_SAMPLES = 16       # how many images to sample
NUM_STEPS_SAMPLE = 50  # number of Euler steps for solving the ODE during sampling

In [19]:
DATA_DIR = "./data_mnist"
# =====================================
# 2. DATASET AND PREPROCESSING SECTION
# =====================================

def get_mnist_dataloader(batch_size: int) -> DataLoader:
    """
    Create a PyTorch DataLoader for MNIST with simple preprocessing.
    """
    # Compose a list of transformations applied to each MNIST image:
    # 1. Convert PIL image to tensor (C x H x W, values in [0, 1])
    # 2. Normalize to approximately [-1, 1] (common for generative models)
    transform = T.Compose([
        T.ToTensor(),                      # convert image to tensor
        T.Normalize((0.5,), (0.5,)),       # scale from [0,1] -> roughly [-1,1]
    ])

    # Download / load MNIST training set
    train_dataset = torchvision.datasets.MNIST(
        root=DATA_DIR,       # directory to store data
        train=True,          # use training split
        download=True,       # if data not present, download it
        transform=transform  # apply transformation defined above
    )

    # Wrap dataset in a DataLoader to handle batching and shuffling
    train_loader = DataLoader(
        train_dataset,           # dataset to load from
        batch_size=batch_size,   # how many samples per batch
        shuffle=True,            # shuffle data every epoch
        num_workers=4,           # worker processes for loading data
        pin_memory=True          # speed: pin memory (good when using GPU)
    )

    # Return the DataLoader object
    return train_loader


---

# **3. Time Embedding Module**

This embeds the scalar time $t\in[0,1]$ into a vector the network can use.

In [14]:
# ================================
# 3. DRIFT NETWORK ARCHITECTURE
# ================================

class TimeEmbedding(nn.Module):
    """
    Simple time embedding: embed scalar t into a higher-dimensional vector
    using sinusoidal embeddings (like in transformers / diffusion models).
    """
    def __init__(self, dim: int):
        # Call parent constructor
        super().__init__()
        # Store the embedding dimension
        self.dim = dim

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        """
        Input:
          t: shape (batch,) with values in [0,1]
        Output:
          embedding: shape (batch, dim)
        """
        # Compute half dimension (we'll use sin and cos pairs)
        half_dim = self.dim // 2
        # Create frequency exponents linearly spaced
        # (using float dtype and device same as t)
        freqs = torch.linspace(0, half_dim - 1, half_dim, device=t.device)
        # Scale frequencies so they span multiple periods
        freqs = math.log(10000.0) * freqs / (half_dim - 1)
        # Compute exponentials of frequencies
        freqs = torch.exp(freqs)  # shape (half_dim,)

        # Reshape t to (batch, 1) to broadcast with freqs
        t = t.unsqueeze(1)  # (batch, 1)

        # Compute argument of sin/cos: shape (batch, half_dim)
        args = t * freqs

        # Compute sin and cos embeddings
        sin_emb = torch.sin(args)
        cos_emb = torch.cos(args)

        # Concatenate sin and cos along feature dimension -> (batch, dim)
        emb = torch.cat([sin_emb, cos_emb], dim=1)

        # Return final embedding
        return emb

---

# **4. Simple U-Net Drift Network**

This predicts the drift $ f_\theta(t, x_t) $.

In [15]:
class SimpleUNet(nn.Module):
    """
    A very simple U-Net-like convolutional network that predicts
    the drift f_theta(t, x) given:
      - current image x_t (noisy / interpolated)
      - scalar time t
    The network outputs a tensor with same shape as x_t (drift vector field).
    """
    def __init__(self, time_dim: int = 128):
        # Call parent constructor
        super().__init__()

        # Store time embedding dimension
        self.time_dim = time_dim

        # Create time embedding module
        self.time_mlp = nn.Sequential(
            TimeEmbedding(time_dim),       # sinusoidal time embedding
            nn.Linear(time_dim, time_dim), # linear layer
            nn.SiLU(),                    # nonlinearity
        )

        # Number of input channels for MNIST images (1 grayscale)
        in_channels = 1

        # Define the first convolution block (downsampling)
        self.conv_down1 = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),  # conv 1
            nn.GroupNorm(4, 32),                                   # normalize
            nn.SiLU(),                                             # nonlinearity
        )

        # Second convolution block (downsampling)
        self.conv_down2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # conv with stride=2 halves resolution
            nn.GroupNorm(8, 64),
            nn.SiLU(),
        )

        # Third convolution block (more features, same spatial size)
        self.conv_mid = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.GroupNorm(8, 64),
            nn.SiLU(),
        )

        # Time embedding projection to match mid-channel dimension
        self.time_to_mid = nn.Linear(time_dim, 64)

        # Upsampling layer to recover original spatial resolution
        self.upsample = nn.ConvTranspose2d(
            64, 32, kernel_size=4, stride=2, padding=1
        )

        # Final convolution block to map back to 1 channel
        self.conv_out = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.GroupNorm(4, 32),
            nn.SiLU(),
            nn.Conv2d(32, 1, kernel_size=3, padding=1),
        )

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the drift network.
        Inputs:
          x: (batch, 1, 28, 28) current image state X_t
          t: (batch,) times in [0,1]
        Output:
          drift: (batch, 1, 28, 28)
        """
        # Pass t through time MLP to get time embedding (batch, time_dim)
        t_emb = self.time_mlp(t)  # (B, time_dim)

        # First conv block (no stride) -> keep 28x28 size, increase channels to 32
        h1 = self.conv_down1(x)  # (B, 32, 28, 28)

        # Second conv block with stride 2 -> 14x14, 64 channels
        h2 = self.conv_down2(h1)  # (B, 64, 14, 14)

        # Broadcast time embedding to spatial feature map:
        # First project time embedding to 64 channels
        t_mid = self.time_to_mid(t_emb)  # (B, 64)
        # Reshape to (B, 64, 1, 1) so it can be added to feature map
        t_mid = t_mid[:, :, None, None]  # (B, 64, 1, 1)

        # Add time conditioning to middle feature map (broadcast over H,W)
        h_mid = self.conv_mid(h2 + t_mid)  # (B, 64, 14, 14)

        # Upsample back to 28x28 and reduce channels to 32
        h_up = self.upsample(h_mid)  # (B, 32, 28, 28)

        # Final conv block to get output drift of shape (B, 1, 28, 28)
        out = self.conv_out(h_up)  # (B, 1, 28, 28)

        # Return predicted drift
        return out

---

# **5. Brownian Bridge SchrÃ¶dinger Bridge**

This is the analytic SB between two points $x_0 \to x_1$:

$$
X_t = (1-t)x_0 + t x_1 + \sigma\sqrt{t(1-t)},\epsilon
$$

The drift is:

$$
f^*(t, x_t) = \frac{x_1 - x_t}{1 - t}.
$$

In [16]:
# =========================================
# 4. BROWNIAN BRIDGE SCHRÃ–DINGER BRIDGE
# =========================================

def sample_brownian_bridge(
    x0: torch.Tensor,
    x1: torch.Tensor,
    t: torch.Tensor,
    sigma: float
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Sample a Brownian bridge state X_t between x0 and x1.

    Brownian bridge between two points (x0, x1) over [0,1] has:
      E[X_t] = (1 - t) * x0 + t * x1
      Var[X_t] = sigma^2 * t * (1 - t)

    We sample:
      X_t = (1 - t) * x0 + t * x1 + sigma * sqrt(t(1-t)) * eps

    We also return the *exact* conditional drift:
      f*(t, x_t | x1) = (x1 - x_t) / (1 - t)
    which is the drift of a Brownian bridge SDE:
      dX_t = (x1 - X_t) / (1 - t) dt + sigma dW_t

    Inputs:
      x0: (B, C, H, W) starting noise
      x1: (B, C, H, W) target data
      t:  (B,) times
      sigma: noise scale

    Outputs:
      x_t:        (B, C, H, W) sampled intermediate state
      drift_star: (B, C, H, W) analytic SB drift at (t, x_t)
    """
    # Ensure t has shape (B, 1, 1, 1) for broadcasting over image dimensions
    t_img = t.view(-1, 1, 1, 1)  # (B,1,1,1)

    # Compute deterministic mean part of Brownian bridge at time t:
    # (1 - t) * x0 + t * x1
    mean_t = (1.0 - t_img) * x0 + t_img * x1

    # Compute standard deviation for Brownian bridge at time t:
    # std = sigma * sqrt(t * (1 - t))
    std_t = sigma * torch.sqrt(torch.clamp(t * (1.0 - t), min=1e-5))  # (B,)
    # Reshape std_t for broadcasting
    std_img = std_t.view(-1, 1, 1, 1)  # (B,1,1,1)

    # Sample standard normal noise with same shape as x0
    eps = torch.randn_like(x0)  # (B,C,H,W)

    # Sample X_t = mean_t + std_img * eps
    x_t = mean_t + std_img * eps  # (B,C,H,W)

    # Compute analytic Brownian bridge drift:
    # f*(t,x_t) = (x1 - x_t) / (1 - t)
    # Avoid division by 0 by adding a small epsilon
    eps_denom = 1e-5
    denom = (1.0 - t_img) + eps_denom  # (B,1,1,1)
    drift_star = (x1 - x_t) / denom  # (B,C,H,W)

    # Return sampled state and analytic drift
    return x_t, drift_star

---

# **6. Training Loop (GSBFLOW-Style Drift Matching)**

In [17]:
# =====================================
# 5. TRAINING LOOP FOR GSBFLOW-STYLE
# =====================================

def train_gsbflow_mnist():
    """
    Main training function for the GSBFLOW-style Brownian bridge model on MNIST.
    """
    # Get MNIST training DataLoader
    train_loader = get_mnist_dataloader(BATCH_SIZE)

    # Create an instance of the drift network and move it to the chosen device
    model = SimpleUNet(time_dim=128).to(device)

    # Use Adam optimizer to train model parameters
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)

    # Put model in training mode (enables e.g., dropout, batchnorm updates)
    model.train()

    # Initialize global step counter (for optional stopping condition)
    global_step = 0

    # Loop over epochs
    for epoch in range(NUM_EPOCHS):
        # Loop over batches from the DataLoader
        for batch_idx, (x1, _) in enumerate(train_loader):
            # Move batch of images to the device (GPU/CPU)
            x1 = x1.to(device)  # (B,1,28,28)

            # Get batch size (may be smaller on last batch)
            B = x1.size(0)

            # Sample Gaussian noise x0 as starting distribution
            # Same shape as x1, from N(0, I)
            x0 = torch.randn_like(x1)  # (B,1,28,28)

            # Sample times t uniformly in (0,1)
            # Shape: (B,)
            t = torch.rand(B, device=device)

            # Sample Brownian bridge state X_t and analytic drift at that state
            x_t, drift_star = sample_brownian_bridge(
                x0=x0, x1=x1, t=t, sigma=SIGMA
            )

            # Predict drift using our neural network model
            # model expects x_t and t as inputs
            drift_pred = model(x_t, t)  # (B,1,28,28)

            # Compute mean-squared error between predicted drift and analytic drift
            loss = F.mse_loss(drift_pred, drift_star)

            # Zero gradients from previous iteration
            optimizer.zero_grad()

            # Backpropagate gradients from loss
            loss.backward()

            # Take an optimization step to update model parameters
            optimizer.step()

            # Increase global step counter
            global_step += 1

            # Optionally, print training progress every N steps
            if global_step % 100 == 0:
                print(
                    f"Epoch [{epoch+1}/{NUM_EPOCHS}] "
                    f"Batch [{batch_idx+1}/{len(train_loader)}] "
                    f"Step {global_step} "
                    f"Loss: {loss.item():.4f}"
                )

            # Optional: allow early stopping after some number of steps (debug)
            if NUM_TRAIN_STEPS is not None and global_step >= NUM_TRAIN_STEPS:
                break

        # If early stopping condition triggered, break outer epoch loop
        if NUM_TRAIN_STEPS is not None and global_step >= NUM_TRAIN_STEPS:
            break

    # After training completes, return the trained model
    return model


In [20]:
# Train the GSBFLOW-style Brownian bridge model on MNIST
model = train_gsbflow_mnist()

Epoch [1/5] Batch [100/469] Step 100 Loss: 2.7460
Epoch [1/5] Batch [200/469] Step 200 Loss: 1.3131
Epoch [1/5] Batch [300/469] Step 300 Loss: 3.3786
Epoch [1/5] Batch [400/469] Step 400 Loss: 0.6795
Epoch [2/5] Batch [31/469] Step 500 Loss: 0.5843
Epoch [2/5] Batch [131/469] Step 600 Loss: 0.5745
Epoch [2/5] Batch [231/469] Step 700 Loss: 0.6986
Epoch [2/5] Batch [331/469] Step 800 Loss: 0.6274
Epoch [2/5] Batch [431/469] Step 900 Loss: 0.5580
Epoch [3/5] Batch [62/469] Step 1000 Loss: 0.8079
Epoch [3/5] Batch [162/469] Step 1100 Loss: 0.8380
Epoch [3/5] Batch [262/469] Step 1200 Loss: 0.6169
Epoch [3/5] Batch [362/469] Step 1300 Loss: 0.5342
Epoch [3/5] Batch [462/469] Step 1400 Loss: 0.6541
Epoch [4/5] Batch [93/469] Step 1500 Loss: 0.7985
Epoch [4/5] Batch [193/469] Step 1600 Loss: 0.5385
Epoch [4/5] Batch [293/469] Step 1700 Loss: 0.9203
Epoch [4/5] Batch [393/469] Step 1800 Loss: 0.6311
Epoch [5/5] Batch [24/469] Step 1900 Loss: 0.6272
Epoch [5/5] Batch [124/469] Step 2000 Loss: 

---

# **7. Sampling from the Trained Model**

We integrate the learned ODE:

$$
dX_t = f_\theta(t, X_t),dt.
$$

In [21]:
# ======================================
# 6. SAMPLING / GENERATION FROM MODEL
# ======================================

@torch.no_grad()  # disable gradient computation in sampling
def sample_from_model(
    model: nn.Module,
    num_samples: int = NUM_SAMPLES,
    num_steps: int = NUM_STEPS_SAMPLE,
    sigma: float = SIGMA
) -> torch.Tensor:
    """
    Sample images from the trained GSBFLOW-style model by solving
    the forward ODE:
      dX_t = f_theta(t, X_t) dt

    We ignore Brownian noise in sampling for simplicity (deterministic flow).
    This is common: flow-matching often uses an ODE for sampling.

    Inputs:
      model: trained drift network
      num_samples: how many images to generate
      num_steps: number of Euler steps used in integration
      sigma: (unused here; kept for extension / reference)

    Output:
      samples: (num_samples,1,28,28) tensor of generated images in [-1,1]
    """
    # Put model in eval mode (disables dropout, etc.)
    model.eval()

    # Start from standard Gaussian noise as initial samples X_0
    x = torch.randn(num_samples, 1, 28, 28, device=device)

    # Create time grid for [0,1]
    # linspace from 0 to 1, inclusive
    t_grid = torch.linspace(0.0, 1.0, num_steps + 1, device=device)

    # Compute step size (delta t) assuming uniform steps
    dt = t_grid[1] - t_grid[0]

    # Loop over time steps from t_0 to t_{num_steps-1}
    for i in range(num_steps):
        # Current time t_i (scalar)
        t = t_grid[i]  # scalar tensor

        # Expand time to shape (batch,) for model input
        t_batch = torch.full((num_samples,), t.item(), device=device)

        # Predict drift at current state and time
        drift = model(x, t_batch)  # (B,1,28,28)

        # Euler update: X_{t+dt} = X_t + drift * dt
        x = x + drift * dt

    # After final step, x is an approximation of X_1 (data sample)
    # Clamp values to [-1,1] for sanity (since training images were normalized)
    x = torch.clamp(x, -1.0, 1.0)

    # Return generated samples
    return x

---

# **8. Visualize Samples**

In [22]:
# After training, sample generated images from the model
samples = sample_from_model(model, num_samples=NUM_SAMPLES)

# Optionally, save a grid of sampled images to disk for visualization
os.makedirs("samples", exist_ok=True)
# Denormalize from [-1,1] back to [0,1] for saving as PNG
samples_vis = (samples + 1.0) / 2.0
# Use torchvision to save a grid image
torchvision.utils.save_image(
    samples_vis,
    fp="samples/gsbflow_mnist_samples.png",
    nrow=int(math.sqrt(NUM_SAMPLES)),
)
print("Saved generated samples to samples/gsbflow_mnist_samples.png")

Saved generated samples to samples/gsbflow_mnist_samples.png
