<a href="https://colab.research.google.com/github/dsuess/stylegan2.jax/blob/master/notebooks/Train%20Example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!rm -r stylegan2.jax
!git clone https://github.com/dsuess/stylegan2.jax
!pip install stylegan2.jax/

In [None]:
# Make sure the Colab Runtime is set to Accelerator: TPU.
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1

# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

In [None]:
import tensorflow_datasets as tfds
import jax 
import functools as ft
import optax
import haiku as hk

from collections import namedtuple
from jax import numpy as jnp
from stylegan2.train import setup_models, initialize_params, GAN
from tqdm import tqdm
from stylegan2 import networks as nx
from haiku._src.data_structures import frozendict

In [None]:
jax.device_count()

In [None]:
def Components(g, d, s):
    return frozendict(g=g, d=d, s=s)

class StyleGan2:
    def __init__(self):
        self.batch_size = 4
        self.models = Components(
            hk.transform(
                lambda latents: nx.SkipGenerator(32, max_hidden_feature_size=128)(latents)
            ),
            hk.without_apply_rng(
                hk.transform(
                    lambda images: nx.ResidualDiscriminator(
                        32, max_hidden_feature_size=128
                    )(images)
                )
            ),
            hk.without_apply_rng(
                hk.transform(
                    lambda latents: nx.style_embedding_network(
                        final_embedding_size=128, intermediate_latent_size=128
                    )(latents)
                )
            ),
        )

        self.optim = optax.sgd(0.01, momentum=0.9)
        
    def initialize_params(self, rng):
        rngs = Components(*jax.random.split(rng, num=3))

        latents = jnp.zeros((self.batch_size, 32), dtype=jnp.float32)
        params_s = self.models.s.init(rngs.s, latents)

        styles = self.models.s.apply(params_s, latents)
        styles = jnp.tile(styles[:, None, :], (1, 8, 1))
        params_g = self.models.g.init(rngs.g, styles)

        images = self.models.g.apply(params_g, rngs.g, styles)
        params_d = self.models.d.init(rngs.d, images)

        model_state = Components(params_g, params_d, params_s)
        optim_state = Components(
            **{name: self.optim.init(params) for name, params in model_state.items()}
        )

        return model_state, optim_state
    
    def discriminator_loss(self, model_state, images, rng):
        latents = jax.random.normal(rng, (images.shape[0], 32))
        styles = self.models.s.apply(model_state.s, latents)
        styles = jnp.tile(styles[:, None], (1, 8, 1))

        fake_images = self.models.g.apply(model_state.g, rng, styles)
        logits = self.models.d.apply(model_state.d, fake_images)
        fake_loss = jnp.maximum(logits, 0) + jnp.log(1 + jnp.exp(-jnp.abs(logits)))
        fake_loss = jnp.mean(fake_loss)

        logits = self.models.d.apply(model_state.d, images)
        real_loss = jnp.maximum(logits, 0) - logits + jnp.log(1 + jnp.exp(-jnp.abs(logits)))
        real_loss = jnp.mean(real_loss)

        return (fake_loss + real_loss) / 2
    
    @ft.partial(jax.jit, static_argnums=[0])
    def discriminator_step(self, model_state, optim_state, images, rng):
        val, grads = jax.value_and_grad(self.discriminator_loss)(model_state, images, rng)
        update_d, opt_state_d = self.optim.update(grads.d, optim_state.d)
        state_d = optax.apply_updates(model_state.d, update_d)

        model = Components(model_state.g, state_d, model_state.s)
        optim = Components(optim_state.g, opt_state_d, optim_state.s)
        return val, model, optim
    
    
    def generator_loss(self, model_state, images, rng):
        latents = jax.random.normal(rng, (images.shape[0], 32))
        styles = self.models.s.apply(model_state.s, latents)
        styles = jnp.tile(styles[:, None], (1, 8, 1))

        fake_images = self.models.g.apply(model_state.g, rng, styles)
        logits = self.models.d.apply(model_state.d, fake_images)

        # Numerical stable implementation of sparse binary cross entropy
        loss = jnp.maximum(logits, 0) - logits + jnp.log(1 + jnp.exp(-jnp.abs(logits)))
        return jnp.mean(loss)
    
    @ft.partial(jax.jit, static_argnums=[0])
    def generator_step(self, model_state, optim_state, images, rng):
        val, grads = jax.value_and_grad(self.generator_loss)(model_state, images, rng)
        update_g, opt_state_g = self.optim.update(grads.g, optim_state.g)
        state_g = optax.apply_updates(model_state.g, update_g)

        update_s, opt_state_s = self.optim.update(grads.s, optim_state.s)
        state_s = optax.apply_updates(model_state.s, update_s)

        model = Components(state_g, model_state.d, state_s)
        optim = Components(opt_state_g, optim_state.d, opt_state_s)
        return val, model, optim
    
    
model = StyleGan2()
key = jax.random.PRNGKey(0)
state, optim_state = model.initialize_params(key)

images = jnp.zeros((model.batch_size, 32, 32, 3), dtype=jnp.float32)
model.generator_step(state, optim_state, images, key)
model.discriminator_step(state, optim_state, images, key)

print("DONE")

In [None]:
%%timeit 
model.discriminator_step(state, optim_state, images, key)

In [None]:
fun = jax.jit(model.discriminator_loss, static_argnums=[0])

In [None]:
%%timeit 
fun(state, images, key)

In [None]:
data = tfds.load("cifar10", split="train")
batch_size = 64
data = (
    data.map(lambda x: x["image"] / 255)
    .repeat()
    .take(2 ** 14)
    .shuffle(1024)
    .batch(batch_size)
)

rngkey, rnginit = jax.random.split(jax.random.PRNGKey(42))
trainer = setup_models()
num_devices = jax.device_count()
state = initialize_params(rnginit, trainer, 1)
state = jax.tree_util.tree_map(lambda x: jnp.stack([x] * num_devices), state)

for epoch in range(10):
    for images in tqdm(data.as_numpy_iterator(), total=2 ** 14 // batch_size):
        pimages = images.reshape((jax.device_count(), -1, *images.shape[1:]))
        #rngkey, rngdisc, rnggen = jax.random.split(rngkey, num=3)
        rng = jax.random.split(rngkey, num=num_devices)
        loss = jax.pmap(discriminator_loss)(trainer.model, rngkey, state.model, images)
      
        #step_fn = ft.partial(generator_step, trainer, state, rnggen)
        #disc_loss, state = jax.pmap(step_fn)(images)

        #print(f"gen_loss={gen_loss}, disc_loss={disc_loss}")
        print(loss)
