In [None]:
import warnings 
warnings.filterwarnings('ignore')
import torch 
import torch.distributions as D
import matplotlib.pyplot as plt
from abc import ABC, abstractmethod
from typing import Type, Optional, List
import matplotlib.cm as cm
from torch import nn
from tqdm import tqdm
from matplotlib.axes._axes import Axes
import numpy as np
from torch.func import vmap, jacrev

from sklearn.datasets import make_moons, make_circles


In [None]:
class Sampleable(ABC):
    @abstractmethod
    def sample(self, n_samples: int):
        pass 

class Density(ABC):
    @abstractmethod 
    def log_density(self, x: torch.Tensor):
        pass 

class Simulator(ABC):
    @abstractmethod 
    def step(self, x, h, t):
        pass 
    
    @torch.no_grad()
    def simulate_with_trajectory(self, x0, ts):
        # ts: [nts, bs, 1]
        # x0: [bs, 2]
        tjs = [x0.clone()] 
        for idx in range(1, ts.shape[0]):
            t = ts[idx - 1, :] # [bs, 1]
            h = ts[idx, :] - t # [bs, 1]
            x0 = self.step(x0, h, t) # [bs, 2]
            tjs.append(x0.clone())
        return torch.stack(tjs, dim=1) # [bs,nts,2]

class ODE(ABC):
    @abstractmethod 
    def drift_term(self, x, t):
        pass 

class EulerSampler(Simulator):
    def __init__(self, ode: ODE):
        self.ode = ode 
    
    def step(self, x, h, t):
        return x + self.ode.drift_term(x, t) * h

class SDE(ABC):
    @abstractmethod
    def drift_term(self, x, t):
        pass 
    
    @abstractmethod 
    def diff_term(self, x, t):
        pass 

class EulerMarySampler(Simulator):
    def __init__(self, sde: SDE):
        self.sde = sde 
        
    def step(self, x, h, t):
        return x + self.sde.drift_term(x, t) * h + self.sde.diff_term(x, t) * torch.sqrt(h) * torch.randn_like(x)


In [None]:
class Gaussian(nn.Module, Density, Sampleable):
    def __init__(self, mean, cov):
        super().__init__()
        self.mean = mean 
        self.cov = cov 
    
    @property
    def distribution(self):
        return D.MultivariateNormal(self.mean, self.cov, validate_args=False)

    @property
    def dim(self):
        return self.distribution.event_shape[0]

    def sample(self, n_samples):
        return self.distribution.sample((n_samples,))
    
    def log_density(self, x):
        return self.distribution.log_prob(x)
    
    @classmethod 
    def isotropic(cls, dim: int, std: float):
        mean = torch.zeros(dim)
        cov = torch.eye(dim) * std ** 2 
        return cls(mean, cov)

class GaussianMixture(nn.Module, Density, Sampleable):
    def __init__(self, mean, cov, weight):
        super().__init__()
        self.mean = mean 
        self.cov = cov 
        self.weight = weight

    @property
    def distribution(self):
        return D.MixtureSameFamily(
            mixture_distribution=D.Categorical(probs=self.weight, validate_args=False),
            component_distribution=D.MultivariateNormal(self.mean, self.cov, validate_args=False),
            validate_args=False
        )

    @property
    def dim(self):
        return self.distribution.event_shape[0]

    def log_density(self, x):
        return self.distribution.log_prob(x)

    def sample(self, n):
        return self.distribution.sample((n,))

    @classmethod 
    def random2d(cls, modes, std=1.0, scale=5.0, seed=0.0):
        torch.manual_seed(seed)
        mean = (torch.rand(modes, 2) - 0.5) * scale
        cov = torch.diag_embed(torch.ones(modes, 2)) * std ** 2
        weights = torch.ones(modes) / modes
        return cls(mean, cov, weights)     

    @classmethod 
    def symmetric2d(cls, modes, std=1.0, scale=5.0, seed=0.0):
        torch.manual_seed(seed)
        angles = torch.linspace(0, 2*torch.pi, modes + 1)[:modes]
        mean = torch.stack([torch.cos(angles), torch.sin(angles)], dim=-1) * scale 
        cov = torch.diag_embed(torch.ones(modes, 2))
        weights = torch.ones(modes) / modes
        return cls(mean, cov, weights)


In [None]:
class Plotting:
    def __init__(self, sampler: Sampleable | Density, n_samples: Optional[int] = 10, ax: Optional[Axes] = None, **kwargs):
        self.sampler = sampler 
        self.n_samples = n_samples
        self.ax = ax if ax is not None else plt.gca()
        self.kwargs = kwargs 
    
    def scatter(self, x: torch.Tensor = None, y: torch.Tensor = None):
        if x is None or y is None:
            samples = self.sampler.sample(self.n_samples)
            x, y = samples[:, 0], samples[:, 1]
        self.ax.scatter(x, y, **self.kwargs)
    
    def plot_trajectory(self, x: torch.Tensor, y: torch.Tensor):
        self.ax.plot(x, y, **self.kwargs)


    def hist(self):
        samples = self.sampler.sample(self.n_samples)
        for idx in range(samples.shape[-1]):
            self.ax.hist(samples[:, idx], **self.kwargs)
    
    def hist2d(self):
        samples = self.sampler.sample(self.n_samples)
        x, y = samples[:, 0], samples[:, 1]
        self.ax.hist2d(x, y, **self.kwargs)
    
    def get_density(self, scale, bins):
        x = torch.linspace(-scale, scale, bins)
        y = torch.linspace(-scale, scale, bins)
        X, Y = torch.meshgrid(x, y) # make all possible pairs of x and y
        xy = torch.stack([X.flatten(), Y.flatten()], dim=-1)
        density = self.sampler.log_density(xy).view(bins, bins).T
        return density

    def imshow(self, scale: float, bins: int):
        density = self.get_density(scale, bins)        
        self.ax.imshow(
            density, 
            extent=[-scale, scale] * 2,
            origin='lower', 
            **self.kwargs
        )
    
    def contour(self, scale, bins):
        density = self.get_density(scale, bins)
        self.ax.contour(
            density, 
            extent=[-scale, scale] * 2,
            origin='lower',
            **self.kwargs
        )
    
    def quiver(self, scale: float, bins: int, ts: torch.Tensor, score_model):
        pairs = torch.meshgrid(torch.linspace(-scale, scale, bins), torch.linspace(-scale, scale, bins))
        xx = pairs[0].reshape(-1, 1)
        yy = pairs[1].reshape(-1, 1)
        xy = torch.cat([xx, yy], dim=-1)
        
        t = ts.view(-1, 1).repeat(bins**2, 1)
        score = score_model(xy, t)

        self.ax.quiver(xy[:,0].detach(), xy[:,1].detach(), score[:,0].detach(), score[:,1].detach(), scale=125, alpha=0.5)


