## Difference Transition Matching (DTM) vs. Flow Matching (FM) Demo

This notebook is a small demo that shows the difference between DTM and FM in a 2 deminsional example, and reproduce my DTM vs. FM GIF.


![Simulation Result](dtm_vs_fm.gif)

In [None]:
import os
import time
from typing import Optional, List
import io

from tqdm import tqdm
from IPython.display import display, Video, Image

import torch
from torch import nn, Tensor

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import imageio.v2 as imageio

if torch.cuda.is_available():
    device = 'cuda:0'
    print('Using gpu')
else:
    device = 'cpu'
    print('Using cpu.')

# Architecture

In [None]:
# Activation class
class Swish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: Tensor) -> Tensor: 
        return torch.sigmoid(x) * x

# Model class
class MLP(nn.Module):
    input_dim:int = 3
    output_dim:int = 2
    def __init__(self, hidden_dim: int = 128, is_tm: bool = False):
        super().__init__()
        self.is_tm = is_tm
        self.input_dim += 3*int(is_tm)

        self.main = nn.Sequential(
            nn.Linear(self.input_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, self.output_dim),
        )
    
    def forward(
        self, 
        x: Tensor, 
        t: Tensor, 
        y: Optional[Tensor] = None, 
        s: Optional[Tensor] = None
    ) -> Tensor:
        x = x.reshape(-1, 2)
        t = t.reshape(-1, 1)

        if self.is_tm:
            assert s is not None, "s timesteps must be provided for transition matching"
            assert y is not None, "y target parametrization must be provided for transition matching"
            s = s.reshape(-1, 1)
            y = y.reshape(-1, 2)
            h = torch.cat([x, t, y, s], dim=1)
        else:
            h = torch.cat([x, t], dim=1)
        
        output = self.main(h)
        
        return output

# Gaussian Mixture Model (GMM)

In [None]:
class GaussianMixtureModel(torch.nn.Module):
    def __init__(self, weights:Tensor, means:Tensor, sigmas:Tensor):
        super().__init__()
        self.weights = torch.nn.Parameter(
            weights.float(), requires_grad=False
        )
        self.means = torch.nn.Parameter(
            means.float(), requires_grad=False
        )
        self.sigmas = torch.nn.Parameter(
            sigmas.float(), requires_grad=False
        )
    
    def log_prob(self, x: Tensor):
        means = self.means[None] 
        sigmas = self.sigmas[None]
        weights = self.weights[None]
        x = x[:,None]
        # log prob of a single Guassian 
        gaussian_log_probs = -torch.sum((x-means)**2, dim=-1)/(2*sigmas**2) \
            - torch.log((2*torch.pi)**0.5*sigmas)
        # log prob of a mixture of Gaussians weighted by 'weights' 
        mixture_log_probs = torch.logsumexp(torch.log(weights) + gaussian_log_probs, dim=-1)
        
        return mixture_log_probs
        
    def sample(self, num_samples: int):
        # sample a single Gaussian index according to weights
        index = torch.multinomial(self.weights, num_samples, replacement=True)
        # sample from the selected Gaussians
        normal = torch.randn(
            size=(num_samples, self.means.size(-1)), device=self.means.device
        )
        sample = self.means[index] + self.sigmas[index][...,None] * normal
        
        return sample
    
    def forward(self, x: Tensor):
        return torch.exp(self.log_prob(x))

# Instantiate source and target

In [None]:
# Init target distributions
target = GaussianMixtureModel(
    weights=torch.tensor(
        [
            0.66, # first component weight
            0.34, # second component weight
        ]
    ),
    means=torch.tensor(
        [
            [7, -1], # first component mean 
            [6, 1],  # second component mean
        ]
    ),
    sigmas=torch.tensor(
        [
            0.80, # first component std 
            0.80, # second component std
        ]
    ),
).to(device)
# Init source distributions
source = GaussianMixtureModel( # simple Normal distribution
    weights=torch.tensor([1.0]),
    means=torch.tensor([[0, 0]]),
    sigmas=torch.tensor([1.0]),
).to(device)

