In [1]:
from enum import auto
import matplotlib as matplt
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Conv2D, MaxPool2D, Flatten, Reshape, InputLayer, Conv2DTranspose, UpSampling2D


In [2]:
img_size = 256
batch_size = 64
train_dir = 'testing_database'

input_shape = (img_size, img_size, 3)

image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1.0/255)
train_ds = image_generator.flow_from_directory( 
    train_dir, 
    class_mode='input', target_size=(img_size,img_size), batch_size=batch_size,
)

Found 225 images belonging to 1 classes.


In [3]:

class AutoEncoder(tf.keras.models.Model):
  def __init__(self):
    super(AutoEncoder, self).__init__()
    self.encoder = tf.keras.Sequential()
    self.encoder.add(InputLayer(input_shape=(img_size, img_size, 3)))
    self.encoder.add(Conv2D(16, (3, 3), activation='relu', padding='same'))
    self.encoder.add(MaxPool2D(pool_size=(2, 2), strides=(2, 2)))
    self.encoder.add(Conv2D(8, (3, 3), activation='relu', padding='same'))
    self.encoder.add(MaxPool2D(pool_size=(2, 2), strides=(2, 2)))
    self.encoder.add(Conv2D(4, (3, 3), activation='relu', padding='same'))
    self.encoder.add(MaxPool2D(pool_size=(2, 2), strides=(2, 2)))
    self.encoder.add(Flatten())
    self.encoder.add(Dense(units=32, activation="relu"))


    self.decoder = tf.keras.Sequential()
    self.decoder.add(InputLayer(input_shape=(32)))
    self.decoder.add(Dense(units=4096, activation="relu"))
    self.decoder.add(Reshape((32, 32, 4)))
    self.decoder.add(UpSampling2D(size=(2, 2)))
    self.decoder.add(Conv2DTranspose(4, (3, 3), activation='relu', padding='same'))
    self.decoder.add(UpSampling2D(size=(2, 2)))
    self.decoder.add(Conv2DTranspose(8, (3, 3), activation='relu', padding='same'))
    self.decoder.add(UpSampling2D(size=(2, 2)))
    self.decoder.add(Conv2DTranspose(16, (3, 3), activation='relu', padding='same'))
    self.decoder.add(Conv2D(3, kernel_size=(3, 3), activation='sigmoid', padding='same'))

  def call(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded



In [4]:

class PlotLearning(tf.keras.callbacks.Callback):
    """
    Callback to plot the learning curves of the model during training.
    """
    def on_train_begin(self, logs={}):
        self.metrics = {}
        for metric in logs:
            self.metrics[metric] = []

    def on_epoch_end(self, epoch, logs={}):
        # Storing metrics
        for metric in logs:
            if metric in self.metrics:
                self.metrics[metric].append(logs.get(metric))
            else:
                self.metrics[metric] = [logs.get(metric)]

        # Plotting
        plt.figure(figsize=(15, 5))
        plt.plot(range(1, epoch + 2), self.metrics["loss"])
        plt.grid()
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Progress of loss function during training")
        plt.tight_layout()
        plt.savefig("training_progress.png")
        plt.figure().clear()
        plt.close()
        plt.cla()
        plt.clf()

In [5]:


autoencoder = AutoEncoder()
autoencoder.compile(optimizer='adam', loss='mse')

autoencoder.encoder.summary()
autoencoder.decoder.summary()

checkpoint_path = "ae.ckpt"
#model.load_weights(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    period=10,
    verbose=1,
    monitor='loss',
    mode='min',
    save_best_only=True)

history = autoencoder.fit(train_ds,
                          epochs=1000,
                          batch_size=batch_size, 
                          shuffle=True,
                          callbacks=[cp_callback, PlotLearning()],
                          verbose=1
                          )

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 256, 256, 16)      448       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 128, 128, 16)     0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 128, 128, 8)       1160      
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 64, 64, 8)        0         
 2D)                                                             
                                                                 
 conv2d_2 (Conv2D)           (None, 64, 64, 4)         292       
                                                                 
 max_pooling2d_2 (MaxPooling  (None, 32, 32, 4)        0

KeyboardInterrupt: 

In [6]:
print(tf.__version__)

2.10.1


In [7]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  0


In [8]:
tf.config.list_physical_devices('GPU')

[]