In [1]:
from typing import Optional, List, Type, Tuple, Dict
from torchvision import datasets, transforms
import torch
import torch.nn as nn
import os
import ssl 

ssl._create_default_https_context = ssl._create_unverified_context
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Step 0: construct target data and initial distribution

In [2]:
from gaussian import Sampleable

class MNISTSampler(nn.Module, Sampleable):
    """
    Sampleable wrapper for the MNIST dataset
    """
    def __init__(self):
        super().__init__()
        self.dataset = datasets.MNIST(
            root='./data',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.Resize((32, 32)),
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ])
        )
        self.dummy = nn.Buffer(torch.zeros(1)) # Will automatically be moved when self.to(...) is called...
        self.dim = list(self.dataset[0][0].shape)

    def sample(self, num_samples: int) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Args:
            - num_samples: the desired number of samples
        Returns:
            - samples: shape (batch_size, c, h, w)
            - labels: shape (batch_size, label_dim)
        """
        if num_samples > len(self.dataset):
            raise ValueError(f"num_samples exceeds dataset size: {len(self.dataset)}")

        indices = torch.randperm(len(self.dataset))[:num_samples]
        samples, labels = zip(*[self.dataset[i] for i in indices])
        samples = torch.stack(samples).to(self.dummy)
        labels = torch.tensor(labels, dtype=torch.int64).to(self.dummy.device)
        return samples, labels

In [3]:
p_data = MNISTSampler()

### Step 1: Build Gaussian Conditional Probability Path

In [4]:
from gaussian import Sampleable

class StandardNormal(nn.Module, Sampleable):
    """
    Sampleable wrapper around torch.randn
    """
    def __init__(self, shape: List[int], std: float = 1.0):
        """
        shape: shape of sampled data
        """
        super().__init__()
        self.shape = shape
        self.std = std
        self.dummy = nn.Buffer(torch.zeros(1)) # Will automatically be moved when self.to(...) is called...

    def sample(self, num_samples) -> torch.Tensor:
        return self.std * torch.randn(num_samples, *self.shape).to(self.dummy.device)

In [5]:
class LinearAlpha:
    """
    Implements alpha_t = t
    """
    def __init__(self):
        # Check alpha_t(0) = 0
        assert torch.allclose(
            self(torch.zeros(1,1,1,1)), torch.zeros(1,1,1,1)
        )
        # Check alpha_1 = 1
        assert torch.allclose(
            self(torch.ones(1,1,1,1)), torch.ones(1,1,1,1)
        )

    def __call__(self, t: torch.Tensor) -> torch.Tensor:
        """
        Args:
            - t: time (num_samples, 1)
        Returns:
            - alpha_t (num_samples, 1)
        """
        return t

    def dt(self, t: torch.Tensor) -> torch.Tensor:
        """
        Evaluates d/dt alpha_t.
        Args:
            - t: time (num_samples, 1)
        Returns:
            - d/dt alpha_t (num_samples, 1)
        """
        return torch.ones_like(t)
        

class LinearBeta:
    """
    Implements beta_t = 1-t
    """
    def __init__(self):
        # Check beta_0 = 1
        assert torch.allclose(
            self(torch.zeros(1,1,1,1)), torch.ones(1,1,1,1)
        )
        # Check beta_1 = 0
        assert torch.allclose(
            self(torch.ones(1,1,1,1)), torch.zeros(1,1,1,1)
        )

    def __call__(self, t: torch.Tensor) -> torch.Tensor:
        """
        Args:
            - t: time (num_samples, 1)
        Returns:
            - beta_t (num_samples, 1)
        """
        return 1-t

    def dt(self, t: torch.Tensor) -> torch.Tensor:
        """
        Evaluates d/dt alpha_t.
        Args:
            - t: time (num_samples, 1, 1, 1)
        Returns:
            - d/dt alpha_t (num_samples, 1, 1, 1)
        """
        return - torch.ones_like(t)

In [6]:
class GaussianConditionalProbabilityPath(nn.Module):
    def __init__(self, p_data: Sampleable, alpha: LinearAlpha, beta: LinearBeta):
        super().__init__()
        p_init = StandardNormal(shape = p_data.dim, std = 1.0)
        self.p_init = p_init
        self.p_data = p_data
        
        self.alpha = alpha
        self.beta = beta

    def sample_marginal_path(self, t: torch.Tensor) -> torch.Tensor:
        num_samples = t.shape[0]
        # Sample conditioning variable z ~ p(z)
        z, _ = self.sample_conditioning_variable(num_samples) # (num_samples, c, h, w)
        # Sample conditional probability path x ~ p_t(x|z)
        x = self.sample_conditional_path(z, t) # (num_samples, c, h, w)
        return x

    def sample_conditioning_variable(self, num_samples: int) -> torch.Tensor:
        return self.p_data.sample(num_samples)

    def sample_conditional_path(self, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:        
        return self.alpha(t) * z + self.beta(t) * torch.randn_like(z)

    def conditional_vector_field(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        alpha_t = self.alpha(t) # (num_samples, 1, 1, 1)
        beta_t = self.beta(t) # (num_samples, 1, 1, 1)
        dt_alpha_t = self.alpha.dt(t) # (num_samples, 1, 1, 1)
        dt_beta_t = self.beta.dt(t) # (num_samples, 1, 1, 1)

        return (dt_alpha_t - dt_beta_t / beta_t * alpha_t) * z + dt_beta_t / beta_t * x

    def conditional_score(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        alpha_t = self.alpha(t)
        beta_t = self.beta(t)
        return (z * alpha_t - x) / beta_t ** 2

In [7]:
# Construct conditional probability path
path = GaussianConditionalProbabilityPath(
    p_data = p_data,
    alpha = LinearAlpha(),
    beta = LinearBeta()
).to(device)

#### Step 1.5: Visualize the conditional path from $X_0$ to a data point in $p_{\text{data}}$

In [None]:
from ultility import plot_conditional_path

plot_conditional_path(path)

### Step 2: Learn vector the field $u(x, t)$ with MLP neural net

In [9]:
from trainer import Trainer
from unet import MNISTUNet

class ConditionalFlowMatchingTrainer(Trainer):
    def __init__(self, path: GaussianConditionalProbabilityPath, model: MNISTUNet, **kwargs):
        super().__init__(model, **kwargs)
        self.path = path

    def get_train_loss(self, batch_size: int) -> torch.Tensor:
      z, _ = self.path.p_data.sample(batch_size)
      t = torch.rand(batch_size, 1, 1, 1).to(z.device)
      x = self.path.sample_conditional_path(z, t)
      u_theta = self.model(x, t)
      u_ref = self.path.conditional_vector_field(x, z, t)

      return torch.mean((u_theta - u_ref)**2)

In [10]:
# Construct learnable vector field
flow_model = MNISTUNet(
    channels = [32, 64, 128],
    num_residual_layers = 4,
    t_embed_dim = 128,
)


PRETRAINED_PATH = 'trained/mnist_unet_fm_10000.pt'

if os.path.exists(PRETRAINED_PATH):
    flow_model.load_state_dict(torch.load(PRETRAINED_PATH, map_location=torch.device('cpu')))
else:
    # Construct trainer
    trainer = ConditionalFlowMatchingTrainer(path, flow_model)
    losses = trainer.train(num_epochs=5000, device=device, lr=1e-3, batch_size=250)
    torch.save(flow_model.state_dict(), "trained/mnist_unet_fm.pt")

### Step 3 : Generate samples from learned model

In [None]:
from ode import LearnedVectorFieldODE, EulerSimulator
from ultility import plot_generated_sample


ode = LearnedVectorFieldODE(flow_model)
simulator = EulerSimulator(ode)

plot_generated_sample(path, simulator, num_timesteps = 1000)