def set_background(ax: plt.Axes):
    """Set axis background as source and target density."""
    limits = { # (x_min, x_max, y_min, y_max)
        'source': (-2.5, 3, -4, 4),
        'target': (3, 9.0, -4, 4)
    }

    density = {
        'source': source,
        'target': target
    }

    xx, yy, zz = {}, {}, {}

    # grid size n_points x n_points
    n_points = 200
    for key in ['source', 'target']:
        x_min, x_max, y_min, y_max = limits[key]
        x = torch.linspace(x_min, x_max, n_points)
        y = torch.linspace(y_min, y_max, n_points)
        x_mesh, y_mesh = torch.meshgrid(x, y, indexing="ij")
        # store meshgrid for x and y
        xx[key] = x_mesh.numpy()
        yy[key] = y_mesh.numpy()
        # Flat grid of shape [n_points**2, 2]
        grid_points = torch.stack([x_mesh.flatten(), y_mesh.flatten()], dim=1).to(device)
        # evaluate densities 
        z_mesh = density[key](grid_points).reshape(n_points, n_points).cpu()
        # store meshgrid for z
        zz[key] = z_mesh.numpy()

    # custom colormap for the target distribution that combines low values from OrRd, high values from PuBu
    # Sample points along OrRd (0 to ~0.5), and PuBu (0.5 to 1.0).
    cutoff = 0.055
    OrBu = mcolors.ListedColormap(
        np.vstack(
            [
                cm.OrRd(np.linspace(0, cutoff, int((256*cutoff))))     , 
                cm.PuBu(np.linspace(cutoff, 1.0, int((256*(1-cutoff)))))
            ]
        ), name='OrBu'
    )
    cmaps = {
        'source': cm.OrRd,
        'target': OrBu,
    }

    # min and max values of densities to normalize color maps
    vmin = np.stack(list(zz.values())).min()
    vmax = np.stack(list(zz.values())).max()

    for key in ['source', 'target']:
        ax.contourf(
            xx[key], 
            yy[key], 
            zz[key], 
            levels=10, 
            cmap=cmaps[key], 
            alpha=0.8,
            vmin=vmin, 
            vmax=vmax,
            zorder=0
        )

    ax.set_xlim(-2.5, 9.0)
    ax.set_ylim(-4, 4)
    ax.set_aspect('equal')
    ax.axis('off')


# Preview source and target distributions
fig, ax = plt.subplots(1,1,figsize=(8,6))
# Set background 
set_background(ax)

fig.tight_layout()
fig.show()

# Train DTM and FM models

In [None]:
def train(model, loss_fn):
    # training arguments
    lr = 0.001
    batch_size = 4096
    iterations = 50000
    print_every = 10000 

    # init optimizer
    optim = torch.optim.Adam(model.parameters(), lr=lr) 

    # cosine annealing scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=iterations, eta_min=0.00001)
        
    # train
    start_time = time.time()
    for i in range(iterations):
        optim.zero_grad() 
        
        # sample time 
        t = torch.rand(batch_size, 1, device=device)
        # sample datapoint X_1
        x_1 = target.sample(batch_size) 
        # Sample source X_0
        x_0 = source.sample(batch_size)
        # get X_t
        x_t = (1-t)*x_0 + t*x_1
        # get Y
        y = x_1 - x_0

        # compute loss
        loss = loss_fn(model, x_t, t, y)

        # optimizer step
        loss.backward() 
        optim.step() 
        scheduler.step()
        
        # log loss
        if (i+1) % print_every == 0:
            elapsed = time.time() - start_time
            print('| iter {:6d} | {:5.2f} ms/step | loss {:8.3f} ' 
                .format(i+1, elapsed*1000/print_every, loss.item())) 
            start_time = time.time()


# Difference transition matching (DTM) and flow matching (FM)
# share the same supervising process prameterization.
# Linear process: (1-t)*X_0 + t*X_1.
# Difference prediction: Y = X_1 - X_0.
# However, they are different in thier choice of modeling, i.e., loss functions and architecture.

def transition_matching_loss(model, x_t, t, y):
    """Apply flow matching loss with Y as target and Gaussian source."""
    # sample time 
    s = torch.rand_like(t)
    # sample source Y_0
    y_0 = torch.randn_like(y)
    # get Y_s
    y_s = (1-s)*y_0 + s*y
    # Transition matching loss
    loss = torch.nn.functional.mse_loss(
        model(x_t, t, y_s, s), y - y_0
    )
    return loss

