In [1]:
import tensorflow as tf
import numpy as np
import random
import time
import os

from tensorflow.keras import backend as K


class Tadam():
    def __init__(self, model, total_steps, lr=1e-3, beta1=0.9, beta2=0.999, gamma=0.25, eps=1e-8, name='TADAM',**kwargs):   
        self.model = model
        self.total_steps = total_steps
        self.lr = lr
        self.bt1 = beta1
        self.bt2 = beta2
        self.gamma = gamma
        self.eps = eps
        
        self.shape = self.flatgrad(self.model.trainable_variables).get_shape()
        self.zero_init = tf.constant(0.0, shape=self.shape)
        
        with K.name_scope(self.__class__.__name__):
            self.m_h = K.variable(self.zero_init, name='m_h')
            self.m = K.variable(self.zero_init, name='m')
            self.s = K.variable(self.zero_init, name='s')
            self.v = K.variable(self.zero_init + eps, name='v')
            self.bp1 = K.variable(beta1, name='bp1')
            self.bp2 = K.variable(beta2, name='bp2')
            self.ls_h = K.variable(0.0, name='ls_h')
            self.ls = K.variable(0.0, name='ls')
            self.pr = K.variable(0.0, name='pr')
            self.dt = K.variable(1.0, name='dt')
            self.t = K.variable(1.0, name='t')
        
    def apply_grad(self, loss, gradient):
        # gradient
        g_flat = self.flatgrad(gradient)
        
        # delta
        d1 = lambda: 1.0
        d2 = lambda: self.compute_delta(loss, self.dt)
        self.dt.assign(tf.case([(tf.less(self.t, 1.1), d1)], default=d2, exclusive=True))
        
        # bias correction
        bc1 = 1.0 - self.bp1
        bc2 = 1.0 - self.bp2
        
        # moving varience
        dv = tf.square(g_flat - self.m_h) * (self.bt2 - self.bp2) / bc2 
        self.v.assign(self.bt2 * self.v + (1.0 - self.bt2) * dv)
        v_h = self.v / bc2
        
        # first moment
        self.m.assign(self.bt1 * self.m + (1.0 - self.bt1) * g_flat)
        self.m_h.assign(self.m / bc1)
        
        # second moment
        self.s.assign(self.bt2 * self.s + (1.0 - self.bt2) * tf.square(g_flat))
        s_h = self.s / bc2
        
        # fisher vector
        f_h = (1.0 + tf.reduce_sum(tf.square(self.m_h) / (v_h + self.eps))) * v_h
        
        # apply trust region
        u_h = tf.maximum(self.dt * f_h, tf.sqrt(s_h))
        g_h = self.m_h * self.dt / (u_h + self.eps)
        
        # update
        g_update = self.flat_to_grad_list(g_h, self.model.trainable_variables)
        p_update = [var.assign(var - self.lr*grad) for grad, var in zip(g_update, self.model.trainable_variables)]
        
        # moving avg of loss
        self.ls.assign(self.bt1 * self.ls + (1.0 - self.bt1) * loss)
        self.ls_h.assign(self.ls / bc1)
        
        # predict reduction
        pr1 = tf.reduce_sum(self.m_h * g_h) 
        pr2 = tf.square(pr1) + tf.reduce_sum(v_h * tf.square(g_h))
        self.pr.assign((pr1 - 0.5 * pr2) * self.lr)
        
        # beta update
        self.bp1.assign(self.bp1*self.bt1)
        self.bp2.assign(self.bp2*self.bt2)
        self.t.assign_add(1.0)
        
        return 
    
    def compute_delta(self, loss, dt):
        rho = self.compute_rho(loss)
        dt_min = tf.pow(1.0 - self.gamma, (self.t - 1.0) / self.total_steps)
        dt_max = 1.0 + tf.pow(self.gamma, (self.t - 1.0) / self.total_steps)
        r1 = lambda: dt_min
        r2 = lambda: dt_max
        r3 = lambda: 1.0
        r = tf.case([(tf.less(rho, self.gamma), r1), (tf.greater(rho, 1.0 - self.gamma), r2)], default=r3, exclusive=True)
        dt = tf.minimum(tf.maximum(r * dt, dt_min), dt_max)
        return dt
    
    def compute_rho(self, loss):
        return (self.ls_h - loss) / tf.maximum(self.pr, self.eps)
        
    def flatgrad(self, grad_list):
        """
        flattens gradients.
        :param grad_list: ([TensorFlow Tensor]) the gradients
        :return: ([TensorFlow Tensor]) flattened gradient
        """
        return tf.concat(axis=0, values=[tf.reshape(grad, [-1]) for grad in grad_list])

    def flat_to_grad_list(self, flat, var_list):
        """
        converts flats to the form of var_list.
        :param flat:
        :param var_list:
        :return: [(Tensorflow Tensor)]
        """
        splits = tf.split(flat, [self.numel(w) for w in var_list])
        return [tf.reshape(t, self.var_shape(w)) for t, w in zip(splits, var_list)]

    def var_shape(self, tensor):
        """
        get TensorFlow Tensor shape
        :param tensor: (TensorFlow Tensor) the input tensor
        :return: ([int]) the shape
        """
        out = tensor.get_shape().as_list()
        assert all(isinstance(a, int) for a in out), \
            "shape function assumes that shape is fully known"
        return out

    def numel(self, tensor):
        """
        get TensorFlow Tensor's number of elements
        :param tensor: (TensorFlow Tensor) the input tensor
        :return: (int) the number of elements
        """
        return int(np.prod(self.var_shape(tensor)))
    

