In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [7]:
# coding: utf-8
import os
import cv2
import h5py
import numpy as np
from keras import backend as K
from keras.layers import (
    BatchNormalization,
    Conv2D,
    Input,
    Lambda,
    LeakyReLU,
    UpSampling2D,
    concatenate,
    Layer,
    InputSpec
)
from keras.models import Model
from keras.optimizers import Adam
import numpy.typing as npt
import matplotlib.pyplot as plt
from keras.utils import plot_model
from skimage.exposure import rescale_intensity

import tensorflow as tf
# from keras.engine import Layer
# from keras.engine import InputSpec

# https://stackoverflow.com/questions/50677544/reflection-padding-conv2d
class ReflectionPadding2D(Layer):
    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        self.input_spec = [InputSpec(ndim=4)]
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def get_output_shape_for(self, s):
        """ If you are using "channels_last" configuration"""
        return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])

    def call(self, x, mask=None):
        w_pad,h_pad = self.padding
        return tf.pad(x, [[0, 0], [0, w_pad], [0, h_pad], [0, 0] ], 'REFLECT')


# def reflection_padding(x, padding):
#     reflected = Lambda(lambda x: x[:, :, ::-1, :])(x)
#     reflected = Lambda(lambda x: x[:, :, : padding[1], :])(reflected)
#     upper_row = concatenate([x, reflected], axis=2)
#     lower_row = Lambda(lambda x: x[:, ::-1, :, :])(upper_row)
#     lower_row = Lambda(lambda x: x[:, : padding[0], :, :])(lower_row)
#     padded = concatenate([upper_row, lower_row], axis=1)
#     return padded


def conv_bn_relu(x, size, filters, kernel_size, strides):
    padding = [0, 0]
    padding[0] = (int(size[0] / strides[0]) - 1) * strides[0] + kernel_size - size[0]
    padding[1] = (int(size[1] / strides[1]) - 1) * strides[1] + kernel_size - size[1]
    x = ReflectionPadding2D(padding=padding)(x)

    x = Conv2D(filters, kernel_size, strides=strides)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)

    new_size = [int(size[0] / strides[0]), int(size[1] / strides[1])]
    return x, new_size


def down_sampling(x, size, filters, kernel_size):
    new_size = [size[0], size[1]]
    if size[0] % 2 != 0:
        x = ReflectionPadding2D((1, 0))(x)
        new_size[0] = size[0] + 1
    if size[1] % 2 != 0:
        x = ReflectionPadding2D((0, 1))(x)
        new_size[1] = size[1] + 1
    size = new_size
    x, size = conv_bn_relu(x, size, filters, kernel_size, (2, 2))
    x, size = conv_bn_relu(x, size, filters, kernel_size, (1, 1))
    return x, size


def upsample(x, size, inter):
    x = UpSampling2D(size=(2, 2))(x)
    if inter == "bilinear":
        x_padded = ReflectionPadding2D((1, 1))(x)
        x = Lambda(
            lambda x: (
                x[:, :-1, 1:, :]
                + x[:, 1:, :-1, :]
                + x[:, :-1, :-1, :]
                + x[:, :-1, :-1, :]
            )
            / 4.0
        )(x_padded)
    return x, [size[0] * 2, size[1] * 2]


def up_sampling(x, size, filters, kernel_size, inter):
    x, size = upsample(x, size, inter)
    x, size = conv_bn_relu(x, size, filters, kernel_size, (1, 1))
    x, size = conv_bn_relu(x, size, filters, 1, (1, 1))
    return x, size


def skip(x, size, filters, kernel_size):
    x, size = conv_bn_relu(x, size, filters, kernel_size, (1, 1))
    return x, size


def define_model(
    num_u,
    num_d,
    kernel_u,
    kernel_d,
    num_s,
    kernel_s,
    height,
    width,
    inter,
    lr,
    input_channel=32,
):
    depth = len(num_u)
    size = [height, width]

    inputs = Input(shape=(height, width, input_channel))

    x = inputs
    down_sampled = []
    sizes = [size]
    for i in range(depth):
        x, size = down_sampling(x, size, num_d[i], kernel_d[i])
        down_sampled.append(x)
        sizes.append(size)

    for i in range(depth - 1, -1, -1):
        if num_s[i] != 0:
            skipped, size = skip(down_sampled[i], size, num_s[i], kernel_s[i])
            x = concatenate([x, skipped], axis=3)
        x, size = up_sampling(x, size, num_u[i], kernel_u[i], inter)

        if sizes[i] != size:
            x = Lambda(lambda x: x[:, : sizes[i][0], : sizes[i][1], :])(x)
            size = sizes[i]

    x = Conv2D(3, 1)(x)
    model = Model(inputs, x)

    return model

