## Problem Set 5, Problem 6

In this exercise, we are going to deal with stochastic differential equation and diffusion model with non-Euclidean data.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch import optim
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions import Distribution, Categorical
from scipy.stats import wasserstein_distance
from scipy.spatial.transform import Rotation

torch.manual_seed(2024);

<hr>

### Itô process
Consider the stochastic differential equation (SDE) $dX_t = f(X_t, t)dt + g(X_t, t)dW_t$. 

$f(X_t, t)dt$ is the drift term, which reflects the smooth, predictable contribution to dx, akin to the deterministic rate of change seen in ordinary differential equations.

$g(X_t, t)dW_t$ is the diffusion term, which adds randomness to the change $dx$, modeled as the product of a volatility factor $g(X_t, t)$ and the increment $dW_t$ of a Wiener process.

When the function of X is given as $\psi (X)$, we can derive $d\psi$ as follows by applying the Itô rule.

$d\psi = \frac{\partial \psi}{\partial X} (f dt + g dW) + \frac{1}{2} \frac{\partial^2 \psi}{\partial X^2} g^2 dt$

We are going to show that the second derivative term is crucial to describe the change of $\psi$.

First, we need to define the function **forward_process**.

In [None]:
def psi_function(x):
    '''
    < input >
    x: torch.Tensor (b, 2)
    < output >
    psi: torch.Tensor (b, )
    '''
    return (x[:, 0]**2 + x[:, 1]**2)

def drift_term(x, t):
    R = torch.Tensor([[0, -1], [1, 0]])
    return torch.mm(x, R.t())

def diffusion_term(x, t):
    return torch.ones_like(x) * 0.1

def forward_process(x, t, f, g, dt = 0.01):
    """
    Forward process for a stochastic differential equation
    dx = f(x, t) dt + g(x, t) dW
    < input >
    x: torch.Tensor (b, 2)
    t: float
    f, g: functions
    < output >
    x_next: torch.Tensor (b, 2)
    """
    ############### YOUR CODE HERE ###############
    x_next = None
    ##############################################
    return x_next

Then, we need to define the function **compute_gradient** and **compute_hessian**. Applying these two functions, we can find $d\psi$ with Itô rule. 

Using torch.autograd.grad() might be helpful.

In [None]:
def compute_gradient(psi, x):
    """
    Compute the gradient of psi with respect to x
    < input >
    psi: function of x
    x: torch.Tensor (b, 2)
    t: float
    < output >
    dpsi_dx: torch.Tensor (b, 2)
    """
    ############### YOUR CODE HERE ###############
    dpsi_dx = None
    ##############################################
    return dpsi_dx

def compute_hessian(psi, x):
    """
    Compute the hessian of psi with respect to x
    < input >
    psi: function of x
    x: torch.Tensor (b, 2)
    t: float
    < output >
    d2psi_dx2: torch.Tensor (b, 2, 2)
    """
    ############### YOUR CODE HERE ###############
    d2psi_dx2 = None
    ##############################################
    return d2psi_dx2

Check if the return value of the functions are in the correct shape.

In [None]:
x = torch.rand(10, 2)
x.requires_grad = True
psi = psi_function(x)
assert psi.shape == (x.shape[0], )
assert compute_gradient(psi_function, x).shape == x.shape
assert compute_hessian(psi_function, x).shape == (x.shape[0], x.shape[1], x.shape[1])

Then, we derive $d\phi$ applying Itô rule. For the second function, do not include the second derivative term.

In [None]:
def d_psi_with_second_derivative(psi, x, t, f, g, dt = 0.01):
    """
    Ito process for a stochastic differential equation. This function includes the second derivative of psi.
    < input >
    psi: function of x
    x: torch.Tensor (b, 2)
    t: float
    f, g: functions
    < output >
    dpsi: torch.Tensor (b, )
    """
    ############### YOUR CODE HERE ###############
    dpsi = None
    ##############################################
    return dpsi

def d_psi_without_second_derivative(psi, x, t, f, g, dt = 0.01):
    """
    Ito process for a stochastic differential equation. This function does not include the second derivative of psi.
    < input >
    psi: function of x
    x: torch.Tensor (b, 2)
    t: float
    f, g: functions
    < output >
    dpsi: torch.Tensor (b, )
    """
    ############### YOUR CODE HERE ###############
    dpsi = None
    ##############################################
    return dpsi

As the initial distribution, we use a mixture of Gaussian distributions with means at $(0.5, 0)$ and $(-0.5, 0)$, and a common variance of $0.01 \cdot I_{2\times 2}$.

