In [1]:
import warnings
from functools import partial

import numpy as np
import ot as pot
import torch


class OTPlanSampler:

    def __init__(
        self,
        method: str,
        reg: float = 0.05,
        reg_m: float = 1.0,
        normalize_cost: bool = False,
        warn: bool = True,
    ) -> None:

        torch.manual_seed(12345)
        np.random.seed(12345)

        if method == "exact":
            self.ot_fn = pot.emd
        elif method == "sinkhorn":
            self.ot_fn = partial(pot.sinkhorn, reg=reg)
        elif method == "unbalanced":
            self.ot_fn = partial(pot.unbalanced.sinkhorn_knopp_unbalanced, reg=reg, reg_m=reg_m)
        elif method == "partial":
            self.ot_fn = partial(pot.partial.entropic_partial_wasserstein, reg=reg)
        else:
            raise ValueError(f"Unknown method: {method}")
        self.reg = reg
        self.reg_m = reg_m
        self.normalize_cost = normalize_cost
        self.warn = warn

    def get_map(self, x0, x1):
        torch.manual_seed(12345)
        np.random.seed(12345)

        a, b = pot.unif(x0.shape[0]), pot.unif(x1.shape[0])
        if x0.dim() > 2:
            x0 = x0.reshape(x0.shape[0], -1)
        if x1.dim() > 2:
            x1 = x1.reshape(x1.shape[0], -1)
        x1 = x1.reshape(x1.shape[0], -1)
        M = torch.cdist(x0, x1) ** 2
        if self.normalize_cost:
            M = M / M.max()  # should not be normalized when using minibatches
        p = self.ot_fn(a, b, M.detach().cpu().numpy())
        if not np.all(np.isfinite(p)):
            print("ERROR: p is not finite")
            print(p)
            print("Cost mean, max", M.mean(), M.max())
            print(x0, x1)
        if np.abs(p.sum()) < 1e-8:
            if self.warn:
                warnings.warn("Numerical errors in OT plan, reverting to uniform plan.")
            p = np.ones_like(p) / p.size
        return p

    def sample_map(self, pi, batch_size, replace=False):
        torch.manual_seed(12345)
        np.random.seed(12345)

        p = pi.flatten()
        p = p / p.sum()
        choices = np.random.choice(
            pi.shape[0] * pi.shape[1], p=p, size=batch_size, replace=replace
        )
        return np.divmod(choices, pi.shape[1])

    def sample_plan(self, x0, x1, replace=False):
        torch.manual_seed(12345)
        np.random.seed(12345)
        pi = self.get_map(x0, x1)
        self.i, self.j = self.sample_map(pi, x0.shape[0], replace=replace)
        return x0[self.i], x1[self.j]



In [9]:
import torch 
from typing import Union
from dataclasses import dataclass

class ConditionalFlowMatcher:
    """Base class for conditional flow matching methods. This class implements the independent
    conditional flow matching methods from [1] and serves as a parent class for all other flow
    matching methods.

    It implements:
    - Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function
    - conditional flow matching ut(x1|x0) = x1 - x0
    - score function $\nabla log p_t(x|x0, x1)$
    """
    def __init__(self, sigma: Union[float, int] = 0.0):
        self.sigma = sigma

    def compute_mu_t(self, x0, x1, t):	
        torch.manual_seed(12345)
        t = pad_t_like_x(t, x0)
        print(11, t.shape, x0.shape, x1.shape)
        return t * x1 + (1 - t) * x0

    def compute_sigma_t(self, t):
        del t
        return self.sigma

    def sample_xt(self, x0, x1, t, epsilon):
        torch.manual_seed(12345)
        mu_t = self.compute_mu_t(x0, x1, t)
        sigma_t = self.compute_sigma_t(t)
        sigma_t = pad_t_like_x(sigma_t, x0)
        return mu_t + sigma_t * epsilon

    def compute_conditional_flow(self, x0, x1, t, xt):
        torch.manual_seed(12345)
        del t, xt
        return x1 - x0

    def sample_noise_like(self, x):
        torch.manual_seed(12345)
        return torch.randn_like(x)

    def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False):
        torch.manual_seed(12345)
        if t is None:
            t = torch.rand(x0.shape[0]).type_as(x0)
        assert len(t) == x0.shape[0], "t has to have batch size dimension"

        eps = self.sample_noise_like(x0)
        xt = self.sample_xt(x0, x1, t, eps)
        ut = self.compute_conditional_flow(x0, x1, t, xt)
        self.x0 = x0
        self.x1 = x1
        if return_noise:
            return t, xt, ut, eps
        else:
            return t, xt, ut

    def compute_lambda(self, t):
        torch.manual_seed(12345)
        sigma_t = self.compute_sigma_t(t)
        return 2 * sigma_t / (self.sigma**2 + 1e-8)

