In [12]:
from abc import ABC, abstractmethod
from typing import Optional, List, Type, Tuple, Dict
import math
import os

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes._axes import Axes
import torch
import torch.nn as nn
import torch.distributions as D
from torch.func import vmap, jacrev
from tqdm import tqdm
import seaborn as sns
from sklearn.datasets import make_moons, make_circles
from torchvision import datasets, transforms
from torchvision.utils import make_grid

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [4]:
class Sampleable(ABC):
    # distributions to be sampled from
    def sample(self, num_samples: int) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Args:
            - num_samples: the desired number of samples
        Returns:
            - samples: shape (batch_size, ...)
            - labels: shape (batch_size, label_dim)
        """
        pass

class IsotropicGaussian(nn.Module, Sampleable):
    def __init__(self, shape: List[int], std: float = 1.0):
        super().__init__()
        self.shape = shape
        self.std = std
        self.dummy = nn.Buffer(torch.zeros(1))

    def sample(self, num_samples) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        return self.std * torch.randn(num_samples, *self.shape).to(self.dummy.device), None

In [7]:
class ConditionalProbabilityPath(nn.Module, ABC):
    def __init__(self, p_simple: Sampleable, p_data: Sampleable):
        super().__init__()
        self.p_simple = p_simple
        self.p_data = p_data

    @abstractmethod
    def sample_conditioning_variable(self, num_samples: int) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        # get z and y
        # returns: z, y = shape(num_samples, c, h, w), shape(num_samples, dim_label)
        pass

    @abstractmethod
    def sample_conditional_path(self, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        # get position form the p_t(x|z), z: shape(bs, c, h, w), t: shape(bs, 1, 1, 1)
        # returns: shape(bs, c, h, w)
        pass
    
    @abstractmethod
    def conditional_vector_field(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        # get the velocity given position x, condition z and time t
        # x: shape(bs, c, h, w), z: shape(bs, c, h, w), t: shape(bs, 1, 1, 1)
        # returns: shape(bs, c, h, w)
        pass

In [8]:
class Alpha(ABC):
    def __init__(self):
        assert torch.allclose(self(torch.zeros(1, 1, 1, 1)), torch.zeros(1, 1, 1, 1))
        assert torch.allclose(self(torch.ones(1, 1, 1, 1)), torch.ones(1, 1, 1, 1))

    @abstractmethod
    def __call__(self, t: torch.Tensor) -> torch.Tensor:
        pass

    def dt(self, t: torch.Tensor) -> torch.Tensor:
        t = t.unsqueeze(1)
        return (vmap(jacrev(self)) (t)).view(-1, 1, 1, 1)
    
class Beta(ABC):
    def __init__(self):
        assert torch.allclose(self(torch.zeros(1, 1, 1, 1)), torch.ones(1, 1, 1, 1))
        assert torch.allclose(self(torch.ones(1, 1, 1, 1)), torch.zeros(1, 1, 1, 1))

    @abstractmethod
    def __call__(self, t: torch.Tensor) -> torch.Tensor:
        pass

    def dt(self, t: torch.Tensor) -> torch.Tensor:
        return (vmap(jacrev(self)) (t.unsqueeze(1))).view(-1, 1, 1, 1)
    
class LinearAlpha(Alpha):
    def __call__(self, t: torch.Tensor) -> torch.Tensor:
        # t: shape(num_samples, 1, 1, 1)
        return t
    
    def dt(self, t: torch.Tensor) -> torch.Tensor:
        # t: shape(num_samples, 1, 1, 1)
        return torch.ones_like(t)
    
class LinearBeta(Beta):
    def __call__(self, t: torch.Tensor) -> torch.Tensor:
        # t: shape(num_samples, 1, 1, 1)
        return 1 - t
    
    def dt(self, t: torch.Tensor) -> torch.Tensor:
        # t: shape(num_samples, 1, 1, 1)
        return - torch.ones_like(t)

In [9]:
class GaussianConditionalProbabilityPath(ConditionalProbabilityPath):
    def __init__(self, p_data: Sampleable, p_simple_shape: List[int], alpha: Alpha, beta: Beta):
        p_simple = IsotropicGaussian(shape=p_simple_shape, std=1.0)
        super().__init__(p_data=p_data, p_simple=p_simple)
        self.alpha = alpha
        self.beta = beta

    def sample_conditioning_variable(self, num_samples: int) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        return self.p_data.sample(num_samples)
    
    def sample_conditional_path(self, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        # z: shape(bs, c, h, w), t: shape(bs, 1, 1, 1)
        return self.alpha(t) * z + self.beta(t) * torch.randn_like(z)

    def conditional_vector_field(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        # x: shape(bs, c, h, w), z: shape(bs, c, h, w), t: shape(bs, c, h, w)
        alpha_t = self.alpha(t)         # (bs, 1, 1, 1)
        beta_t = self.beta(t)           # (bs, 1, 1, 1)
        dt_alpha_t = self.alpha.dt(t)   # (bs, 1, 1, 1)
        dt_beta_t = self.beta.dt(t)     # (bs, 1, 1, 1)

        return (dt_alpha_t - dt_beta_t / beta_t * alpha_t) * z + dt_beta_t / beta_t * x

In [11]:
class VectorField(ABC): 
    @abstractmethod
    def velocity(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
        # xt: the position at time t, shape(bs, c, h, w); t: shape(bs, 1, 1, 1)
        # returns: the velocity shape(bs, c, h, w)
        pass

class Simulator(ABC):
    @abstractmethod
    def step(self, xt: torch.Tensor, t: torch.Tensor, dt: torch.Tensor, **kwargs):
        """
        Takes one simulation step
        Args:
            - xt: state at time t, shape (bs, c, h, w)
            - t: time, shape (bs, 1, 1, 1)
            - dt: time, shape (bs, 1, 1, 1)
        Returns:
            - nxt: state at time t + dt (bs, c, h, w)
        """
        pass

    @torch.no_grad()
    def simulate(self, x: torch.Tensor, ts: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Simulates using the discretization gives by ts
        Args:
            - x_init: initial state, shape (bs, c, h, w)
            - ts: timesteps, shape (bs, nts, 1, 1, 1)
        Returns:
            - x_final: final state at time ts[-1], shape (bs, c, h, w)
        """
        nts = ts.shape[1]
        for t_idx in tqdm(range(nts - 1)):
            t = ts[:, t_idx]
            h = ts[:, t_idx + 1] - ts[:, t_idx]
            x = self.step(x, t, h, **kwargs)
        return x
    
