In [34]:
# Preparations
import tensorflow as tf
import numpy as np
import cv2
import os
np.random.seed(222)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')

In [42]:
# Const def
batch_size = 128
h_size = 10
learning_rate = 1e-3

In [92]:
# Utils
## Save 10 images
def save_images(images, name):
    new_img = np.zeros((320,100,3), dtype=np.float32)
    idx = 0
    for i in range(0, 320, 32):
        im_o = images[idx]
        im_n = images[idx+batch_size]
        new_img[i:i+32,0:32] = im_o.copy()
        new_img[i:i+32,50:82] = im_n.copy()
    new_img = new_img.clip(0,255).astype(np.uint8)
    cv2.imwrite(name, new_img)

In [70]:
# Data
## Load data to memory
db = tf.keras.datasets.cifar10
(_x, y), (_xt, yt) = db.load_data()
in_shape = _x.shape
o_shape = y.shape
print(in_shape, o_shape)
## Regualizatiom
x, xt = _x.astype(np.float32)/255., _xt.astype(np.float32)/255.
## Create dateset
train_x = tf.data.Dataset.from_tensor_slices(x)
train_x = train_x.shuffle(2000).batch(batch_size)
test_x = tf.data.Dataset.from_tensor_slices(xt)
test_x = test_x.batch(batch_size)

(50000, 32, 32, 3) (50000, 1)


In [95]:
class DAE(tf.keras.Model):
    def __init__(self):
        super(DAE, self).__init__()
        
        self.conv = tf.keras.layers.Conv2D(32, (3,3), padding='same', strides=1)
        self.bn = tf.keras.layers.BatchNormalization()
        self.relu = tf.nn.relu()
        
        self.block = tf.keras.Sequential()
        self.block.add(self.conv)
        self.block.add(self.bn)
        self.block.add(relu)
        
        self.encoder_layer = tf.keras.Sequential()
        self.encoder_layer(tf.keras.layers.UpSampling2D((2,2)))
        self.encoder_layer(self.block)
        
        self.decoder_layer = tf.keras.Sequential()
        self.decoder_layer.add(tf.keras.layers.MaxPool2D((2,2)))
        self.decoder_layer.add(self.block)
        
        self.encoder = tf.keras.Sequential()
        self.encoder.add(encoder_layer)
        self.encoder.add(encoder_layer)
        self.encoder
        
        
    def call(self, inputs, training=None):
        h = self.encoder(inputs)
        x_ = self.decoder(h)
        
        return x_

In [96]:
model = DAE()
model.build(input_shape=(None, 32, 32, 3))
model.summary()

Model: "dae_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential_10 (Sequential)   (None, 10)                3844682   
_________________________________________________________________
sequential_11 (Sequential)   (None, 3072)              3847744   
Total params: 7,692,426
Trainable params: 7,692,426
Non-trainable params: 0
_________________________________________________________________


In [100]:
# Train
optimizer = tf.optimizers.Adam(lr=learning_rate)
for epoch in range(20):
    for step, x in enumerate(train_x):
        x=tf.reshape(x, [-1, 32*32*3])
        
        with tf.GradientTape() as tape:
            x_logits = model(x)
            rec_loss = tf.losses.binary_crossentropy(x, x_logits, from_logits=True)
            rec_loss = tf.reduce_mean(rec_loss)
        grad = tape.gradient(rec_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grad, model.trainable_variables))
        
        if step % 200 == 0:
            print(epoch, step, float(rec_loss))
            
    # Evaluation
    x = next(iter(test_x))
    x_flat = tf.reshape(x, [-1, 32*32*3])
    logits = model(x_flat)
    x_ = tf.sigmoid(logits)
    x_img = tf.reshape(x_, [-1, 32, 32, 3])
    
    # Concate original img & new img to compare
    x_toshow = tf.concat([x, x_img], axis=0)
    #cv2.imshow(x)
    
    x_toshow = x_toshow.numpy() * 255. 
    x_toshow = x_toshow.astype(np.uint8)
    save_images(x_toshow, 'Tmp/Result_%d.jpg'%epoch)

0 0 0.5978628993034363
0 200 0.5973314046859741
1 0 0.6093522310256958
1 200 0.5990936160087585
2 0 0.599271297454834
2 200 0.6011543869972229
3 0 0.5994608402252197
3 200 0.5955417156219482
4 0 0.6054425835609436
4 200 0.5852481126785278
5 0 0.593393087387085
5 200 0.6010866761207581
6 0 0.5944775342941284
6 200 0.6061115264892578
7 0 0.6002156734466553
7 200 0.5908162593841553
8 0 0.5991817712783813
8 200 0.5983622074127197
9 0 0.6012532711029053
9 200 0.5996942520141602
10 0 0.6023633480072021
10 200 0.5912548899650574
11 0 0.599865198135376
11 200 0.5971051454544067
12 0 0.5959611535072327
12 200 0.5939546823501587
13 0 0.6047912240028381
13 200 0.5931234359741211
14 0 0.6033704876899719
14 200 0.5980308651924133
15 0 0.6040289998054504
15 200 0.5934886932373047
16 0 0.5922679305076599
16 200 0.595300555229187
17 0 0.5968851447105408
17 200 0.5943707227706909
18 0 0.5974487662315369
18 200 0.5859962105751038
19 0 0.6001567840576172
19 200 0.5819618105888367