class ExactOptimalTransportConditionalFlowMatcher(ConditionalFlowMatcher):
    def __init__(self, sigma: Union[float, int] = 0.0):
        super().__init__(sigma)
        self.ot_sampler = OTPlanSampler(method="exact")

    def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False):
        x0, x1 = self.ot_sampler.sample_plan(x0, x1)
        self.x0 = x0
        self.x1 = x1
        self.i = self.ot_sampler.i
        self.j = self.ot_sampler.j
        return super().sample_location_and_conditional_flow(x0, x1, t, return_noise)


class SchrodingerBridgeConditionalFlowMatcher(ConditionalFlowMatcher):
    def __init__(self, sigma: Union[float, int] = 1.0, ot_method="exact"):
        if sigma <= 0:
            raise ValueError(f"Sigma must be strictly positive, got {sigma}.")
        elif sigma < 1e-3:
            warnings.warn("Small sigma values may lead to numerical instability.")
        super().__init__(sigma)
        self.ot_method = ot_method
        self.ot_sampler = OTPlanSampler(method=ot_method, reg=2 * self.sigma**2)

    def compute_sigma_t(self, t):
        return self.sigma * torch.sqrt(t * (1 - t))

    def compute_conditional_flow(self, x0, x1, t, xt):
        t = pad_t_like_x(t, x0)
        mu_t = self.compute_mu_t(x0, x1, t)
        sigma_t_prime_over_sigma_t = (1 - 2 * t) / (2 * t * (1 - t) + 1e-8)
        ut = sigma_t_prime_over_sigma_t * (xt - mu_t) + x1 - x0
        return ut

    def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False):
        x0, x1 = self.ot_sampler.sample_plan(x0, x1)
        self.x0 = x0
        self.x1 = x1
        self.i = self.ot_sampler.i
        self.j = self.ot_sampler.j
        return super().sample_location_and_conditional_flow(x0, x1, t, return_noise)

class CFM:
    def __init__(self, config: dataclass):
        self.sigma_min = config.SIGMA

    def flowmatcher(self, batch):
        torch.manual_seed(12345)
        CFM = ConditionalFlowMatcher(sigma=self.sigma_min)
        t, xt, ut = CFM.sample_location_and_conditional_flow(batch['source'], batch['target'])
        self.x0 , self.x1 = CFM.x0, CFM.x1
        self.t = t[:, None]
        self.path = xt
        self.u = ut

class OTCFM:
    def __init__(self, config: dataclass):
        self.sigma_min = config.SIGMA

    def flowmatcher(self, batch):
        torch.manual_seed(12345)
        OTFM = ExactOptimalTransportConditionalFlowMatcher(sigma=self.sigma_min)
        t, xt, ut = OTFM.sample_location_and_conditional_flow(batch['source'], batch['target'])
        self.i, self.j = OTFM.i, OTFM.j
        self.x0 , self.x1 = OTFM.x0, OTFM.x1
        self.t = t[:, None]
        self.path = xt
        self.u = ut

