# Train different Diffusion Model (DM) variants with different neural network backbones on a simple vector-target task

In [None]:
%load_ext autoreload
%autoreload 2

## Define Targets

In [None]:
import numpy as np

targets = [[0.5, 2.],
           [-0.5, 2.],
           [0.8, -2.],
           [0.5, -3.5],
           [0., -4],
           [-0.5, -3.5],
           [-0.8, -2.],
           ]
#
targets = np.array(targets) - [10, 20]
# targets = np.array(targets) - [10, 20]

fitness = [1., ]  * len(targets)
fitness = np.array(fitness)

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure()
ax = plt.gca()
ax.scatter(targets[:, 0], targets[:, 1], c='r', marker='^', s=100, label='Targets')
ax.legend()
plt.grid()
plt.show()

## Imports & Helpers

In [None]:
import torch
from condevo.diffusion import RectFlow, DDIM
from condevo.nn import MLP, UNet
from condevo.es.utils import roulette_wheel

In [None]:
num_params = targets.shape[-1]

def get_mlp(num_hidden=96,
            num_layers=6,
            activation="ReLU",
            dropout=0.0,
            layer_norm=False,
            batch_norm=False,
            time_embedding=0,  # 0: disable, >1: train linear projection layer for time embedding
            ):
    return MLP(num_params=num_params, num_hidden=num_hidden, num_layers=num_layers, activation=activation, dropout=dropout, layer_norm=layer_norm, batch_norm=batch_norm, time_embedding=time_embedding)

def get_unet(num_hidden=[64, 32, 16],  # encoding side of the "U"
             activation="GELU",
             dropout=0.0,
             layer_norm=True,
             batch_norm=False,
             time_embedding=0,
             ):
    return UNet(num_params=num_params, num_hidden=num_hidden, activation=activation, dropout=dropout, layer_norm=layer_norm, batch_norm=batch_norm, time_embedding=time_embedding)


In [None]:
def get_ddim(model, backbone,
             skip_connection=False,
             num_steps=300,
             alpha_schedule="cosine_nichol",
             noise_level=1.0,
             clip_gradients=1.,
             scaler=None,
             ):
    log_dir = f"data/ddim_{backbone}"
    return DDIM(nn=model, skip_connection=skip_connection, num_steps=num_steps, alpha_schedule=alpha_schedule, noise_level=noise_level, clip_gradients=clip_gradients, log_dir=log_dir, scaler=scaler)


def get_rflow(model, backbone,
              num_steps=50,
              noise_level=1.0,
              clip_gradients=1.,
              scaler=None,
              ):
    log_dir = f"data/RFlow_{backbone}"

    return RectFlow(nn=model, num_steps=num_steps, noise_level=noise_level, clip_gradients=clip_gradients, log_dir=log_dir, scaler=scaler)


## Train DDIM model with MLP backbone

In [None]:
dataset_size = 1000
max_epochs   =  500
device = "cpu"

# augment dataset -> make dataset large enough so learning progress is smooth
weights = roulette_wheel(torch.tensor(fitness, dtype=torch.float), s=1, normalize=True)
for t, f, w in zip(targets, fitness, weights):
    print(f"target {t} with fitness {f} is weighted as {w}")

target_idx_training = np.random.choice(np.arange(len(targets)), size=dataset_size, p=None, replace=True)

# convert to torch tensors
x = torch.tensor(targets[target_idx_training], dtype=torch.float).to(device)
f = weights[target_idx_training][:, None].to(device)

# initialize model
backbone = get_mlp()
# backbone = get_unet()
ddim = get_ddim(model=backbone, backbone="unet_test")
ddim = ddim.to(device)

# train
history = ddim.fit(
    x,
    weights=f,
    max_epoch=max_epochs,
    lr=3e-3,
    optimizer="Adam",
    weight_decay=1e-5,
)

In [None]:
samples = ddim.sample(shape=(num_params,), num=1000)

count = []
for t in targets:
    dt = torch.linalg.norm(samples - t, dim=-1)
    count.append((dt < 0.25).sum().item() / len(samples))

