## References
- [Anime-Face-Dataset](https://github.com/bchao1/Anime-Face-Dataset)
- [Anime-Face-GAN](https://github.com/yashyenugu/Anime-Face-GAN)
- [jax-dcgan](https://github.com/bilal2vec/jax-dcgan)

## Setup

In [None]:
!pip install flax

## Dataset

In [None]:
!gdown https://drive.google.com/uc?id=1HG7YnakUkjaxtNMclbl2t5sJwGLcHYsI # Download `data.tgz` file
!tar -xzf data.tgz # Extract
!mv cropped data # Rename extracted folder
!rm data.tgz

In [None]:
import imghdr
import os
from typing import Dict, List, Tuple

class Anime:
    def __init__(self, image_path: str):
        self.image_path = image_path

def _walk_children(path: str, full_path: bool, walk_dirs: bool) -> List[str]:
    return sorted([
        os.path.join(path, name) if full_path else name
        for name in next(os.walk(path))[1 if walk_dirs else 2]
    ])

def collect_data() -> Tuple[Tuple[List[Anime], ...], List[str]]:
    DATA_PATH = "data"
    data_list = [Anime(image_path) for image_path in _walk_children("data", True, False) if imghdr.what(image_path) is not None]
    return data_list

def make_slices(data_list: List[Anime]) -> Dict[str, List]:
    slices = {}
    slices["image"] = [anime.image_path for anime in data_list]
    return slices

In [None]:
import tensorflow as tf

class Reader:
    def __init__(self, image_size: Tuple[int, int]):
        self.image_size = image_size # (height, width)

    def read(self, sources: Dict[str, tf.Tensor]) -> tf.Tensor:
        return self._read_image(sources["image"])

    def _read_image(self, source: tf.Tensor) -> tf.Tensor:
        image = tf.io.read_file(source)
        image = tf.image.decode_image(image, channels=3, expand_animations=False)
        image = tf.image.resize(image, self.image_size)
        image /= 255
        return image

def generate_dataset(slices: Dict[str, List], reader: Reader, batch_size: int, shuffle: bool) -> tf.data.Dataset:
    dataset = tf.data.Dataset.from_tensor_slices(slices)
    if shuffle:
        dataset = dataset.shuffle(buffer_size=max(len(data) for data in slices.values()))
    dataset = dataset.map(
        reader.read,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    ).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    return dataset

## Model

In [None]:
from flax import linen as fl

class Discriminator(fl.Module):
    @fl.compact
    def __call__(self, x):
        x = fl.Conv(32, (3, 3), strides=2)(x)
        x = fl.relu(x)
        x = fl.Conv(64, (3, 3), strides=2)(x)
        x = fl.relu(x)
        x = fl.Conv(128, (3, 3), strides=2)(x)
        x = fl.relu(x)
        x = fl.Conv(256, (3, 3), strides=2)(x)
        x = fl.relu(x)
        x = fl.Conv(512, (3, 3), strides=2)(x)
        x = fl.relu(x)
        x = x.reshape((x.shape[0], -1))
        x = fl.Dense(1)(x)
        return x

class Generator(fl.Module):
    @fl.compact
    def __call__(self, x):
        x = fl.Dense(2 * 2 * 512)(x)
        x = fl.relu(x)
        x = x.reshape((x.shape[0], 2, 2, 512)) # (2, 2, 512)
        x = fl.ConvTranspose(256, (3, 3), strides=(2, 2))(x) # (4, 4, 256)
        x = fl.relu(x)
        x = fl.ConvTranspose(128, (3, 3), strides=(2, 2))(x) # (8, 8, 128)
        x = fl.relu(x)
        x = fl.ConvTranspose(64, (3, 3), strides=(2, 2))(x) # (16, 16, 64)
        x = fl.relu(x)
        x = fl.ConvTranspose(32, (3, 3), strides=(2, 2))(x) # (32, 32, 32)
        x = fl.relu(x)
        x = fl.ConvTranspose(3, (3, 3), strides=(2, 2))(x) # (64, 64, 3)
        x = fl.sigmoid(x)
        return x

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from flax.training import train_state as ft

LATENT_SIZE = 128

def _bce_logit(x: jnp.DeviceArray, y: np.ndarray) -> jnp.DeviceArray:
    # Binary Cross Entropy with Logits
    x = x.reshape(x.shape[0])
    return jnp.mean(x - y * x + jnp.log(1 + jnp.exp(-x)))

# @jax.jit
def _disc_train_batch(disc_state: ft.TrainState, gen_state: ft.TrainState, real: np.ndarray, noise: np.ndarray) -> Tuple[ft.TrainState, jnp.DeviceArray]:
    def _disc_loss(params):
        fake = gen_state.apply_fn({"params": gen_state.params}, noise)
        pred = disc_state.apply_fn({"params": params}, np.concatenate([real, fake]))
        loss = _bce_logit(pred, np.concatenate([np.ones(real.shape[0]), np.zeros(fake.shape[0])]))
        return loss
    loss, grads = jax.value_and_grad(_disc_loss)(disc_state.params)
    disc_state = disc_state.apply_gradients(grads=grads)
    return disc_state, loss

# @jax.jit
def _gen_train_batch(disc_state: ft.TrainState, gen_state: ft.TrainState, noise: np.ndarray) -> Tuple[ft.TrainState, jnp.DeviceArray]:
    def _gen_loss(params):
        fake = gen_state.apply_fn({"params": params}, noise)
        pred = disc_state.apply_fn({"params": disc_state.params}, fake)
        loss = _bce_logit(pred, np.ones(pred.shape[0]))
        return loss
    loss, grads = jax.value_and_grad(_gen_loss)(gen_state.params)
    gen_state = gen_state.apply_gradients(grads=grads)
    return gen_state, loss

def train_epoch(
        disc_state: ft.TrainState, gen_state: ft.TrainState, dataset
    ) -> Tuple[ft.TrainState, ft.TrainState, np.float32, np.float32]:
    disc_losses = []
    gen_losses = []
    batchs_num = 0
    for real in dataset:
        noise = np.random.normal(size=(real.shape[0], LATENT_SIZE))
        disc_state, disc_loss = _disc_train_batch(disc_state, gen_state, real, noise)
        disc_losses.append(disc_loss)
        batchs_num += 1
    for _ in range(batchs_num):
        noise = np.random.normal(size=(real.shape[0], LATENT_SIZE))
        gen_state, gen_loss = _gen_train_batch(disc_state, gen_state, noise)
        gen_losses.append(gen_loss)
    disc_epoch_loss = np.mean(jax.device_get(disc_losses))
    gen_epoch_loss = np.mean(jax.device_get(gen_losses))
    return disc_state, gen_state, disc_epoch_loss, gen_epoch_loss

`sigmoid`:
$$
  p(x)=\frac{1}{1+e^{-x}}
$$
`bce`:
$$
  l(p, y)=-(y\log(p)+(1-y)\log(1-p))
$$
`bce_logit`:
$$
  l(x, y)=-(y\log(\frac{1}{1+e^{-x}})+(1-y)\log(1-\frac{1}{1+e^{-x}}))=x-yx+\log(1+e^{-x})
$$

## Training

In [None]:
import optax
import tensorflow_datasets as tfds
from matplotlib import pyplot as plt

IMAGE_SIZE = (64, 64)
BATCH_SIZE = 8
EPOCHS = 10

# Dataset
data_list = collect_data()
data_slices = make_slices(data_list)
dataset = tfds.as_numpy(generate_dataset(data_slices, Reader(IMAGE_SIZE), BATCH_SIZE, True))

# Model and state
rng = jax.random.PRNGKey(0)
rng, disc_rng, gen_rng = jax.random.split(rng, 3)
disc_model = Discriminator()
gen_model = Generator()
disc_state = ft.TrainState.create(
    apply_fn=disc_model.apply,
    params=disc_model.init(disc_rng, jnp.zeros([1, *IMAGE_SIZE, 3]))["params"],
    tx=optax.adam(0.001)
)
gen_state = ft.TrainState.create(
    apply_fn=gen_model.apply,
    params=gen_model.init(gen_rng, jnp.zeros([1, LATENT_SIZE]))["params"],
    tx=optax.adam(0.001)
)

In [None]:
def _generate():
    fake = gen_state.apply_fn({"params": gen_state.params}, np.random.normal(size=(12, LATENT_SIZE)))
    plt.figure(figsize=(24, 2))
    plt.imshow(np.concatenate(fake, axis=1))
    plt.show()

# Training
_generate()
for epoch in range(EPOCHS):
    disc_state, gen_state, disc_loss, gen_loss = train_epoch(disc_state, gen_state, dataset)
    print(f"Epoch {epoch}, disc_loss: {disc_loss:.4f}, gen_loss: {gen_loss:.4f}")
    _generate()