In [1]:
%load_ext autoreload
%autoreload 2

import os
import time
from tqdm import tqdm_notebook as tqdm

import numpy as np
import tensorflow as tf
from tensorflow.keras import models, layers, losses, optimizers

import cv2

from core.layers import Downsampler, Upsampler, Skip
from core.losses import pixelwise_mse
from core.callbacks import SaveResultImage
from core.utils import crop_div_32

import imageio
import matplotlib.pyplot as plt

In [None]:
INPUT_IMG_PATH = "data/lena.png"
OUTPUT_IMG_NAME = os.path.splitext(os.path.basename(INPUT_IMG_PATH))[0]
OUTPUT_DIR = "output/image_reconstruction"
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

original_img = imageio.imread(INPUT_IMG_PATH)
original_img = crop_div_32(original_img)


mask = np.random.choice([1, 0], original_img.shape[:2], p=[0.75, 0.25]).astype(np.uint8)
img = cv2.bitwise_and(original_img, original_img, mask=mask)

img = tf.image.convert_image_dtype(img, dtype=tf.float32)
plt.imshow(img)

In [None]:
input_dim = img.shape[:2] + (32,)

num_filters_down = [128, 128, 128, 128, 128]
ksizes_down = [3, 3, 3, 3, 3]

num_filters_up = num_filters_down
ksizes_up = ksizes_down

num_filters_skip = [4, 4, 4, 4, 4]
ksizes_skip = [1, 1, 1, 1, 1]

sigma_p = 1/30
n_iter = 11000
lr = 0.001
upsampling_mode = "bilinear"

# Build the model
model = skip(input_dim, 
             num_filters_down, ksizes_down, 
             num_filters_up, ksizes_up,
             num_filters_skip, ksizes_skip,
             upsampling_mode, sigma_p=sigma_p)

model.summary()

In [None]:
z = np.random.uniform(0, 0.1, img.shape[:2] + (32,))
z = np.expand_dims(z, axis=0)
y_true = np.expand_dims(img, axis=0)

callbacks = [
    SaveResultImage(n=200, input_tensor=z, output_dir=OUTPUT_DIR, img_name=OUTPUT_IMG_NAME)
]

# Extend the mask to 3 channels
mask = np.transpose(np.array([mask] * 3), [1, 2, 0])

model.compile(loss=pixelwise_mse(mask=mask),
              optimizer=optimizers.Adam(lr=lr))

model.fit(z, y_true, 
          epochs=n_iter,
          callbacks=callbacks)