In [1]:
import tensorflow as tf

In [2]:
from VisionEngine.utils.config import process_config

In [3]:
class BaseModel(object):
    def __init__(self, config):
        self.config = config
        self.model = None

    # save the checkpoint to the path defined in the config file
    def save(self, checkpoint_path):
        if self.model is None:
            raise Exception("You need to build the model first.")

        print("Saving model...")
        self.model.save_weights(checkpoint_path)
        print("Model saved")

    # load the experiment from the path defined in the config file
    def load(self, checkpoint_path):
        raise NotImplementedError

    def build_model(self):
        raise NotImplementedError

        
class SqueezeExcite(tf.keras.layers.Layer):
    def __init__(self, c, r=16, **kwargs):
        super(SqueezeExcite, self).__init__(**kwargs)
        self.c = c
        self.r = r

    def build(self, input_shape):
        self.se = tf.keras.Sequential([
            tf.keras.layers.GlobalAveragePooling2D(),
            tf.keras.layers.Dense(self.c // self.r, use_bias=False),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.Dense(self.c, use_bias=False),
            tf.keras.layers.Activation('sigmoid')])

    def call(self, layer_inputs, **kwargs):
        return self.se(layer_inputs)

    def compute_output_shape(self, input_shape):
        return input_shape

    def get_config(self):
        config = {
            'c': self.c,
            'r': self.r
        }
        base_config = \
            super(SqueezeExcite, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


class SpectralNormalization(tf.keras.layers.Wrapper):
    def __init__(self, layer, iteration=1, eps=1e-12, training=True, **kwargs):
        self.iteration = iteration
        self.eps = eps
        self.do_power_iteration = training
        if not isinstance(layer, tf.keras.layers.Layer):
            raise ValueError(
                'Please initialize `TimeDistributed` layer with a '
                '`Layer` instance. You passed: {input}'.format(input=layer))
        super(SpectralNormalization, self).__init__(layer, **kwargs)

    def build(self, input_shape):
        self.layer.build(input_shape)

        self.w = self.layer.kernel
        self.w_shape = self.w.shape.as_list()

        self.v = self.add_weight(shape=(1, self.w_shape[0] * self.w_shape[1] * self.w_shape[2]),
                                 initializer=tf.initializers.TruncatedNormal(stddev=0.02),
                                 trainable=False,
                                 name='sn_v',
                                 dtype=tf.float32)

        self.u = self.add_weight(shape=(1, self.w_shape[-1]),
                                 initializer=tf.initializers.TruncatedNormal(stddev=0.02),
                                 trainable=False,
                                 name='sn_u',
                                 dtype=tf.float32)

        super(SpectralNormalization, self).build()

    def call(self, inputs):
        self.update_weights()
        output = self.layer(inputs)
        self.restore_weights()  # Restore weights because of this formula "W = W - alpha * W_SN`"
        return output
    
    def update_weights(self):
        w_reshaped = tf.reshape(self.w, [-1, self.w_shape[-1]])
        
        u_hat = self.u
        v_hat = self.v  # init v vector

        if self.do_power_iteration:
            for _ in range(self.iteration):
                v_ = tf.matmul(u_hat, tf.transpose(w_reshaped))
                v_hat = v_ / (tf.reduce_sum(v_**2)**0.5 + self.eps)

                u_ = tf.matmul(v_hat, w_reshaped)
                u_hat = u_ / (tf.reduce_sum(u_**2)**0.5 + self.eps)

        sigma = tf.matmul(tf.matmul(v_hat, w_reshaped), tf.transpose(u_hat))
        self.u.assign(u_hat)
        self.v.assign(v_hat)

        self.layer.kernel.assign(self.w / sigma)

    def restore_weights(self):
        self.layer.kernel.assign(self.w)


class PerceptualLossLayer(tf.keras.layers.Layer):
    def __init__(self, perceptual_loss_model,
                 pereceptual_loss_layers, perceptual_loss_layer_weights,
                 model_input_shape, name, **kwargs):
        super(PerceptualLossLayer, self).__init__(**kwargs)
        self.loss_model_type = perceptual_loss_model
        self.layers = pereceptual_loss_layers
        self.layer_weights = perceptual_loss_layer_weights
        self.model_input_shape = [256, 256, 3]

    def build(self, input_shape):
        if self.loss_model_type == 'vgg':
            self.loss_model_ = tf.keras.applications.VGG16(
                weights='imagenet',
                include_top=False,
                input_shape=self.model_input_shape
                )
            self.loss_model_.trainable = False

            for layer in self.loss_model_.layers:
                layer.trainable = False

            self.loss_layers = [
                tf.keras.layers.BatchNormalization()(self.loss_model_.layers[i].output)
                for i in self.layers
                ]

            self.loss_model = tf.keras.Model(
                self.loss_model_.inputs,
                self.loss_layers,
                )
        else:
            raise NotImplementedError
        super(PerceptualLossLayer, self).build(input_shape)

    def call(self, layer_inputs, **kwargs):
        y_true = layer_inputs[0]
        y_pred = layer_inputs[1]

        self.sample_ = self.loss_model(y_true)
        self.reconstruction_ = self.loss_model(y_pred)

        self.perceptual_loss = 0.
        for i in range(len(self.reconstruction_)):
            shape = tf.cast(tf.shape(self.reconstruction_[i]), dtype='float32')
            self.perceptual_loss += tf.math.reduce_mean(
                self.layer_weights[i] *
                (tf.math.reduce_sum(tf.math.square(self.sample_[i] - self.reconstruction_[i])) /
                    (shape[-1] * shape[1] * shape[1]))
            )

        perceptual_loss = tf.cast(self.perceptual_loss, dtype='float32')
        self.add_loss(perceptual_loss)
        self.add_metric(perceptual_loss, 'mean', 'perceptual_loss')

        return [y_true, y_pred]

    def compute_output_shape(self, input_shape):
        return input_shape

    def get_config(self):
        config = {
            'perceptual_loss_model': self.loss_model_type,
            'pereceptual_loss_layers': self.layers,
            'perceptual_loss_layer_weights':
                self.layer_weights,
            'model_input_shape': self.model_input_shape,
        }
        base_config = \
            super(PerceptualLossLayer, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


class NormalVariational(tf.keras.layers.Layer):
    def __init__(self, size=2, mu_prior=0., sigma_prior=1.,
                    use_kl=False, kl_coef=1.0,
                    use_mmd=True, mmd_coef=100.0, name=None, **kwargs):
        super().__init__(**kwargs)
        self.mu_layer = tf.keras.layers.Dense(size)
        self.sigma_layer = tf.keras.layers.Dense(size)
        if use_kl is True:
            # self.sigma_layer = tf.keras.layers.Dense(size)
            self.kl_coef = tf.Variable(kl_coef, trainable=False, name='kl_coef')
        self.mu_prior = tf.constant(mu_prior, dtype=tf.float32, shape=(size,))
        self.sigma_prior = tf.constant(
            sigma_prior, dtype=tf.float32, shape=(size,)
            )

        self.use_kl = use_kl
        self.use_mmd = use_mmd
        self.mmd_coef = mmd_coef
        self.kernel_f = self._rbf

    def _rbf(self, x, y):
        x_size = tf.shape(x)[0]
        y_size = tf.shape(y)[0]
        dim = tf.shape(x)[1]
        tiled_x = tf.tile(tf.reshape(x, tf.stack([x_size, 1, dim])),
                          tf.stack([1, y_size, 1]))

        tiled_y = tf.tile(tf.reshape(y, tf.stack([1, y_size, dim])),
                          tf.stack([x_size, 1, 1]))

        return tf.exp(-tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) /
                      tf.cast(dim, tf.float32))

    def use_kl_divergence(self, q_mu, q_sigma, p_mu, p_sigma):
        r = q_mu - p_mu
        kl = self.kl_coef * tf.reduce_mean(
            tf.reduce_sum(
                tf.math.log(p_sigma) -
                tf.math.log(q_sigma) -
                .5 * (1. - (q_sigma**2 + r**2) / p_sigma**2), axis=1
                )
            )
        self.add_loss(kl)
        self.add_metric(kl, 'mean', 'kl_divergence')

    def add_mm_discrepancy(self, z, z_prior):
        k_prior = self.kernel_f(z_prior, z_prior)
        k_post = self.kernel_f(z, z)
        k_prior_post = self.kernel_f(z_prior, z)
        mmd = tf.reduce_mean(k_prior) + \
            tf.reduce_mean(k_post) - \
            2 * tf.reduce_mean(k_prior_post)

        mmd = tf.multiply(self.mmd_coef,  mmd, name='mmd')
        self.add_loss(mmd)
        self.add_metric(mmd, 'mean', 'mmd_discrepancy')

    def call(self, inputs):
        if self.use_mmd:
            mu = self.mu_layer(inputs)
            log_sigma = self.sigma_layer(inputs)
            sigma_square = tf.exp(log_sigma * 0.5)
            z = mu + (log_sigma * tf.random.normal(shape=tf.shape(sigma_square)))
            z_prior = tfp.distributions.MultivariateNormalDiag(
                self.mu_prior, self.sigma_prior
                ).sample(tf.shape(z)[0])
            self.add_mm_discrepancy(z, z_prior)

        if self.use_kl:
            mu = self.mu_layer(inputs)
            log_sigma = self.sigma_layer(inputs)
            sigma_square = tf.exp(log_sigma * 0.5)
            self.use_kl_divergence(
                mu,
                sigma_square,
                self.mu_prior,
                self.sigma_prior)

        return z

    def get_config(self):
        base_config = super(NormalVariational, self).get_config()
        config = {
            'use_kl': self.use_kl,
            'use_mmd': self.use_mmd,
            'mmd_coef': self.mmd_coef,
            'kernel_f': self.kernel_f,
        }

        return dict(list(base_config.items()) + list(config.items()))


class SaltAndPepper(tf.keras.layers.Layer):
    def __init__(self, ratio=0.9, **kwargs):
        super(SaltAndPepper, self).__init__(**kwargs)
        self.masking = True
        self.ratio = ratio

    def call(self, inputs, training=None):
        def noised():
            shp = tf.keras.backend.shape(inputs)[1:]
            mask_select = tf.keras.backend.random_binomial(
                shape=shp, p=self.ratio)

            # salt and pepper have the same chance
            mask_noise = tf.keras.backend.random_binomial(shape=shp, p=0.1)
            out = (inputs * (mask_select)) + mask_noise
            return out

        return tf.keras.backend.in_train_phase(
            noised, inputs, training=training)

    def get_config(self):
        config = {'ratio': self.ratio,
                  'masking': self.masking}
        base_config = super(SaltAndPepper, self).get_config()

        return dict(list(base_config.items()) + list(config.items()))


class Encoder(BaseModel):
    def __init__(self, config):
        super(Encoder, self).__init__(config)
        self.make_encoder()

    def make_encoder(self):
        with tf.name_scope('encoder'):
            self.encoder_inputs = tf.keras.layers.Input(
                shape=self.config.model.input_shape, name='input')

            if self.config.model.denoise is True:
                with tf.name_scope('noise_layer'):
                    noise_layers = tf.keras.Sequential([
                        SaltAndPepper(),
                        tf.keras.layers.GaussianNoise(self.config.model.noise_ratio)
                        ], name='noise_layer')

                    noisy_inputs = noise_layers(self.encoder_inputs)

            with tf.name_scope('z_1'):
                h_1_layers = tf.keras.Sequential([
                    tf.keras.layers.Input(self.config.model.input_shape),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        8, 3,  padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.encoder_activations),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        16, 3,  padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.encoder_activations),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        16, 3, padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.encoder_activations),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        32, 3, padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.encoder_activations),
                    tf.keras.layers.AveragePooling2D()], name='h_1')
                if self.config.modeul.denoise is True:
                    h_1 = h_1_layers(noisy_inputs)
                else:
                    h_1 = h_1_layers(self.encoder_inputs)

                h_1_flatten = SqueezeExcite(c=32)(h_1)

            with tf.name_scope('z_2'):
                h_2_layers = tf.keras.Sequential([
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        32, 3, padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.encoder_activations),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        64, 3, padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.encoder_activations),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        64, 3, padding='same')),
                    tf.keras.layers.Activation(self.config.model.encoder_activations),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        128, 3, padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.encoder_activations),
                    tf.keras.layers.AveragePooling2D(),
                ], name='h_2')

                h_2 = h_2_layers(h_1)
                h_2_flatten = SqueezeExcite(c=128)(h_2)

            with tf.name_scope('z_3'):
                h_3_layers = tf.keras.Sequential([
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        128, 3, padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.encoder_activations),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        256, 3, padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.encoder_activations),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        256, 3, padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.encoder_activations),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        512, 3, padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.encoder_activations),
                    tf.keras.layers.AveragePooling2D(),
                ], name='h_3')

                h_3 = h_3_layers(h_2)
                h_3_flatten = SqueezeExcite(c=512)(h_3)

            with tf.name_scope('z_4'):
                h_4_layers = tf.keras.Sequential([
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        512, 3, padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.encoder_activations),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        1024, 3, padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.encoder_activations),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        1024, 3, padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.encoder_activations),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        2048, 3, padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.encoder_activations),
                    tf.keras.layers.AveragePooling2D()], name='h_4')

                h_4 = h_4_layers(h_3)
                h_4_flatten = SqueezeExcite(c=2048)(h_4)

            self.encoder_outputs = [
                h_1_flatten, h_2_flatten, h_3_flatten, h_4_flatten
                ]

            self.encoder = tf.keras.Model(
                self.encoder_inputs, self.encoder_outputs, name='encoder')

