In [1]:
import tensorflow as tf

In [2]:
z_dim = 512
input_dim = (256,256,3)

In [15]:
class NouveauVAE(tf.keras.Model):

    def __init__(self, z_dim, input_dim):
        super(NouveauVAE, self).__init__()
        self.z_dim = z_dim
        self.input_dim = input_dim

        # encoder
        self.enc = tf.keras.Sequential()

        self.enc.add(tf.keras.layers.Input(shape=self.input_dim))

        encoder_stack = [
            self.encoder_cell([self.z_dim // 16, self.z_dim // 8]),
            self.encoder_cell([self.z_dim // 4, self.z_dim // 2]),
            self.encoder_cell([self.z_dim])
        ]
        encoder_res_stack = [
            self.encoder_residual_cell(self.z_dim // 8),
            self.encoder_residual_cell(self.z_dim // 2),
            self.encoder_residual_cell(self.z_dim)
        ]

        for e, r in zip(encoder_stack, encoder_res_stack):
            x = r(e)
            self.enc.add(x)

        self.condition_x = tf.keras.layers.Sequential([
            tf.keras.layers.GlobalAveragePooling2D(),
            tf.keras.layers.Activation('swish'),
            tf.keras.layers.Conv2D(self.z_dim * 2)
        ])

        self.enc.add(self.condition_x)
        
        # decoder
        self.dec = tf.keras.Sequential()
        self.dec.add(tf.keras.layers.Input(shape=self.z_dim * 2))
        decoder_stack = [
            self.decoder_cell([z_dim // 2]),
            self.decoder_cell([z // 4, z // 8]),
            self.decoder_cell([z // 16, z // 32])
        ]
        decoder_res_stack = [
            self.decoder_residual_cell(z_dim // 2, e=1),
            self.decoder_residual_cell(z_dim // 8, e=2),
            self.decoder_residual_cell(z_dim // 32, e=4),
        ]

        for d, r in zip(decoder_stack, decoder_res_stack):
            x = r(d)
            self.dec.add(x)

        self.x_hat = tf.keras.layers.Conv2D(3, kernel_size=1)

        self.dec.add(self.x_hat)
    

    def conv_block(self, c):
        cb = tf.keras.Sequential([
            tf.keras.layers.Conv2D(
                c // 2,
                kernel_size=3,
                padding='same'
            ),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Conv2D(
                c,
                kernel_size=3,
                padding='same'
            ),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('swish')
        ])
        return cb

    def encoder_cell(self, cs):
        ec = tf.keras.Sequential()
        for cs_ in cs:
            ec.add(self.conv_block(cs_))
        return ec

    def upsample_block(self, c):
        ub = tf.keras.Sequential([
            tf.keras.layers.Conv2DTranspose(
                c,
                kernel_size=3,
                stride=2
            ),
            tf.keras.layers.BatchNormalization()
        ])
        return ub

    def decoder_cell(self, cs):
        dc = tf.keras.Sequential()
        for cs_ in cs:
            dc.add(self.upsample_block(cs_))
        return dc

    def squeeze_excite(self, c, r=16):
        se = tf.keras.Sequential([
            tf.keras.layers.GlobalAveragePooling2D(),
            tf.keras.layers.Dense(c // r, bias=False),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.Dense(c, bias=False),
            tf.keras.layers.Activation('sigmoid')])
        return se

    def encoder_residual_cell(self, c):
        erc = tf.keras.Sequential([
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Conv2D(
                c,
                kernel_size=3,
                padding='same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('swish'),
            tf.keras.layers.Conv2D(
                c,
                kernel_size=3,
                padding='same')])
        return erc

    def decoder_residual_cell(self, c, e):
        drc = tf.keras.Sequential([
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Conv2D(
                c * e,
                kernel_size=1,
                padding='same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('swish'),
            tf.keras.layers.DepthwiseConv2D(
                c * e,
                kernel_size=5,
                stride=1),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('swish'),
            tf.keras.layers.Conv2D(
                c,
                kernel_size=1,
                padding='same',
                use_bias=False,
                activation=None),
            tf.keras.layers.BatchNormalization(),
            self.squeeze_excite(c)])

        return drc

    def sample(self, eps=None):
        if eps is None:
            eps = tf.random.normal(shape=(100, self.z_dim))
        return self.decode(eps, appy_sigmoid=True)

    def encode(self, x):
        mu, logvar = tf.split(self.encoder(x), num_or_size_of_splits=2, axis=1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        eps = tf.random.normal(shape=mu.shape)
        return eps * tf.exp(logvar * .5) + mean
    
    def decode(self, z, apply_sigmoid=False):
        logits = self.decode(z)
        if apply_sigmoid:
            probs = tf.sigmoid(logits)
            return probs
        return logits

In [16]:
optimizer = tf.keras.optimizers.Adam(1e-4)


def log_normal_pdf(sample, mean, logvar, raxis=1):
  log2pi = tf.math.log(2. * np.pi)
  return tf.reduce_sum(
      -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
      axis=raxis)


def compute_loss(model, x):
  mean, logvar = model.encode(x)
  z = model.reparameterize(mean, logvar)
  x_logit = model.decode(z)
  cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
  logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])
  logpz = log_normal_pdf(z, 0., 0.)
  logqz_x = log_normal_pdf(z, mean, logvar)
  return -tf.reduce_mean(logpx_z + logpz - logqz_x)


@tf.function
def train_step(model, x, optimizer):
  """Executes one training step and returns the loss.

  This function computes the loss and gradients, and uses the latter to
  update the model's parameters.
  """
  with tf.GradientTape() as tape:
    loss = compute_loss(model, x)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

In [17]:
epochs = 10
# set the dimensionality of the latent space to a plane for visualization later
num_examples_to_generate = 16

# keeping the random vector constant for generation (prediction) so
# it will be easier to see the improvement.
random_vector_for_generation = tf.random.normal(
    shape=[num_examples_to_generate, z_dim])
model = NouveauVAE(z_dim, input_dim)

TypeError: The added layer must be an instance of class Layer. Found: <bound method NouveauVAE.squeeze_excite of <__main__.NouveauVAE object at 0x7fc8ea882d60>>