for t, f, w, c in zip(targets, fitness, weights, count):
    print(f"target {t} weighted by {w:.3f}: sampling rate {c}")

In [None]:
fig, (ax_hist, ax) = plt.subplots(1, 2, figsize=(12, 6))
ax_hist.plot(history)
ax_hist.set_xlabel("Epochs")
ax_hist.set_ylabel("Loss")
ax_hist.grid()

ax.scatter(samples[:, 0], samples[:, 1], c='b', marker='o', alpha=0.5)
ax.scatter(targets[:, 0], targets[:, 1], c='r', marker='^', s=100, label='Targets')
ax.legend()
plt.show()

### Analyze the denoising stack of DDIM models

In [None]:
@torch.no_grad()
def sample(model, num_samples, num_params, t_start=None):
    # sample entire denoising stack (num_(t_)steps, num_samples, num_params) for t in [T, 0]
    if t_start is None:
        t_start = model.num_steps - 1

    xt = torch.randn(t_start + 1, num_samples, num_params)  # RETURN entire denoising stack
    one = torch.ones(num_samples, 1, device=xt.device, dtype=xt.dtype)

    for T in range(t_start, 0, -1):
        t = one * model._step_discrete_to_continuous(T)
        a = model.alpha[T-1]
        s = model.sigma[T] * model.noise_level
        z = torch.randn_like(xt[T])

        eps = model.predict_eps(xt[T], t)
        eps, x0_pred = model.get_clamped_eps_x0(xt[T], eps, T)

        eps_sqrt_term = (1 - a - s ** 2).clamp_min(0).sqrt()
        xt[T-1] = a.sqrt() * x0_pred + eps_sqrt_term * eps + s * z

    return model.scaler.inverse_transform(xt)

In [None]:
from condevo.stats import kl_divergence_sampled
from condevo.stats import grid_entropy_2d

# generate PD from target x according to fitness
pd_size = 2000
p = np.array(weights)
p /= p.sum()
idx = np.random.choice(np.arange(len(targets)), p=p, size=pd_size)
px = x[idx]

# sample data
samples = sample(ddim, 1000, 2)

# eval statistics
grid_size = 33
grid_range = 4  # ddim.diff_range
print("binsize =", grid_range * 2 / (grid_size - 1))
entropy = [grid_entropy_2d(xt, grid_size=grid_size, range_min=-grid_range, range_max=grid_range) for xt in reversed(samples)]
KL_divergence = [kl_divergence_sampled(px, px_hat, grid_size=grid_size, range_min=-grid_range, range_max=grid_range) for px_hat in reversed(samples)]

# plot statistics
fig = plt.figure(figsize=(4, 4))
ax = plt.gca()
ax.plot(np.arange(ddim.num_steps), KL_divergence)
ax.set_ylabel("KL-Divergence")
ax.set_xlabel("Denoising steps $(t-T)$")

t_ax = plt.twinx()
t_ax.plot([], label="KL-Divergence")
t_ax.plot(np.arange(ddim.num_steps), entropy, color="tab:orange", label="Entropy")
plt.ylabel("Entropy")
plt.legend()
plt.show()

In [None]:
##### plot denoising over time
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display

samples = sample(ddim, 1000, 2)

T, N, _ = samples.shape
fig = plt.figure(figsize=(4, 4))
ax = plt.gca()

# Ensure numpy float array (handles torch tensors too)
samples_np = np.asarray(samples, dtype=float)[::-1]

# Fixed limits
lims = [(samples_np[..., i].min(), samples_np[..., i].max()) for i in range(2)]
ax.set_xlim(lims[0]); ax.set_ylim(lims[1]);
#ax.set_xlim([-20, 20]); ax.set_ylim([-30, 20]);

# Init
xy0 = samples_np[0]
scat = ax.scatter(xy0[:, 0], xy0[:, 1], s=5, alpha=0.8)
title = ax.set_title("t = 0")