def flow_matching_loss(model, x_t, t, y):
    # Flow matching loss
    loss = torch.nn.functional.mse_loss(
        model(x_t, t), y
    )
    return loss

# init model for dtm
model_dtm = MLP(hidden_dim=512, is_tm=True).to(device) 

# init model for fm
model_fm = MLP(hidden_dim=512).to(device)

print("Training DTM model:")
train(model_dtm, transition_matching_loss)
print("\n")
print("Training FM model:")
train(model_fm, flow_matching_loss)

# Sample DTM and FM

In [None]:
# function that sample Y_t ~ p_{Y|X_t} using DTM model
INNER_STEPS=32
def dtm_fn(x_t,t, ):
    y_s = torch.randn(x_t.size(0), 2).to(device)
    s_grid = torch.linspace(0, 1, INNER_STEPS+1, device=device)
    for s, ds in zip(s_grid[:-1], torch.diff(s_grid)):
        u_s = model_dtm(x_t, t, y_s, s)
        y_s = y_s + ds * u_s
    
    return y_s

# function that sample Y_t=E[Y|X_t] using FM model
def fm_fn(x_t, t):
    return model_fm(x_t, t)

@torch.no_grad()
def sampler(model_fn, x_0, T):
    # store t along the trajectory
    traj_t = torch.empty(T+1,1, device=device)
    # store X_t along the trajectory
    traj_x = torch.empty(T+1,2, device=device)
    # store Y_t along the trajectory
    traj_y = torch.empty(T,2, device=device)

    t = torch.zeros(1,1, device=device)
    traj_t[0] = t
    x_t = x_0.view(1, 2)
    traj_x[0] = x_t
    for i in range(T):
        y_t = model_fn(x_t, t)
        x_t = x_t + 1/T * y_t
        t = t + 1/T

        traj_t[i+1] = t
        traj_x[i+1] = x_t
        traj_y[i] = y_t
    
    return traj_t, traj_x, traj_y

# number transition steps to plot T=2,4,8,32,128
T_dtm = 2**torch.tensor([1, 2, 3, 5, 7], device=device)

# seed for source points
torch.manual_seed(2025)
X_0 = torch.randn(len(T_dtm), 2).to(device)
X_0[:-2] = torch.randn(3, 2).to(device)

t_dtm = []
X_dtm = []
Y_dtm = []
# seed for sampling the transition probability
torch.manual_seed(111)
for T, x_0 in zip(T_dtm, X_0):
    traj_t, traj_x, traj_y = sampler(dtm_fn, x_0, T)
    t_dtm.append(traj_t.detach().cpu())
    X_dtm.append(traj_x.detach().cpu())
    Y_dtm.append(traj_y.detach().cpu())

# number transition steps for FM is always 100
T_fm = torch.tensor([100]*len(T_dtm), device=device)
t_fm = []
X_fm = []
Y_fm = []
# FM transitions are deterministic so no need to set seed for reproducibility
for T, x_0 in zip(T_fm, X_0):
    traj_t, traj_x, traj_y = sampler(fm_fn, x_0, T)
    t_fm.append(traj_t.detach().cpu())
    X_fm.append(traj_x.detach().cpu())
    Y_fm.append(traj_y.detach().cpu())

is_preprocessed_for_plotting = False

# preview of trajectories
for i, T in enumerate(T_dtm):
    fig, ax = plt.subplots(1,1,figsize=(8,6))
    set_background(ax)
    # DTM trajectory
    ax.plot(
        X_dtm[i][:, 0], 
        X_dtm[i][:, 1], 
        color='black', 
        linewidth=4,
        label=f'DTM: {T.item():1d}-steps'
    )
    # FM trajectory
    ax.plot(
        X_fm[i][:, 0], 
        X_fm[i][:, 1], 
        color='black', 
        alpha=0.3,
        linewidth=4,
        label='FM'
    )
    
    ax.legend(
        fontsize=20,
        loc='upper right',
        bbox_to_anchor=(0.98, 0.98),  
        borderaxespad=0.2,            
        frameon=False                 
    )
    
    ax.set_aspect('equal')
    ax.axis('off')
    fig.tight_layout()
    fig.show()


