In [1]:
! nvcc -V

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Mon_May__3_19:15:13_PDT_2021
Cuda compilation tools, release 11.3, V11.3.109
Build cuda_11.3.r11.3/compiler.29920130_0


In [2]:
try:
    import jax
except ModuleNotFoundError:
    ! pip install --user --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    ! pip install --user flax
    
! pip list | grep ax    

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax[cuda]
  Using cached jax-0.3.14-py3-none-any.whl
Collecting etils[epath]
  Using cached etils-0.6.0-py3-none-any.whl (98 kB)
Collecting jaxlib==0.3.14+cuda11.cudnn82
  Using cached https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.14%2Bcuda11.cudnn82-cp37-none-manylinux2014_x86_64.whl (161.9 MB)
Installing collected packages: etils, jaxlib, jax
Successfully installed etils-0.6.0 jax-0.3.14 jaxlib-0.3.14+cuda11.cudnn82
flax                                  0.5.2
jax                                   0.3.14
jaxlib                                0.3.14+cuda11.cudnn82
jupyter-server-mathjax                0.2.5
optax                                 0.1.3


In [3]:
from tqdm import tqdm
from pathlib import Path
from typing import Tuple, Any
from datetime import datetime
from functools import partial

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
from flax.training.checkpoints import save_checkpoint
import optax
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

from model import DiffusionModel

In [20]:
def create_output_dir(output_dir: Path) -> Tuple[Path, Path, Path]:
    output_dir = output_dir / datetime.now().strftime('%Y%m%d-%H%M%S')
    ckpt_dir = output_dir / 'models'
    log_dir = output_dir / 'logs'
    
    if not output_dir.exists():
        output_dir.mkdir(parents=True)
        ckpt_dir.mkdir()
        log_dir.mkdir()

    return (output_dir, ckpt_dir, log_dir)


def preprocess_image(data, image_size):
    image = data['image']
    height = tf.shape(image)[0]
    width = tf.shape(image)[1]
    crop_size = tf.minimum(height, width)
    image = tf.image.crop_to_bounding_box(image,
                                          (height - crop_size) // 2,
                                          (width - crop_size) // 2,
                                          crop_size,
                                          crop_size)
    # resize and clip
    # for image downsampling it is important to turn on antialiasing
    image = tf.image.resize(image, size=(image_size, image_size),
                            antialias=True)
    return tf.clip_by_value(image / 255.0, 0.0, 1.0)


def prepare_datasets(image_size: int = 64,
                     batch_size: int = 64):
    dataset_name = 'oxford_flowers102'
    split_train = 'train[:80%]+validation[:80%]+test[:80%]'
    split_val = 'train[80%:]+validation[80%:]+test[80%:]'

    preprocess_fn = partial(preprocess_image, image_size=image_size)
    
    ds_train = tfds.load(dataset_name, split=split_train, shuffle_files=True)\
                   .map(preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE)\
                   .cache()\
                   .shuffle(buffer_size=10*batch_size)\
                   .batch(batch_size, drop_remainder=True)\
                   .prefetch(buffer_size=tf.data.AUTOTUNE)
    ds_train = tfds.as_numpy(ds_train)
                   
    ds_val = tfds.load(dataset_name, split=split_val, shuffle_files=True)\
                 .map(preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE)\
                 .cache()\
                 .batch(batch_size, drop_remainder=True)\
                 .prefetch(buffer_size=tf.data.AUTOTUNE)
    ds_val = tfds.as_numpy(ds_val)

    return ds_train, ds_val


class TrainState(train_state.TrainState):
    batch_stats: Any


def model(**kwargs):
    return DiffusionModel(**kwargs)


def l1_loss(predictions, targets):
    return jnp.abs(predictions - targets)


def kernel_inception_distance():
    raise NotImplementedError()


def update_ema(p_cur, p_new, momentum: float = 0.999):
    return momentum*p_cur + (1-momentum)*p_new


@jax.jit
def train_step(state, batch, rng):
    def loss_fn(params):
        outputs, mutated_vars = state.apply_fn(
            {'params': params, 'batch_stats': state.batch_stats},
            batch, rng, train=True,
            mutable=['batch_stats']
        )
        noises, images, pred_noises, pred_images = outputs
        
        noise_loss = l1_loss(pred_noises, noises).mean()
        image_loss = l1_loss(pred_images, images).mean()
        loss = noise_loss + image_loss
        return loss, mutated_vars
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, mutated_vars), grads = grad_fn(state.params)
    state = state.apply_gradients(
        grads=grads,
        batch_stats=mutated_vars['batch_stats'])
    return state, loss
        