def update(frame):
    xy = samples_np[frame]              # (N, 2)
    scat.set_offsets(xy[:, :2])
    title.set_text(f"t = {frame}")
    return scat, title

anim = FuncAnimation(fig, update, frames=T, interval=50, blit=False)

display(HTML(anim.to_jshtml()))

The animation shows some important features:
- First, the particle cloud (the to-be-generated samples) expands from the origin towards the target loation (near [-10, -20])
- Then, it separates hierarchically into the respective revined clouds

The diffusion model needs to learn, is the initial drivt towards the target data regime.

**Importantly: if the number of denoising steps don't suffice, the diffusion might not reach the target regime in time. Try a dataset biased by a mean of [-100, -200], for instance.**

For such cases, data scaling will be important! Below, we present the same results with a simple MEAN / STD `StandardScaler`

## Using a StandardScaler

### DDIM

In [None]:
import torch

dataset_size = 1000
max_epochs   =  500
device = "cpu"

# augment dataset -> make dataset large enough so learning progress is smooth
weights = roulette_wheel(torch.tensor(fitness, dtype=torch.float), s=1, normalize=True)
for t, f, w in zip(targets, fitness, weights):
    print(f"target {t} with fitness {f} is weighted as {w}")

target_idx_training = np.random.choice(np.arange(len(targets)), size=dataset_size, p=None, replace=True)

# convert to torch tensors
x = torch.tensor(targets[target_idx_training], dtype=torch.float).to(device)
f = weights[target_idx_training][:, None].to(device)

# initialize model
backbone = get_mlp()
# backbone = get_unet()
ddim = get_ddim(model=backbone,
                backbone="mlp_test",
                scaler="StandardScaler",  # <-- this is implemented in condevo.preprocessing.starndard_scaler
                )
ddim = ddim.to(device)

# train
history = ddim.fit(
    x,
    weights=f,
    max_epoch=max_epochs,
    lr=3e-3,
    optimizer="Adam",
    weight_decay=1e-5,
)

In [None]:
ddim.scaler.mean, ddim.scaler.std

In [None]:
samples = ddim.sample(shape=(num_params,), num=1000)

count = []
for t in targets:
    dt = torch.linalg.norm(samples - t, dim=-1)
    count.append((dt < 0.25).sum().item() / len(samples))

for t, f, w, c in zip(targets, fitness, weights, count):
    print(f"target {t} weighted by {w:.3f}: sampling rate {c}")

fig, (ax_hist, ax) = plt.subplots(1, 2, figsize=(12, 6))
ax_hist.plot(history)
ax_hist.set_xlabel("Epochs")
ax_hist.set_ylabel("Loss")
ax_hist.grid()

ax.scatter(samples[:, 0], samples[:, 1], c='b', marker='o', alpha=0.5)
ax.scatter(targets[:, 0], targets[:, 1], c='r', marker='^', s=100, label='Targets')
ax.legend()
plt.show()

In [None]:
##### plot denoising over time
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display

samples = sample(ddim, 500, 2)

T, N, _ = samples.shape
fig = plt.figure(figsize=(4, 4))
ax = plt.gca()

# Ensure numpy float array (handles torch tensors too)
samples_np = np.asarray(samples, dtype=float)[::-1]

# Fixed limits
lims = [(samples_np[..., i].min(), samples_np[..., i].max()) for i in range(2)]
ax.set_xlim(lims[0]); ax.set_ylim(lims[1]);
# ax.set_xlim([-20, 20]); ax.set_ylim([-20, 20]);

# Init
xy0 = samples_np[0]
scat = ax.scatter(xy0[:, 0], xy0[:, 1], s=5, alpha=0.8)
title = ax.set_title("t = 0")

def update(frame):
    xy = samples_np[frame]              # (N, 2)
    scat.set_offsets(xy[:, :2])
    title.set_text(f"t = {frame}")
    return scat, title

anim = FuncAnimation(fig, update, frames=T, interval=50, blit=False)

display(HTML(anim.to_jshtml()))

### Rectified Flow

