In [1]:
import tensorflow as tf
import numpy as np
import random
import time
import os


# model
class CAE(tf.keras.Model):
    def __init__(self, in_shape, filters, code_dim):
        super(CAE, self).__init__()
        self.in_shape = in_shape
        self.filters = filters
        self.code_dim = code_dim
        
        # input
        self.inputs = tf.keras.layers.Input(shape=self.in_shape)

        # encoder
        self.enc01 = tf.keras.layers.Conv2D(
            filters=self.filters[0], kernel_size=3, strides=2, padding='same')(self.inputs)
        self.enc01 = tf.keras.activations.swish(self.enc01)
        
        self.enc02 = tf.keras.layers.Conv2D(
            filters=self.filters[1], kernel_size=3, strides=2, padding='same')(self.enc01)
        self.enc02 = tf.keras.activations.swish(self.enc02)
        
        self.enc03 = tf.keras.layers.Conv2D(
            filters=self.filters[2], kernel_size=3, strides=2, padding='same')(self.enc02)
        self.enc03 = tf.keras.activations.swish(self.enc03)
        self.enc03_flat = tf.keras.layers.Flatten()(self.enc03)
        
        # code
        self.code = tf.keras.layers.Dense(self.code_dim)(self.enc03_flat)
        
        # decoder
        self.dec03 = tf.keras.layers.Dense(self.enc03_flat.shape[-1])(self.code)
        self.dec03 = tf.keras.activations.swish(self.dec03)
        self.dec03 = tf.keras.layers.Reshape(target_shape=self.enc03.shape[1:])(self.dec03)

        self.dec02 = tf.keras.layers.Conv2DTranspose(
            filters=self.filters[1], kernel_size=3, strides=2, padding='same')(self.dec03)
        self.dec02 = tf.keras.activations.swish(self.dec02)

        self.dec01 = tf.keras.layers.Conv2DTranspose(
            filters=self.filters[0], kernel_size=3, strides=2, padding='same')(self.dec02)
        self.dec01 = tf.keras.activations.swish(self.dec01)
        
        # output
        self.outputs = tf.keras.layers.Conv2DTranspose(
            filters=self.in_shape[2], kernel_size=3, strides=2, padding='same')(self.dec01)

        self.model = tf.keras.Model(inputs=self.inputs, outputs=self.outputs)
        

In [2]:
# load data
def load_data(data_name, batch_size):
    if data_name == 'MNIST':
        (train_ds,_), (test_ds,_) = tf.keras.datasets.mnist.load_data()
    elif data_name == 'Fashion-MNIST':
        (train_ds,_), (test_ds,_) = tf.keras.datasets.fashion_mnist.load_data()
    elif data_name == 'CIFAR-10':
        (train_ds,_), (test_ds,_) = tf.keras.datasets.cifar10.load_data()
    
    AUTOTUNE = tf.data.AUTOTUNE

    train_size = train_ds.shape[0]
    test_size = test_ds.shape[0]

    train_ds = preprocess_images(data_name, train_ds)
    test_ds = preprocess_images(data_name, test_ds)

    train_ds = (tf.data.Dataset.from_tensor_slices(train_ds)
                     .shuffle(train_size).batch(batch_size))
    test_ds = (tf.data.Dataset.from_tensor_slices(test_ds)
                    .shuffle(test_size).batch(batch_size))

    return train_ds, test_ds, train_size


def preprocess_images(data_name, img):
    if data_name == 'MNIST':
        img = (img.reshape((img.shape[0], 28, 28, 1)) / 255.).astype('float32')
        img = tf.image.resize(img, [32,32])
    if data_name == 'Fashion-MNIST':
        img = (img.reshape((img.shape[0], 28, 28, 1)) / 255.).astype('float32')
        img = tf.image.resize(img, [32,32])
    elif data_name == 'CIFAR-10':
        img = tf.image.rgb_to_grayscale(img)
        img = tf.image.convert_image_dtype(img, tf.float32)
    return img


In [3]:
# gradient flatten
def flatgrad(grad_list):
    """
    flattens gradients.

    :param grad_list: ([TensorFlow Tensor]) the gradients
    :return: ([TensorFlow Tensor]) flattened gradient
    """
    return tf.concat(axis=0, values=[tf.reshape(grad, [-1]) for grad in grad_list])


def flat_to_grad_list(flat, var_list):
    """
    converts flats to the form of var_list.
    :param flat:
    :param var_list:
    :return: [(Tensorflow Tensor)]
    """
    splits = tf.split(flat, [numel(w) for w in var_list])
    return [tf.reshape(t, var_shape(w)) for t, w in zip(splits, var_list)]


