In [1]:
# from https://objax.readthedocs.io/en/latest/notebooks/Custom_Networks.html
import os
from tqdm import tqdm

import numpy as np
import tensorflow_datasets as tfds

import jax.numpy as jnp

import objax
from objax.util import EasyDict
from objax.zoo.dnnet import DNNet

In [2]:
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
data = tfds.as_numpy(
    tfds.load(name='mnist', batch_size=-1, data_dir=DATA_DIR))

def prepare(x):
    """Pads 2 pixels to the left, right, top, and bottom of each image, scales pixel value to [-1, 1], and converts to NCHW format."""
    s = x.shape
    x_pad = np.zeros((s[0], 32, 32, 1))
    x_pad[:, 2:-2, 2:-2, :] = x
    return objax.util.image.nchw(
        np.concatenate([x_pad.astype('f') * (1 / 127.5) - 1] * 3, axis=-1))

train = EasyDict(image=prepare(
    data['train']['image']), label=data['train']['label'])
test = EasyDict(image=prepare(
    data['test']['image']), label=data['test']['label'])
ndim = train.image.shape[-1]

del data

In [3]:
from objax import random

def normal_0_1(shape):
    return random.normal(shape, mean=0.0, stddev=1.0)

In [4]:
class Generator(objax.Module):
    def __init__(self):
        self.conv_block_1 = objax.nn.Sequential([
            objax.nn.ConvTranspose2D(100, 64*8, k=4, strides=1, padding=objax.constants.ConvPadding.VALID, use_bias=False, w_init=normal_0_1),
            objax.nn.BatchNorm2D(64*8)
        ])

        self.conv_block_2 = objax.nn.Sequential([
            objax.nn.ConvTranspose2D(64*8, 64*4, k=4, strides=2, padding=objax.constants.ConvPadding.SAME, use_bias=False, w_init=normal_0_1),
            objax.nn.BatchNorm2D(64*4)
        ])

        self.conv_block_3 = objax.nn.Sequential([
            objax.nn.ConvTranspose2D(64*4, 64*2, k=4, strides=2, padding=objax.constants.ConvPadding.SAME, use_bias=False, w_init=normal_0_1),
            objax.nn.BatchNorm2D(64*2)
        ])

        self.conv_block_4 = objax.nn.Sequential([
            objax.nn.ConvTranspose2D(64*2, 64, k=4, strides=2, padding=objax.constants.ConvPadding.SAME, use_bias=False, w_init=normal_0_1),
            objax.nn.BatchNorm2D(64)
        ])

        self.out_conv = objax.nn.ConvTranspose2D(64, 3, k=4, strides=2, padding=objax.constants.ConvPadding.SAME, w_init=normal_0_1, use_bias=False)

    def __call__(self, x, training):
        x = self.conv_block_1(x, training=training)
        x = objax.functional.relu(x)

        x = self.conv_block_2(x, training=training)
        x = objax.functional.relu(x)

        x = self.conv_block_3(x, training=training)
        x = objax.functional.relu(x)

        x = self.conv_block_4(x, training=training)
        x = objax.functional.relu(x)

        x = self.out_conv(x)
        x = objax.functional.tanh(x)

        return x

In [5]:
class Discriminator(objax.Module):
    def __init__(self):
        self.conv_block_1 = objax.nn.Conv2D(3, 64, k=4, strides=2, padding=objax.constants.ConvPadding.SAME, use_bias=False, w_init=normal_0_1)

        self.conv_block_2 = objax.nn.Sequential([
            objax.nn.Conv2D(64, 64*2, k=4, strides=2, padding=objax.constants.ConvPadding.SAME, use_bias=False, w_init=normal_0_1),
            objax.nn.BatchNorm2D(64*2)
        ])

        self.conv_block_3 = objax.nn.Sequential([
            objax.nn.Conv2D(64*2, 64*4, k=4, strides=2, padding=objax.constants.ConvPadding.SAME, use_bias=False, w_init=normal_0_1),
            objax.nn.BatchNorm2D(64*4)
        ])

        self.conv_block_4 = objax.nn.Sequential([
            objax.nn.Conv2D(64*4, 64*8, k=4, strides=2, padding=objax.constants.ConvPadding.SAME, use_bias=False, w_init=normal_0_1),
            objax.nn.BatchNorm2D(64*8)
        ])

        self.out_conv = objax.nn.Conv2D(64*8, 1, k=4, strides=1, padding=objax.constants.ConvPadding.VALID, w_init=normal_0_1, use_bias=False)

    def __call__(self, x, training):
        x = self.conv_block_1(x)
        x = objax.functional.leaky_relu(x, 0.02)

        x = self.conv_block_2(x, training=training)
        x = objax.functional.leaky_relu(x, 0.02)

        x = self.conv_block_3(x, training=training)
        x = objax.functional.leaky_relu(x, 0.02)

        x = self.conv_block_4(x, training=training)
        x = objax.functional.leaky_relu(x, 0.02)

        x = self.out_conv(x)
        x = objax.functional.sigmoid(x)

        x = jnp.reshape(x, [-1, 1])

        return x

