In [None]:
!pip install -qr requirements.txt

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

from ddim.models.unet import UNet
from ddim.models.model_utils import get_default_channel_mult
from ddim.diffusion.parameters import linear_beta_schedule, diff_params_from_betas
from ddim.diffusion.inference import denoising_loop
from ddim.training.steps import get_train_loop
from ddim.training.ema import EMA
from ddim.training.data import TimestepSampler, ImageDataset, DatasetSampler

import numpy as np
import random
import pickle

import jax
from jax import numpy as jnp
from jax.experimental.pjit import pjit, PartitionSpec
from jax.experimental.maps import mesh
import optax

import jmp
import time

In [None]:
# Basic jax mesh setup
prng = jax.random.PRNGKey(1337)
devices = np.array(jax.devices()).reshape((1, 4, 2))
precision_policy = jmp.Policy(
    compute_dtype = jnp.bfloat16,
    param_dtype = jnp.bfloat16,
    output_dtype = jnp.bfloat16,
)

In [None]:
# Set up model
image_size = 128
model = UNet(
    dims = 2,
    model_channels = 256,
    channel_mult = get_default_channel_mult(image_size),
    use_scale_shift_norm = True,
    dropout = 0.0,
    num_head_channels = 64,
    num_res_blocks = 2,
    attention_resolutions = (32, 16, 8),
    out_channels = 3,
    dtype = precision_policy.compute_dtype
)

# Initialize parameters
init_pjit = pjit(model.init, [None, PartitionSpec("batch", "x", "y"), PartitionSpec("batch")], PartitionSpec(None))
with mesh(devices, ('batch', 'x', 'y')):
    params = init_pjit(prng, jnp.zeros((1, image_size, image_size, 3)), jnp.zeros((1,)))

param_count = 0
for param in jax.tree_util.tree_flatten(params)[0]:
    param_count += len(param.flatten())
print(f"param count: {param_count}")

# Initialize diffusion
diff_params = diff_params_from_betas(linear_beta_schedule(1000))

In [None]:
# Set up training
batch_size = 1
train_steps_per_iter = 2500
train_steps_total = 100000000
eval_every = 10000
save_every = 100000
ema = EMA()
resize = None
if image_size != 512:
    resize = image_size
dataset = ImageDataset("data/portraits/*.jpg", resize = resize)
data_sampler = DatasetSampler(dataset, batch_size = batch_size)
timestep_sampler = TimestepSampler(1000, batch_size)

learning_rate = 0.00001
opt = optax.chain(
    optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
    optax.scale(-learning_rate)
)
opt_params = opt.init(params)

# Build a train loop that runs for train_steps_per_iter batches
train_loop = get_train_loop(
    opt, 
    model, 
    diff_params, 
    data_sampler, 
    timestep_sampler, 
    ema, 
    how_many = train_steps_per_iter,
)

In [None]:
# Train
loss = 0
start = time.time()
with mesh(devices, ('batch', 'x', 'y')):
    for i in range(0, train_steps_total // train_steps_per_iter):
        if  (i * train_steps_per_iter) % eval_every == 0:
            # Progress shot
            prng_img = jax.random.PRNGKey(random.randint(0, 2**32))
            images_in = jax.random.normal(prng_img, (1, image_size, image_size, 3))
            timesteps_in = jnp.array([0,], dtype=jnp.int32)
            out = denoising_loop(model, params, diff_params, images_in.astype(jnp.float32))
            image = np.array(((out[0, :, :, :] + 1.0) * 127.5)).astype(np.uint8)
            plt.figure(figsize=(8, 8))
            plt.imshow(image)
            plt.show()
        
        # Print loss
        img_s = i * round((train_steps_per_iter * batch_size) / (time.time() - start), 2)
        print(f"e: {i * train_steps_per_iter}, l: {loss}, img/s: {img_s}")

        # Save parameters
        if (i * train_steps_per_iter) % save_every == 0:
            with open(f"params_{i * train_steps_per_iter}.pkl", "wb") as f:
                pickle.dump(params, f)             
        
        # Run train_steps_per_iter training batches
        prng, params, opt_params, loss = train_loop(prng, params, opt_params)