def var_shape(tensor):
    """
    get TensorFlow Tensor shape
    :param tensor: (TensorFlow Tensor) the input tensor
    :return: ([int]) the shape
    """
    out = tensor.get_shape().as_list()
    assert all(isinstance(a, int) for a in out), \
        "shape function assumes that shape is fully known"
    return out


def numel(tensor):
    """
    get TensorFlow Tensor's number of elements
    :param tensor: (TensorFlow Tensor) the input tensor
    :return: (int) the number of elements
    """
    return int(np.prod(var_shape(tensor)))


# loss computation
def compute_mse(train_model, x, training=True):
    x_logit = train_model(x, training=training)
    loss_fn = tf.keras.losses.MeanSquaredError()
    loss = loss_fn(x, x_logit)
    return loss


In [4]:
# TADAM
@tf.function
def train_step(model, tr_opt, x, lr, beta1, beta2, gamma, ls_h, ls, pr, dt, m_h, m, s, v, t, total_steps, eps=1e-8):
    """
    Executes one training step and returns the loss.
    This function computes the loss and gradients, and uses the latter to
    update the model's parameters.
    """
    with tf.GradientTape() as tape:
        loss = compute_mse(model,x, True)
    gradient = tape.gradient(loss, model.trainable_variables)
    g_flat = flatgrad(gradient)

    # bias correction
    bp1 = tf.pow(beta1, t)
    bp2 = tf.pow(beta2, t)
    bc1 = 1.0 - bp1
    bc2 = 1.0 - bp2

    # compute delta
    if t>1:
        rho = (ls_h - loss) / tf.maximum(pr, eps)
    else:
        rho = 0.5
    dt_min = tf.pow(1.- gamma, (t-1) / total_steps)
    dt_max = 1. + tf.pow(gamma, (t-1) / total_steps)
    f1 = lambda: dt_min
    f2 = lambda: dt_max
    f3 = lambda: 1.0
    r = tf.case([(tf.less(rho, gamma), f1), (tf.greater(rho, 1. - gamma), f2)], default=f3, exclusive=True)
    dt = tf.minimum(tf.maximum(r * dt, dt_min), dt_max) 

    # moving varience
    dv = tf.square(g_flat - m_h) * (beta2 - bp2) / bc2 
    v = beta2 * v + (1.0 - beta2) * dv
    v_h = v / bc2

    # first moment
    m = beta1 * m + (1.0 - beta1) * g_flat
    m_h = m / bc1

    # second moment
    s = beta2 * s + (1.0 - beta2) * tf.square(g_flat)
    s_h = s / bc2

    # fisher vector
    f_h = (1.0 + tf.reduce_sum(tf.square(m_h) / (v_h + eps))) * v_h

    # apply trust region
    u_h = tf.maximum(dt * f_h, tf.sqrt(s_h))
    g_h = m_h * dt / (u_h + eps)

    # update
    g_update = flat_to_grad_list(g_h, model.trainable_variables)
    grads_and_vars = [(grad, var) for grad, var in zip(g_update, model.trainable_variables)]
    tr_opt.apply_gradients(grads_and_vars)

    # moving avg of loss
    ls = beta1 * ls + (1.0 - beta1) * loss
    ls_h = ls / bc1

    # predict reduction
    pr1 = tf.reduce_sum(m_h * g_h) 
    pr2 = tf.square(pr1) + tf.reduce_sum(v_h * tf.square(g_h))
    pr = (pr1 - 0.5 * pr2) * lr

    return ls_h, ls, loss, pr, dt, m_h, m, s, v, t + 1.0

