In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.datasets import load_digits
from sklearn import datasets
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as tt

**DISCLAIMER**

The presented code is not optimized, it serves an educational purpose. It is written for CPU, it uses only fully-connected networks and an extremely simplistic dataset. However, it contains all components that can help to understand how flow matching works, and it should be rather easy to extend it to more sophisticated models. This code could be run almost on any laptop/PC, and it takes a couple of minutes top to get the result.

## Dataset: Digits

In this example, we go wild and use a dataset that is simpler than MNIST! We use a scipy dataset called Digits. It consists of ~1500 images of size 8x8, and each pixel can take values in $\{0, 1, \ldots, 16\}$.

The goal of using this dataset is that everyone can run it on a laptop, without any gpu etc.

In [None]:
class Digits(Dataset):
    """Scikit-Learn Digits dataset."""

    def __init__(self, mode='train', transforms=None):
        digits = load_digits()
        if mode == 'train':
            self.data = digits.data[:1000].astype(np.float32)
        elif mode == 'val':
            self.data = digits.data[1000:1350].astype(np.float32)
        else:
            self.data = digits.data[1350:].astype(np.float32)

        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transforms:
            sample = self.transforms(sample)
        return sample

## Flow Matching

In [None]:
class FlowMatching(nn.Module):
    def __init__(self, vnet, sigma, D, T, stochastic_euler=False, prob_path="icfm"):
        super(FlowMatching, self).__init__()

        print('Flow Matching by JT.')

        self.vnet = vnet

        self.time_embedding = nn.Sequential(nn.Linear(1, D), nn.Tanh())
        
        # other params
        self.D = D
        
        self.T = T

        self.sigma = sigma
        
        self.stochastic_euler = stochastic_euler
        
        assert prob_path in ["icfm", "fm"], f"Error: The probability path could be either Independent CFM (icfm) or Lipman's Flow Matching (fm) but {prob_path} was provided."
        self.prob_path = prob_path
        
        self.PI = torch.from_numpy(np.asarray(np.pi))
    
    def log_p_base(self, x, reduction='sum', dim=1):
        log_p = -0.5 * torch.log(2. * self.PI) - 0.5 * x**2.
        if reduction == 'mean':
            return torch.mean(log_p, dim)
        elif reduction == 'sum':
            return torch.sum(log_p, dim)
        else:
            return log_p
    
    def sample_base(self, x_1):
        # Gaussian base distribution
        if self.prob_path == "icfm":
            return torch.randn_like(x_1)
        elif self.prob_path == "fm":
            return torch.randn_like(x_1)
        else:
            return None
    
    def sample_p_t(self, x_0, x_1, t):
        if self.prob_path == "icfm":
            mu_t = (1. - t) * x_0 + t * x_1
            sigma_t = self.sigma
        elif self.prob_path == "fm":
            mu_t = t * x_1
            sigma_t = t * self.sigma - t + 1.
        
        x = mu_t + sigma_t * torch.randn_like(x_1)
        
        return x
    
    def conditional_vector_field(self, x, x_0, x_1, t):
        if self.prob_path == "icfm":
            u_t = x_1 - x_0
        elif self.prob_path == "fm":
            u_t = (x_1 - (1. - self.sigma) * x) / (1. - (1. - self.sigma) * t)
        
        return u_t

    def forward(self, x_1, reduction='mean'):
        # =====Flow Matching
        # =====
        # z ~ q(z), e.g., q(z) = q(x_0) q(x_1), q(x_0) = base, q(x_1) = empirical
        # t ~ Uniform(0, 1)
        x_0 = self.sample_base(x_1)  # sample from the base distribution (e.g., Normal(0,I))
        t = torch.rand(size=(x_1.shape[0], 1))
        
        # =====
        # sample from p(x|z)
        x = self.sample_p_t(x_0, x_1, t)  # sample independent rv 

        # =====
        # invert interpolation, i.e., calculate vector field v(x,t)
        t_embd = self.time_embedding(t)
        v = self.vnet(x + t_embd)
        
        # =====
        # conditional vector field
        u_t = self.conditional_vector_field(x, x_0, x_1, t)

        # =====LOSS: Flow Matching
        FM_loss = torch.pow(v - u_t, 2).mean(-1)
        
        # Final LOSS
        if reduction == 'sum':
            loss = FM_loss.sum()
        else:
            loss = FM_loss.mean()

        return loss

    def sample(self,  batch_size=64):
        # Euler method
        # sample x_0 first
        x_t = self.sample_base(torch.empty(batch_size, self.D))
        
        # then go step-by-step to x_1 (data)        
        ts = torch.linspace(0., 1., self.T)
        delta_t = ts[1] - ts[0]
        
        for t in ts[1:]:
            t_embedding = self.time_embedding(torch.Tensor([t]))
            x_t = x_t + self.vnet(x_t + t_embedding) * delta_t
            # Stochastic Euler method
            if self.stochastic_euler:
                x_t = x_t + torch.randn_like(x_t) * delta_t
        
        x_final = torch.tanh(x_t)
        return x_final
    
    def log_prob(self, x_1, reduction='mean'):
        # backward Euler (see Appendix C in Lipman's paper)
        ts = torch.linspace(1., 0., self.T)
        delta_t = ts[1] - ts[0]
        
        for t in ts:
            if t == 1.:
                x_t = x_1 * 1.
                f_t = 0.
            else:
                # Calculate phi_t
                t_embedding = self.time_embedding(torch.Tensor([t]))
                x_t =x_t - self.vnet(x_t + t_embedding) * delta_t
                
                # Calculate f_t
                # approximate the divergence using the Hutchinson trace estimator and the autograd
                self.vnet.eval()  # set the vector field net to evaluation
                
                x = torch.FloatTensor(x_t.data)  # copy the original data (it doesn't require grads!)
                x.requires_grad = True 
                
                e = torch.randn_like(x)  # epsilon ~ Normal(0, I) 
                
                e_grad = torch.autograd.grad(self.vnet(x).sum(), x, create_graph=True)[0]
                e_grad_e = e_grad * e
                f_t = e_grad_e.view(x.shape[0], -1).sum(dim=1)

                self.vnet.eval()  # set the vector field net to train again
        
        log_p_1 = self.log_p_base(x_t, reduction='sum') - f_t
        
        if reduction == "mean":
            return log_p_1.mean()
        elif reduction == "sum":
            return log_p_1.sum()

