In [1]:
from argparse import Namespace
from tqdm import tqdm
from typing import List, Sequence, Tuple

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import matplotlib.pyplot as plt
%matplotlib notebook
import optax

from jax_learning.losses.supervised_loss import squared_loss
from jax_learning.models.layers import MLP

In [2]:
seed = 0
np.random.seed(seed)

# Noise Generation

In [3]:
betas = jnp.array((0.1, 0.1, 0.1))
alphas = 1 - betas
alpha_cum_prods = jnp.cumprod(alphas)

sqrt_alpha_cumprods = jnp.sqrt(alpha_cum_prods)
sqrt_one_minus_alpha_cumprods = jnp.sqrt(1 - alpha_cum_prods)

def add_noise(x_init: np.ndarray, noise: np.ndarray, t: np.int32):
    return sqrt_alpha_cumprods[t, None] * x_init + sqrt_one_minus_alpha_cumprods[t, None] * noise

# Model Creation

In [4]:
class Model(eqx.Module):
    obs_dim: int = eqx.static_field()
    model: eqx.Module

    def __init__(
        self,
        obs_dim: Sequence[int],
        context_dim: Sequence[int],
        hidden_dim: int,
        num_hidden: int,
        key: jrandom.PRNGKey,
    ):
        self.obs_dim = obs_dim
        self.model = MLP(self.obs_dim + context_dim + 1, self.obs_dim, hidden_dim, num_hidden, key)
        
    @jax.jit
    def predict(
        self,
        x: np.ndarray,
        t: float,
    ) -> np.ndarray:
        x_t = jnp.concatenate((x, jnp.array([t])))
        return self.model(x_t)

# Download MNIST
Reference: https://github.com/hsjeong5/MNIST-for-Numpy/blob/master/mnist.py

In [5]:
import gzip
import numpy as np
import os
import pickle

from urllib import request

filename = [
["training_images","train-images-idx3-ubyte.gz"],
["test_images","t10k-images-idx3-ubyte.gz"],
["training_labels","train-labels-idx1-ubyte.gz"],
["test_labels","t10k-labels-idx1-ubyte.gz"]
]

def download_mnist():
    base_url = "http://yann.lecun.com/exdb/mnist/"
    for name in filename:
        print("Downloading "+name[1]+"...")
        request.urlretrieve(base_url+name[1], name[1])
    print("Download complete.")

def save_mnist():
    mnist = {}
    for name in filename[:2]:
        with gzip.open(name[1], 'rb') as f:
            mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28*28)
    for name in filename[-2:]:
        with gzip.open(name[1], 'rb') as f:
            mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
    with open("mnist.pkl", 'wb') as f:
        pickle.dump(mnist,f)
    print("Save complete.")

def init():
    download_mnist()
    save_mnist()

def load():
    with open("mnist.pkl",'rb') as f:
        mnist = pickle.load(f)
    return mnist["training_images"], mnist["training_labels"], mnist["test_images"], mnist["test_labels"]

if not os.path.isfile("mnist.pkl"):
    init()
    
print("loading MNIST dataset")
(train_x, train_y, test_x, test_y) = load()
train_x = train_x.astype(np.float) / 255
test_x = test_x.astype(np.float) / 255
print(train_x.shape, train_y.shape)
print(test_x.shape, test_y.shape)

loading MNIST dataset
(60000, 784) (60000,)
(10000, 784) (10000,)


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  train_x = train_x.astype(np.float) / 255
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  test_x = test_x.astype(np.float) / 255


# DDPM

In [6]:
@eqx.filter_grad(has_aux=True)
def compute_loss(model: Model, xs: np.ndarray, timesteps: np.ndarray, targs: np.ndarray):
    preds = jax.vmap(model.predict)(xs, timesteps)
    loss = jnp.sum(jax.vmap(squared_loss)(preds, targs)) / len(xs)
    return loss, {"loss": loss}

In [7]:
def train(
    train_x: np.ndarray,
    train_c: np.ndarray,
    model: Model,
    opt: optax.GradientTransformation,
    cfg: Namespace
) -> Tuple[Model, List]:
    """
    train_x: data to reconstruct
    train_c: data context
    model: the diffusion model
    opt: optimizer for changing the parameters of the model
    """
    opt_state = opt.init(model)
    losses = []
    
    for i in tqdm(range(cfg.num_iterations)):
        train_idxes = np.random.permutation(len(train_x))[:cfg.batch_size]
        curr_c = train_c[train_idxes]
        curr_x = train_x[train_idxes]
        timesteps = np.random.randint(cfg.max_t, size=cfg.batch_size)
        
        noise = np.random.randn(*curr_x.shape)
        noisy_x = add_noise(curr_x, noise, timesteps)
        noisy_x = jnp.concatenate((noisy_x, curr_c[:, None]), axis=1)
        grads, info = compute_loss(model, noisy_x, timesteps, noise)
        updates, opt_state = opt.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        losses.append(info["loss"])
        
    return model, losses

In [None]:
obs_dim = int(np.product(train_x.shape[1:]))
context_dim = 1
hidden_dim = 256
num_hidden = 4
key = jrandom.PRNGKey(seed)

model = Model(obs_dim, context_dim, hidden_dim, num_hidden, key)

lr = 3e-4
opt_transforms = [optax.scale_by_rms(), optax.scale(-lr)]
opt = optax.chain(*opt_transforms)

cfg_dict = {
    "num_iterations": 50000,
    "batch_size": 512,
    "max_t": 20,
}
cfg = Namespace(**cfg_dict)

trained_model, losses = train(train_x, train_y, model, opt, cfg)

 14%|███████▋                                                | 6897/50000 [03:03<18:59, 37.82it/s]

In [None]:
plt.plot(np.arange(len(losses)), losses)
plt.xlabel("Number of updates")
plt.ylabel("Loss")
plt.show()

In [None]:
pred = trained_model.predict(jnp.concatenate((np.random.randn(*test_x[0].shape), test_y[[0]])), 9)

In [None]:
plt.imshow(pred.reshape((28, 28)))
plt.show()