### Gaussian Conditional Probability Paths

In [None]:
gs = GaussianMixture.random2d(modes=5, std=0.5, scale=30.0, seed=12) 

In [None]:
ax = plt.gca()
Plotting(sampler=gs, n_samples=100,ax=ax, vmin=-15.0, cmap=plt.get_cmap('Blues')).imshow(scale=20.0, bins=200)
Plotting(sampler=gs, n_samples=100,ax=ax, linestyles='solid', levels=20, colors='grey', alpha=0.25).contour(scale=20.0, bins=200)

In [None]:
PARAMS = {
    "scale": 15.0,
    "target_scale": 10.0,
    "target_std": 1.0,
}

In [None]:
scale = PARAMS['scale']
fig, axes = plt.subplots(1,3, figsize=(24,8))
p_simple = Gaussian.isotropic(dim=2, std = 1.0)
p_data = GaussianMixture.symmetric2d(modes=5, std=PARAMS["target_std"], scale=PARAMS["target_scale"])

Plotting(sampler=p_simple, ax=axes[0], vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds')).imshow(scale=scale, bins=200)
Plotting(sampler=p_data, ax=axes[1], vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues')).imshow(scale=scale, bins=200)

Plotting(sampler=p_simple, ax=axes[2], vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds')).imshow(scale=scale, bins=200)
Plotting(sampler=p_data, ax=axes[2], vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues')).imshow(scale=scale, bins=200)

In [None]:
class Alpha(ABC):
    def __init__(self):
        assert torch.allclose(
            self(torch.zeros(1,1)), torch.zeros(1,1)
        )
        assert torch.allclose(
            self(torch.ones(1,1)), torch.ones(1,1)
        )
    @abstractmethod
    def __call__(self, t):
        pass 
    
    def dt(self, t):
        t = t.unsqueeze(-1) # the last dimension must be 1 for broadcasting 
        return vmap(jacrev(t)).view(-1, 1)

class Beta(ABC):
    def __init__(self):
        assert torch.allclose(
            self(torch.zeros(1,1)), torch.ones(1,1)
        )
        assert torch.allclose(
            self(torch.ones(1,1)), torch.zeros(1,1)
        )
    @abstractmethod 
    def __call__(self, t):
        pass 
    
    def dt(self, t):
        t = t.unsqueeze(-1)
        return vmap(jacrev(self))(t).view(-1, 1)

In [None]:
class LinearAlpha(Alpha):
    def __call__(self, t):
        return t
    
    def dt(self, t):
        return torch.ones_like(t)

class LinearBeta(Beta):
    def __call__(self, t):
        return 1 - t

    def dt(self, t):
        return -torch.ones_like(t)

class SqrtBeta(Beta):
    def __call__(self, t):
        return torch.sqrt(1 - t)
    
    def dt(self, t):
        return - 0.5 / (torch.sqrt(1 - t) + 1e-4)

In [None]:
class CPP(nn.Module, ABC):
    def __init__(self, p_simple: Sampleable, p_data: Sampleable):
        super().__init__()
        self.p_simple = p_simple
        self.p_data = p_data

    def sample_marginal_path(self, t: torch.Tensor) -> torch.Tensor:
        """
        sample randomly from the data distribution 
        sample from the path between data point z and starting point 
        """
        num_samples = t.shape[0]
        z = self.sample_conditioning_variable(num_samples)  
        x = self.sample_conditional_path(z, t) 
        return x

    @abstractmethod
    def sample_conditioning_variable(self, num_samples: int) -> torch.Tensor: 
        """
        sample from the data distribution p(z)
        """
        pass
    
    @abstractmethod
    def sample_conditional_path(self, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        sample from the path between data point z and initial distribution data point P(.|z)
        """
        pass 
        
    @abstractmethod
    def conditional_vector_field(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        the path which is being followed by the conditional probability path u(x|z)
        """
        pass

    @abstractmethod
    def conditional_score(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Compute: derivative(log(p(x|z)))
        """
        pass

In [None]:
class GCPP(CPP):
    def __init__(self, p_data: Sampleable, alpha: LinearAlpha, beta: LinearBeta):
        p_simple = Gaussian.isotropic(p_data.dim, 1.0)
        super().__init__(p_simple, p_data)
        self.alpha = alpha
        self.beta = beta
    def sample_conditioning_variable(self, n_sample: int) -> torch.Tensor:
        return self.p_data.sample(n_sample)

    def sample_conditional_path(self, z, t):
        # sampling from standard gaussian randn() and then changing 
        # standard deviation and mean of that to represent sampling 
        # from isotropic gaussian 
        return self.alpha(t) * z + self.beta(t) * torch.randn_like(z)

    def conditional_vector_field(self, x, z, t):
        alpha_t = self.alpha(t)
        alpha_dt = self.alpha.dt(t)
        beta_t = self.beta(t)
        beta_dt = self.beta.dt(t)
        return (alpha_dt - beta_dt/beta_t * alpha_t) * z + (beta_dt / beta_t) * x

    def conditional_score(self, x, z, t):
        return (self.alpha(t) * z - x) / self.beta(t) ** 2


In [None]:
conditional_path = GCPP(
    p_data=GaussianMixture.symmetric2d(modes=5, scale=20.0),
    alpha=LinearAlpha(),
    beta=SqrtBeta() 
)

In [None]:
z_samples = conditional_path.sample_conditioning_variable(100)

In [None]:
scale = 30.0
ax = plt.gca() 

Plotting(sampler=conditional_path.p_data, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Blues')).imshow(scale=scale, bins=200)
Plotting(sampler=conditional_path.p_data, ax=ax, linestyles='solid', levels=20, colors='grey', alpha=0.25).contour(scale=scale, bins=200)

ts = torch.linspace(0.0, 1.0, 7)
z = conditional_path.sample_conditioning_variable(1)
Plotting(sampler=conditional_path.p_data, ax=ax, marker='*', color='red').scatter(x=z[:,0], y=z[:,1])

for t in ts:
    samples = conditional_path.sample_conditional_path(z.repeat(10, 1), t.repeat(10, 1))
    ax.scatter(samples[:, 0], samples[:, 1], alpha=0.5) 


### Gaussian Vector Field

In [None]:
class ConditionalVF(ODE):
    def __init__(self, cpp: CPP, z: torch.Tensor):
        self.cond_prob_path = cpp 
        self.z = z 

    def drift_term(self, x, t):
        bs = x.shape[0]
        z = self.z.expand(bs, *self.z.shape[1:])
        return self.cond_prob_path.conditional_vector_field(x, z, t)

class ConditionalVFStoc(SDE):
    def __init__(self, cpp: CPP, z: torch.Tensor, sigma: float):
        self.cpp = cpp 
        self.z = z
        self.sigma = sigma 

    def drift_term(self, x, t):
        z = self.z.expand(x.shape[0], *self.z.shape[1:])
        return self.cpp.conditional_vector_field(x, z, t) + \
            0.5 * self.sigma ** 2 * self.cpp.conditional_score(x, z, t)

    def diff_term(self, x, t):
        return self.sigma * torch.ones_like(x)


In [None]:
conditional_path = GCPP(
    p_data=GaussianMixture.symmetric2d(modes=5, scale=20.0),
    alpha=LinearAlpha(),
    beta=SqrtBeta() 
)

In [None]:
z = conditional_path.sample_conditioning_variable(1)
ode = ConditionalVF(cpp=conditional_path, z=z)
sim = EulerSampler(ode=ode) 

ts = torch.linspace(0.0, 1.0, 100)
x0 = conditional_path.sample_conditional_path(z.repeat(10, 1), ts[0].repeat(10, 1))
tjs = sim.simulate_with_trajectory(x0, ts.view(-1, 1, 1).repeat(1, 10, 1)) #[bs, 2], [nts, bs, 1]
z.shape, x0.shape, ts.shape, tjs.shape # [bs, nts, 2]

In [None]:
scale = 30.0
ax = plt.gca() 

Plotting(sampler=conditional_path.p_data, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Blues')).imshow(scale=scale, bins=200)
Plotting(sampler=conditional_path.p_simple, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Reds')).imshow(scale=scale, bins=200)
Plotting(sampler=conditional_path.p_data, ax=ax, linestyles='solid', levels=20, colors='grey', alpha=0.25).contour(scale=scale, bins=200)

ax.scatter(z[:,0], z[:, 1], marker='*', color='red')
ax.scatter(x0[:, 0], x0[:, 1], alpha=0.5)

for idx in range(tjs.shape[0]):
    ax.plot(tjs[idx, :, 0], tjs[idx, :, 1], color='gray')

In [None]:
sigma = 0.88
z = conditional_path.sample_conditioning_variable(1)
sde = ConditionalVFStoc(cpp=conditional_path, z=z, sigma=sigma)
sim = EulerMarySampler(sde=sde)

ts = torch.linspace(0.0, 1.0, 100)
x0 = conditional_path.sample_conditional_path(z.repeat(10, 1), ts[0].repeat(10, 1))
tjs = sim.simulate_with_trajectory(x0, ts.view(-1, 1, 1).repeat(1, 10, 1)) #[bs, 2], [nts, bs, 1]
z.shape, x0.shape, ts.shape, tjs.shape # [bs, nts, 2]

In [None]:
scale = 30.0
ax = plt.gca() 

Plotting(sampler=conditional_path.p_data, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Blues')).imshow(scale=scale, bins=200)
Plotting(sampler=conditional_path.p_simple, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Reds')).imshow(scale=scale, bins=200)
Plotting(sampler=conditional_path.p_data, ax=ax, linestyles='solid', levels=20, colors='grey', alpha=0.25).contour(scale=scale, bins=200)

ax.scatter(z[:,0], z[:, 1], marker='*', color='red')
ax.scatter(x0[:, 0], x0[:, 1], alpha=0.5)

for idx in range(tjs.shape[0]):
    ax.plot(tjs[idx, :, 0], tjs[idx, :, 1], color='black')

### Flow matching 

In [None]:
def build_mlp(dims: List[int], activation: Type[nn.Module] = nn.SiLU):
    mlp = [] 
    for idx in range(len(dims) - 1):
        mlp.append(nn.Linear(dims[idx], dims[idx + 1]))
        if idx < len(dims) - 2:
            mlp.append(activation())
    return nn.Sequential(*mlp)

In [None]:
class MLPVF(nn.Module):
    def __init__(self, dim: int, hiddens: List[int]):
        super().__init__()
        self.dim = dim 
        self.net = build_mlp([dim + 1]+hiddens+[dim])
    
    def forward(self, x: torch.Tensor, t: torch.Tensor):
        xt = torch.cat([x, t], dim=-1) #[bs, dim], [nts, bs, 1] = [nts, bs, dim+1]
        return self.net(xt)

In [None]:
class Trainer(ABC):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model 
    
    @abstractmethod
    def get_train_loss(self, **kwargs) -> torch.Tensor:
        pass 

    def get_optimizer(self, lr: float):
        return torch.optim.Adam(self.model.parameters(), lr=lr)

    def train(self, n_epochs: int, device: torch.device, lr:float=1e-3, **kwargs):
        self.model.to(device)
        opt = self.get_optimizer(lr)
        self.model.train()

        pbar = tqdm(enumerate(range(n_epochs)))
        for idx, ep in pbar:
            opt.zero_grad()
            loss = self.get_train_loss(**kwargs)
            loss.backward()
            opt.step()
            if idx%1000 == 0:
                print() 
            pbar.set_description(f'Epoch {idx}, loss: {loss.item()}')
    
        self.model.eval()

In [None]:
class CFMTrainer(Trainer):
    def __init__(self, cond_path: CPP, model: MLPVF, **kwargs):
        super().__init__(model, **kwargs)
        self.cond_path = cond_path 
    
    def get_train_loss(self, batch_size: int):
        z = self.cond_path.p_data.sample(batch_size) # [bs, 1]
        t = torch.rand(batch_size, 1).to(z) # [bs, 1]
        x = self.cond_path.sample_conditional_path(z, t) # [bs, 2]

        u_theta = self.model(x, t)
        u_ref = self.cond_path.conditional_vector_field(x, z, t)
        error = torch.sum(torch.square(u_theta - u_ref), dim=-1)
        return torch.mean(error)


In [None]:
PARAMS = {
    "scale": 15.0,
    "target_scale": 10.0,
    "target_std": 1.0,
}

In [None]:
cond_path = GCPP(
    p_data=GaussianMixture.symmetric2d(
        modes=5, 
        std=PARAMS['target_std'],
        scale=PARAMS['target_scale']
    ),
    alpha=LinearAlpha(),
    beta=SqrtBeta()
) 
mlp = MLPVF(dim=2, hiddens=[64,64,64,64])
trainer = CFMTrainer(cond_path=cond_path, model=mlp)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
trainer.train(n_epochs=5000, device=device, lr=1e-3, batch_size=1000)

In [None]:
class LearnedVF(ODE):
    def __init__(self, net: MLPVF):
        self.net = net 
    
    def drift_term(self, x, t):
        return self.net(x, t)

In [None]:
scale = 20
fig, axes = plt.subplots(1,4, figsize=(36, 12))

ax = axes[3]
ax.set_title('MLP Marginal VF', fontsize=20)
Plotting(sampler=cond_path.p_data, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Blues')).imshow(scale=scale, bins=200)
Plotting(sampler=cond_path.p_simple, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Reds')).imshow(scale=scale, bins=200)
Plotting(sampler=cond_path.p_data, ax=ax, linestyles='solid', levels=20, colors='grey', alpha=0.25).contour(scale=scale, bins=200)

bs = 100

ts = torch.linspace(0.0, 1.0, bs).view(-1, 1, 1).repeat(1, bs, 1) # [nts, bs, 1]

ode = LearnedVF(net=mlp)
sim = EulerSampler(ode=ode)
x0 = cond_path.p_simple.sample(bs)
tjs = sim.simulate_with_trajectory(x0, ts) # [bs, 2], [nts, bs, 1] = [bs, nts, 2]

for bs in range(tjs.shape[0]):
    Plotting(sampler=cond_path, ax=ax, color='black', alpha=0.25).plot_trajectory(tjs[bs, :, 0].detach(), tjs[bs, :, 1].detach())
    

ax = axes[2]
ax.set_title('MLP learned VF', fontsize=20)
Plotting(sampler=cond_path.p_data, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Blues')).imshow(scale=scale, bins=200)
Plotting(sampler=cond_path.p_simple, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Reds')).imshow(scale=scale, bins=200)
Plotting(sampler=cond_path.p_data, ax=ax, linestyles='solid', levels=20, colors='grey', alpha=0.25).contour(scale=scale, bins=200)

bs = 100
ts = torch.linspace(0.0, 1.0, 7).view(-1, 1, 1).repeat(1, bs, 1)
x0 = cond_path.p_simple.sample(bs)

ode = LearnedVF(net=mlp)
sim = EulerSampler(ode=ode)
xts = sim.simulate_with_trajectory(x0, ts)

for nts in range(xts.shape[1]):
    ax.scatter(xts[:, nts, 0], xts[:, nts, 1])
    
    
ax = axes[0]
ax.set_title('ground truth marginal VF', fontsize=20)
Plotting(sampler=cond_path.p_data, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Blues')).imshow(scale=scale, bins=200)
Plotting(sampler=cond_path.p_simple, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Reds')).imshow(scale=scale, bins=200)
Plotting(sampler=cond_path.p_data, ax=ax, linestyles='solid', levels=20, colors='grey', alpha=0.25).contour(scale=scale, bins=200)

bs = 100
ts = torch.linspace(0.0, 1.0, 7).view(-1, 1, 1).repeat(1, bs, 1) # [nts, bs, 1]
x0 = cond_path.p_simple.sample(bs)

for idx in range(ts.shape[0]):
    samples = cond_path.sample_marginal_path(ts[idx])
    ax.scatter(samples[:, 0], samples[:, 1])
    
    
    
ax = axes[1]
ax.set_title('Ground Truth Marginal VF', fontsize=20)
Plotting(sampler=cond_path.p_data, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Blues')).imshow(scale=scale, bins=200)
Plotting(sampler=cond_path.p_simple, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Reds')).imshow(scale=scale, bins=200)
Plotting(sampler=cond_path.p_data, ax=ax, linestyles='solid', levels=20, colors='grey', alpha=0.25).contour(scale=scale, bins=200)

bs = 100
ts = torch.linspace(0.0, 1.0, bs).view(-1, 1, 1).repeat(1, bs, 1) # [nts, bs, 1]

z = cond_path.sample_conditioning_variable(bs)
ode = ConditionalVF(cpp=cond_path, z=z)

sim = EulerSampler(ode=ode)
x0 = cond_path.p_simple.sample(bs)
tjs = sim.simulate_with_trajectory(x0, ts) # [bs, 2], [nts, bs, 1] = [bs, nts, 2]

for bs in range(tjs.shape[0]):
    Plotting(sampler=cond_path, ax=ax, color='black', alpha=0.25).plot_trajectory(tjs[bs, :, 0].detach(), tjs[bs, :, 1].detach())

### Score Matching

In [None]:
class CSMTrainer(Trainer):
    def __init__(self, cond_path: CPP, model: MLPVF, **kwargs):
        super().__init__(model, **kwargs)
        self.cond_path = cond_path 
    
    def get_train_loss(self, batch_size: int):
        z = self.cond_path.p_data.sample(batch_size) # [bs, 1]
        t = torch.rand(batch_size, 1).to(z) # [bs, 1]
        x = self.cond_path.sample_conditional_path(z, t) # [bs, 2]

        s_theta = self.model(x, t)
        s_ref = self.cond_path.conditional_score(x, z, t)
        error = torch.sum(torch.square(s_theta - s_ref), dim=-1)
        return torch.mean(error)


In [None]:
PARAMS = {
    "scale": 20.0,
    "target_scale": 15.0,
    "target_std": 1.0,
}

In [None]:
path = GCPP(
    p_data=GaussianMixture.symmetric2d(
        modes=5, 
        std=PARAMS['target_std'],
        scale=PARAMS['target_scale']
    ),
    alpha=LinearAlpha(),
    beta=SqrtBeta()
) 

device = 'cuda' if torch.cuda.is_available() else 'cpu'

flow_model = MLPVF(dim=2, hiddens=[64,64,64,64])
score_model = MLPVF(dim=2, hiddens=[64,64,64,64])

flow_trainer = CFMTrainer(cond_path=path, model=flow_model)
score_trainer = CSMTrainer(cond_path=path, model=score_model)

In [None]:
print('training flow model')
flow_trainer.train(n_epochs=5000, device=device, lr=1e-3, batch_size=1000)
print('training score model')
score_trainer.train(n_epochs=5000, device=device, lr=1e-3, batch_size=1000)

In [None]:
class langevinSDE(SDE):
    def __init__(self, flow_model: MLPVF, score_model: MLPVF, sigma: float):
        self.flow_model = flow_model 
        self.score_model = score_model 
        self.sigma = sigma 
    
    def drift_term(self, x, t):
        return self.flow_model(x, t) + 0.5 * self.sigma ** 2 * self.score_model(x, t)
        
    def diff_term(self, x, t):
        return self.sigma * torch.randn_like(x)

In [None]:
scale = PARAMS['scale'] 
sigma = 0.6
fig, axes = plt.subplots(1,4, figsize=(36, 12))

############################################
ax = axes[3]
ax.set_title('MLP Marginal VF', fontsize=20)
Plotting(sampler=path.p_data, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Blues')).imshow(scale=scale, bins=200)
Plotting(sampler=path.p_simple, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Reds')).imshow(scale=scale, bins=200)
Plotting(sampler=path.p_data, ax=ax, linestyles='solid', levels=20, colors='grey', alpha=0.25).contour(scale=scale, bins=200)

bs = 100

ts = torch.linspace(0.0, 1.0, bs).view(-1, 1, 1).repeat(1, bs, 1) # [nts, bs, 1]

sde = langevinSDE(flow_model=flow_model, score_model=score_model, sigma=sigma)
sim = EulerMarySampler(sde=sde)
x0 = path.p_simple.sample(bs)
tjs = sim.simulate_with_trajectory(x0, ts) # [bs, 2], [nts, bs, 1] = [bs, nts, 2]

for bs in range(tjs.shape[0]):
    Plotting(sampler=path, ax=ax, color='black', alpha=0.25).plot_trajectory(tjs[bs, :, 0].detach(), tjs[bs, :, 1].detach())
    
############################################
ax = axes[2]
ax.set_title('MLP learned VF', fontsize=20)
Plotting(sampler=path.p_data, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Blues')).imshow(scale=scale, bins=200)
Plotting(sampler=path.p_simple, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Reds')).imshow(scale=scale, bins=200)
Plotting(sampler=path.p_data, ax=ax, linestyles='solid', levels=20, colors='grey', alpha=0.25).contour(scale=scale, bins=200)

bs = 100
ts = torch.linspace(0.0, 1.0, 7).view(-1, 1, 1).repeat(1, bs, 1)
x0 = path.p_simple.sample(bs)

sde = langevinSDE(flow_model=flow_model, score_model=score_model, sigma=sigma)
sim = EulerMarySampler(sde=sde)
xts = sim.simulate_with_trajectory(x0, ts)

for nts in range(xts.shape[1]):
    ax.scatter(xts[:, nts, 0], xts[:, nts, 1])
    
############################################
ax = axes[0]
ax.set_title('ground truth marginal VF', fontsize=20)
Plotting(sampler=path.p_data, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Blues')).imshow(scale=scale, bins=200)
Plotting(sampler=path.p_simple, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Reds')).imshow(scale=scale, bins=200)
Plotting(sampler=path.p_data, ax=ax, linestyles='solid', levels=20, colors='grey', alpha=0.25).contour(scale=scale, bins=200)

bs = 100
ts = torch.linspace(0.0, 1.0, 7).view(-1, 1, 1).repeat(1, bs, 1) # [nts, bs, 1]
x0 = path.p_simple.sample(bs)

for idx in range(ts.shape[0]):
    samples = path.sample_marginal_path(ts[idx])
    ax.scatter(samples[:, 0], samples[:, 1])
    
    
############################################
ax = axes[1]
ax.set_title('Ground Truth Marginal VF', fontsize=20)
Plotting(sampler=path.p_data, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Blues')).imshow(scale=scale, bins=200)
Plotting(sampler=path.p_simple, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Reds')).imshow(scale=scale, bins=200)
Plotting(sampler=path.p_data, ax=ax, linestyles='solid', levels=20, colors='grey', alpha=0.25).contour(scale=scale, bins=200)

bs = 100
ts = torch.linspace(0.0, 1.0, bs).view(-1, 1, 1).repeat(1, bs, 1) # [nts, bs, 1]

z = path.sample_conditioning_variable(bs)
ode = ConditionalVF(cpp=path, z=z)

sim = EulerSampler(ode=ode)
x0 = path.p_simple.sample(bs)
tjs = sim.simulate_with_trajectory(x0, ts) # [bs, 2], [nts, bs, 1] = [bs, nts, 2]
xts = path.sample_marginal_path(ts)

# for bs in range(xts.shape[1]):
#     Plotting(sampler=path, ax=ax, color='black', alpha=0.25).plot_trajectory(xts[:, bs, 0].detach(), xts[:, bs, 1].detach())

for bs in range(tjs.shape[0]):
    Plotting(sampler=path, ax=ax, color='black', alpha=0.25).plot_trajectory(tjs[bs, :, 0].detach(), tjs[bs, :, 1].detach())

In [None]:
ts.shape

In [None]:
ax = plt.gca() 
xts = path.sample_marginal_path(ts)

for bs in range(xts.shape[1]):
    Plotting(sampler=path, ax=ax, color='black', alpha=0.25).plot_trajectory(xts[:, bs, 0].detach(), xts[:, bs, 1].detach())

### Visualization the score vector field 

In [None]:
# flow_model to predict the vector field u_t(x)
# score_model to predict the score dt.log.p(x)
# score_vf_model to predict the score vf -> for visualization of the score vector field 

$$\tilde{s}_t^{\theta}(x) = \frac{u_t^{\theta}(x) - a_tx}{b_t} = \frac{\alpha_t u_t^{\theta}(x) - \dot{\alpha}_t x}{\beta_t^2 \dot{\alpha}_t - \alpha_t \dot{\beta}_t \beta_t},$$

In [None]:
class ScoreVF(nn.Module):
    def __init__(self, flow_model: MLPVF, alpha: Alpha, beta: Beta):
        super().__init__()
        self.flow_model = flow_model
        self.alpha = alpha 
        self.beta = beta
        
    def forward(self, x: torch.Tensor, t: torch.Tensor):
        alpha_t = self.alpha(t)
        alpha_dt = self.alpha.dt(t)
        beta_t = self.beta(t)
        beta_dt = self.beta.dt(t)
        
        num = alpha_t * self.flow_model(x, t) - alpha_dt * x 
        den = beta_t ** 2 * alpha_dt - alpha_t * beta_dt * beta_t
        
        return num / den

In [None]:
num_bins = 30
num_marginals = 4

learned_score_model = score_model
flow_score_model = ScoreVF(flow_model, path.alpha, path.beta)

fig, axes = plt.subplots(2, num_marginals, figsize=(6 * num_marginals, 12))

scale = PARAMS["scale"]

ts = torch.linspace(0.0, 0.9999, num_marginals).to(device)

axes[0,0].set_ylabel("Learned with Score Matching", fontsize=12)
axes[1,0].set_ylabel("Computed from $u_t^{{\\theta}}(x)$", fontsize=12)

for idx in range(num_marginals):
    t = ts[idx]
    
    # Learned scores
    ax = axes[0, idx]
    Plotting(sampler=path.p_simple, ax=ax).quiver(scale=scale, bins=num_bins, ts=t, score_model=learned_score_model) 
    
    Plotting(sampler=path.p_simple, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Reds')).imshow(scale=scale, bins=200)
    Plotting(sampler=path.p_data, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Blues')).imshow(scale=scale, bins=200)
    
    # Flow score model
    ax = axes[1, idx]
    Plotting(sampler=path.p_simple, ax=ax).quiver(scale=scale, bins=num_bins, ts=t, score_model=flow_score_model) 
    
    Plotting(sampler=path.p_simple, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Reds')).imshow(scale=scale, bins=200)
    Plotting(sampler=path.p_data, ax=ax, vmin=-10, alpha=0.75, cmap=plt.get_cmap('Blues')).imshow(scale=scale, bins=200)

### Flow matching between arbitrary distribution and linear conditional probability path

In [None]:
class MoonSampleable(Sampleable):
    def __init__(self, noise: float=0.05, scale: float=5.0, offset: Optional[torch.Tensor]=None):
        self.noise = noise 
        self.scale = scale
        if offset is None:
            offset = torch.zeros(2)
        self.offset = offset

    @property
    def dim(self) -> int:
        return 2
    
    def sample(self, n_samples: int) -> torch.Tensor:
        samples, _ = make_moons(
            n_samples=n_samples,
            noise=self.noise,
            random_state=None
        )
        return self.scale * torch.from_numpy(samples.astype(np.float32))

class CircleSampleable(Sampleable):
    def __init__(self, noise: float=0.05, scale: float=5.0, offset: Optional[torch.Tensor] = None):
        self.noise = noise 
        self.scale = scale 
        if offset is None:
            offset = torch.zeros(2)
        self.offset = offset 
    @property 
    def dim(self) -> int:
        return 2
        
    def sample(self, n_samples: int) -> torch.Tensor:
        samples, _ = make_circles(
            n_samples=n_samples,
            noise=self.noise,
            factor=0.5,
            random_state=None
        )
        return self.scale * torch.from_numpy(samples.astype(np.float32))
        

class CheckerboardSampleable(Sampleable):
    def __init__(self, grid_size: int=3, scale: float=5.0):
        self.grid_size = grid_size
        self.scale = scale

    @property
    def dim(self) -> int:
        return 2

    def sample(self, num_samples: int) -> torch.Tensor:
  
        grid_length = 2 * self.scale / self.grid_size
        samples = torch.zeros(0,2)
        
        while samples.shape[0] < num_samples:
            new_samples = (torch.rand(num_samples,2) - 0.5) * 2 * self.scale
            x_mask = torch.floor((new_samples[:,0] + self.scale) / grid_length) % 2 == 0 # (bs,)
            y_mask = torch.floor((new_samples[:,1] + self.scale) / grid_length) % 2 == 0 # (bs,)
            accept_mask = torch.logical_xor(~x_mask, y_mask)
            samples = torch.cat([samples, new_samples[accept_mask]], dim=0)
        
        return samples[:num_samples]


In [None]:
Plotting(sampler=CircleSampleable(), n_samples=20000, bins=100).hist2d()

In [None]:
class LCPP(CPP):
    def __init__(self, p_simple: Sampleable, p_data: Sampleable):
        super().__init__(p_simple, p_data)

    def sample_conditioning_variable(self, num_samples: int) -> torch.Tensor:
        return self.p_data.sample(num_samples)
    
    def sample_conditional_path(self, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        # interpolant or p(.|z)
        x0 = self.p_simple.sample(z.shape[0])
        return (1 - t) * x0 + t * z
        
    def conditional_vector_field(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        return (z - x) / (1 - t)

    def conditional_score(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        raise Exception("You should not be calling this function!")

In [None]:
path = LCPP(
    p_simple = Gaussian.isotropic(dim=2, std=1.0),
    p_data = CheckerboardSampleable(grid_size=4)
)
z = path.p_data.sample(1) # (1,2)

In [None]:
def hist2d_samples(samples, ax: Optional[Axes] = None, bins: int = 200, scale: float = 5.0, percentile: int = 99, xrange=None, yrange=None, **kwargs):
    if xrange is None:
        xrange = [-scale, scale]
    if yrange is None:
        yrange = [-scale, scale]

    H, xedges, yedges = np.histogram2d(samples[:, 0], samples[:, 1], bins=bins, range=[xrange, yrange])
    
    # Determine color normalization based on the 99th percentile
    cmax = np.percentile(H, percentile)
    cmin = 0.0
    norm = cm.colors.Normalize(vmax=cmax, vmin=cmin)
    
    # Plot using imshow for more control
    extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
    ax.imshow(H.T, extent=extent, origin='lower', norm=norm, **kwargs)

In [None]:
ts = torch.linspace(0, 1, 5)
n_marginals = 5
_, axes = plt.subplots(3, 5, figsize=(6 * n_marginals, 6 * 3))
scale = 6.0 
n_samples = 10000

for idx, t in enumerate(ts):
    zz = z.repeat(n_samples, 1)
    tt = t.view(1, 1).repeat(n_samples, 1)
    percentile = min(99 + 2 * torch.sin(t).item(), 100)
    samples = path.sample_conditional_path(zz, tt) # [100k, 2]
    hist2d_samples(samples=samples.cpu(), ax=axes[0, idx], bins=300, scale=scale, percentile=percentile, alpha=1.0)
    

ode = ConditionalVF(path, z)
sim = EulerSampler(ode=ode)
x0 = path.p_simple.sample(n_samples)
ts = torch.linspace(0, 1, 500)
xts = sim.simulate_with_trajectory(x0, ts.view(-1,1,1).expand(-1,n_samples,1)) # [bs, nts, 2]
# record_every_idxs = record_every(len(ts), len(ts) // (n_marginals - 1))

# xts = xts[:,record_every_idxs,:] # [bs, nts//n, 2]

for idx, val in enumerate([0, 100, 220, 350, 499]):
    xx = xts[:,val,:]
    tt = ts[val]
    percentile = min(99 + 2 * torch.sin(tt).item(), 100)
    hist2d_samples(samples=xx.cpu(), ax=axes[1, idx], bins=300, scale=scale, percentile=percentile, alpha=1.0)

ts = torch.linspace(0.0, 1.0, n_marginals)
for idx, t in enumerate(ts):
    zz = z.expand(n_samples, -1)
    tt = t.view(1,1).expand(n_samples,1)
    xts = path.sample_marginal_path(tt)
    hist2d_samples(samples=xts.cpu(), ax=axes[2, idx], bins=300, scale=scale, percentile=99, alpha=1.0)

axes[0,idx].scatter(z[:,0], z[:,1], marker='*', color='red')
axes[1,idx].scatter(z[:,0], z[:,1], marker='*', color='red')
    

### Comparison between Gaussian and Linear Conditional Probability Paths 

In [None]:
PARAMS

In [None]:
# viewing gaussian conditional probability path 
gs_path = GCPP(    
    p_data=GaussianMixture.symmetric2d(modes=5, std=PARAMS['target_std'], scale=PARAMS['target_scale']),
    alpha=LinearAlpha(),
    beta=SqrtBeta()
)
linear_path = LCPP(
    p_simple = Gaussian.isotropic(dim=2, std=1.0),
    p_data = CheckerboardSampleable(grid_size=4)
)

gs_z = gs_path.sample_conditioning_variable(1)
ln_z = linear_path.sample_conditioning_variable(1) 
gs_z, ln_z

#### Checking the linear and guassian conditional probability paths
Found out both of them are linear, over time the variance of the samples decreases and ultimates reaches the target distribution data point 

In [None]:
ts = torch.linspace(0, 1, 5)
_, axes = plt.subplots(2, 5, figsize=(6 * 5, 6 * 2))

for idx, t in enumerate(ts):
    gs_samples = gs_path.sample_conditional_path(gs_z.repeat(n_samples, 1), t.view(1,1).repeat(n_samples, 1))
    ln_samples = linear_path.sample_conditional_path(ln_z.repeat(n_samples, 1), t.view(1,1).repeat(n_samples, 1))
    
    percentile = min(99 + 2 * torch.sin(t).item(), 100)
    hist2d_samples(samples=gs_samples.cpu(), ax=axes[0, idx], bins=300, scale=15, percentile=percentile, alpha=1.0)#, xrange=[-12,-9], yrange=[7, 10])
    hist2d_samples(samples=ln_samples.cpu(), ax=axes[1, idx], bins=300, scale=6, percentile=percentile, alpha=1.0)#, xrange=[2,6],yrange=[-3,1])

axes[0, idx].scatter(gs_z[:, 0], gs_z[:, 1], marker='*', color='red')
axes[1, idx].scatter(ln_z[:, 0], ln_z[:, 1], marker='*', color='red')

#### Checking if the vector field also takes to the target distribution data point 

If we follow the vector field, then we also reaches the target distribution data point exactly in the same manner as the conditional probability path 

In [None]:
def simulate(ode, n_samples, nts):
    sim = EulerSampler(ode=ode)
    x0 = path.p_simple.sample(n_samples)
    ts = torch.linspace(0, 1, nts)
    return sim.simulate_with_trajectory(x0, ts.view(-1,1,1).expand(-1,n_samples,1))

In [None]:
_, axes = plt.subplots(2, 5, figsize=(6 * 5, 6 * 2))

ln_xts = simulate(ode=ConditionalVF(linear_path, ln_z), n_samples=n_samples, nts=500)
gs_xts = simulate(ode=ConditionalVF(gs_path, gs_z), n_samples=n_samples, nts=500)

ts = torch.linspace(0, 1, 500)

for idx, t in enumerate([0, 100, 220, 350, 499]):
    xts_1 = gs_xts[:, t, :]
    xts_2 = ln_xts[:, t, :]
    t = ts[t]
            
    percentile = min(99 + 2 * torch.sin(t).item(), 100)
    hist2d_samples(samples=xts_1.cpu(), ax=axes[0, idx], bins=300, scale=15, percentile=percentile, alpha=1.0)
    hist2d_samples(samples=xts_2.cpu(), ax=axes[1, idx], bins=300, scale=15, percentile=percentile, alpha=1.0)

axes[0, idx].scatter(gs_z[:, 0], gs_z[:, 1], marker='*', color='red')
axes[1, idx].scatter(ln_z[:, 0], ln_z[:, 1], marker='*', color='red')

#### Checking if the marginal distribution takes us to the data distribution or not

In [None]:
ts = torch.linspace(0.0, 1.0, 5)
_, axes = plt.subplots(2, 5, figsize=(6 * 5, 6 * 2))

for idx, t in enumerate(ts):
    ln_xts = linear_path.sample_marginal_path(t.view(1,1).expand(n_samples,1))
    gs_xts = gs_path.sample_marginal_path(t.view(1,1).expand(n_samples,1))
    
    hist2d_samples(samples=gs_xts.cpu(), ax=axes[0, idx], bins=300, scale=20, percentile=99, alpha=1.0)
    hist2d_samples(samples=ln_xts.cpu(), ax=axes[1, idx], bins=300, scale=15, percentile=99, alpha=1.0)

axes[0, idx].scatter(gs_z[:, 0], gs_z[:, 1], marker='*', color='red')
axes[1, idx].scatter(ln_z[:, 0], ln_z[:, 1], marker='*', color='red')

### Creating Flow Model to learn vector field from arbitrary source to target distribution

In [None]:
linear_path = LCPP(
    # p_simple = Gaussian.isotropic(dim=2, std=1.0),
    p_simple = CircleSampleable(),
    p_data = CheckerboardSampleable(grid_size=4)
)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

flow_model = MLPVF(dim=2, hiddens=[64,64,64,64])

flow_trainer = CFMTrainer(cond_path=linear_path, model=flow_model)

In [None]:
flow_trainer.train(n_epochs=20000, device=device, batch_size=2000)

In [None]:
_, axes = plt.subplots(2, 5, figsize=(6 * 5, 6*2))
nts = 1000

ts = torch.linspace(0, 1, nts)
x0 = linear_path.p_simple.sample(n_samples)
ode = LearnedVF(flow_model)
sim = EulerSampler(ode=ode)
ln_xts = sim.simulate_with_trajectory(x0, ts.view(-1, 1, 1).repeat(1, n_samples, 1))

for idx, t in enumerate([0, 399, 599, 799, 999]):
    xts_t = ln_xts[:, t, :]
    t = ts[t]
    xts_gt = linear_path.sample_marginal_path(t.view(1,1).repeat(xts_t.shape[0], 1))
            
    hist2d_samples(samples=xts_t.cpu(), ax=axes[0, idx], bins=200, scale=6, percentile=99, alpha=1.0)
    hist2d_samples(samples=xts_gt.cpu(), ax=axes[1, idx], bins=200, scale=6, percentile=99, alpha=1.0)
    