<a href="https://colab.research.google.com/github/baroodb/code/blob/main/auto_encoders.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import os
from tensorflow.keras.layers import Conv2D, ReLU, BatchNormalization, Dropout, Dense, Input, Flatten, Reshape, Conv2DTranspose, Activation
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K

In [None]:
IMG_SHAPE = (224, 224, 3)
latent_dim = 64
filters = [32, 64]

inputs = Input(shape=IMG_SHAPE)
x = inputs

for filter in filters:
  x = Conv2D(filter, kernel_size=3, strides=2, padding='same')(x)
  x = ReLU()(x)
  x = BatchNormalization()(x)
volume_size = K.int_shape(x)
x = Flatten()(x)
latent_space = Dense(latent_dim)(x)

encoder = Model(inputs, latent_space, name='encoder')
# here we start the decoder part 
decoder_input = Input(shape=(latent_dim,))
x = Dense(units=np.prod(volume_size[1:]))(decoder_input)
x = Reshape(target_shape=(volume_size[1], volume_size[2], volume_size[3]))(x)

for filter in filters[::-1]:
  x = Conv2DTranspose(filter, kernel_size=3, strides=2, padding='same')(x)
  x = ReLU()(x)
  x = BatchNormalization()(x)

x = Conv2DTranspose(3, kernel_size=3, padding='same')(x)
decoder_output = Activation('sigmoid')(x)

decoder = Model(decoder_input, decoder_output)

auto_encoder = Model(inputs, decoder(encoder(inputs)))

In [None]:
auto_encoder.summary()

Model: "model_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_22 (InputLayer)        [(None, 224, 224, 3)]     0         
_________________________________________________________________
encoder (Functional)         (None, 64)                12864896  
_________________________________________________________________
model_3 (Functional)         (None, 224, 224, 3)       13102403  
Total params: 25,967,299
Trainable params: 25,966,915
Non-trainable params: 384
_________________________________________________________________
