# SDE

We import libraries and visualise target and prior distributions.

In [None]:
%load_ext autoreload
%autoreload 2

from dataset import TwoDimDataClass
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
# deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from setup import device  # choose which device to use in setup.py

# misc
from tqdm import tqdm

target_ds = TwoDimDataClass(dataset_type='swiss_roll', 
                            N=1000000, 
                            batch_size=256)

# target_ds = TwoDimDataClass(dataset_type='moon', 
#                             N=1000000, 
#                             batch_size=1000)

prior_ds = TwoDimDataClass(dataset_type='gaussian_centered',
                           N=1000000,
                           batch_size=1000)

Ntest = 500
sample_f = target_ds[0:Ntest]
sample_b = prior_ds[0:Ntest]

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.scatter(sample_f[:, 0], sample_f[:, 1], alpha=0.6)
ax.scatter(sample_b[:, 0], sample_b[:, 1], alpha=0.6)
ax.grid(False)
ax.set_aspect('equal', adjustable='box')
strtitle = "Target and Prior datasets"
ax.set_title(strtitle)
ax.legend(['Datasest (target)', 'Dataset (prior)'])

We import three SDE instances, Ornstein-Uhlenbeck (OU), Variance-Preserving (VP), and Variance Exploding (VE).

In [None]:
from sde import OU, VPSDE, VESDE

ou = OU(N=1000, T=1)
vp = VPSDE(N=1000, T=1)
ve = VESDE(N=1000, T=1)

## Visualising the forward process

In [None]:
# plot the forward diffusion

def plot_mean_and_std(sde, name):

    cmap = plt.get_cmap('viridis')

    n_arr = 1000
    plot_Ntest = 100

    t_arr = np.linspace(0, sde.T, n_arr)
    mean_arr = np.zeros((n_arr, 1))
    std_arr  = np.zeros((n_arr, 1))

    analytic_mean_arr = np.zeros((n_arr, 1))
    analytic_std_arr  = np.zeros((n_arr, 1))

    fig = plt.figure(figsize=(20,5), dpi=80)
    ax = fig.add_subplot(1, 4, 1)
    ax_final = fig.add_subplot(1, 4, 2)
    ax.set_ylim(-4,4)
    ax.set_xlim(-4,4)
    ax.set_title("forward trajectory")
    ax_final.set_ylim(-4,4)
    ax_final.set_xlim(-4,4)
    ax_final.set_title("final samples")

    for k, tt in enumerate(tqdm(t_arr)):
        x0 = target_ds[range(plot_Ntest)].float()
        t = torch.zeros((x0.shape[0],)) + tt
        mean, std = sde.marginal_prob(t, x0)
        z = torch.randn(mean.shape).to(device)
        xt = mean + std * z

        xt = xt.cpu().detach().numpy()

        mean_arr[k] = (xt.mean(axis=0)**2).sum()
        std_arr[k]  = ((xt.std(axis=0) - 1)**2).sum()

        ax.scatter(xt[:, 0], xt[:, 1], alpha=0.4, color=cmap(tt/t_arr[-1]))
        if (k == t_arr.shape[0]-1):
            ax_final.scatter(xt[:, 0], xt[:, 1], alpha=0.9, color=cmap(tt/t_arr[-1]))


    ax.grid(False)
    ax.set_aspect('auto', adjustable='box')
    ax_final.grid(False)
    ax_final.set_aspect('auto', adjustable='box')

    ax = fig.add_subplot(1, 4, 3)
    plt.plot(t_arr, np.abs(mean_arr))
    plt.title("Distance of mean from 0")
    ax.grid(False)
    ax = fig.add_subplot(1, 4, 4)
    plt.plot(t_arr, std_arr)
    plt.title("Distance of standard deviation from 1")
    ax.grid(False)
    plt.savefig(f"./mean_std_{name}.jpg")

In [None]:
plot_mean_and_std(ou, 'ou')
plot_mean_and_std(vp, 'vp')
plot_mean_and_std(ve, 've')

In [None]:
# animate the diffusion process

def animate_diffusion(sdes, names):

    n_arr = 1000
    plot_Ntest = 100

    fig = plt.figure(figsize=(15,5), dpi=80)
    scats = []
    data = []
    for i, sde, name in zip(range(len(sdes)), sdes, names):
        
        t_arr = np.linspace(0, sde.T, n_arr)

        ax = fig.add_subplot(1, 3, i+1)
        ax.set_ylim(-4,4)
        ax.set_xlim(-4,4)
        ax.set_title(f"{name}")
        scat = ax.scatter([], [], alpha=0.4)
        scats.append(scat)

        points = []
        for k, tt in enumerate(tqdm(t_arr)):
            x0 = target_ds[range(plot_Ntest)].float()
            t = torch.zeros((x0.shape[0],)) + tt
            mean, std = sde.marginal_prob(t, x0)
            z = torch.randn(mean.shape)
            xt = mean + std * z
            points.append(xt.numpy())
        data.append(points)

        ax.grid(False)
        ax.set_aspect('auto', adjustable='box')

    ani = animation.FuncAnimation(fig, update, frames=n_arr, fargs=(data, scats))
    ani.save(f"./diffusion.gif", fps=120)

def update(k, data, scats):
    for scat, points in zip(scats, data):
        scat.set_offsets(points[k])
    plt.suptitle(f"Diffusion process (t={k+1}/{1000})")
    return scats, 

## Learning the score

In [None]:
from network import SimpleNet 
from torch.optim import Adam
from train import get_sde_step_fn, train_diffusion
from loss import DSMLoss, ISMLoss
from copy import deepcopy