In [5]:
def training(data_name, in_shape, filters, code_dim, batch_size=128, epochs=1, 
             init_lr=1e-3, beta1=0.9, beta2=.0999, gamma=0.25, eps=1e-8, seed=0):
    
    print('DATA: {}, BATCH: {}, EPOCH: {}'
          .format(data_name, batch_size, epochs))
    print('INIT_LR: {:.0e}, BETA1: {:.4f}, BETA2: {:.4f}, GAMMA: {:.4f}'
          .format(init_lr, beta1, beta2, gamma))
    
    # fix seed
    tf.random.set_seed(seed)  # Tensorflow
    np.random.seed(seed)  # numpy
    random.seed(seed)  # Python
     
    # set model
    model = CAE(in_shape, filters, code_dim).model
    model.summary()
    
    # load data
    train_dataset, test_dataset, train_size = load_data(data_name, batch_size)
    steps_per_epoch = int(train_size / batch_size)
    total_steps = int(steps_per_epoch * epochs)
    
    # optimizer
    train_opt = tf.keras.optimizers.SGD(learning_rate=init_lr)
    
    # initial parameters
    ls_h = 0.0
    ls = 0.0
    pr = 0.0
    m_h = 0.0
    m = 0.0
    s = 0.0
    v = eps
    dt = 1.0
    t = 1.0

    # train model
    for epoch in range(1, epochs + 1):        
        tmloss = tf.keras.metrics.Mean()
        start_time = time.time()
        for train_x in train_dataset:
            ls_h, ls, loss, pr, dt, m_h, m, s, v, t = train_step(model, train_opt, train_x,
                init_lr, beta1, beta2, gamma, ls_h, ls, pr, dt, m_h, m, s, v, t, total_steps)
            tmloss(loss)
        end_time = time.time()
        train_time = end_time - start_time
        tloss = tmloss.result()

        vmloss = tf.keras.metrics.Mean()
        for test_x in test_dataset:
            loss = compute_mse(model, test_x, False)
            vmloss(loss) 
        vloss = vmloss.result()
        
        print('Epoch: {}, loss_t: {:.4e}, loss_v: {:.4e}, dt: {:.2f}, time: {:.2f}'
              .format(epoch, tloss, vloss, dt, train_time))


In [6]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = ""

datas = ['MNIST', 'Fashion-MNIST', 'CIFAR-10']
in_shape = (32, 32, 1)
filters = [16, 32, 64]
code_dim = 16

training(datas[0], in_shape, filters, code_dim, batch_size=128, epochs=100, 
         init_lr=1e-3, beta1=0.9, beta2=0.999, gamma=0.25)


DATA: MNIST, BATCH: 128, EPOCH: 100
INIT_LR: 1e-03, BETA1: 0.9000, BETA2: 0.9990, GAMMA: 0.2500


2023-06-24 09:43:45.415647: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-06-24 09:43:46.155092: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 10410 MB memory:  -> device: 0, name: NVIDIA GeForce GTX 1080 Ti, pci bus id: 0000:0a:00.0, compute capability: 6.1


Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 32, 32, 1)]       0         
                                                                 
 conv2d (Conv2D)             (None, 16, 16, 16)        160       
                                                                 
 tf.nn.silu (TFOpLambda)     (None, 16, 16, 16)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 8, 8, 32)          4640      
                                                                 
 tf.nn.silu_1 (TFOpLambda)   (None, 8, 8, 32)          0         
                                                                 
 conv2d_2 (Conv2D)           (None, 4, 4, 64)          18496     
                                                                 
 tf.nn.silu_2 (TFOpLambda)   (None, 4, 4, 64)          0     

Epoch: 1, loss_t: 3.0891e-02, loss_v: 8.6022e-03, dt: 1.98, time: 8.63
Epoch: 2, loss_t: 7.5140e-03, loss_v: 6.5878e-03, dt: 1.97, time: 4.13
Epoch: 3, loss_t: 6.2882e-03, loss_v: 5.8424e-03, dt: 1.94, time: 4.11
Epoch: 4, loss_t: 5.7518e-03, loss_v: 5.5353e-03, dt: 1.95, time: 4.20
Epoch: 5, loss_t: 5.4284e-03, loss_v: 5.2315e-03, dt: 1.93, time: 4.22
Epoch: 6, loss_t: 5.1980e-03, loss_v: 5.1990e-03, dt: 1.85, time: 4.07
Epoch: 7, loss_t: 5.0339e-03, loss_v: 4.9703e-03, dt: 1.83, time: 4.27
Epoch: 8, loss_t: 4.9003e-03, loss_v: 4.8134e-03, dt: 1.85, time: 4.08
Epoch: 9, loss_t: 4.7933e-03, loss_v: 4.7303e-03, dt: 1.79, time: 4.09
Epoch: 10, loss_t: 4.7016e-03, loss_v: 4.6963e-03, dt: 1.87, time: 4.04
Epoch: 11, loss_t: 4.6278e-03, loss_v: 4.6031e-03, dt: 1.86, time: 4.12
Epoch: 12, loss_t: 4.5490e-03, loss_v: 4.5586e-03, dt: 1.66, time: 4.09
Epoch: 13, loss_t: 4.4901e-03, loss_v: 4.5198e-03, dt: 1.83, time: 4.14
Epoch: 14, loss_t: 4.4351e-03, loss_v: 4.4338e-03, dt: 1.82, time: 4.15
E