In [None]:
import tensorflow as tf
from tensorflow import keras
from keras.datasets import mnist

import matplotlib.pyplot as plt
import numpy as np

In [None]:
def display_digits(digits):
    fig, axis = plt.subplots(5, 10)
    fig.tight_layout(pad=-1)
    plt.gray()
    a = 0

    for i in range(5):
        for j in range(10):
            axis[i, j].imshow(tf.squeeze(digits[a]))
            axis[i, j].xaxis.set_visible(False)
            axis[i, j].yaxis.set_visible(False)

            a += 1

def add_noise(data, noise_factor):
    noisy_data = data + noise_factor * np.random.normal(loc=0.0, scale=1, size=data.shape)
    noisy_data = tf.clip_by_value(noisy_data, clip_value_min=0.0, clip_value_max=1.0)
    return noisy_data

def get_and_preprocess_data(noise_strength=0.3):
    (Xtrain, _), (Xtest, _) = mnist.load_data()

    Xtrain = Xtrain / 255
    Xtest = Xtest / 255

    Xtrain = Xtrain[..., tf.newaxis]
    Xtest = Xtest[..., tf.newaxis]

    Xtrain_noisy = add_noise(Xtrain, noise_strength)
    Xtest_noisy = add_noise(Xtest, noise_strength)

    return (Xtrain_noisy, Xtrain), (Xtest_noisy, Xtest)

def make_model():
    model = keras.models.Sequential([
        keras.layers.Input(shape=(28, 28, 1)),

        keras.layers.Conv2D(128, (2, 2), activation="relu", padding="same"),
        keras.layers.BatchNormalization(),

        keras.layers.Conv2D(128, (2, 2), activation="relu", padding="same", strides=(2, 2)),
        keras.layers.BatchNormalization(),

        keras.layers.Conv2D(128, (2, 2), activation="relu", padding="same"),
        keras.layers.BatchNormalization(),

        keras.layers.Conv2D(128, (2, 2), activation="relu", padding="same", strides=(2, 2)),
        keras.layers.BatchNormalization(),

        keras.layers.Conv2D(512, (2, 2), activation="relu", padding="same"),

        keras.layers.Conv2DTranspose(512, (2, 2), strides=(2, 2), activation="relu"),
        keras.layers.BatchNormalization(),

        keras.layers.Conv2D(256, (2, 2), activation="relu", padding="same"),
        keras.layers.BatchNormalization(),

        keras.layers.Conv2D(256, (2, 2), activation="relu", padding="same"),
        keras.layers.BatchNormalization(),

        keras.layers.Conv2D(128, (2, 2), activation="relu", padding="same"),
        keras.layers.BatchNormalization(),

        keras.layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), activation="relu", padding="same"),
        keras.layers.Conv2D(64, (2, 2), activation="relu", padding="same"),
        keras.layers.BatchNormalization(),
    
        keras.layers.Conv2D(1, (2, 2), activation="relu", padding="same")
    ])

    model.compile(loss=keras.losses.MeanSquaredError(), optimizer="adam", metrics=["mse"])

    return model

In [None]:
# Testing Model

model = keras.models.load_model("MNIST-denoising-autoencoder-10-epochs.model")

(Xtrain, Ytrain), (Xtest, Ytest) = get_and_preprocess_data(noise_strength=0.5)

original = Ytest[:50]
noised = Xtest[:50]
denoised = model.predict(noised)

display_digits(noised)
display_digits(original)
display_digits(denoised)

In [None]:
# Training and Visualizing

(Xtrain, Ytrain), (Xtest, Ytest) = get_and_preprocess_data(noise_strength=0.3)

i = 0

class TrainingVisualization(keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs):

        global i

        if batch % 5 == 0:
            mse = logs["mse"]
            denoised_demonstration_digit = self.model.predict(np.array([Xtest[1]]), verbose=0)

            plt.imsave(f"/content/training/{i}-{mse}.png", denoised_demonstration_digit.reshape(28, 28))

            i += 1

model = make_model()
model.fit(
    Xtrain,
    Ytrain,
    batch_size=256,
    epochs=10,
    callbacks=[TrainingVisualization()]
)