<a href="https://colab.research.google.com/github/dsuess/stylegan2.jax/blob/master/notebooks/Train%20Example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!rm -r stylegan2.jax
!git clone https://github.com/dsuess/stylegan2.jax
!pip install stylegan2.jax/

In [None]:
# Make sure the Colab Runtime is set to Accelerator: TPU.
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1

# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

In [None]:
import tensorflow_datasets as tfds
import jax 
import functools as ft

from jax import numpy as jnp
from stylegan2.train import setup_models, initialize_params, GAN
from tqdm import tqdm

In [None]:
jax.device_count()

In [None]:
@ft.partial(jax.jit, static_argnums=[0])
def discriminator_loss(model, rng, model_state, images):
    latents = jax.random.normal(rng, (images.shape[0], 32))
    styles = model.s.apply(model_state.s, latents)
    styles = jnp.tile(styles[:, None], (1, 8, 1))

    fake_images = model.g.apply(model_state.g, rng, styles)
    logits = model.d.apply(model_state.d, fake_images)
    fake_loss = jnp.maximum(logits, 0) + jnp.log(1 + jnp.exp(-jnp.abs(logits)))
    fake_loss = jnp.mean(fake_loss)

    logits = model.d.apply(model_state.d, images)
    real_loss = jnp.maximum(logits, 0) - logits + jnp.log(1 + jnp.exp(-jnp.abs(logits)))
    real_loss = jnp.mean(real_loss)

    return fake_loss + real_loss

In [None]:
data = tfds.load("cifar10", split="train")
batch_size = 64
data = (
    data.map(lambda x: x["image"] / 255)
    .repeat()
    .take(2 ** 14)
    .shuffle(1024)
    .batch(batch_size)
)

rngkey, rnginit = jax.random.split(jax.random.PRNGKey(42))
trainer = setup_models()
num_devices = jax.device_count()
state = initialize_params(rnginit, trainer, 1)
state = jax.tree_util.tree_map(lambda x: jnp.stack([x] * num_devices), state)

for epoch in range(10):
    for images in tqdm(data.as_numpy_iterator(), total=2 ** 14 // batch_size):
        pimages = images.reshape((jax.device_count(), -1, *images.shape[1:]))
        #rngkey, rngdisc, rnggen = jax.random.split(rngkey, num=3)
        rng = jax.random.split(rngkey, num=num_devices)
        loss = jax.pmap(discriminator_loss)(trainer.model, rngkey, state.model, images)
      
        #step_fn = ft.partial(generator_step, trainer, state, rnggen)
        #disc_loss, state = jax.pmap(step_fn)(images)

        #print(f"gen_loss={gen_loss}, disc_loss={disc_loss}")
        print(loss)


In [None]:
images = next(data.as_numpy_iterator())
fn = ft.partial(trainer.model.d.apply, state.model.d)
pfn = jax.pmap(fn)
pimages = images.reshape((jax.device_count(), -1, *images.shape[1:]))

In [None]:
%%timeit
pfn(pimages)

In [None]:
%%timeit
fn(images)

In [None]:
state.model