# Make frames and GIF

In [None]:
if not is_preprocessed_for_plotting:
    # Downsample the FM trajectories for visualization
    t_fm = [t[::4] for t in t_fm]
    X_fm = [x_t[::4] for x_t in X_fm]
    Y_fm = [y_t[::4] for y_t in Y_fm]

    # The difference prediction Y = X_1 - X_0.
    # Hence Y_t (i.e., Y | X_t) is plotted as 
    # the difference between data prediction X_1
    # and noise prediction X_0. Noise prediction is 
    # computed with X_0 = X_t - t*Y_t and data prediction 
    # with X_1 = X_t + (1-t)*Y_t
    Y_dtm = [ 
        torch.stack(
            [x_t[:-1] - t[:-1]*y_t, x_t[:-1] + (1-t[:-1])*y_t],
        dim=1
        ) for x_t, y_t, t in zip(X_dtm, Y_dtm, t_dtm)

    ]

    Y_fm = [ 
        torch.stack(
            [x_t[:-1] - t[:-1]*y_t, x_t[:-1] + (1-t[:-1])*y_t],
        dim=1
        ) for x_t, y_t, t in zip(X_fm, Y_fm, t_fm)

    ]
    is_preprocessed_for_plotting = True

# set colors
blue = '#4C72B0'
red = '#C44E52'
green = "#81C784"
yellow = "#FFF176"

def set_legend(ax: plt.Axes, source_point: Tensor, T_dtm: int):
    x, y = source_point.view(-1)
    # plot a single point line of DTM for the legend
    ax.plot(
        x, 
        y,
        color='black', 
        linewidth=4,
        label=f'DTM: {T_dtm:1d}-steps'
    )

    # plot a single point line of FM for the legend
    ax.plot(
        x, 
        y,
        color= "#9E9E9E",
        alpha=0.8,
        linewidth=4,
        label='FM'
    )

    ax.legend(
        fontsize=20,
        loc='upper right',
        bbox_to_anchor=(0.98, 0.98),  
        borderaxespad=0.2,            
        frameon=False                 
    )

# Flag that set whether to save frames to disk or only save the GIF.
SAVE_TO_DISK = False
# set DPI to 150 or 300 for full size.
DPI = 150
frames = []
def save_frame(fig: plt.Figure, name_counter: int):
    # set target
    if SAVE_TO_DISK:
        os.makedirs('dtm_vs_fm', exist_ok=True)
        target = f'dtm_vs_fm/{name_counter:03d}.png'
    else:
        target = io.BytesIO()
    # save frame to target
    fig.savefig(target, format='png', dpi=DPI, bbox_inches='tight')
    # read the frame back
    if SAVE_TO_DISK:
        frames.append(imageio.imread(target))
    else:
        target.seek(0)
        frames.append(imageio.imread(target))
        target.close() 