class SBCFM:
    def __init__(self, config: dataclass):
        self.sigma_min = config.SIGMA
    
    def flowmatcher(self, batch):
        torch.manual_seed(12345)
        SBFM = SchrodingerBridgeConditionalFlowMatcher(sigma=self.sigma_min, ot_method='exact')
        t, xt, ut = SBFM.sample_location_and_conditional_flow(batch['source'], batch['target'])
        self.i, self.j = SBFM.i, SBFM.j
        self.x0 , self.x1 = SBFM.x0, SBFM.x1
        self.t = t[:, None]
        self.path = xt
        self.u = ut 

def pad_t_like_x(t, x):
    if isinstance(t, (float, int)):
        return t
    return t.reshape(-1, *([1] * (x.dim() - 1)))



In [18]:
import torch 
from dataclasses import dataclass

class ConditionalFlowMatching:

	def __init__(self, config: dataclass, coupling: str = None):
		self.sigma_min = config.SIGMA
		self.batch_size = config.BATCH_SIZE
		self.coupling = coupling

	def source_target_coupling(self, batch):
		""" conditional variable z = (x_0, x1) ~ pi(x_0, x_1)
		"""	
		if self.coupling == 'OT':
			OT = OTPlanSampler(method='exact')	
			pi = OT.get_map(batch['source'], batch['target'])		
			self.i, self.j = OT.sample_map(pi, self.batch_size, replace=False)
			self.x0 = batch['target'][self.i]  
			self.x1 = batch['source'][self.j] 

		elif self.coupling == 'SB':
			OT = OTPlanSampler(method='exact', reg=2 * self.sigma_min**2)	
			pi = OT.get_map(batch['source'], batch['target'])		
			self.i, self.j = OT.sample_map(pi, self.batch_size, replace=False)
			self.x0 = batch['target'][self.i]  
			self.x1 = batch['source'][self.j] 
		
		elif self.coupling == 'ContextOT':
			OT = OTPlanSampler(method='exact')	
			pi = OT.get_map(batch['source context'], batch['target context'])		
			self.i, self.j = OT.sample_map(pi, self.batch_size, replace=False)
			self.x0 = batch['target'][self.i]  
			self.x1 = batch['source'][self.j] 

		elif self.coupling == 'ContextSB':
			OT = OTPlanSampler(method='exact', reg=2 * self.sigma**2)	
			pi = OT.get_map(batch['source'], batch['target'] )		
			self.i, self.j = OT.sample_map(pi, self.batch_size, replace=False)
			self.x0 = batch['target'][self.i]  
			self.x1 = batch['source'][self.j] 

		else:	
			self.x0 = batch['source'] 
			self.x1 = batch['target']

	def conditional_probability_path(self):
		""" mean and std of the Guassian conditional probability p_t(x|x_0,x_1)
		"""
		t = self.reshape_time(self.t, x=self.x1)
		self.mean = t * self.x1 + (1 - t) * self.x0
		self.std = self.sigma_min * (torch.sqrt(t * (1 - t)) if self.coupling == 'SB' else 1.0 )
	
	def conditional_vector_fields(self):
		""" regression objective: conditional vector field u_t(x|x_0,x_1)
		"""
		self.u = self.x1 - self.x0 

	def sample_time(self):
		""" sample time from Uniform: t ~ U[0,1]
		"""
		torch.manual_seed(12345)
		self.t = torch.rand(self.batch_size, device=self.x1.device).type_as(self.x1)
	

	def sample_conditional_path(self):
		""" sample a path: x_t ~ p_t(x|x_0, x_1)
		"""
		torch.manual_seed(12345)
		self.conditional_probability_path()
		self.path = self.mean + self.std * torch.randn_like(self.x1)

	def flowmatcher(self, batch):
		""" conditional flow-mathcing MSE loss
		"""
		self.source_target_coupling(batch)
		self.sample_time() 
		self.conditional_vector_fields()
		self.sample_conditional_path()

	def reshape_time(self, t, x):
		if isinstance(t, (float, int)):
			return t
		return t.reshape(-1, *([1] * (x.dim() - 1)))


