# A Minimal Toy Image Generator

This file build a basic toy-data image generator based on flow-matching based method

The file is adapted from the [mit diffusion model course](https://diffusion.csail.mit.edu/), and the repository is at [here](https://github.com/eje24/iap-diffusion-labs)

## Overview

This notebook implements a complete toy framework for flow matching, designed to visualize and test the behavior of learned vector fields between simple source and target distributions.

<br>

### Part 1. Sampleable Distributions

We define five configurable 2D distributions from which samples can be drawn:
 
1. **Stretched Gaussian** — the 2-D Gaussian whose `$cov$` is not diagonal
2. **Moons** — 2D crescent-shaped dataset.  
3. **Checkerboard** — discrete alternating square pattern.  
4. **Circles** — uniform distribution over a 2D circular region.

Each batch is a tensor of shape `(batch_size, dim)`, where `dim = tunnel * Width * Height`.

These images are in the floder `data`, and you can check that.

The part 1 is just to make preparation -- to load in the data.

<br>

### Part 2. Vector Field and Simulators

Two numerical solvers are implemented to simulate trajectories based on learned vector fields:

1. **Euler Method** — first-order integrator.  
2. **Heun’s Method** — second-order integrator (improved Euler scheme).

<br>

### Part 3. Alpha/Beta and Conditional Vector Fields

We define two classes of conditional vector fields parameterized by neural networks and constrained by:

1. $ \alpha(t)^2 + \beta(t)^2 = 1 $  
2. $ \alpha(t) + \beta(t) = 1 $

These determine the blending between source and target scores in the learned velocity field.

<br>

### Part 4. Neural Network Architecture and Package of Training Proess

The core of the vector field $ u_\theta(x, t) $ is a Multi-Layer Perceptron (MLP). The hidden structure is customizable.

As for the training process:

1. **Sample from Target and Time**: get $z$ from the data randomly get time $ts$
2. **Sample from Source**: Draw base samples $ x_0 \sim N(0, I_d) $, and based on $x_0$ and $ts$, calculute the $x_t$
3. **Learn Vector Field**: Train $ u_\theta $ to match source and target via flow matching loss.  

<br>

### Part 5. Training, Visualization and Evaluation

# Part 0: Basic Preparation

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image       # Python Image Libiray (PIL)
import torchvision.transforms as transforms
from tqdm import tqdm
from abc import ABC, abstractmethod
from typing import List, Dict, Type, Tuple
from torch.func import vmap, jacrev
import torch.nn as nn
import math
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"The device is: {device}")


# Part 1: Prepare Datasets and DataLoader

In [None]:
# Loads grayscale PNG images and returns them as tensors

class ToyImageData(Dataset):
    def __init__(self, root_dir: str, transform = None):
        self.root_dir = root_dir
        self.transform = transform if transform else transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
        ])
        self.image_files = sorted([
            file for file in os.listdir(self.root_dir)
            if file.endswith('.png')
        ])

    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, index):
        image_path = os.path.join(self.root_dir, self.image_files[index])
        image = Image.open(image_path).convert('L')                                 # L mode means the Gray image
        image = self.transform(image)                                               # turn the image into the tensor
        return image

In [None]:
datasets_checkerboard = ToyImageData("data/checkerboard")
dataloader_checkerboard = DataLoader(
    dataset = datasets_checkerboard,
    batch_size = 128,
    shuffle=True
)

datasets_circles = ToyImageData("data/circles")
dataloader_circles = DataLoader(
    dataset=datasets_circles,
    batch_size=128,
    shuffle=True
)

datasets_moons = ToyImageData("data/moons")
dataloader_moons = DataLoader(
    dataset=datasets_moons,
    batch_size=128,
    shuffle=True
)

datasets_stretched_gaussian = ToyImageData("data/stretched_gaussian")
dataloader_stretched_gaussian = DataLoader(
    dataset=datasets_stretched_gaussian,
    batch_size=128,
    shuffle=True
)

# Part 2: Vector Field and Simulator