class EulerSimulator(Simulator):
    def __init__(self, vectorfield: VectorField):
        self.vectorfield = vectorfield
    
    def step(self, xt: torch.Tensor, t: torch.Tensor, dt: torch.Tensor, **kwargs) -> torch.Tensor:
        return xt + self.vectorfield.velocity(xt, t, **kwargs) * dt

In [None]:
MiB = 1024 ** 2

def model_size_b(model: nn.Module) -> torch.Tensor:
    size = 0
    for param in model.parameters():
        size += param.nelement() * param.element_size()
    for buf in model.buffers():
        size += buf.nelement() * buf.element_size()
    return size

class Trainer(ABC):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model

    @abstractmethod
    def get_train_loss(self, **kwargs) -> torch.Tensor:
        pass

    def get_optimizer(self, lr: float):
        return torch.optim.Adam(self.model.parameters(), lr=lr)
    
    def train(self, num_epochs: int, device: torch.device, name: str, lr: float=1e-3, **kwargs):
        # print the model size
        size_b = model_size_b(self.model)
        print(f'Training model with size: {size_b / MiB:.3f} MiB')

        # start
        self.model.to(device)
        opt = self.get_optimizer(lr=lr)
        self.model.train()
        losses = []

        # train loop
        pbar = tqdm(enumerate(range(num_epochs)))
        for idx, epoch in pbar:
            opt.zero_grad()
            loss = self.get_train_loss(**kwargs)
            loss.backward()
            opt.step()
            losses.append(loss.item())
            pbar.set_description(f'Epoch {idx}, loss: {loss.item():.3f}')
        
        # finish
        self.model.eval()

        save_dir = "./training_output"
        os.makedirs(save_dir, exist_ok=True)
        
        # visualize the loss curve
        plt.plot(losses)
        plt.title("Training Loss Curve")
        plt.xlabel("num_epochs")
        plt.ylabel("loss")
        plt.grid(True)

        if loss_name is not None:
            os.makedirs(os.path.dirname(loss_name), exist_ok=True)
            plt.savefig(loss_name, bbox_inches='tight', pad_inches=0, dpi=300)
            print(f"Loss curve saved to: {loss_name}")
        plt.show()

        # save the parameters
        torch.save(self.u_theta.state_dict(), self.save_addr)