In [1]:
%load_ext autoreload
%autoreload 2

import os
from tqdm import tqdm_notebook as tqdm

import numpy as np
import tensorflow as tf
from tensorflow.keras import models, layers, losses, optimizers

from core.layers import Downsampler, Upsampler, Skip
from core.utils import crop_div_32

import imageio
import matplotlib.pyplot as plt

In [2]:
OUTPUT_DIR = "output/jpeg_denoising"

img = imageio.imread("data/snail.jpg")
img = crop_div_32(img)
img = tf.image.convert_image_dtype(img, dtype=tf.float32)

factor = 4

In [3]:
num_filters_down = [8, 16, 32, 64, 128]
ksizes_down = [3, 3, 3, 3, 3]

num_filters_up = num_filters_down
ksizes_up = ksizes_down

num_filters_skip = [0, 0, 0, 4, 4]
ksizes_skip = [0, 0, 0, 1, 1]

sigma_p = 1/30
n_iter = 2000
lr = 0.01
upsampling_mode = "bilinear"

In [4]:
skip_outputs = [None] * len(num_filters_skip)
model_input = layers.Input(shape=img.shape[:2] + (32,), dtype=tf.float32)
x = model_input

for i in range(len(num_filters_down)):
    x = Downsampler(num_filters_down[i],
                    ksizes_down[i])(x)
    if num_filters_skip[i]:
        skip_outputs[i] = Skip(num_filters_skip[i], ksizes_skip[i])(x)

    
for i in range(len(num_filters_up) - 1, -1, -1):
    if num_filters_skip[i]:
        x = tf.concat((x, skip_outputs[i]), axis=3)
    x = Upsampler(num_filters_up[i],
                  ksizes_up[i],
                  scale_factor=2,
                  upsampling_mode=upsampling_mode)(x)

    
# Transform to original 3-channel image
model_output = layers.Conv2D(filters=3, kernel_size=1, 
                             strides=1, padding='SAME', 
                             activation='sigmoid')(x)
model = models.Model(inputs=model_input, outputs=model_output)

model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 256, 384, 32 0                                            
__________________________________________________________________________________________________
downsampler (Downsampler)       (None, 128, 192, 8)  2960        input_1[0][0]                    
__________________________________________________________________________________________________
downsampler_1 (Downsampler)     (None, 64, 96, 16)   3616        downsampler[0][0]                
__________________________________________________________________________________________________
downsampler_2 (Downsampler)     (None, 32, 48, 32)   14144       downsampler_1[0][0]              
______________________________________________________________________________________________

In [5]:
opt = optimizers.Adam(lr=lr)

def pixelwise_mse(y_true, y_pred):
    batch_size = y_true.shape[0]
    y_true = tf.reshape(y_true, (batch_size, -1))
    y_pred = tf.reshape(y_pred, (batch_size, -1))
    
    return tf.keras.losses.mean_squared_error(y_true, y_pred)

z = np.random.uniform(0, 0.1, img.shape[:2] + (32,))
z = np.expand_dims(z, axis=0)
y_true = np.expand_dims(img, axis=0)
loss_vals = []

for it in tqdm(range(1, n_iter + 1)):
    with tf.GradientTape() as tape:
        y_pred = model(z, training=True)
        main_loss = pixelwise_mse(y_true, y_pred)
        loss_vals.append(main_loss.numpy())
        loss = tf.add_n([main_loss] + model.losses)
    grads = tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(grads, model.trainable_variables))
    
    # Add perturb noise to z
    z += np.random.randn(*z.shape) * sigma_p
    if it % 100 == 0 or it == 1:
        print(f"Iter {it}: loss={loss.numpy()}")
        output_img = tf.image.convert_image_dtype(y_pred.numpy().squeeze(), dtype=tf.uint8)
        output_path = os.path.join(OUTPUT_DIR, f"snail_{it}.jpg")
        imageio.imsave(output_path, output_img)

HBox(children=(IntProgress(value=0, max=2000), HTML(value='')))



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

Iter 1: loss=[0.20074014]
Iter 100: loss=[0.02769649]
Iter 200: loss=[0.01566191]
Iter 300: loss=[0.00766597]
Iter 400: loss=[0.00514044]
Iter 500: loss=[0.00475298]
Iter 600: loss=[0.00411579]
Iter 700: loss=[0.00393403]
Iter 800: loss=[0.00350419]
Iter 900: loss=[0.00382191]
Iter 1000: loss=[0.00309782]
Iter 1100: loss=[0.00329131]
Iter 1200: loss=[0.00284968]
Iter 1300: loss=[0.00294562]
Iter 1400: loss=[0.00410508]
Iter 1500: loss=[0.00285121]
Iter 1600: loss=[0.00264457]
Iter 1700: loss=[0.0027393]
Iter 1800: loss=[0.00277589]
Iter 1900: loss=[0.00259325]
Iter 2000: loss=[0.00294586]

