In [1]:
%load_ext autoreload
%autoreload 2

In [17]:
import os
import sys
import copy
parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir, os.pardir))
sys.path.append(parent_dir)

# Regular Imports
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from tqdm import tqdm
import matplotlib.pyplot as plt

from inference.distribution import Gaussian

# Distributions

## Target Distribution $\phi^4$
We load premade samples from a `torch.Tensor` of shape `(batch_size, lattice_points, lattice_points)`

In [39]:
from inference.distribution import Sampleable
class EmpiricalPhi4(Sampleable):

    def __init__(self, samples: torch.Tensor, device = None):
        self._samples = samples.to(device)
        self._lattice_size = int(samples.shape[-1])

    @property
    def dim(self):
        shape = self._samples.shape
        return shape[-1] * shape[-2]

    def sample(self, num_samples: int):
        '''
        Returns
            Shape (num_samples, L, L)
        '''
        batch_size = self._samples.shape[0]
        if num_samples > batch_size:
            raise ValueError(f"num_samples ({num_samples}) cannot exceed batch_size ({batch_size})")

        indices = torch.randperm(batch_size)[:num_samples]
        return self._samples[indices]
    

# Load Phi4 Samples
samples = torch.load('phi4_coupling0p02_kinetic0p3.pt', map_location=torch.device('cpu')) # Shape (batch_size, L, L)
dist_phi4 = EmpiricalPhi4(samples=samples)

## Easy to Sample Distribution

In [38]:
class GaussianNoise(Sampleable):

    def __init__(self, lattice_size, device = None):
        self._lattice_size = lattice_size
        self.device = device

    @property
    def dim(self):
        return self._lattice_size ** 2
    
    def sample(self, num_samples: int) -> torch.Tensor:
        return torch.rand(num_samples, self._lattice_size, self._lattice_size, device=self.device)


# Score Matching

In [None]:
from abc import ABC, abstractmethod