In [2]:
# model
class CAE(tf.keras.Model):
    def __init__(self, in_shape, filters, code_dim):
        super(CAE, self).__init__()
        self.in_shape = in_shape
        self.filters = filters
        self.code_dim = code_dim
        
        # input
        self.inputs = tf.keras.layers.Input(shape=self.in_shape)

        # encoder
        self.enc01 = tf.keras.layers.Conv2D(
            filters=self.filters[0], kernel_size=3, strides=2, padding='same')(self.inputs)
        self.enc01 = tf.keras.activations.swish(self.enc01)
        
        self.enc02 = tf.keras.layers.Conv2D(
            filters=self.filters[1], kernel_size=3, strides=2, padding='same')(self.enc01)
        self.enc02 = tf.keras.activations.swish(self.enc02)
        
        self.enc03 = tf.keras.layers.Conv2D(
            filters=self.filters[2], kernel_size=3, strides=2, padding='same')(self.enc02)
        self.enc03 = tf.keras.activations.swish(self.enc03)
        self.enc03_flat = tf.keras.layers.Flatten()(self.enc03)
        
        # code
        self.code = tf.keras.layers.Dense(self.code_dim)(self.enc03_flat)
        
        # decoder
        self.dec03 = tf.keras.layers.Dense(self.enc03_flat.shape[-1])(self.code)
        self.dec03 = tf.keras.activations.swish(self.dec03)
        self.dec03 = tf.keras.layers.Reshape(target_shape=self.enc03.shape[1:])(self.dec03)

        self.dec02 = tf.keras.layers.Conv2DTranspose(
            filters=self.filters[1], kernel_size=3, strides=2, padding='same')(self.dec03)
        self.dec02 = tf.keras.activations.swish(self.dec02)

        self.dec01 = tf.keras.layers.Conv2DTranspose(
            filters=self.filters[0], kernel_size=3, strides=2, padding='same')(self.dec02)
        self.dec01 = tf.keras.activations.swish(self.dec01)
        
        # output
        self.outputs = tf.keras.layers.Conv2DTranspose(
            filters=self.in_shape[2], kernel_size=3, strides=2, padding='same')(self.dec01)

        self.model = tf.keras.Model(inputs=self.inputs, outputs=self.outputs, name='ConvAE')
        

In [3]:
# load data
def load_data(batch_size, epochs):
    (train_ds,_), (test_ds,_) = tf.keras.datasets.mnist.load_data()

    AUTOTUNE = tf.data.AUTOTUNE

    train_size = train_ds.shape[0]
    test_size = test_ds.shape[0]

    train_ds = preprocess_images(train_ds)
    test_ds = preprocess_images(test_ds)

    train_ds = (tf.data.Dataset.from_tensor_slices(train_ds)
                     .shuffle(train_size).batch(batch_size))
    test_ds = (tf.data.Dataset.from_tensor_slices(test_ds)
                    .shuffle(test_size).batch(batch_size))
    
    steps_per_epoch = int(train_size / batch_size)
    total_steps = int(steps_per_epoch * epochs)

    return train_ds, test_ds, total_steps


def preprocess_images(img):
    img = (img.reshape((img.shape[0], 28, 28, 1)) / 255.).astype('float32')
    img = tf.image.resize(img, [32,32])
    
    return img


In [4]:
# loss computation
def compute_mse(train_model, x, training=True):
    x_logit = train_model(x, training=training)
    loss_fn = tf.keras.losses.MeanSquaredError()
    loss = loss_fn(x, x_logit)
    return loss

# optimiztion
@tf.function
def train_step(model, opt, x):
    """
    Executes one training step and returns the loss.
    This function computes the loss and gradients, and uses the latter to
    update the model's parameters.
    """
    with tf.GradientTape() as tape:
        loss = compute_mse(model, x, True)
    gradient = tape.gradient(loss, model.trainable_variables)

    opt.apply_grad(loss, gradient)
    delta = opt.dt
    
    return loss, delta