In [6]:
generator = Generator()
discriminator = Discriminator()

In [18]:
z = random.normal([1, 100, 1, 1])
img = random.normal([1, 3, 64, 64])

In [7]:
lr = 0.03  # learning rate
batch = 4
epochs = 10

In [41]:
def train_model(generator, discriminator):

    g_opt = objax.optimizer.Momentum(generator.vars())
    d_opt = objax.optimizer.Momentum(discriminator.vars())

    def d_loss(x, z):
        d_loss_real = objax.functional.loss.cross_entropy_logits(discriminator(x, training=True), jnp.ones_like(x)).mean()

        fake_img = generator(z, training=False)
        d_loss_fake = objax.functional.loss.cross_entropy_logits(discriminator(fake_img, training=True), jnp.zeros_like(x)).mean()

        d_loss = d_loss_real + d_loss_fake

        return d_loss

    def g_loss(x, z):
        fake_img = generator(z, training=True)
        return objax.functional.loss.cross_entropy_logits_sparse(discriminator(fake_img, training=False), jnp.ones_like(x)).mean()

    d_gv = objax.GradValues(d_loss, discriminator.vars())
    g_gv = objax.GradValues(g_loss, generator.vars())

    def d_train_op(x, z):        
        g, v = d_gv(x, z)  # returns gradients, loss
        d_opt(lr, g)
        return v

    def g_train_op(x, z):        
        g, v = g_gv(x, z)  # returns gradients, loss
        g_opt(lr, g)
        return v

    d_train_op = objax.Jit(d_train_op, d_gv.vars() + d_opt.vars())
    g_train_op = objax.Jit(g_train_op, g_gv.vars() + g_opt.vars())

    for epoch in range(epochs):
        d_avg_loss = 0
        g_avg_loss = 0

        shuffle_idx = np.random.permutation(train.image.shape[0])
        for it in tqdm(range(0, train.image.shape[0], batch)):
            sel = shuffle_idx[it: it + batch]

            z = random.normal([batch, 100, 1, 1])
            img = train.image[sel]
            
            g_avg_loss += float(g_train_op(train.image[sel], z)[0]) * len(sel)
            d_avg_loss += float(d_train_op(train.image[sel], z)[0]) * len(sel)


        d_avg_loss /= it + len(sel)
        g_avg_loss /= it + len(sel)

        print('Epoch %04d d Loss %.2f g Loss %.2f' % (epoch + 1, d_avg_loss, g_avg_loss))

In [42]:
train_model(generator, discriminator)

  0%|          | 0/15000 [00:00&lt;?, ?it/s]


UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: tracer created on line /Users/bilal/miniconda3/envs/pytorch/lib/python3.7/site-packages/objax/variable.py:181 (split).

In [124]:
out = generator(x, training=True)

In [125]:
out = discriminator(img, training=True)

In [126]:
jnp.shape(out)

(1, 1)

In [56]:
out.shape

(1, 64, 13, 13)

In [48]:
print(generator.vars())

(Generator).conv_block_1(Sequential)[0](ConvTranspose2D).w          819200 (4, 4, 512, 100)
(Generator).conv_block_1(Sequential)[1](BatchNorm2D).running_mean      512 (1, 512, 1, 1)
(Generator).conv_block_1(Sequential)[1](BatchNorm2D).running_var       512 (1, 512, 1, 1)
(Generator).conv_block_1(Sequential)[1](BatchNorm2D).beta              512 (1, 512, 1, 1)
(Generator).conv_block_1(Sequential)[1](BatchNorm2D).gamma             512 (1, 512, 1, 1)
(Generator).conv_block_2(Sequential)[0](ConvTranspose2D).w         2097152 (4, 4, 256, 512)
(Generator).conv_block_2(Sequential)[1](BatchNorm2D).running_mean      256 (1, 256, 1, 1)
(Generator).conv_block_2(Sequential)[1](BatchNorm2D).running_var       256 (1, 256, 1, 1)
(Generator).conv_block_2(Sequential)[1](BatchNorm2D).beta              256 (1, 256, 1, 1)
(Generator).conv_block_2(Sequential)[1](BatchNorm2D).gamma             256 (1, 256, 1, 1)
(Generator).conv_block_3(Sequential)[0](ConvTranspose2D).w          524288 (4, 4, 128, 256)
(Gen