@partial(jax.jit, static_argnums=4)
def evaluate(params,
             batch_stats,
             rng,
             images,
             diffusion_steps: int):
    def eval_fn(model):
        # TODO: quantitative metrics

        generated_images = model.generate(rng, images.shape, diffusion_steps)
        return generated_images

    variables = {'params': params, 'batch_stats': batch_stats}
    return nn.apply(eval_fn, model())(variables)


def run(epochs: int,
        image_size: int,
        batch_size: int,
        learning_rate: float,
        weight_decay: float,
        val_diffusion_steps: int,
        output_dir: Path):
    output_dir, ckpt_dir, log_dir = create_output_dir(output_dir)
    summary_writer = tf.summary.create_file_writer(str(log_dir))
    
    rng = jax.random.PRNGKey(0)
    rng, key_init, key_diffusion = jax.random.split(rng, 3)

    ds_train, _ = prepare_datasets(image_size, batch_size)

    image_shape = (batch_size, image_size, image_size, 3)
    dummy = jnp.ones(image_shape, dtype=jnp.float32)

    variables = model().init(key_init, dummy, key_diffusion,
                             train=True)

    state = TrainState.create(
        apply_fn=model().apply,
        params=variables['params'],
        batch_stats=variables['batch_stats'],
        tx=optax.adamw(learning_rate, weight_decay=weight_decay)
    )
    ema_params = state.params.copy(add_or_replace={})
    rng, rng_train, rng_val = jax.random.split(rng, 3)

    for epoch in range(epochs):
        losses = []
        pbar = tqdm(ds_train, desc=f'Epoch {epoch}')
        for images in pbar:
            rng_train, key = jax.random.split(rng_train)
            state, loss = train_step(state, images, key)

            pbar.set_postfix({'loss': f'{loss:.5f}'})
            losses.append(loss)
            ema_params = jax.tree_map(update_ema, ema_params, state.params)

        generated_images = evaluate(ema_params,
                                    state.batch_stats,
                                    rng=rng_val,
                                    images=dummy,
                                    diffusion_steps=val_diffusion_steps)

        with summary_writer.as_default():
            tf.summary.scalar('loss', np.mean(losses), step=epoch)
            tf.summary.image('generated', generated_images, step=epoch,
                             max_outputs=8)
        save_checkpoint(ckpt_dir, state, step=epoch)

In [21]:
args = {
    'epochs': 50,
    'image_size': 64,
    'batch_size': 64,
    'learning_rate': 1e-3,
    'weight_decay': 1e-4,
    'val_diffusion_steps': 20,
    'output_dir': Path('./outputs')
}

In [22]:
tf.config.experimental.set_visible_devices([], 'GPU')

run(**args)

Epoch 0: 100%|██████████| 102/102 [00:22<00:00,  4.55it/s, loss=0.51078]
Epoch 1: 100%|██████████| 102/102 [00:08<00:00, 11.43it/s, loss=0.42015]
Epoch 2: 100%|██████████| 102/102 [00:08<00:00, 11.36it/s, loss=0.47383]
Epoch 3: 100%|██████████| 102/102 [00:09<00:00, 11.30it/s, loss=0.36899]
Epoch 4: 100%|██████████| 102/102 [00:08<00:00, 11.43it/s, loss=0.34633]
Epoch 5: 100%|██████████| 102/102 [00:08<00:00, 11.53it/s, loss=0.36658]
Epoch 6: 100%|██████████| 102/102 [00:08<00:00, 11.40it/s, loss=0.32866]
Epoch 7: 100%|██████████| 102/102 [00:08<00:00, 11.43it/s, loss=0.30865]
Epoch 8: 100%|██████████| 102/102 [00:08<00:00, 11.39it/s, loss=0.31624]
Epoch 9: 100%|██████████| 102/102 [00:08<00:00, 11.35it/s, loss=0.31032]
Epoch 10: 100%|██████████| 102/102 [00:08<00:00, 11.38it/s, loss=0.37706]
Epoch 11: 100%|██████████| 102/102 [00:08<00:00, 11.45it/s, loss=0.32258]
Epoch 12: 100%|██████████| 102/102 [00:08<00:00, 11.45it/s, loss=0.30739]
Epoch 13: 100%|██████████| 102/102 [00:08<00:00,