In [None]:
!pip install -qr requirements.txt

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

from ddim.models.unet import UNet
from ddim.diffusion.parameters import linear_beta_schedule, diff_params_from_betas
from ddim.diffusion.inference import denoising_loop
from ddim.training.steps import get_train_loop
from ddim.training.ema import EMA
from ddim.training.data import TimestepSampler, DataSampler

import numpy as np
import random
import pickle

import jax
from jax import numpy as jnp
from jax.experimental.pjit import pjit, PartitionSpec
from jax.experimental.maps import mesh
import optax

import jmp

from torchvision.datasets import MNIST

In [None]:
# Basix jax mesh setup
prng = jax.random.PRNGKey(1337)
devices = np.array(jax.devices()).reshape((1, 4, 2))
precision_policy = jmp.Policy(
    compute_dtype = jnp.bfloat16,
    param_dtype = jnp.bfloat16,
    output_dtype = jnp.bfloat16,
)

In [None]:
# Load MNIST digits and transform to -1 -> 1 range
mnist_dataset = MNIST('/tmp/mnist/', download=True)

mnist_data = []
for mnist_image in mnist_dataset:
    mnist_data.append(jnp.array(mnist_image[0]).reshape(28, 28, 1))
mnist_data = jnp.array(mnist_data) / 127.5 - 1.0

plt.imshow(mnist_data[0])
plt.colorbar()

In [None]:
# Set up model
model = UNet(
    dims = 2,
    model_channels = 64,
    channel_mult = (1, 2),
    use_scale_shift_norm = True,
    dropout = 0.0,
    num_head_channels = 8,
    num_res_blocks = 1,
    attention_resolutions = (2,),
    out_channels = 1,
    dtype = precision_policy.compute_dtype
)

# Initialize parameters
image_in = jnp.zeros((1, 28, 28, 1))
embed_in = jnp.zeros((1,))
init_pjit = pjit(model.init, [None, PartitionSpec("batch", "x", "y"), PartitionSpec("batch")], PartitionSpec(None))
with mesh(devices, ('batch', 'x', 'y')):
    params = init_pjit(prng, jnp.zeros((1, 28, 28, 1)), jnp.zeros((1,)))
    
param_count = 0
for param in jax.tree_util.tree_flatten(params)[0]:
    param_count += len(param.flatten())
print(f"param count: {param_count}")

# Initialize diffusion
diff_params = diff_params_from_betas(linear_beta_schedule(1000))

In [None]:
# Set up training
batch_size = 32
ema = EMA()
data_sampler = DataSampler(mnist_data, batch_size)
timestep_sampler = TimestepSampler(1000, batch_size)

learning_rate = 0.0001
opt = optax.chain(
    optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
    optax.scale(-learning_rate)
)
opt_params = opt.init(params)

# Build a train loop that runs for 10000 batches
train_loop = get_train_loop(
    opt, 
    model, 
    diff_params, 
    data_sampler, 
    timestep_sampler, 
    ema, 
    how_many = 10000
)

In [None]:
# Train
loss = 0
with mesh(devices, ('batch', 'x', 'y')):
    for i in range(0, 10):
        # Progress shot
        prng_img = jax.random.PRNGKey(random.randint(0, 2**32))
        images_in = jax.random.normal(prng_img, (1, 28, 28, 1))
        timesteps_in = jnp.array([0,], dtype=jnp.int32)
        out = denoising_loop(model, params, diff_params, images_in.astype(jnp.float32))
        image = (out[0, :, :, :] + 1.0) / 2.0
        plt.figure(figsize=(8, 8))
        plt.imshow(image)
        plt.show()
        
        # Print loss
        print(f"e: {i * 10000}, l: {loss}")
        
        # Save parameters
        with open(f"params_{i * 10000}.pkl", "wb") as f:
            pickle.dump(params, f)             
        
        # Run 10000 training batches
        prng, params, opt_params, loss = train_loop(prng, params, opt_params)