In [None]:
class GaussianMixture(Distribution):
    def __init__(self, means, covariances):
        self.means = means
        self.covariances = covariances
        self.weights = torch.tensor([0.5, 0.5])
        self.categorical = Categorical(self.weights)
    def sample(self, sample_shape = torch.Size()):
        component_samples = self.categorical.sample(sample_shape)
        samples = []
        for i in range(2):
            gaussian = MultivariateNormal(self.means[i], self.covariances[i])
            component_sample = gaussian.sample(sample_shape)
            samples.append(component_sample[component_samples == i])
        return torch.cat(samples, dim=0)
    def log_prob(self, value):
        log_probs = torch.stack([MultivariateNormal(self.means[i], covariance_matrix=self.covariances[i]).log_prob(value)
                                 for i in range(2)])
        weighted_log_probs = torch.logsumexp(torch.log(self.weights) + log_probs, dim=0)
        return weighted_log_probs
    
means = torch.tensor([[0.5, 0], [-0.5, 0]])
covariances = torch.stack([torch.eye(2) * 0.01, torch.eye(2) * 0.01])
dist = GaussianMixture(means, covariances)

Sample 100 points to visualize the distribution of the mixture of Gaussians.

In [None]:
x = dist.sample((100, ))
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.axis('equal')
ax.scatter(x[:, 0], x[:, 1])
plt.show()

Using the $d\psi$ functions we wrote, let's predict how the $\psi$ function changes according to the changes in $X$. 
Then, compare how similar the distribution of $\psi(X)$ after the change is to the predicted distribution. 
We will use Wasserstein distance to measure the distance between the two distributions.


In [None]:
batch = 100
dt = 0.01
t = torch.arange(0, 1.57, dt)
x = dist.sample((batch, ))
x.requires_grad = True

psi1 = psi_function(x)
psi2 = psi_function(x)

x_list = [x]
psi1_list = [psi1]
psi2_list = [psi2]
for i in range(len(t)):
    x = forward_process(x, t[i], drift_term, diffusion_term, dt)
    dpsi1 = d_psi_with_second_derivative(psi_function, x, t[i], drift_term, diffusion_term, dt)
    dpsi2 = d_psi_without_second_derivative(psi_function, x, t[i], drift_term, diffusion_term, dt)
    psi1 = psi1 + dpsi1
    psi2 = psi2 + dpsi2
    x_list.append(x)
    psi1_list.append(psi1)
    psi2_list.append(psi2)

psi1 = psi1_list[-1]
psi2 = psi2_list[-1]
psi_true = psi_function(x_list[-1])

distance1 = wasserstein_distance(psi1.detach().numpy(), psi_true.detach().numpy())
distance2 = wasserstein_distance(psi2.detach().numpy(), psi_true.detach().numpy())
print("Wasserstein distance")
print(f"Itô rule with second derivative : {distance1}")
print(f"Itô rule without second derivative : {distance2}")

Q. Explain the above results.

A. ~~

<hr>

### Diffusion model on SO(3)

This time, we will explore diffusion models based on SDEs. The content will specifically focus on score-based generative modeling, where the drift term is zero in the forward process.
We recommend referring to the paper "Generative Modeling by Estimating Gradients of the Data Distribution" by Yang Song et al. for further details.

Especially we are going to use Denoising Score Matching and sampling with annealed Langevin dynamics. It perturbs the data point X with a pre-specified noise distribution $q_{\sigma}(\tilde{X}|X)$ and then employs score matching to estimate the score of the perturbed data distribution $q_{\sigma}(\tilde{X}) \triangleq \int q_{\sigma} (\tilde{X}|X)p_{data}(X)dX$.

We use $q_\sigma (\tilde{X}|X) = \mathcal{N}(\tilde{X}|X, \sigma^2 I)$ as the noise distribution resulting in $\nabla_{\tilde{X}} \log q_\sigma (\tilde{X}|X) = - \frac{\tilde{X}-X}{\sigma^2}$. 
Consequently, the denosing score matching objective for the scheduled $\sigma_t$ can be expressed as follows.

$\mathcal{L}(\theta; \sigma) \triangleq \frac{1}{2} \mathbb{E}_{t \sim \mathcal{U}(0, T)} \mathbb{E}_{X \sim p_{data(X)}} \mathbb{E}_{\tilde{X} \sim \mathcal{N}(X, \sigma_t^2 I)} || s_\theta (\tilde{X}, \sigma_t) + \frac{\tilde{X}-X}{\sigma_t^2}||_2^2$

For sampling, we employ annealed Langevin dynamics, and the data distribution can be approached using the following equation for the scheduled $\alpha_i$.

$X_t \leftarrow X_{t+1} + \frac{\alpha_i}{2} s_\theta (X_{t+1}, \sigma_i) + \sqrt{\alpha_i} z_t$, where $z_t \sim \mathcal{N}(0, I)$, $t = T-1, T-2, \cdots, 1, 0$

To apply the above operation to the SO(3) manifold, perturbations are defined on the tangent plane, specifically in the corresponding $\mathbb{R}^3$ space, and the data points are moved using the matrix exponential. 
Applying noise with increasing variance to the data distribution brings it closer to a uniform distribution on the SO(3) manifold. 
Consequently, starting from this uniform distribution, the data distribution can be recovered by applying scores learned through training the model.
For more details, refer to the paper "SE(3)-DiffusionFields: Learning smooth cost functions for joint grasp and motion optimization through diffusion" by Julen Urain et al.

