In [None]:
!pip install -qr requirements.txt

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

from ddim.models.unet import UNet
from ddim.models.model_utils import load_params, get_default_channel_mult
from ddim.diffusion.parameters import linear_beta_schedule, diff_params_from_betas
from ddim.diffusion.inference import denoising_loop

import numpy as np
import random

import jax
from jax import numpy as jnp
from jax.experimental.pjit import pjit, PartitionSpec
from jax.experimental.maps import mesh

import jmp

In [None]:
# Jax setup
devices = np.array(jax.devices()).reshape((1, 4, 2))
precision_policy = jmp.Policy(
    compute_dtype = jnp.bfloat16,
    param_dtype = jnp.bfloat16,
    output_dtype = jnp.bfloat16,
)

In [None]:
# Create model, load parameters, set up diffusion
image_size = 128
model = UNet(
    dims = 2,
    model_channels = 256,
    channel_mult = get_default_channel_mult(image_size),
    use_scale_shift_norm = True,
    dropout = 0.0,
    num_head_channels = 64,
    num_res_blocks = 2,
    attention_resolutions = (32, 16, 8),
    out_channels = 3,
    dtype = precision_policy.compute_dtype
)

params = load_params("params_danbooru2019_128_300000.pkl", devices, precision_policy)
diff_params = diff_params_from_betas(linear_beta_schedule(1000))

In [None]:
# pjit the denoising loop
def denoising_loop_pjit(params, diff_params, images):
    return denoising_loop(model, params, diff_params, images)

denoising_loop_pjit = pjit(denoising_loop_pjit, 
   [
       PartitionSpec(None),
       PartitionSpec(None),
       PartitionSpec("batch", "x", "y")
   ], 
   PartitionSpec("batch", "x", "y"),
   # static_argnums = (0,) # TODO
)

In [None]:
# Single image
prng_img = jax.random.PRNGKey(random.randint(0, 2**32))
images_in = jax.random.normal(prng_img, (1, 128, 128, 3))
with mesh(devices, ('batch', 'x', 'y')):
    out = denoising_loop(model, params, diff_params, images_in.astype(jnp.float32))
image = np.array(((out[i, :, :, :] + 1.0) * 127.5)).astype(np.uint8)
plt.figure(figsize=(8, 8))
plt.imshow(image)
plt.colorbar()
plt.show()

In [None]:
# Batch inference
temp = 0.9
prng_img = jax.random.PRNGKey(random.randint(0, 2**32))
images_in = jax.random.normal(prng_img, (25, 128, 128, 3)) * temp
with mesh(devices, ('batch', 'x', 'y')):
    out = denoising_loop_pjit(params, diff_params, images_in.astype(jnp.float32))
    
plt.figure(figsize=(16, 16))
for i in range(25):
    plt.subplot(5, 5, i + 1)
    image = np.array(((out[i, :, :, :] + 1.0) * 127.5)).astype(np.uint8)
    plt.imshow(image, cmap="Greys")
    plt.xticks([])
    plt.yticks([])
plt.tight_layout()
plt.savefig("out.png")
plt.show()