In [6]:
config_file = '/home/etheredge/Workspace/VisionEngine/VisionEngine/configs/butterfly_periodic_config.json'
config = process_config(config_file)

In [7]:
encoder = Encoder(config)

In [8]:
encoder.encoder.summary()

Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input (InputLayer)              [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
h_1 (Sequential)                (None, 128, 128, 32) 9099        input[0][0]                      
__________________________________________________________________________________________________
h_2 (Sequential)                (None, 64, 64, 128)  141440      h_1[1][0]                        
__________________________________________________________________________________________________
h_3 (Sequential)                (None, 32, 32, 512)  2225664     h_2[0][0]                        
____________________________________________________________________________________________

In [115]:
class Decoder(BaseModel):
    def __init__(self, config):
        super(Decoder, self).__init__(config)
        self.make_decoder()

    def make_decoder(self):
        with tf.name_scope('decoder'):

            self.z_1_input = tf.keras.layers.Input(
                (self.config.model.latent_size,), name='z_1')

            self.z_2_input = tf.keras.layers.Input(
                (self.config.model.latent_size,), name='z_2')

            self.z_3_input = tf.keras.layers.Input(
                (self.config.model.latent_size,), name='z_3')

            self.z_4_input = tf.keras.layers.Input(
                (self.config.model.latent_size,), name='z_4')

            with tf.name_scope('z_tilde_4'):
                z_4 = self.z_4_input
                z_4 = tf.keras.layers.Dense(16*16*2048, activation=None)(z_4)
                z_4 = tf.keras.layers.Reshape((16,16,2048))(z_4)
                
                z_tilde_4_layers = tf.keras.Sequential([
                    tf.keras.layers.UpSampling2D(),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        2048,
                        kernel_size=3,
                        padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        1024,
                        kernel_size=3,
                        padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.decoder_activations),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        1024,
                        kernel_size=3,
                        padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        512,
                        kernel_size=3,
                        padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.decoder_activations)], name='z_tilde_4')
                z_tilde_4 = z_tilde_4_layers(z_4)

            with tf.name_scope('z_tilde_3'):
                z_3 = self.z_3_input
                z_3 = tf.keras.layers.Dense(32*32*512, activation=None)(z_3)
                z_3 = tf.keras.layers.Reshape((32,32,512))(z_3)

                z_tilde_3_layers = tf.keras.Sequential([
                    tf.keras.layers.UpSampling2D(),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        512,
                        kernel_size=3,
                        padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        256,
                        kernel_size=3,
                        padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.decoder_activations),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        256,
                        kernel_size=3,
                        padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        128,
                        kernel_size=3,
                        padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.decoder_activations)], name='z_tilde_3')

                input_z_tilde_3 = tf.keras.layers.Concatenate()([z_tilde_4, z_3])
                z_tilde_3 = z_tilde_3_layers(input_z_tilde_3)

            with tf.name_scope('z_tilde_2'):
                z_2 = self.z_2_input
                z_2 = tf.keras.layers.Dense(64*64*128, activation=None)(z_2)
                z_2 = tf.keras.layers.Reshape((64,64,128))(z_2)
                                                        
                z_tilde_2_layers = tf.keras.Sequential([
                    tf.keras.layers.UpSampling2D(),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        128,
                        kernel_size=3,
                        padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        64,
                        kernel_size=3,
                        padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.decoder_activations),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        64,
                        kernel_size=3,
                        padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        32,
                        kernel_size=3,
                        padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.decoder_activations)], name='z_tilde_2')

                input_z_tilde_2 = tf.keras.layers.Concatenate()([z_tilde_3, z_2])
                z_tilde_2 = z_tilde_2_layers(input_z_tilde_2)

            with tf.name_scope('z_tilde_1'):
                z_1 = self.z_1_input
                z_1 = tf.keras.layers.Dense(128*128*32, activation=None)(z_1)
                z_1 = tf.keras.layers.Reshape((128,128,32))(z_1)
                z_tilde_1_layers = tf.keras.Sequential([
                    tf.keras.layers.UpSampling2D(),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        32,
                        kernel_size=3,
                        padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        16,
                        kernel_size=3,
                        padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.decoder_activations),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        16,
                        kernel_size=3,
                        padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        8,
                        kernel_size=3,
                        padding='same')),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation(self.config.model.decoder_activations),
                    SpectralNormalization(tf.keras.layers.Conv2D(
                        self.config.model.input_shape[2], 3, 1, padding='same')),
                    tf.keras.layers.Activation(self.config.model.last_activation)], name='z_tilde_1')

                input_z_tilde_1 = tf.keras.layers.Concatenate()([z_tilde_2, z_1])

                self.decoder_outputs = z_tilde_1_layers(input_z_tilde_1)
                self.decoder_inputs = [
                    self.z_1_input, self.z_2_input, self.z_3_input, self.z_4_input
                    ]

            self.decoder = tf.keras.Model(
                self.decoder_inputs,
                self.decoder_outputs,
                name='decoder')