In [None]:
import torch

dataset_size = 1000
max_epochs   =  500
device = "cpu"

# augment dataset -> make dataset large enough so learning progress is smooth
weights = roulette_wheel(torch.tensor(fitness, dtype=torch.float), s=1, normalize=True)
for t, f, w in zip(targets, fitness, weights):
    print(f"target {t} with fitness {f} is weighted as {w}")

target_idx_training = np.random.choice(np.arange(len(targets)), size=dataset_size, p=None, replace=True)

# convert to torch tensors
x = torch.tensor(targets[target_idx_training], dtype=torch.float).to(device)
f = weights[target_idx_training][:, None].to(device)

# initialize model
backbone = get_mlp()
# backbone = get_unet()
rflow = get_rflow(model=backbone, backbone="unet_test", scaler="StandardScaler")
rflow = rflow.to(device)

# train
history = rflow.fit(
    x,
    weights=f,
    max_epoch=max_epochs,
    lr=3e-3,
    optimizer="Adam",
    weight_decay=1e-5,
)

In [None]:
samples = rflow.sample(shape=(num_params,), num=1000)

count = []
for t in targets:
    dt = torch.linalg.norm(samples - t, dim=-1)
    count.append((dt < 0.25).sum().item() / len(samples))

for t, f, w, c in zip(targets, fitness, weights, count):
    print(f"target {t} weighted by {w:.3f}: sampling rate {c}")

fig, (ax_hist, ax) = plt.subplots(1, 2, figsize=(12, 6))
ax_hist.plot(history)
ax_hist.set_xlabel("Epochs")
ax_hist.set_ylabel("Loss")
ax_hist.grid()

ax.scatter(samples[:, 0], samples[:, 1], c='b', marker='o', alpha=0.5)
ax.scatter(targets[:, 0], targets[:, 1], c='r', marker='^', s=100, label='Targets')
ax.legend()
plt.show()

In [None]:
@torch.no_grad()
def rflow_sample(model, num_samples, num_params, t_start=None):
    # sample entire denoising stack (num_(t_)steps, num_samples, num_params) for t in [T, 0]
    if t_start is None:
        t_start = 0

    xtt = torch.randn(model.num_steps, num_samples, num_params)  # RETURN entire denoising stack
    tt = torch.linspace(0, 1, model.num_steps, device=model.device).view(-1, 1)  # (T,1)
    dt = 1. / model.num_steps * model.matthew_factor

    xt = xtt[0]
    for T in range(t_start, model.num_steps):
        t = tt[T].expand(xt.size(0), 1)  # (B,1)
        v = model(xt, t)
        xt = xt + v * dt

        if model.noise_level:
            xt += (dt ** 0.5) * torch.randn_like(xt) * model.noise_level * (1 - t)

        if model.diff_range_filter:
            xt = model.diff_clamp(xt)

        xtt[T, ...] = xt.clone()

    return model.scaler.inverse_transform(xtt)

In [None]:
##### plot denoising over time
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display

samples = rflow_sample(rflow, 500, 2)

T, N, _ = samples.shape
fig = plt.figure(figsize=(4, 4))
ax = plt.gca()

# Ensure numpy float array (handles torch tensors too)
samples_np = np.asarray(samples, dtype=float)

# Fixed limits
lims = [(samples_np[..., i].min(), samples_np[..., i].max()) for i in range(2)]
ax.set_xlim(lims[0]); ax.set_ylim(lims[1]);
# ax.set_xlim([-20, 20]); ax.set_ylim([-20, 20]);

# Init
xy0 = samples_np[0]
scat = ax.scatter(xy0[:, 0], xy0[:, 1], s=5, alpha=0.8)
title = ax.set_title("t = 0")

def update(frame):
    xy = samples_np[frame]              # (N, 2)
    scat.set_offsets(xy[:, :2])
    title.set_text(f"t = {frame}")
    return scat, title

anim = FuncAnimation(fig, update, frames=T, interval=50, blit=False)

display(HTML(anim.to_jshtml()))