## Evaluation and Training functions

**Evaluation step, sampling and curve plotting**

In [None]:
def evaluation(test_loader, name=None, model_best=None, epoch=None):
    # EVALUATION
    if model_best is None:
        # load best performing model
        model_best = torch.load(name + '.model')

    model_best.eval()
    loss = 0.
    N = 0.
    for indx_batch, test_batch in enumerate(test_loader):
        loss_t = -model_best.log_prob(test_batch, reduction='sum')
        loss = loss + loss_t.item()
        N = N + test_batch.shape[0]
    loss = loss / N

    if epoch is None:
        print(f'FINAL LOSS: nll={loss}')
    else:
        print(f'Epoch: {epoch}, val nll={loss}')

    return loss


def samples_real(name, test_loader):
    # REAL-------
    num_x = 4
    num_y = 4
    x = next(iter(test_loader)).detach().numpy()

    fig, ax = plt.subplots(num_x, num_y)
    for i, ax in enumerate(ax.flatten()):
        plottable_image = np.reshape(x[i], (8, 8))
        ax.imshow(plottable_image, cmap='gray')
        ax.axis('off')

    plt.savefig(name+'_real_images.pdf', bbox_inches='tight')
    plt.close()


def samples_generated(name, data_loader, extra_name='', T=None):
    # GENERATIONS-------
    model_best = torch.load(name + '.model')
    model_best.eval()
    
    if T is not None:
        model_best.T = T

    num_x = 4
    num_y = 4
    x = model_best.sample(batch_size=num_x * num_y)
    x = x.detach().numpy()

    fig, ax = plt.subplots(num_x, num_y)
    for i, ax in enumerate(ax.flatten()):
        plottable_image = np.reshape(x[i], (8, 8))
        ax.imshow(plottable_image, cmap='gray')
        ax.axis('off')

    plt.savefig(name + '_generated_images' + extra_name + '.pdf', bbox_inches='tight')
    plt.close()