In [21]:
from DynGenModels.datamodules.jetclass.configs import JetClass_Config
from DynGenModels.datamodules.jetclass.datasets import JetClassDataset
from DynGenModels.datamodules.jetclass.dataloader import JetClassDataLoader

config = JetClass_Config()
config.DATA_SPLIT_FRACS = [0.9, 0.1, 0.0]
config.BATCH_SIZE = 32
config.SIGMA = 0.1

jetclass = JetClassDataset(config)
dataloader = JetClassDataLoader(jetclass, config)

cfm = ConditionalFlowMatching(config)
torchcfm = CFM(config)

for batch in dataloader.valid:
    torchcfm.flowmatcher(batch)
    cfm.flowmatcher(batch)

    assert torch.allclose(cfm.path, torchcfm.path), f"cfm path != torchcfm path"
    assert torch.allclose(cfm.x0, torchcfm.x0),  f"cfm x0 != torchcfm x0"
    assert torch.allclose(cfm.x1, torchcfm.x1),  f"cfm x1 != torchcfm x1"
    assert torch.allclose(cfm.u, torchcfm.u),  f"cfm ut != torchcfm ut"
    break

INFO: building dataloaders...
INFO: train/val/test split ratios: 0.9/0.1/0.0
INFO: train size: 90000, validation size: 10000, testing sizes: 0
11 torch.Size([32, 1, 1]) torch.Size([32, 30, 3]) torch.Size([32, 30, 3])
t shape =  torch.Size([32, 1, 1])
x1 shape =  torch.Size([32, 30, 3])


In [22]:
from DynGenModels.datamodules.jetclass.configs import JetClass_Config
from DynGenModels.datamodules.jetclass.datasets import JetClassDataset
from DynGenModels.datamodules.jetclass.dataloader import JetClassDataLoader

config = JetClass_Config()
config.DATA_SPLIT_FRACS = [0.9, 0.1, 0.0]
config.BATCH_SIZE = 32
config.SIGMA = 0.1

jetclass = JetClassDataset(config)
dataloader = JetClassDataLoader(jetclass, config)

cfm = ConditionalFlowMatching(config, coupling='OT')
torchcfm = OTCFM(config)

for batch in dataloader.valid:
    cfm.flowmatcher(batch)
    torchcfm.flowmatcher(batch)
    print(cfm.i)
    print(cfm.j)
    print(torchcfm.i)
    print(torchcfm.j)
    print(cfm.x0[0])
    print(cfm.x1[0])
    print(torchcfm.x0[0])
    print(torchcfm.x1[0])
    print(cfm.u[0], torchcfm.u[0])
    break

INFO: building dataloaders...
INFO: train/val/test split ratios: 0.9/0.1/0.0
INFO: train size: 90000, validation size: 10000, testing sizes: 0
t shape =  torch.Size([32, 1, 1])
x1 shape =  torch.Size([32, 30, 3])
11 torch.Size([32, 1, 1]) torch.Size([32, 30, 3]) torch.Size([32, 30, 3])
[29 10  5  6 18 19 30 20 23  0  3  9 21 25 27 22 14 31 26  1 12 13 15 17
 24  2  7 11 16  4  8 28]
[15  3  6 21 17 22 31 16 27 10  2  5  4  0 13 12 25 29 24 20  1  7 19  8
 14 23 11 26  9 18 28 30]
[29 10  5  6 18 19 30 20 23  0  3  9 21 25 27 22 14 31 26  1 12 13 15 17
 24  2  7 11 16  4  8 28]
[15  3  6 21 17 22 31 16 27 10  2  5  4  0 13 12 25 29 24 20  1  7 19  8
 14 23 11 26  9 18 28 30]