name_counter = 0
for i, T in enumerate(T_dtm):
    print(f'Generating frames for T={T.item():1d}-steps DTM vs FM:')
    # save first frame
    fig, ax = plt.subplots(1,1,figsize=(8,6))
    set_background(ax)
    ax.set_xlim(-2.5, 9.0)
    ax.set_ylim(-4, 4)
    ax.set_aspect('equal')
    ax.axis('off')
    set_legend(
        ax, 
        source_point=X_fm[i][0], 
        T_dtm=T.item()
    )
    # scatter source point in red
    ax.scatter(
        X_fm[i][0, 0], X_fm[i][0, 1], s=128, c=red, alpha=1.0, zorder=3
    )
    # save frame
    fig.tight_layout()
    save_frame(fig, name_counter)
    name_counter += 1

    # plot frames of FM trajectory 
    for j in range(len(X_fm[i])):
        fig, ax = plt.subplots(1,1,figsize=(8,6))
        set_background(ax)
        set_legend(
            ax, 
            source_point=X_fm[i][0], 
            T_dtm=T.item()
        )
        # plot FM trajectory up to step j
        ax.plot(
            X_fm[i][:j+1, 0], 
            X_fm[i][:j+1, 1], 
            color= "#9E9E9E",
            alpha=0.8,
            linewidth=4,
        )
        # scatter source point in red
        ax.scatter(
            X_fm[i][0, 0], X_fm[i][0, 1], s=128, c=red, alpha=1.0, zorder=3
        )
        # if j is the last step, scatter the end of the line in blue
        if j == len(X_fm[i])-1:
            ax.scatter(
                X_fm[i][j, 0], X_fm[i][j, 1], s=128, c=blue, alpha=1.0, zorder=3
            )
        # else, scatter the end of the line in gray,
        # and plot the j-th Y_t.
        else:
            ax.scatter(
                X_fm[i][j, 0], X_fm[i][j, 1], s=128, 
                color= "#9E9E9E", alpha=0.8,
                zorder=3
            )
            ax.plot(
                    Y_fm[i][j,:, 0],
                    Y_fm[i][j,:, 1], 
                    color=yellow,
                    linewidth=4,
                    marker='o', 
                    markersize=11.31,
                    zorder=1,
                )
        
        # save frame
        fig.tight_layout()
        save_frame(fig, name_counter)
        name_counter += 1
        plt.close('all')
    
    
    # plot frames of DTM trajectory 
    for j in tqdm(range(len(X_dtm[i]))):
        # For DTM Y_t is dissolved into the plot using the alpha scale.
        # To prevent the GIF from being too long, for larger number of steps T,
        # only a subset od the alpha scales are used.
        alpha_scale = [0, 0.25, 0.5, 0.75, 1.0, 1.0]

        if T in [2, 4]:
            k_indices = [0, 1, 2, 3, 4, 5]
        
        elif T == 8:
            k_indices = [0, 2, 4, 5]
        
        elif T == 32:
            k_indices = [0, 4, 5]
        
        elif T == 128:
            k_indices = [5]
        
        def make_step(k):
            return k == k_indices[-1]
        
        for k in k_indices:
            fig, ax = plt.subplots(1,1,figsize=(8,6))
            set_background(ax)
            set_legend(
                ax, 
                source_point=X_fm[i][0], 
                T_dtm=T.item()
            )
            # scatter source point in red
            ax.scatter(
                X_fm[i][0, 0], X_fm[i][0, 1], s=128, c=red, alpha=1.0, zorder=3
            )
            # plot the full FM trajectory
            ax.plot(
                X_fm[i][:, 0], 
                X_fm[i][:, 1], 
                color= "#9E9E9E",
                alpha=0.8,
                linewidth=4,
            )
            # scatter the end point of FM trajectory in blue
            ax.scatter(
                    X_fm[i][-1, 0], X_fm[i][-1, 1], s=128, c=blue, alpha=1.0, zorder=3
                )
            
            alpha = alpha_scale[k]
            if j == len(X_dtm[i])-1:
                # scatter the end point of DTM trajectory in blue
                ax.scatter(
                    X_dtm[i][-1, 0], X_dtm[i][-1, 1], s=128, c=blue, alpha=1.0, zorder=3
                )
            else:
                # else, scatter the end of the line in black,
                # and plot the j-th Y_t.
                ax.plot(
                        Y_dtm[i][j,:, 0],  
                        Y_dtm[i][j,:, 1],  
                        color=green,
                        alpha=alpha,
                        linewidth=4,
                        marker='o', 
                        markersize=11.31,
                    )
                ax.scatter(
                    X_dtm[i][j+int(make_step(k)), 0], 
                    X_dtm[i][j+int(make_step(k)), 1], 
                    s=128, 
                    c='black',
                    alpha=1.0, 
                    zorder=3,
                )
            # plot FM trajectory up to step j + int(make_step(k))
            ax.plot(
                X_dtm[i][:j+1+int(make_step(k)), 0], 
                X_dtm[i][:j+1+int(make_step(k)), 1], 
                color='black', 
                linewidth=4,
            )
            # save frame
            fig.tight_layout()
            save_frame(fig, name_counter)
            name_counter += 1
            plt.close('all')


GIF = True
# Create GIF or MP4 video from saved frames
if GIF:
    # Save as GIF
    imageio.mimsave("dtm_vs_fm.gif", frames, duration=1/24)
    display(Image("dtm_vs_fm.gif"))
else:
    # Save as MP4 video
    imageio.mimwrite("dtm_vs_fm.mp4", frames, fps=24)
    display(Video("dtm_vs_fm.mp4", width=500, height=300))