class ConditionalProbabilityPath(torch.nn.Module, ABC):
    """
    Abstract base class for conditional probability paths
    """
    def __init__(self, p_simple: Sampleable, p_data: Sampleable):
        super().__init__()
        self.p_simple = p_simple
        self.p_data = p_data

    def sample_marginal_path(self, t: torch.Tensor) -> torch.Tensor:
        """
        Samples from the marginal distribution p_t(x) = p_t(x|z) p(z)
        Args:
            - t: time (num_samples, 1)
        Returns:
            - x: samples from p_t(x), (num_samples, dim)
        """
        num_samples = t.shape[0]
        # Sample conditioning variable z ~ p(z)
        z = self.sample_conditioning_variable(num_samples) # (num_samples, dim)
        # Sample conditional probability path x ~ p_t(x|z)
        x = self.sample_conditional_path(z, t) # (num_samples, dim)
        return x

    @abstractmethod
    def sample_conditioning_variable(self, num_samples: int) -> torch.Tensor:
        """
        Samples the conditioning variable z
        Args:
            - num_samples: the number of samples
        Returns:
            - z: samples from p(z), (num_samples, dim)
        """
        pass

    @abstractmethod
    def sample_conditional_path(self, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Samples from the conditional distribution p_t(x|z)
        Args:
            - z: conditioning variable (num_samples, dim)
            - t: time (num_samples, 1)
        Returns:
            - x: samples from p_t(x|z), (num_samples, dim)
        """
        pass

    @abstractmethod
    def conditional_vector_field(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Evaluates the conditional vector field u_t(x|z)
        Args:
            - x: position variable (num_samples, dim)
            - z: conditioning variable (num_samples, dim)
            - t: time (num_samples, 1)
        Returns:
            - conditional_vector_field: conditional vector field (num_samples, dim)
        """
        pass

    @abstractmethod
    def conditional_score(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Evaluates the conditional score of p_t(x|z)
        Args:
            - x: position variable (num_samples, dim)
            - z: conditioning variable (num_samples, dim)
            - t: time (num_samples, 1)
        Returns:
            - conditional_score: conditional score (num_samples, dim)
        """
        pass

class LinearConditionalProbabilityPath(ConditionalProbabilityPath):
    def __init__(self, p_simple: Sampleable, p_data: Sampleable):
        super().__init__(p_simple, p_data)

    def sample_conditioning_variable(self, num_samples: int) -> torch.Tensor:
        """
        Samples the conditioning variable z ~ p_data(x)
        Args:
            - num_samples: the number of samples
        Returns:
            - z: samples from p(z), (num_samples, ...)
        """
        return self.p_data.sample(num_samples)

    def sample_conditional_path(self, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Samples the random variable X_t = (1-t) X_0 + t z
        Args:
            - z: conditioning variable (num_samples, dim)
            - t: time (num_samples, 1)
        Returns:
            - x: samples from p_t(x|z), (num_samples, dim)
        """
        num_samples, _, _ = z.shape

        x0 = self.p_simple.sample(num_samples)

        return (1-t) * x0 + t * z

    def conditional_vector_field(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Evaluates the conditional vector field u_t(x|z) = (z - x) / (1 - t)
        Note: Only defined on t in [0,1)
        Args:
            - x: position variable (num_samples, L, L)
            - z: conditioning variable (num_samples, L, L)
            - t: time (num_samples, 1, 1)
        Returns:
            - conditional_vector_field: conditional vector field (num_samples, L, L)
        """
        return (z - x) / (1 - t)

    def conditional_score(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Not known for Linear Conditional Probability Paths
        """
        raise Exception("You should not be calling this function!")

# Machine Learning

## Model

In [101]:
from typing import List, Type

def build_mlp(dims: List[int], activation: Type[torch.nn.Module] = torch.nn.SiLU):
        mlp = []
        for idx in range(len(dims) - 1):
            mlp.append(torch.nn.Linear(dims[idx], dims[idx + 1]))
            if idx < len(dims) - 2:
                mlp.append(activation())
        return torch.nn.Sequential(*mlp)

class MLPVectorField(torch.nn.Module):
    """
    MLP-parameterization of the learned vector field u_t^theta(x)
    """
    def __init__(self, dim: int, hiddens: List[int]):
        super().__init__()
        self.dim = dim
        self.net = build_mlp([dim+1] + hiddens + [dim])

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        """
        Args:
        - x: (bs, L, L)
        - t: (bs, 1, 1)
        Returns:
        - u_t^theta(x): (bs, L, L)

        """
        og_shape = x.shape # (bs, L, L)

        x = x.view(x.shape[0], -1)  # shape: (bs, L*L)
        t = t.view(-1, 1)           # shape: (bs, 1)
        xt = torch.cat([x, t], dim=-1)  # shape: (bs, L*L + 1)
        xt = self.net(xt)

        return xt.view(og_shape)        # reshape to (bs, L, L)

## Trainer

In [106]:
from abc import ABC

class Trainer(ABC):
    def __init__(self, model: torch.nn.Module):
        super().__init__()
        self.model = model

    @abstractmethod
    def get_train_loss(self, **kwargs) -> torch.Tensor:
        pass

    def get_optimizer(self, lr: float):
        return torch.optim.Adam(self.model.parameters(), lr=lr)

    def train(self, num_epochs: int, device: torch.device, lr: float = 1e-3, **kwargs) -> torch.Tensor:
        # Start
        self.model.to(device)
        opt = self.get_optimizer(lr)
        self.model.train()

        # Train loop
        pbar = tqdm(enumerate(range(num_epochs)))
        for idx, epoch in pbar:
            opt.zero_grad()
            loss = self.get_train_loss(**kwargs)
            loss.backward()
            opt.step()
            pbar.set_description(f'Epoch {idx}, loss: {loss.item()}')

        # Finish
        self.model.eval()

class ConditionalFlowMatchingTrainer(Trainer):
    def __init__(self, path: ConditionalProbabilityPath, model: MLPVectorField, **kwargs):
        super().__init__(model, **kwargs)
        self.path = path

    def get_train_loss(self, batch_size: int) -> torch.Tensor:
        z = self.path.p_data.sample(batch_size)
        t = torch.rand(batch_size, 1, 1)
        x = self.path.sample_conditional_path(z,t)

        u_model = self.model(x,t)
        u_ref = self.path.conditional_vector_field(x,z,t)

        return torch.norm(u_model - u_ref) / batch_size

In [None]:
# Construct conditional probability path
path = LinearConditionalProbabilityPath(
    p_simple = GaussianNoise(lattice_size=dist_phi4._lattice_size),
    p_data = dist_phi4
).to(device)

# Construct learnable vector field
linear_flow_model = MLPVectorField(dim=dist_phi4.dim, hiddens=[64,64,64,64])

# Construct trainer
trainer = ConditionalFlowMatchingTrainer(path, linear_flow_model)

losses = trainer.train(num_epochs=10000, device=device, lr=1e-3, batch_size=2000)

0it [00:00, ?it/s]

torch.Size([2000, 32, 32])


Epoch 2, loss: 1.3018250465393066: : 2it [00:00,  2.72it/s]

torch.Size([2000, 32, 32])
torch.Size([2000, 32, 32])
torch.Size([2000, 32, 32])


Epoch 5, loss: 1.2503498792648315: : 5it [00:01,  6.01it/s]

torch.Size([2000, 32, 32])
torch.Size([2000, 32, 32])


Epoch 8, loss: 1.167574405670166: : 9it [00:01,  9.07it/s] 

torch.Size([2000, 32, 32])
torch.Size([2000, 32, 32])
torch.Size([2000, 32, 32])


Epoch 10, loss: 1.1025828123092651: : 11it [00:01,  9.13it/s]

torch.Size([2000, 32, 32])
torch.Size([2000, 32, 32])


Epoch 13, loss: 1.0052131414413452: : 13it [00:01, 10.45it/s]

torch.Size([2000, 32, 32])
torch.Size([2000, 32, 32])
torch.Size([2000, 32, 32])


Epoch 15, loss: 0.9435123205184937: : 15it [00:02, 11.28it/s]

torch.Size([2000, 32, 32])
torch.Size([2000, 32, 32])
torch.Size([2000, 32, 32])


Epoch 19, loss: 0.8287903666496277: : 19it [00:02, 11.52it/s]

torch.Size([2000, 32, 32])
torch.Size([2000, 32, 32])
torch.Size([2000, 32, 32])


Epoch 21, loss: 0.7843371033668518: : 21it [00:02, 12.07it/s]

torch.Size([2000, 32, 32])
torch.Size([2000, 32, 32])


Epoch 24, loss: 0.716825008392334: : 25it [00:02, 11.66it/s] 

torch.Size([2000, 32, 32])
torch.Size([2000, 32, 32])
torch.Size([2000, 32, 32])


Epoch 26, loss: 0.683868408203125: : 27it [00:02, 12.41it/s] 

torch.Size([2000, 32, 32])
torch.Size([2000, 32, 32])


Epoch 26, loss: 0.683868408203125: : 27it [00:03,  8.58it/s]


torch.Size([2000, 32, 32])


KeyboardInterrupt: 

0it [00:00, ?it/s]


RuntimeError: The size of tensor a (2000) must match the size of tensor b (32) at non-singleton dimension 1