tensor([[ 4.8348,  0.3245, -0.9599],
        [ 3.2693, -0.5537, -0.4143],
        [ 1.6528,  0.6981,  1.5533],
        [ 0.6421,  0.7413,  1.5577],
        [ 0.3023, -0.5516, -0.4333],
        [ 0.1522, -2.2564,  0.8458],
        [ 0.1505,  0.7221,  1.6054],
        [-0.0328,  0.5769, -0.9108],
        [-0.1935,  0.

In [6]:
from DynGenModels.datamodules.jetclass.configs import JetClass_Config
from DynGenModels.datamodules.jetclass.datasets import JetClassDataset
from DynGenModels.datamodules.jetclass.dataloader import JetClassDataLoader

config = JetClass_Config()
config.DATA_SPLIT_FRACS = [0.9, 0.1, 0.0]
config.BATCH_SIZE = 32
config.SIGMA = 0.1

jetclass = JetClassDataset(config)
dataloader = JetClassDataLoader(jetclass, config)

cfm = ConditionalFlowMatching(config, coupling='ContextOT')
torchcfm = OTCFM(config)

for batch in dataloader.valid:
    cfm.flowmatcher(batch)
    torchcfm.flowmatcher(batch)
    print(cfm.i)
    print(cfm.j)
    print(torchcfm.i)
    print(torchcfm.j)
    break

INFO: building dataloaders...
INFO: train/val/test split ratios: 0.9/0.1/0.0
INFO: train size: 90000, validation size: 10000, testing sizes: 0
[29 10  5  6 18 19 30 20 23  0  3  9 21 25 27 22 14 31 26  1 12 13 15 17
 24  2  7 11 16  4  8 28]
[10 29 11 12 18 19  9  4  0 23 21  8 14 27  6 16  5  7 13  1 15 26 24 25
 28 17 22 20 31  3 30  2]
[29 10  5  6 18 19 30 20 23  0  3  9 21 25 27 22 14 31 26  1 12 13 15 17
 24  2  7 11 16  4  8 28]
[ 5 19 25 27 11 13  9 23  2 24 17 14 21 28 10 18  0 12  6 16 30 15 31  7
  8 29  3  1 22 20 26  4]


In [23]:
from DynGenModels.datamodules.jetclass.configs import JetClass_Config
from DynGenModels.datamodules.jetclass.datasets import JetClassDataset
from DynGenModels.datamodules.jetclass.dataloader import JetClassDataLoader

config = JetClass_Config()
config.DATA_SPLIT_FRACS = [0.9, 0.1, 0.0]
config.BATCH_SIZE = 32
config.SIGMA = 0.1

jetclass = JetClassDataset(config)
dataloader = JetClassDataLoader(jetclass, config)

cfm = ConditionalFlowMatching(config, coupling='SB')
torchcfm = SBCFM(config)

for batch in dataloader.valid:
    cfm.flowmatcher(batch)
    torchcfm.flowmatcher(batch)
    print(cfm.i)
    print(cfm.j)
    print(torchcfm.i)
    print(torchcfm.j)
    break

INFO: building dataloaders...
INFO: train/val/test split ratios: 0.9/0.1/0.0
INFO: train size: 90000, validation size: 10000, testing sizes: 0
t shape =  torch.Size([32, 1, 1])
x1 shape =  torch.Size([32, 30, 3])
11 torch.Size([32, 1, 1]) torch.Size([32, 30, 3]) torch.Size([32, 30, 3])
11 torch.Size([32, 1, 1]) torch.Size([32, 30, 3]) torch.Size([32, 30, 3])
[29 10  5  6 18 19 30 20 23  0  3  9 21 25 27 22 14 31 26  1 12 13 15 17
 24  2  7 11 16  4  8 28]
[ 5 19 25 27 11 13  9 23  2 24 17 14 21 28 10 18  0 12  6 16 30 15 31  7
  8 29  3  1 22 20 26  4]
[29 10  5  6 18 19 30 20 23  0  3  9 21 25 27 22 14 31 26  1 12 13 15 17
 24  2  7 11 16  4  8 28]
[ 5 19 25 27 11 13  9 23  2 24 17 14 21 28 10 18  0 12  6 16 30 15 31  7
  8 29  3  1 22 20 26  4]
