In [None]:
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from datasets.load import load_dataset


@dataclass
class Config:
  batch_size: int = 32
  img_size: int = 128
  epochs: int = 500
  total_samples: int = 5_000_000
  loss_type: str = 'mae'
  dataset: str = 'cartoonset'
  viz: str = 'matplotlib'
  model: str = 'stable_unet'
  eval_every: int = 200
  log_every: int = 200
  # model config
  channels_in: int = 3
  channels_block: int = 128
  channels_mlp: int = 128
  patches_mlp: int = 128
  num_layers: int = 8
  patch_size: int = 8
  num_patch_rows: int = 16
  # ema config
  ema_decay: float = 0.995
  ema_update_every: int = 10
  ema_update_after_step: int = 100
  # schedule config
  schedule: str = 'cosine'
  beta_start: float = 3e-4
  beta_end: float = 0.5
  timesteps: int = 1_000
  # optimizer config
  lr_start: float = 2e-5
  drop_1_mult: float = 1.0
  drop_2_mult: float = 1.0

  @property
  def steps_per_epoch(self) -> int:
    return self.total_samples // (self.epochs * self.batch_size)

  @property
  def total_steps(self) -> int:
    return self.total_samples // self.batch_size


config = Config()

print(jax.devices())

In [None]:
if config.dataset == 'mnist':
  hfds = load_dataset('mnist', split='train')
  X = np.stack(hfds['image'])[..., None]
  ds = tf.data.Dataset.from_tensor_slices(X.astype(np.float32))
elif config.dataset == 'pokemon':
  hfds = load_dataset('lambdalabs/pokemon-blip-captions', split='train')
  hfds = hfds.map(
    lambda sample: {'image': sample['image'].resize((64 + 16, 64 + 16))},
    remove_columns=['text'],
    batch_size=96,
  )
  X = np.stack(hfds['image'])
  ds = tf.data.Dataset.from_tensor_slices(X.astype(np.float32))
elif config.dataset == 'cartoonset':
  hfds = load_dataset('cgarciae/cartoonset', '10k', split='train')
  ds = tf.data.Dataset.from_generator(
    lambda: hfds,
    output_signature={
      'img_bytes': tf.TensorSpec(shape=(), dtype=tf.string),
    },
  )

  def process_fn(x):
    x = tf.image.decode_png(x['img_bytes'], channels=3)
    x = tf.cast(x, tf.float32)
    x = tf.image.resize(x, (128, 128))
    return x

  ds = ds.map(process_fn)
else:
  raise ValueError(f'Unknown dataset {config.dataset}')

ds = ds.map(lambda x: x / 127.5 - 1.0)
ds = ds.repeat()
ds = ds.shuffle(seed=42, buffer_size=1_000)
ds = ds.batch(config.batch_size, drop_remainder=True)
ds = ds.prefetch(tf.data.AUTOTUNE)

In [None]:
from einop import einop

x_sample: np.ndarray = ds.as_numpy_iterator().next()

n_rows = 4
n_cols = 7
t = x_sample[: n_rows * n_cols]
plt.figure(figsize=(3 * n_cols, 3 * n_rows))
t = einop(t, '(row col) h w c -> (row h) (col w) c', row=n_rows, col=n_cols)
plt.imshow(t);

In [None]:
from flax.experimental import nnx


def expand_to(a: jax.Array, b: jax.Array):
  new_shape = a.shape + (1,) * (b.ndim - a.ndim)
  return a.reshape(new_shape)


class ScheduleState(nnx.Variable):
  pass


class DDPMSchedule(nnx.GraphNode):
  def __init__(self, betas: jax.Array, timesteps: int, rngs: nnx.Rngs):
    self.timesteps = timesteps
    self.betas = ScheduleState(betas)
    self.alphas = ScheduleState(1.0 - betas)
    self.alpha_bars = ScheduleState(jnp.cumprod(1.0 - betas))
    self.rngs = rngs

  @nnx.jit
  def forward(self, x: jax.Array, t: jax.Array):
    key = self.rngs.schedule()
    alpha_bars = expand_to(self.alpha_bars[t], x)
    noise = jax.random.normal(key, x.shape)
    xt = jnp.sqrt(alpha_bars) * x + jnp.sqrt(1.0 - alpha_bars) * noise
    return xt, noise

  @nnx.jit
  def reverse(self, x: jax.Array, noise: jax.Array, t: jax.Array):
    betas = expand_to(self.betas[t], x)
    alphas = expand_to(self.alphas[t], x)
    alpha_bars = expand_to(self.alpha_bars[t], x)

    key = self.rngs.schedule()
    z = jnp.where(
      expand_to(t, x) > 0, jax.random.normal(key, x.shape), jnp.zeros_like(x)
    )
    a = 1.0 / jnp.sqrt(alphas)
    b = betas / jnp.sqrt(1.0 - alpha_bars)
    x = a * (x - b * noise) + jnp.sqrt(betas) * z
    return x

