In [None]:
import tensorflow as tf
from tensorflow.keras import *
import numpy as np
import matplotlib.pyplot as plt

In [2]:
def conv_block(filters, kernel_size, strides=(1, 1), padding="valid"):
    return Sequential([
        layers.Conv2D(filters, kernel_size, strides, padding),
        layers.BatchNormalization(),
        layers.ReLU()
    ])

In [None]:
def deconv_block(filters, kernel_size, strides=(1, 1), padding="valid"):
    return Sequential([
        layers.Conv2DTranspose(filters, kernel_size, strides, padding),
        layers.BatchNormalization(),
        layers.LeakyReLU(0.2)
    ])

In [19]:
def get_loss(y, y_pred):
    return tf.reduce_sum((y - y_pred)**2) / 2

In [8]:
def add_noise(img, noise_rate=0.3):
    noise = tf.random.normal(img.shape[1:])
    noisy_img = img + noise_rate * noise
    noisy_img = tf.clip_by_value(noisy_img, 0, 1)
    return noisy_img

In [None]:
enc = Sequential([
    layers.Input([256, 256, 3]),
    conv_block(16, 3),
    conv_block(16, 3),
    layers.MaxPool2D(),
    conv_block(32, 3),
    conv_block(32, 3),
    layers.MaxPool2D(),
    conv_block(64, 4),
    conv_block(64, 3),
    layers.MaxPool2D(),
    conv_block(128, 3),
    conv_block(128, 3),
    layers.MaxPool2D(),
    conv_block(256, 3),
    conv_block(256, 3),
    layers.MaxPool2D(),
    layers.Conv2D(256, 4)
])

In [4]:
dec = Sequential([
    layers.Input([1, 1, 256]),
    deconv_block(256, 4),
    layers.UpSampling2D(),
    deconv_block(256, 3),
    deconv_block(256, 3),
    layers.UpSampling2D(),
    deconv_block(128, 3),
    deconv_block(128, 3),
    layers.UpSampling2D(),
    deconv_block(64, 3),
    deconv_block(64, 4),
    layers.UpSampling2D(),
    deconv_block(32, 3),
    deconv_block(32, 3),
    layers.UpSampling2D(),
    deconv_block(16, 3),
    layers.Conv2DTranspose(3, 3)
])

In [5]:
model = Sequential([
    layers.Input([256, 256, 3]),
    enc,
    dec
])

In [20]:
model.compile(
    optimizer=optimizers.Adam(5e-4),
    loss=get_loss
)

In [10]:
raw_dataset = utils.image_dataset_from_directory(
    "../input/cat-and-dog/training_set/cats/",
    label_mode=None,
    batch_size=128,
    image_size=(256, 256)
)

dataset = raw_dataset.map(lambda img: (add_noise(img/255.0), img/255.0))

In [None]:
history = model.fit(dataset, batch_size=128, epochs=100)

In [9]:
raw = utils.load_img("../input/cat-and-dog/testing_set/cats/cat.1.jpg",
                     target_size=(256, 256))
img = utils.img_to_array(raw)
img /= 255.0
img = img.reshape([1] + list(img.shape))
plt.imshow(add_noise(img, 0.35)[0])

In [40]:
# result after 100 epochs
plt.imshow(model(img)[0])

In [41]:
plt.plot(history["loss"])