In [67]:
import tensorflow as tf
from tensorflow.keras import layers, Model
import cv2
import numpy as np
import matplotlib.pyplot as plt
from keras.layers import Add

In [41]:
def load_image(path):
    image = tf.io.read_file(path)
    image = tf.image.decode_image(image)
    image = tf.image.resize(image, (94, 94))
    image = tf.cast(image, tf.float32)
    image = image / 255.0
    image = tf.expand_dims(image, axis=0)

    return image

In [81]:
input_shape = (94, 94, 3)
latent_dim = 32

class Encoder(Model):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = tf.keras.Sequential([
            layers.InputLayer(input_shape=input_shape),
            layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
            layers.MaxPooling2D((3, 3), padding='same'),
            layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
            layers.MaxPooling2D((3, 3), padding='same'),
            layers.Flatten()
        ])
        self.mean = layers.Dense(latent_dim)
        self.log_var = layers.Dense(latent_dim)
    
    def call(self, X):
        X = self.encoder(X)
        mean = self.mean(X)
        log_var = self.log_var(X)
        return mean, log_var

class Decoder(Model):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.decoder = tf.keras.Sequential([
            layers.InputLayer(input_shape=(1, latent_dim)),
            layers.Dense(7 * 7 * 64, activation='relu'),
            layers.Reshape((7, 7, 64)),
            layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'),
            layers.UpSampling2D((2,2)),
            layers.Conv2DTranspose(32, (3, 3), activation='relu', padding='same'),
            layers.UpSampling2D((2,2)),
            layers.Conv2DTranspose(1, (3, 3), activation='sigmoid', padding='same')
        ])
    
    def call(self, X):
        return self.decoder(X)

In [82]:
image = load_image('./data/2_2000.jpg')
encoder = Encoder(latent_dim=latent_dim)
mean, log_var = encoder(image)
z = Add()([mean, log_var])

decoder = Decoder(latent_dim=latent_dim)
decoder(tf.expand_dims(z, axis=0))

<tf.Tensor: shape=(1, 28, 28, 1), dtype=float32, numpy=
array([[[[0.5023215 ],
         [0.5012458 ],
         [0.502522  ],
         [0.50047255],
         [0.49950603],
         [0.50233066],
         [0.50212896],
         [0.50198793],
         [0.5035004 ],
         [0.5027773 ],
         [0.5014228 ],
         [0.50191003],
         [0.50206953],
         [0.50100285],
         [0.50034547],
         [0.49785408],
         [0.4982706 ],
         [0.49785233],
         [0.49828836],
         [0.50055414],
         [0.5005758 ],
         [0.5034149 ],
         [0.50391567],
         [0.503051  ],
         [0.5024207 ],
         [0.50215864],
         [0.50125813],
         [0.49957442]],

        [[0.50333726],
         [0.5005828 ],
         [0.5023642 ],
         [0.5008646 ],
         [0.5018056 ],
         [0.5052104 ],
         [0.5045795 ],
         [0.5026359 ],
         [0.50200164],
         [0.49921715],
         [0.5004836 ],
         [0.49989742],
         [0.49896738],

In [72]:
tf.expand_dims(z, axis=0)

<tf.Tensor: shape=(1, 1, 32), dtype=float32, numpy=
array([[[-0.04353073,  0.31795764, -0.6333571 ,  0.34215534,
          0.1790723 ,  0.09230009, -0.7705061 ,  0.1187809 ,
         -0.07959988,  0.49292642,  0.12898296, -0.21240151,
          0.07456356,  0.30315936,  0.03565378,  0.499841  ,
         -0.27301455, -0.08600289, -0.7540641 ,  0.10767657,
          0.07046899,  0.11565666,  0.8243308 , -0.3151492 ,
         -0.14584413, -0.1511961 , -0.2056031 , -0.19813327,
         -0.3145393 , -0.12720178,  0.2785978 , -0.04838055]]],
      dtype=float32)>

<tf.Tensor: shape=(1, 32), dtype=float32, numpy=
array([[ 0.30966386, -0.00167593, -0.36596885,  0.04478392, -0.28022134,
        -0.22825491,  0.00319269, -0.13159831, -0.13743316,  0.21807921,
         0.16215336, -0.15709433,  0.4973972 ,  0.14520139,  0.05401339,
        -0.0888254 ,  0.07673918, -0.6179297 ,  0.26593304,  0.24588312,
         0.04651898,  0.08651538,  0.33737385,  0.04766154, -0.25304472,
         0.10589579,  0.2339116 , -0.01081447, -0.39712906, -0.39639586,
         0.16809629,  0.092714  ]], dtype=float32)>