In [5]:
def training(in_shape, filters, code_dim, batch_size, epochs=1, 
             init_lr=1e-3, beta1=0.9, beta2=0.999, gamma=0.25, seed=0):
    
    print('LOSS: {}, DATA: {},  BATCH: {}, EPOCH: {}, SEED: {}'
          .format('MSE', 'MNIST', batch_size, epochs, seed))
    print('INIT_LR: {:.0e}, BETA1: {:.4f}, BETA2: {:.4f}, GAMMA: {:.4f}'
          .format(init_lr, beta1, beta2, gamma))
    
    # fix seed
    tf.random.set_seed(seed)  # Tensorflow
    np.random.seed(seed)  # numpy
    random.seed(seed)  # Python
    
    # model
    model = CAE(in_shape, filters, code_dim).model
    model.summary()
    
    # data
    train_dataset, test_dataset, total_steps = load_data(batch_size, epochs)

    # optimizer
    opt = Tadam(model, total_steps, init_lr, beta1, beta2, gamma)
    
    # train model
    for epoch in range(1, epochs + 1):
        tmloss = tf.keras.metrics.Mean()
        start_time = time.time()
        for train_x in train_dataset:
            loss, delta = train_step(model, opt, train_x)
            tmloss(loss)
        end_time = time.time()
        tloss = tmloss.result()

        vmloss = tf.keras.metrics.Mean()
        for test_x in test_dataset:
            loss = compute_mse(model, test_x, False)
            vmloss(loss) 
        vloss = vmloss.result()

        print('Epoch: {}, loss_t: {:.4e}, loss_v: {:.4e}, dt: {:.2f}, time: {:.2f}'
              .format(epoch, tloss, vloss, delta, end_time - start_time))

In [6]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

in_shape = (32, 32, 1)
filters = [16, 32, 64]
code_dim = 16

training(in_shape, filters, code_dim, batch_size=128, epochs=100, 
         init_lr=1e-3, beta1=0.9, beta2=0.999, gamma=0.25)


LOSS: MSE,  DATA: MNIST,  BATCH: 128, EPOCH: 100, SEED: 0
INIT_LR: 1e-03, BETA1: 0.9000, BETA2: 0.9990, GAMMA: 0.2500
Model: "ConvAE"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 32, 32, 1)]       0         
                                                                 
 conv2d (Conv2D)             (None, 16, 16, 16)        160       
                                                                 
 tf.nn.silu (TFOpLambda)     (None, 16, 16, 16)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 8, 8, 32)          4640      
                                                                 
 tf.nn.silu_1 (TFOpLambda)   (None, 8, 8, 32)          0         
                                                                 
 conv2d_2 (Conv2D)           (None, 4, 4, 64)          18496     
        

Epoch: 1, loss_t: 3.0960e-02, loss_v: 8.5160e-03, dt: 1.98, time: 7.87
Epoch: 2, loss_t: 7.5089e-03, loss_v: 6.6965e-03, dt: 1.97, time: 4.10
Epoch: 3, loss_t: 6.3080e-03, loss_v: 5.8726e-03, dt: 1.93, time: 4.15
Epoch: 4, loss_t: 5.7805e-03, loss_v: 5.6573e-03, dt: 1.88, time: 4.04
Epoch: 5, loss_t: 5.4316e-03, loss_v: 5.2908e-03, dt: 1.93, time: 4.11
Epoch: 6, loss_t: 5.1926e-03, loss_v: 5.0891e-03, dt: 1.92, time: 4.11
Epoch: 7, loss_t: 5.0272e-03, loss_v: 4.9590e-03, dt: 1.87, time: 4.14
Epoch: 8, loss_t: 4.8847e-03, loss_v: 4.8127e-03, dt: 1.77, time: 4.09
Epoch: 9, loss_t: 4.7653e-03, loss_v: 4.7324e-03, dt: 1.83, time: 4.07
Epoch: 10, loss_t: 4.6741e-03, loss_v: 4.6151e-03, dt: 1.77, time: 4.14
Epoch: 11, loss_t: 4.5990e-03, loss_v: 4.5685e-03, dt: 1.74, time: 4.16
Epoch: 12, loss_t: 4.5281e-03, loss_v: 4.5227e-03, dt: 1.78, time: 4.04
Epoch: 13, loss_t: 4.4550e-03, loss_v: 4.4895e-03, dt: 1.83, time: 4.09
Epoch: 14, loss_t: 4.4084e-03, loss_v: 4.4145e-03, dt: 1.75, time: 4.05
E