<a href="https://colab.research.google.com/github/myagues/potpourri/blob/master/jax/flax_linen_TPU_dcgan_celeba.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# DCGAN tutorial with CelebA dataset

Implementation of Deep Convolutional Generative Adversarial Networks (DCGAN), based on [Pytorch's tutorial](https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html), using Flax with [Linen API](https://github.com/google/flax/tree/master/flax/linen).

A. Radford, L. Metz and S. Chintala. "Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks". (2015). [arXiv:1511.06434v2 [cs.LG]](https://arxiv.org/abs/1511.06434v2).

In [None]:
!pip install --upgrade -q tensorflow-datasets ml_collections git+https://github.com/google/flax

In [None]:
# Make sure the Colab Runtime is set to Accelerator: TPU.
import requests
import os
if "TPU_DRIVER_MODE" not in globals():
    url = "http://" + os.environ["COLAB_TPU_ADDR"].split(":")[0] + ":8475/requestversion/tpu_driver0.1-dev20191206"
    resp = requests.post(url)
    TPU_DRIVER_MODE = 1

# The following is required to use TPU Driver as JAX's backend.
from jax.config import config as jax_config
jax_config.FLAGS.jax_xla_backend = "tpu_driver"
jax_config.FLAGS.jax_backend_target = "grpc://" + os.environ["COLAB_TPU_ADDR"]
print(jax_config.FLAGS.jax_backend_target)

In [None]:
import functools
import glob
import imageio
import io
import jax
import time

import matplotlib.pyplot as plt
import ml_collections
import tensorflow as tf
import tensorflow_datasets as tfds
import ipywidgets as widgets

from absl import logging
from flax import jax_utils, linen as nn, optim, struct
from flax.training import checkpoints, common_utils
from jax import numpy as jnp, lax, random
from typing import Any

jax_config.enable_omnistaging()

%config InlineBackend.figure_format = "svg"
logging.set_verbosity(logging.INFO)

In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
rng = random.PRNGKey(0)

config = ml_collections.ConfigDict()
config.num_epochs = 10
# effective batch of 128 * 8 does not result in good images unless you cycle for a big 
# number of epochs, and a batch size of 16 images per core is a bit of a waste :(
config.learning_rate = 2e-4 * jnp.sqrt(1)
config.batch_size = 128 * 1
config.z_vec = 100
config.ngf = 64
config.ndf = 64
config.half_precision = True
config.ckpt_dir = "/content/drive/My Drive/Colab Notebooks/ckpt"
config.data_dir = "/content/drive/My Drive/Colab Notebooks/tfds"

n_devices = jax.local_device_count()
if config.batch_size % n_devices > 0:
    raise ValueError("Batch size must be divisible by the number of devices.")
local_batch_size = config.batch_size // jax.host_count()

if config.half_precision:
    data_dtype = tf.bfloat16
    model_dtype = jnp.bfloat16
else:
    data_dtype = tf.float32
    model_dtype = jnp.float32

input_shape_g = (local_batch_size, 1, 1, config.z_vec)
input_shape_d = (local_batch_size, 64, 64, 3)

In [None]:
def adjust_dynamic_range(data, drange_in=[0, 255], drange_out=[-1, 1]):
    if drange_in != drange_out:
        scale = (drange_out[1] - drange_out[0]) / (drange_in[1] - drange_in[0])
        bias = drange_out[0] - drange_in[0] * scale
        data = data * scale + bias
    return data


def img_transforms(image, dtype, drange_out=[-1, 1]):
    image = tf.image.resize_with_pad(image, image.shape[0], 64)
    image = tf.image.resize_with_crop_or_pad(image, 64, 64)
    image = tf.cast(image, dtype=dtype)
    image = adjust_dynamic_range(image, drange_out=drange_out)
    return image


def build_ds(data_dir, batch_size, dtype=tf.float32, shuffle=True):
    (ds_train, ds_validation, ds_test), ds_info = tfds.load(
        "celeb_a",
        split=["train", "validation", "test"],
        shuffle_files=True,
        data_dir=data_dir,
        with_info=True,
    )
    ds_celeba = ds_train.concatenate(ds_validation).concatenate(ds_test)
    ds_celeba = ds_celeba.map(
        lambda x: img_transforms(x["image"], dtype),
        num_parallel_calls=tf.data.experimental.AUTOTUNE,
    )
    ds_celeba = ds_celeba.cache()
    if shuffle:
        ds_celeba = ds_celeba.shuffle(ds_info.splits.total_num_examples)
    ds_celeba = ds_celeba.batch(batch_size, drop_remainder=True)
    ds_celeba = ds_celeba.prefetch(tf.data.experimental.AUTOTUNE)
    return ds_celeba

In [None]:
dataset = build_ds(config.data_dir, config.batch_size, shuffle=False)
batch = next(iter(dataset))

gridspec_kw = {"wspace": 0.0, "hspace": 0.0}
fig, axes = plt.subplots(nrows=8, ncols=8, figsize=(8, 8), gridspec_kw=gridspec_kw)
for idx, ax in enumerate(axes.flat):
    batch_ = adjust_dynamic_range(batch[idx, ...]._numpy(), [-1, 1], [0, 255])
    ax.imshow(batch_.astype(jnp.uint8))
    ax.axis("off")
fig.tight_layout()
del batch, batch_

In [None]:
@struct.dataclass
class DCGANConfig:
    dtype: Any = jnp.float32
    train: bool = False
    ngf: int = 64
    ndf: int = 64
    z_vec: int = 100
    n_channels: int = 3


class Generator(nn.Module):
    """Generator module."""

    config: DCGANConfig

    @nn.compact
    def __call__(self, inputs):
        cfg = self.config
        conv_transpose = functools.partial(
            nn.ConvTranspose,
            kernel_size=(4, 4),
            strides=(2, 2),
            use_bias=False,
            dtype=cfg.dtype,
            kernel_init=nn.initializers.normal(stddev=2e-2),
        )

        x = conv_transpose(8 * cfg.ngf, strides=(1, 1), padding="VALID")(inputs)
        x = nn.BatchNorm(use_running_average=not cfg.train, dtype=cfg.dtype)(x)
        x = nn.relu(x)

        for i in [4, 2, 1]:
            x = conv_transpose(i * cfg.ngf)(x)
            x = nn.BatchNorm(use_running_average=not cfg.train, dtype=cfg.dtype)(x)
            x = nn.relu(x)

        x = conv_transpose(cfg.n_channels)(x)
        x = nn.tanh(x)
        return x


class Discriminator(nn.Module):
    """Discriminator module."""

    config: DCGANConfig

    @nn.compact
    def __call__(self, inputs):
        cfg = self.config
        conv = functools.partial(
            nn.Conv,
            kernel_size=(4, 4),
            strides=(2, 2),
            use_bias=False,
            dtype=cfg.dtype,
            kernel_init=nn.initializers.normal(stddev=2e-2),
        )

        x = conv(1 * cfg.ndf)(inputs)
        x = nn.leaky_relu(x, 0.2)

        for i in [2, 4, 8]:
            x = conv(i * cfg.ndf)(x)
            x = nn.BatchNorm(use_running_average=not cfg.train, dtype=cfg.dtype)(x)
            x = nn.leaky_relu(x, 0.2)

        x = conv(1, strides=(1, 1), padding="VALID")(x)
        # x = nn.sigmoid(x)
        return x

In [None]:
@jax.vmap
def bce_with_logits(labels, logits):
    # logits = nn.log_sigmoid(logits)
    # return -jnp.sum(labels * logits + (1.0 - labels) * jnp.log(-jnp.expm1(logits)))
    return (
        jnp.maximum(logits, 0)
        - logits * labels
        + jnp.log(1 + jnp.exp(-jnp.abs(logits)))
    )


def train_step(apply_fn_g, apply_fn_d, state_g, state_d, input_images, rng=None):
    rng, new_rng = random.split(rng)
    noise = random.normal(rng, (input_images.shape[0], 1, 1, config.z_vec))
    noise = noise.astype(input_images.dtype)

    apply_fn_g_ = functools.partial(apply_fn_g, mutable=["batch_stats"])
    apply_fn_d_ = functools.partial(apply_fn_d, mutable=["batch_stats"])
    optimizer_g = state_g.optimizer
    optimizer_d = state_d.optimizer

    def loss_fn_g(params_g):
        vars_g = {"params": params_g, **state_g.model_state}
        generated_images, new_vars_g = apply_fn_g_(vars_g, noise)

        def loss_fn_d(params_d):
            vars_d = {"params": params_d, **state_d.model_state}
            real_output, vars_d1 = apply_fn_d_(vars_d, input_images)
            fake_output, vars_d2 = apply_fn_d_(vars_d, generated_images)
            new_vars_d = jax.tree_multimap(lambda x, y: (x + y) / 2, vars_d1, vars_d2)

            real_loss = bce_with_logits(jnp.ones_like(real_output), real_output).mean()
            fake_loss = bce_with_logits(jnp.zeros_like(fake_output), fake_output).mean()
            loss_d = real_loss + fake_loss
            return loss_d, (fake_output, new_vars_d)

        aux, grad_d = jax.value_and_grad(loss_fn_d, has_aux=True)(optimizer_d.target)
        grad_d = lax.pmean(grad_d, axis_name="batch")
        loss_d, (fake_output, new_vars_d) = aux

        loss_g = bce_with_logits(jnp.ones_like(fake_output), fake_output).mean()
        return loss_g, (loss_d, grad_d, new_vars_d, new_vars_g)

    aux, grad_g = jax.value_and_grad(loss_fn_g, has_aux=True)(optimizer_g.target)
    grad_g = lax.pmean(grad_g, axis_name="batch")
    loss_g, (loss_d, grad_d, new_vars_d, new_vars_g) = aux

    new_state_g = state_g.replace(
        optimizer=optimizer_g.apply_gradient(grad_g), model_state=new_vars_g
    )
    new_state_d = state_d.replace(
        optimizer=optimizer_d.apply_gradient(grad_d), model_state=new_vars_d
    )
    metrics = {"loss_g": loss_g, "loss_d": loss_d}
    metrics = lax.pmean(metrics, axis_name="batch")
    return (new_rng, new_state_g, new_state_d, metrics)

In [None]:
@struct.dataclass
class TrainState:
    optimizer: optim.Optimizer
    model_state: Any


def sync_batch_stats(state):
    """Sync the batch statistics across replicas."""
    avg = jax.pmap(lambda x: lax.pmean(x, "x"), "x")

    new_model_state = state.model_state.copy(
        {"batch_stats": avg(state.model_state["batch_stats"])}
    )
    return state.replace(model_state=new_model_state)


def create_train_state(rng, optimizer, input_shape, model, dtype):
    params, model_state = initialized(rng, input_shape, model, dtype)
    state = TrainState(optimizer=optimizer.create(params), model_state=model_state)
    return state


@functools.partial(jax.jit, static_argnums=(1, 2, 3))
def initialized(key, input_shape, model, dtype):
    variables = model.init({"params": key}, jnp.ones(input_shape, dtype))
    model_state, params = variables.pop("params")
    return params, model_state

In [None]:
train_config = DCGANConfig(
    dtype=model_dtype,
    ngf=config.ngf,
    ndf=config.ndf,
    z_vec=config.z_vec,
    train=True,
)
eval_config = train_config.replace(train=False)

init_ep = 0
rng, rng_g, rng_d, seed_rng = random.split(rng, 4)
optimizer_def = optim.Adam(learning_rate=config.learning_rate, beta1=0.5)

state_g = create_train_state(
    rng_g, optimizer_def, input_shape_g, Generator(eval_config), model_dtype
)
state_d = create_train_state(
    rng_d, optimizer_def, input_shape_d, Discriminator(eval_config), model_dtype
)

state_g, state_d, init_ep = checkpoints.restore_checkpoint(
    config.ckpt_dir, [state_g, state_d, init_ep]
)
state_g = jax_utils.replicate(state_g)
state_d = jax_utils.replicate(state_d)

p_train_step = jax.pmap(
    functools.partial(
        train_step, Generator(train_config).apply, Discriminator(train_config).apply
    ),
    axis_name="batch",
)
rngs = random.split(rng, n_devices)

seed = random.normal(seed_rng, (64, 1, 1, config.z_vec)).astype(eval_config.dtype)

In [None]:
# processing and caching all dataset into memory, for better shuffling, takes ~3m
dataset = build_ds(config.data_dir, config.batch_size, data_dtype, shuffle=True)
fig_list = []
loss_g = []
loss_d = []

for ep in range(init_ep, config.num_epochs):
    epoch_metrics = []
    t_epoch = time.perf_counter()
    for idx, batch in enumerate(dataset):
        batch_ = common_utils.shard(batch._numpy())
        (rngs, state_g, state_d, metrics) = p_train_step(
            state_g, state_d, batch_, rng=rngs
        )
        epoch_metrics.append(metrics)
        if (idx + 1) % 50 == 0:
            epoch_metrics_ = common_utils.get_metrics(epoch_metrics)
            summary = jax.tree_map(lambda x: x.mean(), epoch_metrics_)
            logging.info(
               "Epoch: %2d\tStep: %4d\tLoss_D: %2.4f\tLoss_G: %2.4f",
               ep + 1, idx + 1, float(summary["loss_d"]), float(summary["loss_g"])
            )
    logging.info("Epoch time: %.4fs", time.perf_counter() - t_epoch)
    loss_g.extend(epoch_metrics_["loss_g"])
    loss_d.extend(epoch_metrics_["loss_d"])

    state_g = sync_batch_stats(state_g)
    state_d = sync_batch_stats(state_d)
    if jax.host_id() == 0:
        checkpoints.save_checkpoint(
            config.ckpt_dir,
            [jax_utils.unreplicate(state_g), jax_utils.unreplicate(state_d), ep + 1],
            ep + 1,
        )
        state_g_ = jax_utils.unreplicate(state_g)
        generated_images = Generator(eval_config).apply(
            {"params": state_g_.optimizer.target, **state_g_.model_state},
            seed,
            mutable=False,
        )
        generated_images = adjust_dynamic_range(generated_images, [-1, 1], [0, 255])
        
        fig, axes = plt.subplots(nrows=8, ncols=8, figsize=(8, 8), gridspec_kw=gridspec_kw)
        for img, ax in zip(generated_images.astype(jnp.uint8), axes.flat):
            ax.imshow(img)
            ax.axis("off")
        fig.tight_layout()
        fig.savefig(f"image_at_epoch_{ep + 1:02d}.png")
        fig_list.append(fig)
        plt.close()

In [None]:
widgets.interactive(
    lambda epoch: display(fig_list[epoch - 1]),
    epoch=widgets.IntSlider(min=1, max=len(fig_list)),
)

In [None]:
plt.plot(loss_g, label="G")
plt.plot(loss_d, label="D")
plt.title("Train losses")
plt.legend()
plt.show()

In [None]:
filenames = sorted(glob.glob("image_at_epoch_*.png"))

with imageio.get_writer("dcgan_celeba.gif", mode="I", duration=0.5) as writer:
    for filename in filenames:
        image = imageio.imread(filename)
        writer.append_data(image)

In [None]:
widgets.Image(value=open("dcgan_celeba.gif", "rb").read(), format="gif")