# Image Inpainting Based on Partial Convolutions in Keras
---

In [None]:
from keras.callbacks import ModelCheckpoint, TensorBoard, CSVLogger
from keras.preprocessing.image import load_img, img_to_array
from inpainter_utils.pconv2d_data import DataGenerator, torch_preprocessing, torch_postprocessing
from inpainter_utils.pconv2d_model import pconv_model
import matplotlib.pyplot as plt
import numpy as np

# SETTINGS:
IMG_DIR_TRAIN   = "data/images/train/"
IMG_DIR_VAL     = "data/images/validation/"
IMG_DIR_TEST    = "data/images/test/"
VGG16_WEIGHTS   = "data/vgg16_weights/vgg16_pytorch2keras.h5"
WEIGHTS_DIR     = "callbacks/weights/"
TB_DIR          = "callbacks/tensorboard/"
CSV_DIR         = "callbacks/csvlogger/"
BATCH_SIZE      = 5
STEPS_PER_EPOCH = 2500
EPOCHS_STAGE1   = 70
EPOCHS_STAGE2   = 50
LR_STAGE1       = 0.0002
LR_STAGE2       = 0.00005
STEPS_VAL       = 100
BATCH_SIZE_VAL  = 4
IMAGE_SIZE      = (512, 512)

## Data generators

In [2]:
# DATA GENERATORS:
train_datagen   = DataGenerator(preprocessing_function=torch_preprocessing, horizontal_flip=True)
train_generator = train_datagen.flow_from_directory(
    IMG_DIR_TRAIN,
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE
)
val_datagen   = DataGenerator(preprocessing_function=torch_preprocessing)
val_generator = val_datagen.flow_from_directory(
    IMG_DIR_VAL,
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE_VAL,
    seed=22,
    mask_init_seed=1,
    total_steps=STEPS_VAL,
    shuffle=False
)
# Create testing generator
test_datagen = DataGenerator(preprocessing_function=torch_preprocessing)
test_generator = test_datagen.flow_from_directory(
    IMG_DIR_TEST,
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE
)

## Training
### Stage 1. Initial training (BN enabled)

In [3]:
#LAST_CHECKPOINT = "callbacks/weights/initial/weights.70-2.02-1.95.hdf5"
model = pconv_model(lr=LR_STAGE1, image_size=IMAGE_SIZE, vgg16_weights=VGG16_WEIGHTS)
#model.load_weights(LAST_CHECKPOINT)

In [None]:
model.fit_generator(
    train_generator,
    steps_per_epoch=STEPS_PER_EPOCH,
    epochs=EPOCHS_STAGE1,
    validation_data=val_generator,
    validation_steps=STEPS_VAL,
    callbacks=[
        CSVLogger(CSV_DIR + "initial/log.csv", append=True),
        TensorBoard(log_dir=TB_DIR + "initial/", write_graph=True),
        ModelCheckpoint(WEIGHTS_DIR + "initial/weights.{epoch:02d}-{val_loss:.2f}-{loss:.2f}.hdf5", monitor="val_loss", verbose=1, save_weights_only=True)
    ]
)

### Stage 2. Fine-tuning (BN frozen in encoder)

In [3]:
LAST_CHECKPOINT = WEIGHTS_DIR + "initial/weights.80-1.94-1.83.hdf5"
model = pconv_model(fine_tuning=True, lr=LR_STAGE2, image_size=IMAGE_SIZE, vgg16_weights=VGG16_WEIGHTS)
model.load_weights(LAST_CHECKPOINT)

In [None]:
model.fit_generator(
    train_generator,
    steps_per_epoch=STEPS_PER_EPOCH,
    initial_epoch=EPOCHS_STAGE1,
    epochs=EPOCHS_STAGE1 + EPOCHS_STAGE2,
    validation_data=val_generator,
    validation_steps=STEPS_VAL,
    callbacks=[
        CSVLogger(CSV_DIR + "fine_tuning/log.csv", append=True),
        TensorBoard(log_dir=TB_DIR + "fine_tuning/", write_graph=True),
        ModelCheckpoint(WEIGHTS_DIR + "fine_tuning/weights.{epoch:02d}-{val_loss:.2f}-{loss:.2f}.hdf5", monitor="val_loss", verbose=1, save_weights_only=True)
    ]
)

---
## Prediction
### Load the model:

In [13]:
LAST_CHECKPOINT = WEIGHTS_DIR + "fine_tuning/weights.120-1.73-1.78.hdf5"
model = pconv_model(predict_only=True, image_size=IMAGE_SIZE)
model.load_weights(LAST_CHECKPOINT)
k = 1

### First, try images with random masks from the train set:

In [None]:
# Make a prediction for a batch of examples:
(input_img, mask), orig_img = next(test_generator)
output_img = model.predict([input_img, mask])

# Post-processing:
orig_img   = torch_postprocessing(orig_img)
input_img  = torch_postprocessing(input_img) * mask # the (0,0,0) masked pixels are made grey by post-processing
output_img = torch_postprocessing(output_img)
output_comp = input_img.copy()
output_comp[mask == 0] = output_img[mask == 0]

fig, axes = plt.subplots(input_img.shape[0], 2, figsize=(15, 29))
for i in range(input_img.shape[0]):
    #axes[i,0].imshow(orig_img[i])
    axes[i,0].imshow(input_img[i])
    axes[i,1].imshow(output_img[i])
    #axes[i,2].imshow(output_comp[i])
    axes[i,0].tick_params(bottom=False, left=False, labelbottom=False, labelleft=False)
    axes[i,1].tick_params(bottom=False, left=False, labelbottom=False, labelleft=False)
axes[0,0].set_title('Masked image')
axes[0,1].set_title('Prediction')
plt.tight_layout()
plt.savefig("data/examples/{}_result.png".format(k), bbox_inches='tight', pad_inches=0)
plt.show()
k += 1

### Second, try on your own images and masks:

In [None]:
img_fname  = "data/examples/own_image.jpg"
mask_fname = "data/examples/own_mask.jpg"
# Mask is assumed to have masked pixels in black and valid pixels in white

# Loading and pre-processing:
orig_img = img_to_array(load_img(img_fname, target_size=IMAGE_SIZE))
orig_img = orig_img[None,...] 

mask = load_img(mask_fname, target_size=IMAGE_SIZE)
mask = (img_to_array(mask) == 255).astype(np.float32)
mask = mask[None,...] 

# Prediction:
output_img = model.predict([torch_preprocessing(orig_img.copy()) * mask, mask])

# Post-processing:
output_img  = torch_postprocessing(output_img)
input_img   = orig_img * mask
output_comp = input_img.copy()
output_comp[mask == 0] = output_img[mask == 0]

# Plot:
fig, axes = plt.subplots(2, 2, figsize=(20,20))
axes[0,0].imshow(orig_img[0].astype('uint8'))
axes[0,0].set_title('Original image')
axes[0,1].imshow(mask[0])
axes[0,1].set_title('Mask')
axes[1,0].imshow(input_img[0].astype('uint8'))
axes[1,0].set_title('Masked image')
axes[1,1].imshow(output_img[0])
axes[1,1].set_title('Prediction')
for ax in axes.flatten():
    ax.tick_params(bottom=False, left=False, labelbottom=False, labelleft=False)
plt.tight_layout()
plt.savefig("data/examples/own_image_result.png", bbox_inches='tight', pad_inches=0)