In [None]:
class VectorField(ABC):
    @abstractmethod
    def velocity(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        # xt: shape(bs, dims), t: shape(bs, 1), returns: shape(bs, dims) where dims = channels * width * height
        # return the speed velocity(xt, t) at position xt and time t
        pass

class Simulator(ABC):
    @abstractmethod
    def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
        # xt: shape(bs, dims), t: shape(bs, 1), h: shape(bs, 1) (h and t should be board cast form shape(,))
        # return the state at t + h
        pass

    def simulate(self, x: torch.Tensor, ts: torch.Tensor) -> torch.Tensor:
        # x: shape(bs, dims), ts: shape(num_ts,)
        # put in the state at ts[0], return the state at ts[-1]
        for index in range(ts.shape[0] - 1):
            t = ts[index].expand(x.shape[0], 1)
            h = (ts[index+1] - ts[index]).expand(x.shape[0], 1)
            x = self.step(x, t, h)

        return x
    
class EulerSimulator(Simulator):
    def __init__(self, vector_field: VectorField):
        self.vector_field = vector_field

    def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
        return xt + self.vector_field.velocity(xt, t) * h
    
class HenuSimulator(Simulator):
    def __init__(self, vector_field: VectorField):
        self.vector_field = vector_field

    def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
        x_euler = xt + self.vector_field.velocity(xt, t) * h
        return xt + 0.5 * (self.vector_field.velocity(xt, t) + self.vector_field.velocity(x_euler, t + h)) * h

# Part 3: Alpha, Beta and the Conditional Vector Field

#### 1. Gaussian conditional probability path:    

A Gaussian conditional probability path is given by

$$p_t(x|z) = N(x;\alpha_t z,\beta_t^2 I_d),\quad\quad\quad p_{\text{simple}}=N(0,I_d),$$

where $\alpha_t: [0,1] \to \mathbb{R}$ and $\beta_t: [0,1] \to \mathbb{R}$ are monotonic, continuously differentiable functions satisfying $\alpha_1 = \beta_0 = 1$ and $\alpha_0 = \beta_1 = 0$. 

In other words, this implies that $p_1(x|z) = \delta_z$ and $p_0(x|z) = N(0, I_d)$ is a unit Gaussian. Before we dive into things, let's take a look at $p_{\text{simple}}$ and $p_{\text{data}}$. 

And simply,

$$X_{t | z} = \alpha_t z + \beta_t X_0$$

In this section, we'll be using 

$$\alpha_t = t \quad \quad \text{and} \quad \quad \beta_t = 1-t.$$

and

$$\alpha_t = t \quad \quad \text{and} \quad \quad \beta_t = \sqrt{1-t}.$$

In [None]:
class Alpha(ABC):
    def __init__(self, eps: float = 1e-8):
        self.eps = eps
        assert torch.allclose(self(torch.zeros(1, 1)), torch.zeros(1, 1), atol=math.sqrt(self.eps))
        assert torch.allclose(self(torch.ones(1, 1)), torch.ones(1, 1))

    @abstractmethod
    def __call__(self, t: torch.Tensor) -> torch.Tensor:
        # t: shape(bs, 1), or fundamentally, shape(1, 1), returns: shape(bs, 1), or fundamentally, shape(1, 1)
        pass

    def dt(self, t: torch.Tensor) -> torch.Tensor:
        # t: shape(bs, 1) or fundatmentally, (1, 1), returns: shape(bs, 1) or (1, 1)
        return vmap(jacrev(self)) (t)
    
class Beta(ABC):
    def __init__(self, eps: float = 1e-8):
        self.eps = eps
        assert torch.allclose(self(torch.zeros(1, 1)), torch.ones(1, 1))
        assert torch.allclose(self(torch.ones(1, 1)), torch.zeros(1, 1), atol=math.sqrt(self.eps))

    @abstractmethod
    def __call__(self, t) -> torch.Tensor:
        pass

    def dt(self, t: torch.Tensor) -> torch.Tensor:
        return vmap(jacrev(self)) (t)
    
class LinearAlpha(Alpha):
    def __init__(self):
        super().__init__()

    def __call__(self, t: torch.Tensor) -> torch.Tensor:
        return t
    
    def dt(self, t: torch.Tensor) -> torch.Tensor:
        return torch.ones_like(t)
    
class LinearBeta(Beta):
    def __init__(self):
        super().__init__()

    def __call__(self, t: torch.Tensor) -> torch.Tensor:
        return 1 - t
    
    def dt(self, t: torch.Tensor) -> torch.Tensor:
        return - torch.ones_like(t)
    
class SquareRootBeta(Beta):
    def __init__(self):
        super().__init__()

    def __call__(self, t: torch.Tensor) -> torch.Tensor:
        return torch.sqrt(torch.clamp(1 - t, min=self.eps))
    
    def dt(self, t: torch.Tensor) -> torch.Tensor:
        return -0.5 / torch.sqrt(torch.clamp(1 - t, min=self.eps))
    
alpha_linear = LinearAlpha()
beta_linear = LinearBeta()
beta_sqrt = SquareRootBeta()

In [None]:
class ConditionalVectorField(ABC):
    @abstractmethod
    def velocity(self, xt: torch.Tensor, t: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        # xt: shape(bs, dims), t: shape(bs, 1), z: shape(bs, dims) where dims = channels * width * height
        # at time t and position xt, given z, return the velocity veccity(xt, t, z)
        pass

class GaussianConditionalVectorField(ConditionalVectorField):
    def __init__(self, alpha: Alpha, beta: Beta):
        self.alpha = alpha
        self.beta = beta

    def velocity(self, xt: torch.Tensor, t: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        alpha_d_t, beta_d_t = self.alpha.dt(t), self.beta.dt(t)
        alpha_t, beta_t = self.alpha(t), self.beta(t)
        return (alpha_d_t - beta_d_t * alpha_t / beta_t) * z + beta_d_t / beta_t * xt
    
gaussian_conditional_vector_field_linear = GaussianConditionalVectorField(alpha=alpha_linear, beta=beta_linear)
gaussian_conditional_vector_field_sqrt = GaussianConditionalVectorField(alpha=alpha_linear, beta=beta_sqrt)

# Part 4: NN and Train classes

In [None]:
def make_mlp(dims: List[int], activation: Type[nn.Module] = nn.SiLU):
    mlp = []
    for index in range(len(dims) - 1):
        mlp.append(nn.Linear(dims[index], dims[index+1]))
        if index < len(dims) - 2:
            mlp.append(activation())
    return nn.Sequential(*mlp)

class MLPVectorField(nn.Module):
    def __init__(self, dim: int = 1 * 64 * 64, hiddens: List[int] = [2048, 512, 256], activation = None):
        super().__init__()
        self.dim = dim
        self.mlp = make_mlp([dim+1] + hiddens + [dim], activation) if activation else make_mlp([dim+1] + hiddens + [dim])

    def forward(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        # xt: shape(bs, dims), t: shape(bs, 1)
        input = torch.cat([xt, t], dim=-1)
        return self.mlp(input)
    
mlp_stretched_gaussian = MLPVectorField()
mlp_moons = MLPVectorField()
mlp_circles = MLPVectorField()
mlp_checkerboard = MLPVectorField()

Recall now that from lecture that our goal is to learn the *marginal vector field* $u_t(x)$ given by 

$$u_t^{\text{ref}}(x) = \mathbb{E}_{z \sim p_t(z|x)}\left[u_t^{\text{ref}}(x|z)\right].$$

Unfortunately, we don't actually know what $u_t^{\text{ref}}(x)$ is! 

We will thus approximate $u_t^{\text{ref}}(x)$ as a neural network $u_t^{\theta}(x)$, and exploit the identity 

$$ u_t^{\text{ref}}(x) = \text{argmin}_{u_t(x)} \,\,\mathbb{E}_{z \sim p_t(z|x)} \left[\lVert u_t(x) - u_t^{\text{ref}}(x|z)\rVert^2\right]$$ 

to obtain the **conditional flow matching objective**

$$ \mathcal{L}_{\text{CFM}}(\theta) = \,\,\mathbb{E}_{z \sim p(z), x \sim p_t(x|z)} \left[\lVert u_t^{\theta}(x) - u_t^{\text{ref}}(x|z)\rVert^2\right].$$

To model $u_t^{\theta}(x)$, we'll use a simple MLP. This network will take in both $x$ and $t$, and will return the learned vector field $u_t^{\theta}(x)$.


We simulate the loss function: 

$$\mathcal{L}_{\text{CFM}}(\theta) = \,\,\mathbb{E}_{{t \in \mathcal{U}[0,1), z \sim p(z), x \sim p_t(x|z)}} {\lVert u_t^{\theta}(x) - u_t^{\text{ref}}(x|z)\rVert^2}$$

using a Monte-Carlo estimate of the form

$$\frac{1}{N}\sum_{i=1}^N {\lVert u_{t_i}^{\theta}(x_i) - u_{t_i}^{\text{ref}}(x_i|z_i)\rVert^2}, \quad \quad \quad \forall i\in[1, \dots, N]: {\,z_i \sim p_{\text{data}},\, t_i \sim \mathcal{U}[0,1),\, x_i \sim p_t(\cdot | z_i)}.$$

Here, $N$ is our *batch size*.

In [None]:
class MLPGaussianTrainer():
    def __init__(self, u_theta: MLPVectorField, u_cond: GaussianConditionalVectorField, dataloader: DataLoader, save_addr: str = None):
        self.u_theta = u_theta
        self.u_cond = u_cond
        self.dataloader = dataloader
        if save_addr:
            if not save_addr.endswith("pth"):
                save_addr += "pth"
        else: 
            save_addr = "mlp_parameters.pth"
        self.save_addr = save_addr

    def get_optimizer(self, lr: float = 1e-4):
        return torch.optim.Adam(self.u_theta.parameters(), lr=lr)
    
    def get_loss(self, xt: torch.Tensor, t: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        # xt: shape(bs, dims), t: shape(bs, 1)

        # velocity: shape(bs, dims)
        velocity_cond = self.u_cond.velocity(xt, t, z)
        velocity_mlp = self.u_theta(xt, t)

        errors = torch.sum((velocity_cond - velocity_mlp) ** 2, dim=-1)
        return torch.mean(errors)

    def train(self, device: torch.device, num_epochs: int = 50000, lr: float = 1e-4, loss_image_name: str = None):
        # make preparation
        self.u_theta.to(device)
        self.u_theta.train()
        optimizer = self.get_optimizer(lr=lr)
        losses = []
        
        # train process
        pbr = tqdm(range(num_epochs))
        for epoch in pbr:
            total_loss = 0.0
            for batch in self.dataloader:
                z = batch.to(device).view(batch.shape[0], -1)           # z: shape(bs, dims)

                t = torch.rand(z.shape[0], 1).to(device)                     # t: shape(bs, 1)
  
                x_init = torch.randn_like(z).to(device)                        # xt: shape(bs, dims)
                xt = self.u_cond.alpha(t) * z + self.u_cond.beta(t) * x_init

                loss = self.get_loss(xt, t, z)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
            pbr.set_description(f"Epoch {epoch}, the loss is: {total_loss}")
            losses.append(total_loss)

        self.u_theta.eval()

        # visualize the loss curve
        plt.plot(losses)
        plt.title("Training Loss Curve")
        plt.xlabel("num_epochs")
        plt.ylabel("loss")
        plt.grid(True)

        if loss_image_name is not None:
            os.makedirs(os.path.dirname(loss_image_name), exist_ok=True)
            plt.savefig(loss_image_name, bbox_inches='tight', pad_inches=0, dpi=300)
            print(f"Loss curve saved to: {loss_image_name}")
        plt.show()

        # save the parameters
        torch.save(self.u_theta.state_dict(), self.save_addr)

In [None]:
trainer_stretched_gaussian = MLPGaussianTrainer(
    u_theta = mlp_stretched_gaussian,
    u_cond = gaussian_conditional_vector_field_linear,
    dataloader = dataloader_stretched_gaussian,
    save_addr = "to_stretched_gaussian_mlp.pth"
)

trainer_moons = MLPGaussianTrainer(
    u_theta = mlp_moons,
    u_cond = gaussian_conditional_vector_field_linear,
    dataloader = dataloader_moons,
    save_addr = "to_moons_mlp.pth"
)

trainer_circles = MLPGaussianTrainer(
    u_theta = mlp_circles,
    u_cond = gaussian_conditional_vector_field_linear,
    dataloader = dataloader_circles,
    save_addr = "to_circles_mlp.pth"
)

trainer_checkerboard = MLPGaussianTrainer(
    u_theta = mlp_checkerboard,
    u_cond = gaussian_conditional_vector_field_linear,
    dataloader = dataloader_checkerboard,
    save_addr = "to_checkerboard_mlp.pth"
)

In [None]:
class LearnedVectorField(VectorField):
    def __init__(self, model: MLPVectorField):
        self.model = model

    def velocity(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        return self.model(xt, t)

# Part 5: Training, Evaluation and visluation

In [None]:
def visualize_generated_tensor(tensor: torch.Tensor, title: str = "Generated Image", save_path: str = None):
    """
    - tensor: shape = (4096,) → reshape to (64, 64), and visualize it as grayscale image.
    - save_path: if provided, saves the image to this file path.
    """
    if tensor.ndim != 1 or tensor.shape[0] != 4096:
        raise ValueError(f"Expected tensor of shape (4096,), got {tensor.shape}")

    image = tensor.detach().cpu().view(64, 64)  # reshape
    image = torch.clamp(image, 0.0, 1.0)        # ensure range in [0, 1]

    plt.imshow(image, cmap='gray')
    plt.title(title)
    plt.axis('off')

    if save_path is not None:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
        print(f"Image saved to: {save_path}")

    plt.show()

In [None]:
x_init = torch.randn(1,4096).to(device)
ts = torch.linspace(0.0, 1.0, 500).to(device)

### 1. Moons

In [None]:
trainer_moons.train(device=device, loss_image_name="loss_analysis/moons.png")

vector_field_moons = LearnedVectorField(model=mlp_moons)
euler_moons = EulerSimulator(vector_field=vector_field_moons)

In [None]:
generated_moons = euler_moons.simulate(x_init, ts).squeeze()
visualize_generated_tensor(tensor=generated_moons, title="Moons", save_path="generated/Moons.png")

### 2. Circles

In [None]:
trainer_circles.train(device=device, loss_image_name="loss_analysis/circles.png")

vector_field_circles = LearnedVectorField(model=mlp_circles)
euler_circles = EulerSimulator(vector_field=vector_field_circles)

In [None]:
generated_circles = euler_circles.simulate(x_init, ts).squeeze()
visualize_generated_tensor(tensor=generated_circles, title="Circles", save_path="generated/Circles.png")

### 3. Checkerboard

In [None]:
trainer_checkerboard.train(device=device, loss_image_name="loss_analysis/checkerboard.png")

vector_field_checkerboard = LearnedVectorField(model=mlp_checkerboard)
euler_circles = EulerSimulator(vector_field=vector_field_checkerboard)

In [None]:
generated_checkerboard = euler_circles.simulate(x_init, ts).squeeze()
visualize_generated_tensor(tensor=generated_checkerboard, title="Checkerboard", save_path="generated/Checkerboard.png")

### 4. Stretched Gaussian

In [None]:
trainer_stretched_gaussian.train(device=device, loss_image_name="loss_analysis/stretched_gaussian.png")

vector_field_stretched_gaussian = LearnedVectorField(model=mlp_stretched_gaussian)
euler_stretched_gaussian = EulerSimulator(vector_field=vector_field_stretched_gaussian)

In [None]:
generated_stretched_gaussian = euler_stretched_gaussian.simulate(x_init, ts).squeeze()
visualize_generated_tensor(tensor=generated_stretched_gaussian, title="Stretched Gaussian", save_path="generated/Stretched_Gaussian.png")