def define_denoising_model(height, width):
    num_u = [128, 128, 128, 128, 128]
    num_d = [128, 128, 128, 128, 128]
    kernel_u = [3, 3, 3, 3, 3]
    kernel_d = [3, 3, 3, 3, 3]
    num_s = [4, 4, 4, 4, 4]
    kernel_s = [1, 1, 1, 1, 1]
    lr = 0.01
    inter = "bilinear"

    model = define_model(
        num_u, num_d, kernel_u, kernel_d, num_s, kernel_s, height, width, inter, lr
    )
    model.compile(loss="mse", optimizer=Adam(learning_rate=lr))

    return model


def denoising(image:npt.NDArray):
    height, width = image.shape[:2]
    model = define_denoising_model(height, width)
    input_noise = np.random.uniform(0, 0.1, (1, height, width, 32))

    print("Starting training:")
    for i in range(1800):
        x = input_noise + np.random.normal(0, 1 / 30.0, (height, width, 32))
        
        if i % 100 == 0:
          output = rescale_intensity(model.predict_on_batch(x)[0], out_range="uint8")
          cv2.imwrite(f"/content/drive/MyDrive/deep image prior/img_003_SRF_2_HR_denoised - {i}.png", output)

        metrics = model.train_on_batch(
            x=x,
            y=image[None, :, :, :],
            return_dict=True,
        )

        if i % 100 == 0:
          print("Epoch: %d, Loss: %f" % (i, metrics["loss"]))

    return rescale_intensity(model.predict(input_noise)[0], out_range="uint8")

In [None]:
model = define_denoising_model(320, 512)
plot_model(model, to_file="/content/drive/MyDrive/deep image prior/model.png", show_shapes=True, show_layer_activations=True)

In [5]:
from skimage.metrics import peak_signal_noise_ratio, mean_squared_error, structural_similarity
from typing import Dict
def img_compare(img1: npt.NDArray, img2: npt.NDArray) -> Dict[str, float]:
    """
    Calculates and returns the MSE and SSIM similarity metrics between two images.

    Args:
        img1 (npt.NDArray): Image 1 to compare
        img2 (npt.NDArray): Image 2 to compare

    Returns:
        Dict[str, float]: MSE, SSIM, PSNR
    """
    mse = mean_squared_error(img1, img2)
    ssim = structural_similarity(
        img1,
        img2,
        channel_axis=2,
        gaussian_weights=True,
        sigma=1.5,
        use_sample_covariance=False,
        multichannel=True
    )
    psnr = peak_signal_noise_ratio(img1, img2)

    return {
        "mse": mse,
        "ssim": ssim,
        "psnr": psnr,
    }

In [8]:
# Image dimensions individually must be a multiple of 32, but it doesn't need to be square.
# eg. 320 x 512 works. 330 x 512 doesn't
# Image will be cropped to nearest multiple of 32 for each dimension

original_img = cv2.imread("/content/drive/MyDrive/deep image prior/img_003_SRF_2_HR.png")
height, width = original_img.shape[:2]
height_cropped = height - height % 32
width_cropped = width - width % 32
original_img = original_img[:height_cropped, :width_cropped, :]
print(original_img.shape)

noisy_img = cv2.imread("/content/drive/MyDrive/deep image prior/img_003_SRF_2_HR_noisy.png")
noisy_img = original_img[:height_cropped, :width_cropped, :]

denoised_img = denoising(noisy_img)
cv2.imwrite("/content/drive/MyDrive/deep image prior/img_003_SRF_2_HR_denoised.png", denoised_img)

print("Clean vs Noisy", img_compare(original_img, noisy_img))
print("Denoised vs Noisy", img_compare(denoised_img, noisy_img))

(320, 512, 3)
Starting training:
Epoch: 0, Loss: 12430.856445
Epoch: 100, Loss: 967.370483
Epoch: 200, Loss: 620.559937
Epoch: 300, Loss: 478.816559
Epoch: 400, Loss: 385.919800
Epoch: 500, Loss: 322.532532
Epoch: 600, Loss: 284.199493
Epoch: 700, Loss: 245.544388
Epoch: 800, Loss: 217.631958
Epoch: 900, Loss: 191.074905
Epoch: 1000, Loss: 174.652176
Epoch: 1100, Loss: 152.921402
Epoch: 1200, Loss: 144.790970
Epoch: 1300, Loss: 127.774269
Epoch: 1400, Loss: 117.783981
Epoch: 1500, Loss: 106.582741
Epoch: 1600, Loss: 98.969193
Epoch: 1700, Loss: 91.347252
Clean vs Noisy {'mse': 0.0, 'ssim': 1.0, 'psnr': inf}
Denoised vs Noisy {'mse': 147.88871053059896, 'ssim': 0.8495263817920445, 'psnr': 26.43145338604619}


  return 10 * np.log10((data_range ** 2) / err)