def plot_curve(name, nll_val):
    plt.plot(np.arange(len(nll_val)), nll_val, linewidth='3')
    plt.xlabel('epochs')
    plt.ylabel('nll')
    plt.savefig(name + '_nll_val_curve.pdf', bbox_inches='tight')
    plt.close()

**Training step**

In [None]:
def training(name, max_patience, num_epochs, model, optimizer, training_loader, val_loader):
    nll_val = []
    best_nll = 1000.
    patience = 0

    # Main loop
    for e in range(num_epochs):
        # TRAINING
        model.train()
        for indx_batch, batch in enumerate(training_loader):
            loss = model.forward(batch)

            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()

        # Validation
        loss_val = evaluation(val_loader, model_best=model, epoch=e)
        nll_val.append(loss_val)  # save for plotting

        if e == 0:
            print('saved!')
            torch.save(model, name + '.model')
            best_nll = loss_val
        else:
            if loss_val < best_nll:
                print('saved!')
                torch.save(model, name + '.model')
                best_nll = loss_val
                patience = 0

                # samples_generated(name, val_loader, extra_name="_epoch_" + str(e))
            else:
                patience = patience + 1
        
        samples_generated(name, val_loader, extra_name="_epoch_" + str(e))
        
        if patience > max_patience:
            break

    nll_val = np.asarray(nll_val)

    return nll_val

## Experiments

**Initialize datasets**

In [None]:
transforms = tt.Lambda(lambda x: 2. * (x / 17.) - 1.)  # changing to [-1, 1]

In [None]:
train_data = Digits(mode='train', transforms=transforms)
val_data = Digits(mode='val', transforms=transforms)
test_data = Digits(mode='test', transforms=transforms)

training_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

**Hyperparameters**

In [None]:
prob_path = "fm"

D = 64   # input dimension

M = 512  # the number of neurons in scale (s) and translation (t) nets

sigma = 0.1

T = 100

lr = 1e-3 # learning rate
num_epochs = 1000 # max. number of epochs
max_patience = 20 # an early stopping is used, if training doesn't improve for longer than 20 epochs, it is stopped

**Creating a folder for results**

In [None]:
name = prob_path + '_' + str(T)
result_dir = 'results/' + name + '/'
if not (os.path.exists(result_dir)):
    os.mkdir(result_dir)

**Initializing the model**

In [None]:
nnet = nn.Sequential(nn.Linear(D, M), nn.SELU(),
                     nn.Linear(M, M), nn.SELU(),
                     nn.Linear(M, M), nn.SELU(),
                     nn.Linear(M, D), nn.Hardtanh(min_val=-3., max_val=3.))

# Eventually, we initialize the full model
model = FlowMatching(nnet, sigma=sigma, T=T, D=D, stochastic_euler=False, prob_path=prob_path)

**Optimizer - here we use Adamax**

In [None]:
# OPTIMIZER
optimizer = torch.optim.Adamax([p for p in model.parameters() if p.requires_grad == True], lr=lr)

**Training loop**

In [None]:
# Training procedure
nll_val = training(name=result_dir + name, max_patience=max_patience, num_epochs=num_epochs, model=model, optimizer=optimizer,
                       training_loader=training_loader, val_loader=val_loader)

**The final evaluation**

In [None]:
test_loss = evaluation(name=result_dir + name, test_loader=test_loader)
f = open(result_dir + name + '_test_loss.txt', "w")
f.write(str(test_loss))
f.close()

samples_real(result_dir + name, test_loader)
samples_generated(result_dir + name, test_loader, extra_name='FINAL')

plot_curve(result_dir + name, nll_val)