def train(sde):
    # setup hyperparameters
    model = SimpleNet(in_dim=2, enc_shapes=[512,512,512,512], dec_shapes=[512,512,512], z_dim=100).to(device)
    loss_fn = DSMLoss(alpha=0.3, diff_weight=True)
    # loss_fn = ISMLoss()
    optimizer = Adam(model.parameters(), lr=1e-5)
    ema = deepcopy(model)
    step_fn = get_sde_step_fn(model=model, opt=optimizer, ema=ema, sde=sde, loss_fn=loss_fn)
    N_steps = 10000
    # N_steps = 2000
    # train
    train_diffusion(target_ds.get_dataloader(), step_fn, N_steps, plot=True)
    return ema

In [None]:
ou_model = train(ou)

In [None]:
ve_model = train(ve)

In [None]:
ve_model = train(ve)

## Sampling and visualising the backward process

In [None]:
from sampling import Sampler

def sample(sde, model):
    sde_backward = sde.reverse(model)
    sampler = Sampler(eps=1e-3)
    sampler_fn = sampler.get_sampling_fn(sde_backward, prior_ds)
    plot_Ntest = 200
    out, ntot, timesteps, x_hist = sampler_fn(N_samples=plot_Ntest)
    return out, ntot, timesteps, x_hist

In [None]:
out, ntot, timesteps, x_hist = sample(ou, ou_model)

In [None]:
def plot_backward(out, ntot, timesteps, x_hist, sde, model):

    plot_Ntest = 200

    fig = plt.figure(figsize=(15, 5), dpi=50)
    ax = fig.add_subplot(1, 3, 1)
    t_arr = np.linspace(0, sde.T, sde.N)

    cmap = plt.get_cmap('viridis')

    x_hist = x_hist.cpu().numpy()
    for k, tt in enumerate(t_arr):
        outk = x_hist[k]
        ax.scatter(outk[:, 0], outk[:, 1], alpha=0.3, color=cmap(tt/t_arr[-1]))
    ax.grid(False)
    # ax.xlim(-0.5, 0.5)
    # ax.ylim(-5, 5)
    ax.set_aspect('auto', adjustable='box')
    strtitle = "backward trajectory"
    ax.set_title(strtitle)

    xmin, xmax = -4, 4
    ymin, ymax = -4, 4
    xx, yy = np.mgrid[xmin:xmax:20j, ymin:ymax:20j]
    # fig = plt.figure()
    t0 = 0.05
    x = np.concatenate(
        (xx.reshape(-1, 1), yy.reshape(-1, 1)), axis=-1)
    t = np.zeros((x.shape[0],)) + t0

    t_tensor = torch.tensor(t).float()
    x_tensor = torch.tensor(x).float()

    with torch.no_grad():
        out = model(t_tensor, x_tensor)

    out = out.cpu().numpy()

    u = out[:, 0].reshape(xx.shape)
    v = out[:, 1].reshape(yy.shape)
    ax = fig.add_subplot(1, 3, 2)
    ax.quiver(xx, yy, u, v)
    ax.set_aspect('auto', adjustable='box')
    strtitle = f"score at time t={t0}"
    ax.set_title(strtitle)

    out_true = target_ds[range(0, plot_Ntest)]
    # fig = plt.figure()
    ax = fig.add_subplot(1, 3, 3)
    ax.scatter(x_hist[-1, :, 0], x_hist[-1, :, 1], alpha=0.6)
    ax.scatter(out_true[:, 0], out_true[:, 1], alpha=0.6)
    ax.grid(False)
    ax.set_aspect('auto', adjustable='box')
    strtitle = "final backward particles"
    ax.set_title(strtitle)
    plt.savefig("./task1_output.jpg")

In [None]:
plot_backward(out, ntot, timesteps, x_hist, ou, ou_model)

In [None]:
def animate_scatter(x_hist, target_ds):

    out_true = target_ds[range(0, 200)]

    fig = plt.figure(figsize=(5,5), dpi=80)
    scats = []
    data = []

    ax = fig.add_subplot(1, 1, 1)
    ax.set_ylim(-4,4)
    ax.set_xlim(-4,4)
    scat = ax.scatter([], [], alpha=0.6)
    ax.scatter(out_true[:, 0], out_true[:, 1], alpha=0.6)

    data = x_hist

    ax.grid(False)
    ax.set_aspect('auto', adjustable='box')

    ani = animation.FuncAnimation(fig, update, frames=len(x_hist)//6, fargs=(data, scat))
    ani.save(f"./backward_diffusion.gif", fps=120)

def update(k, data, scat):
    scat.set_offsets(data[6*k+5])
    plt.suptitle(f"Backward diffusion process (t={6*k+5}/{1000})")
    return scat,

In [None]:
animate_scatter(x_hist, target_ds)

## Evaluating

In [None]:
# from pytorch3d.loss.chamfer import chamfer_distance
from chamferdist import ChamferDistance
chamfer_distance = ChamferDistance()
from sampling import Sampler
N_test = 10000

In [None]:
def evaluate(sde, model):

    sde_backward = sde.reverse(model)
    sampler = Sampler(eps=1e-3)
    sampler_fn = sampler.get_sampling_fn(sde_backward, prior_ds)

    pc_gen, ntot, timesteps, x_hist = sampler_fn(N_samples=N_test)
    # pc_gen = 8*torch.rand(10000,2) - 4
    pc_ref = target_ds[:10000]

    pc_gen = pc_gen.reshape(100, 100, -1).to(device)
    pc_ref = pc_ref.reshape(100, 100, -1).to(device)
    print("CD: ",  chamfer_distance(pc_gen, pc_ref)/len(pc_gen))

In [None]:
evaluate(ou, ou_model)