In [None]:
def polynomial_schedule(
  beta_start, beta_end, timesteps, exponent=2.0, **kwargs
):
  betas = jnp.linspace(0, 1, timesteps) ** exponent
  return betas * (beta_end - beta_start) + beta_start


def sigmoid_schedule(beta_start, beta_end, timesteps, **kwargs):
  betas = jax.nn.sigmoid(jnp.linspace(-6, 6, timesteps))
  return betas * (beta_end - beta_start) + beta_start


def cosine_schedule(beta_start, beta_end, timesteps, s=0.008, **kwargs):
  x = jnp.linspace(0, timesteps, timesteps + 1)
  ft = jnp.cos(((x / timesteps) + s) / (1 + s) * jnp.pi * 0.5) ** 2
  alphas_cumprod = ft / ft[0]
  betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
  betas = jnp.clip(betas, 0.0001, 0.9999)
  betas = (betas - betas.min()) / (betas.max() - betas.min())
  return betas * (beta_end - beta_start) + beta_start


# %%
if config.schedule == 'polynomial':
  schedule_fn = polynomial_schedule
elif config.schedule == 'sigmoid':
  schedule_fn = sigmoid_schedule
elif config.schedule == 'cosine':
  schedule_fn = cosine_schedule
else:
  raise ValueError(f'Unknown schedule {config.schedule}')

betas = schedule_fn(
  config.beta_start,
  config.beta_end,
  config.timesteps,
)
rngs = nnx.Rngs(0)
schedule = DDPMSchedule(betas, config.timesteps, rngs)
n_rows = 2
n_cols = 7

_, (ax_img, ax_plot) = plt.subplots(2, 1, figsize=(3 * n_cols, 3 * n_rows))

t = jnp.linspace(0, config.timesteps, n_cols).astype(jnp.uint32)
x_data = einop(x_sample[0], 'h w c -> b h w c', b=n_cols)
x_data, x_data_noise = schedule.forward(x_data, t)
x_viz = einop(x_data, 'col h w c -> h (col w) c', col=n_cols)
x_viz = (x_viz + 1.0) / 2.0
ax_img.imshow(x_viz)

linear = polynomial_schedule(
  betas.min(), betas.max(), config.timesteps, exponent=1.0
)
ax_plot.plot(linear, label='linear', color='black', linestyle='dotted')
ax_plot.plot(betas)
for s in ['top', 'bottom', 'left', 'right']:
  ax_plot.spines[s].set_visible(False)

In [None]:
# %%
from functools import partial


@partial(nnx.jit, static_argnames=['return_all'])
def sample(
  model: nnx.Module,
  x0: jax.Array,
  ts: jax.Array,
  schedule: DDPMSchedule,
  *,
  return_all=True,
):
  ts = einop(ts, 't -> t b', b=x0.shape[0])

  @partial(nnx.scan, state_axes={}, split_rngs=True)
  def scan_fn(
    x: jax.Array, t: jax.Array, model: nnx.Module, schedule: DDPMSchedule
  ):
    pred_noise = model(x, t)
    x = schedule.reverse(x, pred_noise, t)
    if return_all:
      return x, x
    else:
      return x, None

  x, xs = scan_fn(x0, ts, model, schedule)

  if xs is not None:
    return xs
  else:
    return x


def render_image(x, ax=None):
  if ax is None:
    ax = plt.gca()

  if x.shape[-1] == 1:
    x = x[..., 0]
    cmap = 'gray'
  else:
    cmap = None
  x = (x / 2.0 + 0.5) * 255
  x = np.clip(x, 0, 255).astype(np.uint8)
  ax.imshow(x, cmap=cmap)
  ax.axis('off')

In [None]:
class Baseline(nnx.Module):
  def __init__(self, sample: jax.Array):
    self.sample = sample

  def __call__(self, x, t):
    return 10.0 * (x - self.sample[None])


baseline = Baseline(x_sample[0])

x0 = jax.random.normal(rngs.sample(), x_sample.shape)
ts = jnp.arange(config.timesteps)[::-1]
xf = sample(baseline, x0, ts, schedule, return_all=False)

nnx.display(xf[0])
render_image(xf[0])

In [None]:
from functools import partial
from flax.experimental import nnx


def MLP(din, dmid, dout, rngs):
  return nnx.Sequential(
    nnx.Linear(din, dmid, rngs=rngs),
    nnx.relu,
    nnx.Linear(dmid, dout, rngs=rngs),
  )