In [None]:
from functions.Lie import Logmap, Expmap

def sample_from_data(batch: int):
    '''
    This function generates a batch of rotation matrices, which describes the rotation with respect to the z-axis.
    '''
    theta = torch.rand(batch) * 2 * np.pi
    return Expmap(torch.stack([torch.zeros_like(theta), torch.zeros_like(theta), theta], dim=1))

def marginal_prob_std(t, sigma = 0.5):
    '''
    It computes the scheduled standard deviation at time t. (sigma_t)
    '''
    return 2 * np.sqrt((sigma ** (2 * t) - 1.) / (2. * np.log(sigma)))

class diffusion_SO3(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(9+1, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 512), 
            torch.nn.ReLU(),
            torch.nn.Linear(512, 3)
        )
        self.timesteps = torch.arange(0, 1, 0.005)

    def __call__(self, R, t):
        return self.score_function(R, t)
    
    def init_sample(self, batch: int):
        return torch.tensor(Rotation.random(batch).as_matrix(), dtype=torch.float32)
    
    def sample(self, batch: int):
        steps_fit = 50
        R = self.init_sample(batch)
        for t in self.timesteps: # annealed Langevin dynamics
            R = self._step(R, 1-t, noise_off = False)
        for _ in range(steps_fit): # extra steps for fitting
            R = self._step(R, 0, noise_off = True)
        return R
    
    def _step(self, R, t, noise_off = False):
        '''
        This function is used for sampling. 
        From the current rotation matrix R and time t, it generates the next rotation matrix R_next.
        '''
        batch = R.shape[0]
        eps = 1e-3
        time = t * (1-eps) + eps
        sigma_T = marginal_prob_std(eps)
        sigma_i = marginal_prob_std(time)
        ratio = sigma_i ** 2 / sigma_T ** 2
        alpha = 1e-3 * ratio
        noise = torch.randn(batch, 3) * 0.5
        if noise_off:
            alpha = 0.003
            noise = torch.zeros(batch, 3)
        '''
        R: torch.Tensor (b, 3, 3)
        time: float
        noise: torch.Tensor (b, 3)
        alpha: float
        < output > (Hint: use score_function and Expmap)
        R_next: torch.Tensor (b, 3, 3)
        '''
        ############### YOUR CODE HERE ###############
        R_next = None
        ##############################################
        return R_next
    
    def score_function(self, R, t):
        '''
        It computes the score function of the diffusion process using neural network.
        < input >
        R: torch.Tensor (b, 3, 3)
        t: torch.Tensor (b, )
        < output > (Hint: use self.net)
        score: torch.Tensor (b, 3)
        '''
        ############### YOUR CODE HERE ###############
        score = None
        ##############################################
        return score

In [None]:
model = diffusion_SO3()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
batch = 1000
epoch = 10000
loss_list = []
for e in range(epoch):
    t = torch.rand(batch) * (1 - 1e-3) + 1e-3
    std = torch.tensor(marginal_prob_std(t.numpy()), dtype = torch.float32).unsqueeze(-1)
    R_data = sample_from_data(batch)
    noise = torch.randn(batch, 3)
    R_perturb = torch.bmm(R_data, Expmap(noise * std))
    score_predict = model(R_perturb, t)
    score_target = - noise / std # == - (x_tilde - x) / std^2
    
    loss = ((score_predict - score_target).pow(2).sum(-1)).mean()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (e+1) % 100 == 0:
        loss_list.append(loss.item())
        if (e+1) % 1000 == 0:
            print(f"Epoch {e+1} : {loss.item()}")
    
plt.plot(loss_list)
plt.show()

The following code visually illustrates the initial and final distributions of the diffusion model. 
The result shows the columns of the rotation matrices tranformed into frames.

In [None]:
batch = 100
R0 = model.init_sample(batch).detach().numpy()
R = model.sample(batch).detach().numpy()

fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111, projection='3d')
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_zlim(-1, 1)
for r in R0:
    ax.quiver(0, 0, 0, r[0, 0], r[1, 0], r[2, 0], color='tab:red')
    ax.quiver(0, 0, 0, r[0, 1], r[1, 1], r[2, 1], color='tab:green')
    ax.quiver(0, 0, 0, r[0, 2], r[1, 2], r[2, 2], color='tab:blue')
plt.show()

fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111, projection='3d')
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_zlim(-1, 1)
for r in R:
    ax.quiver(0, 0, 0, r[0, 0], r[1, 0], r[2, 0], color='tab:red')
    ax.quiver(0, 0, 0, r[0, 1], r[1, 1], r[2, 1], color='tab:green')
    ax.quiver(0, 0, 0, r[0, 2], r[1, 2], r[2, 2], color='tab:blue')
plt.show()


Q. Explain the above results.

A. ~~