In [116]:
decoder = Decoder(config)

In [118]:
decoder.decoder.summary()

Model: "decoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
z_4 (InputLayer)                [(None, 10)]         0                                            
__________________________________________________________________________________________________
dense_56 (Dense)                (None, 524288)       5767168     z_4[0][0]                        
__________________________________________________________________________________________________
z_3 (InputLayer)                [(None, 10)]         0                                            
__________________________________________________________________________________________________
reshape_53 (Reshape)            (None, 16, 16, 2048) 0           dense_56[0][0]                   
____________________________________________________________________________________________

## decoder.decoder.summary()

In [2]:
z_dim = 512
input_dim = (256,256,3)

In [36]:
class ConvBlock(tf.keras.layers.Layer):
    def __init__(self, c, **kwargs):
        super(ConvBlock, self).__init__(**kwargs)
        self.c = c

    def build(self, input_shape):
        self.cb = tf.keras.Sequential([
            tf.keras.layers.Conv2D(
                self.c // 2,
                kernel_size=3,
                padding='same'
            ),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Conv2D(
                self.c,
                kernel_size=3,
                padding='same'
            ),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('swish')
        ])

    def call(self, layer_inputs, **kwargs):
        return self.cb(layer_inputs)

    def compute_output_shape(self, input_shape):
        return input_shape

    def get_config(self):
        config = {
            'c': self.c,
        }
        base_config = \
            super(ConvBlock, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


class EncoderCell(tf.keras.layers.Layer):
    def __init__(self, cs, **kwargs):
        super(EncoderCell, self).__init__(**kwargs)
        self.cs = cs

    def build(self, input_shape):
        self.ec = tf.keras.Sequential()
        for cs_ in self.cs:
            self.ec.add(ConvBlock(cs_))

    def call(self, layer_inputs, **kwargs):
        return self.ec(layer_inputs)

    def compute_output_shape(self, input_shape):
        return input_shape

    def get_config(self):
        config = {
            'cs': self.cs,
        }
        base_config = \
            super(EncoderCell, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


class UpsampleBlock(tf.keras.layers.Layer):
    def __init__(self, c, **kwargs):
        super(UpsampleBlock, self).__init__(**kwargs)
        self.c = c

    def build(self, input_shape):
        self.ub = tf.keras.Sequential([
            tf.keras.layers.Conv2DTranspose(
                self.c,
                kernel_size=3,
                stride=2
            ),
            tf.keras.layers.BatchNormalization()
        ])

    def call(self, layer_inputs, **kwargs):
        return self.ub(layer_inputs)

    def compute_output_shape(self, input_shape):
        return input_shape

    def get_config(self):
        config = {
            'c': self.c,
        }
        base_config = \
            super(UpsampleBlock, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


class DecoderCell(tf.keras.layers.Layer):
    def __init__(self, cs, **kwargs):
        super(DecoderCell, self).__init__(**kwargs)
        self.cs = cs

    def build(self, input_shape):
        self.dc = tf.keras.Sequential()
        for cs_ in self.cs:
            self.dc.add(self.UpsampleBlock(cs_))

    def call(self, layer_inputs, **kwargs):
        return self.se(layer_inputs)

    def compute_output_shape(self, input_shape):
        return input_shape

    def get_config(self):
        config = {
            'cs': self.cs
        }
        base_config = \
            super(DecoderCell, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


class EncoderResidualCell(tf.keras.layers.Layer):
    def __init__(self, c, **kwargs):
        super(EncoderResidualCell, self).__init__(**kwargs)
        self.c = c

    def build(self, input_shape):
        self.erc = tf.keras.Sequential([
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Conv2D(
                self.c,
                kernel_size=3,
                padding='same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('swish'),
            tf.keras.layers.Conv2D(
                self.c,
                kernel_size=3,
                padding='same'),
            SqueezeExcite(self.c)])

    def call(self, layer_inputs, **kwargs):
        return self.erc(layer_inputs)

    def compute_output_shape(self, input_shape):
        return input_shape

    def get_config(self):
        config = {
            'c': self.c
        }
        base_config = \
            super(EncoderResidualCell, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


class DecoderResidualCell(tf.keras.layers.Layer):
    def __init__(self, c, e, **kwargs):
        super(DecoderResidualCell, self).__init__(**kwargs)
        self.c = c
        self.e = e

    def build(self, input_shape):
        self.drc = tf.keras.Sequential([
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Conv2D(
                self.c * self.e,
                kernel_size=1,
                padding='same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('swish'),
            tf.keras.layers.DepthwiseConv2D(
                self.c * self.e,
                kernel_size=5,
                stride=1),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('swish'),
            tf.keras.layers.Conv2D(
                self.c,
                kernel_size=1,
                padding='same',
                use_use_bias=False,
                activation=None),
            tf.keras.layers.BatchNormalization(),
            SqueezeExcite(self.c)])

        def call(self, layer_inputs, **kwargs):
            return self.drc(layer_inputs)

        def compute_output_shape(self, input_shape):
            return input_shape

        def get_config(self):
            config = {
                'c': self.c,
                'e': self.e
            }
            base_config = \
                super(DecoderResidualCell, self).get_config()
            return dict(list(base_config.items()) + list(config.items()))


class SqueezeExcite(tf.keras.layers.Layer):
    def __init__(self, c, r=16, **kwargs):
        super(SqueezeExcite, self).__init__(**kwargs)
        self.c = c
        self.r = r

    def build(self, input_shape):
        self.se = tf.keras.Sequential([
            tf.keras.layers.GlobalAveragePooling2D(),
            tf.keras.layers.Dense(self.c // self.r, use_bias=False),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.Dense(self.c, use_bias=False),
            tf.keras.layers.Activation('sigmoid')])

    def call(self, layer_inputs, **kwargs):
        return self.se(layer_inputs)

    def compute_output_shape(self, input_shape):
        return input_shape

    def get_config(self):
        config = {
            'c': self.c,
            'r': self.r
        }
        base_config = \
            super(SqueezeExcite, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


class NouveauVAE(tf.keras.Model):

    def __init__(self, z_dim, input_dim):
        super(NouveauVAE, self).__init__()
        self.z_dim = z_dim
        self.input_dim = input_dim

        # encoder
#         self.enc = tf.keras.Sequential()

        self.enc_inputs = tf.keras.layers.Input(shape=self.input_dim, name='input')

        encoder_stack = [
            EncoderCell([self.z_dim // 16, self.z_dim // 8]),
            EncoderCell([self.z_dim // 4, self.z_dim // 2]),
            EncoderCell([self.z_dim])
        ]
        encoder_res_stack = [
            EncoderResidualCell(self.z_dim // 8),
            EncoderResidualCell(self.z_dim // 2),
            EncoderResidualCell(self.z_dim)
        ]

        for e, r in zip(encoder_stack[:1], encoder_res_stack[:1]):
            x = r()e()(self.enc_inputs)
            
        for e, r in zip(encoder_stack[1:], encoder_res_stack[1:]):
            x = r()e()(x)

        self.condition_x = tf.keras.layers.Sequential([
            tf.keras.layers.GlobalAveragePooling2D(),
            tf.keras.layers.Activation('swish'),
            tf.keras.layers.Conv2D(self.z_dim * 2)
        ])

        self.z = self.condition_x()(x)
        
        

        # decoder
#         self.dec = tf.keras.Sequential()
        self.dec.add(tf.keras.layers.Input(shape=self.z_dim * 2))
        decoder_stack = [
            DecoderCell([z_dim // 2]),
            DecoderCell([z // 4, z // 8]),
            DecoderCell([z // 16, z // 32])
        ]
        decoder_res_stack = [
            DecoderResidualCell(z_dim // 2, e=1),
            DecoderResidualCell(z_dim // 8, e=2),
            DecoderResidualCell(z_dim // 32, e=4),
        ]

        for d, r in zip(decoder_stack, decoder_res_stack):
            x = r(d)
            self.dec.add(x)

        self.x_hat = tf.keras.layers.Conv2D(3, kernel_size=1)

        self.dec.add(self.x_hat)
    

    def sample(self, eps=None):
        if eps is None:
            eps = tf.random.normal(shape=(100, self.z_dim))
        return self.decode(eps, appy_sigmoid=True)

    def encode(self, x):
        mu, logvar = tf.split(self.encoder(x), num_or_size_of_splits=2, axis=1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        eps = tf.random.normal(shape=mu.shape)
        return eps * tf.exp(logvar * .5) + mean
    
    def decode(self, z, apply_sigmoid=False):
        logits = self.decode(z)
        if apply_sigmoid:
            probs = tf.sigmoid(logits)
            return probs
        return logits

In [37]:
optimizer = tf.keras.optimizers.Adam(1e-4)


def log_normal_pdf(sample, mean, logvar, raxis=1):
  log2pi = tf.math.log(2. * np.pi)
  return tf.reduce_sum(
      -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
      axis=raxis)


def compute_loss(model, x):
  mean, logvar = model.encode(x)
  z = model.reparameterize(mean, logvar)
  x_logit = model.decode(z)
  cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
  logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])
  logpz = log_normal_pdf(z, 0., 0.)
  logqz_x = log_normal_pdf(z, mean, logvar)
  return -tf.reduce_mean(logpx_z + logpz - logqz_x)


@tf.function
def train_step(model, x, optimizer):
  """Executes one training step and returns the loss.

  This function computes the loss and gradients, and uses the latter to
  update the model's parameters.
  """
  with tf.GradientTape() as tape:
    loss = compute_loss(model, x)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

In [38]:
epochs = 10
# set the dimensionality of the latent space to a plane for visualization later
num_examples_to_generate = 16

# keeping the random vector constant for generation (prediction) so
# it will be easier to see the improvement.
random_vector_for_generation = tf.random.normal(
    shape=[num_examples_to_generate, z_dim])
model = NouveauVAE(z_dim, input_dim)

ValueError: ('Input has undefined rank:', TensorShape(None))