class MixerBlock(nnx.Module):
  def __init__(
    self, num_patches, hidden_size, mix_patch_size, mix_hidden_size, *, rngs
  ):
    self.patch_mixer = MLP(num_patches, mix_patch_size, num_patches, rngs=rngs)
    self.hidden_mixer = MLP(
      hidden_size, mix_hidden_size, hidden_size, rngs=rngs
    )
    self.norm1 = nnx.LayerNorm(num_patches, rngs=rngs)
    self.norm2 = nnx.LayerNorm(hidden_size, rngs=rngs)

  def __call__(self, y):
    y = einop(y, '... p c -> ... c p')
    y = y + self.patch_mixer(self.norm1(y))
    y = einop(y, '... c p -> ... p c')
    y = y + self.hidden_mixer(self.norm2(y))
    return y


class Mixer2d(nnx.Module):
  def __init__(
    self,
    img_size,
    patch_size,
    hidden_size,
    mix_patch_size,
    mix_hidden_size,
    num_blocks,
    t1,
    *,
    rngs,
  ):
    height, width, input_size = img_size
    assert (height % patch_size) == 0
    assert (width % patch_size) == 0
    num_patches = (height // patch_size) * (width // patch_size)

    self.conv_in = nnx.Conv(
      input_size + 1,
      hidden_size,
      kernel_size=(patch_size, patch_size),
      strides=(patch_size, patch_size),
      rngs=rngs,
    )
    self.conv_out = nnx.ConvTranspose(
      hidden_size,
      input_size,
      kernel_size=(patch_size, patch_size),
      strides=(patch_size, patch_size),
      rngs=rngs,
    )
    self.blocks = [
      MixerBlock(
        num_patches, hidden_size, mix_patch_size, mix_hidden_size, rngs=rngs
      )
      for _ in range(num_blocks)
    ]
    self.norm = nnx.LayerNorm(
      num_patches, reduction_axes=-2, feature_axes=-2, rngs=rngs
    )
    self.t1 = t1

  def __call__(self, x, t):
    t = t / self.t1
    _, height, width, _ = x.shape
    t = einop(t, 'b -> b h w 1', h=height, w=width)
    x = jnp.concatenate([x, t], axis=-1)
    x = self.conv_in(x)
    _, patch_height, patch_width, _ = x.shape
    x = einop(x, 'b h w c -> b (h w) c')
    for block in self.blocks:
      x = block(x)
    x = self.norm(x)
    x = einop(x, 'b (h w) c -> b h w c', h=patch_height, w=patch_width)
    return self.conv_out(x)


model = Mixer2d(
  img_size=(128, 128, 3),
  patch_size=8,
  hidden_size=64,
  mix_patch_size=256,
  mix_hidden_size=256,
  num_blocks=4,
  t1=10.0,
  rngs=rngs,
)
t_sample = jnp.full((x_sample.shape[0],), 100)
y_sample = model(x_sample, t_sample)

nnx.display(y_sample[0])
nnx.display(model)

In [None]:
from typing import Generic, TypeVar

M = TypeVar('M', bound=nnx.Module)


class EMA(nnx.GraphNode, Generic[M]):
  def __init__(self, model: M, decay: float):
    self.decay = decay
    self.model = model
    self.graphdef, self.ema_params, self.rest = nnx.split(model, nnx.Param, ...)
    # copy arrays to avoid aliasing
    self.ema_params = jax.tree.map(jnp.array, self.ema_params)

  def ema_model(self):
    return nnx.merge(self.graphdef, self.ema_params, self.rest)

  @nnx.jit
  def update(self):
    def _ema(ema_params, new_params):
      return self.decay * ema_params + (1.0 - self.decay) * new_params

    new_params = nnx.state(self.model, nnx.Param)
    self.ema_params: nnx.State = jax.tree.map(_ema, self.ema_params, new_params)


ema = EMA(model, config.ema_decay)

nnx.display(ema)

In [None]:
import optax

tx = optax.chain(
  optax.clip_by_global_norm(1.0),
  optax.adamw(
    optax.piecewise_constant_schedule(
      config.lr_start,
      {
        int(config.total_steps * 1 / 3): config.drop_1_mult,
        int(config.total_steps * 2 / 3): config.drop_2_mult,
      },
    )
  ),
)
optimizer = nnx.Optimizer(model, tx)

In [None]:
from IPython.display import clear_output

if config.loss_type == 'mse':
  loss_metric = lambda a, b: jnp.mean((a - b) ** 2)
elif config.loss_type == 'mae':
  loss_metric = lambda a, b: jnp.mean(jnp.abs(a - b))
else:
  raise ValueError(f'Unknown loss type {config.loss_type}')


def loss_fn(model: Mixer2d, xt, t, noise):
  pred_noise = model(xt, t)
  return loss_metric(noise, pred_noise)


@nnx.jit
def train_step(
  model: Mixer2d,
  ema: EMA[Mixer2d],
  optimizer: nnx.Optimizer,
  schedule: DDPMSchedule,
  x: jax.Array,
  rngs: nnx.Rngs,
):
  print("compiling 'train_step' ...")
  ema_model = ema.ema_model()
  t = jax.random.randint(
    rngs.schedule(),
    x.shape[0:1],
    minval=0,
    maxval=config.timesteps,
    dtype=jnp.uint32,
  )
  xt, noise = schedule.forward(x, t)

  loss, grads = nnx.value_and_grad(loss_fn)(model, xt, t, noise)
  ema_loss = loss_fn(ema_model, xt, t, noise)

  optimizer.update(grads)
  return {'loss': loss, 'ema_loss': ema_loss}


# %%
import numpy as np
from tqdm import tqdm
from IPython import display

print(jax.devices())

axs_diffusion = None
axs_metrics = None
ds_iterator = ds.as_numpy_iterator()
logs = {}
history = {}
step = 0
disp_diffusion = None
disp_metrics = None

# %%

for step in tqdm(
  range(step, config.total_steps), total=config.total_steps, unit='step'
):
  if step % config.eval_every == 0:
    # --------------------
    # visualize progress
    # --------------------
    n_rows = 3
    n_cols = 5
    viz_key = jax.random.PRNGKey(1)
    x0 = jax.random.normal(viz_key, (n_rows * n_cols, *x_sample.shape[1:]))

    ts = np.arange(config.timesteps)[::-1]
    xf = sample(model, x0, ts, schedule, return_all=False)
    xf = np.asarray(xf)
    xf = einop(
      xf, '(row col) h w c -> (row h) (col w) c', row=n_rows, col=n_cols
    )

    if axs_diffusion is None:
      plt.figure(figsize=(3 * n_cols, 3 * n_rows))
      disp_diffusion = display.display("diffusion", display_id=True)
      axs_diffusion = plt.gca()

    # plt.figure(figsize=(3 * n_cols, 3 * n_rows))
    # clear_output(wait=True)
    axs_diffusion.clear()
    xf = (xf + 1.0) / 2.0
    xf = np.clip(xf, 0, 1)
    axs_diffusion.imshow(xf)
    disp_diffusion.update(axs_diffusion.figure)
    # plt.show()
    # plt.pause(0.1)
  # --------------------
  # trainig step
  # --------------------
  x = ds_iterator.next()
  logs = train_step(model, ema, optimizer, schedule, x, rngs)
  logs = jax.tree.map(np.asarray, logs)  # convert to numpy
  for k, v in logs.items():
    history.setdefault(k, []).append(v)

  if (
    step >= config.ema_update_after_step and step % config.ema_update_every == 0
  ):
    ema.update()

  if step % config.log_every == 0:
    if axs_metrics is None:
      plt.figure(figsize=(12, 6))
      disp_metrics = display.display('metrics', display_id=True)
      axs_metrics = plt.gca()

    # clear_output(wait=True)
    axs_metrics.clear()
    for k, v in history.items():
      axs_metrics.plot(v, label=k)

    # force render
    # axs_metrics.get_figure().show()
    disp_metrics.update(axs_metrics.figure)

# %%
n_rows = 3
n_cols = 5
viz_key = jax.random.PRNGKey(1)
t = jax.random.normal(viz_key, (n_rows * n_cols, *x_sample.shape[1:]))

ts = np.arange(config.timesteps)[::-1]
xf = sample(model, t, ts, schedule, return_all=False)
xf = np.asarray(xf)
xf = einop(xf, '(row col) h w c -> (row h) (col w) c', row=n_rows, col=n_cols)

plt.figure(figsize=(3 * n_cols, 3 * n_rows))
render_image(xf)

In [None]:
n_rows = 3
n_cols = 5
viz_key = jax.random.PRNGKey(1)
x0 = jax.random.normal(viz_key, (n_rows * n_cols, *x_sample.shape[1:]))

ts = np.arange(config.timesteps)[::-1]
xf = sample(model, x0, ts, schedule, return_all=False)
xf = np.asarray(xf)
plt.imshow(xf[0])

In [None]:
plt.imshow(x_data[1])
plt.figure()
plt.imshow(x_data_noise[1])
plt.figure()
plt.imshow(x_data[1] - x_data_noise[1])