<a href="https://colab.research.google.com/github/furrypython/PConv-Tensorflow2/blob/master/image_inpainting.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dogs vs Cats Image Inpainting With Partial Convolution  
Unofficial implementation of [Liu et al., 2018. Image Inpainting for Irregular Holes Using Partial Convolutions](https://arxiv.org/abs/1804.07723).

# Mount Google Drive to Google Colaboratory 
Load your data on Google Drive into Google Colaboratory. You can skip this part if you don't use Colab.

In [None]:
import sys
from google.colab import drive
drive.mount('/gdrive', force_remount=True)
sys.path.append('/gdrive/My Drive/PConv-Tensorflow2/libs')

# Import packages

In [2]:
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [3]:
from process_data import create_input_pipeline, create_input_pipeline_test, create_input_dataset
from pconv2d_layer import PConv2D
from pconv_model import build_pconv_unet
from loss import get_vgg16_weights, StyleModel
from train import fit

# Load Data
Referencing <a href="https://colab.research.google.com/github/tensorflow/examples/blob/master/courses/udacity_intro_to_tensorflow_for_deep_learning/l05c01_dogs_vs_cats_without_augmentation.ipynb#scrollTo=KwQtSOz0VrVX" target="_blank">this</a> Colab, we use a filtered version of <a href="https://www.kaggle.com/c/dogs-vs-cats/data" target="_blank">Dogs vs. Cats</a> dataset here. 
Let's directly download the dataset from a URL and unzip it to the Colab filesystem.

In [None]:
_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
zip_dir = tf.keras.utils.get_file('cats_and_dogs_filterted.zip', origin=_URL, extract=True)

In [None]:
zip_dir_base = os.path.dirname(zip_dir)
base_dir = os.path.join(os.path.dirname(zip_dir), 'cats_and_dogs_filtered')

In [None]:
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

# Setting Model Parameters

In [None]:
IMG_SHAPE  = 256
BATCH_SIZE = 5
EPOCHS = 30
EPOCHS_FT = 8

In [None]:
# The weights files directory
weights_dir = '/gdrive/My Drive/PConv-Tensorflow2/weights'
# The weights files directory
checkpoints_dir = weights_dir + '/ckpts'

Port the VGG16 weights from PyTorch <a href="https://github.com/ezavarygin/vgg16_pytorch2keras" target="_blank">this</a> way. The weights file will be used in loss calculation.

In [None]:
# Define the VGG16 model for loss calculation
get_vgg16_weights(weights_dir)
vgg16_weights = weights_dir + '/vgg16_pytorch2keras.h5'
vgg16 = StyleModel(weights=vgg16_weights)

# Prepare Data

In [None]:
# Prepare the training dataset
train_dataset = create_input_pipeline(train_dir, batch_size=BATCH_SIZE)
# Prepare the validation dataset
val_dataset = create_input_pipeline(validation_dir, batch_size=BATCH_SIZE)

## Visualize Training images  
Let's visualize how a single batch would look like.

In [None]:
def display(display_list):
  plt.figure(figsize=(8, 8))
  title = ['Input Masked Image', 'Input Mask Image', 'Ground Truth', 'Predicted Image']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
    plt.tight_layout()
    plt.axis('off')
  plt.show()

In [None]:
for target_batch in train_dataset.take(1):
  (masked_batch, mask_batch), target_batch = create_input_dataset(target_batch, batch_size=BATCH_SIZE)

for b in range(BATCH_SIZE):
  display([masked_batch[b], mask_batch[b], target_batch[b]])

# Train Model
The model was trained in two steps:  
- **Part 1: Initial training** 
  - The Batch Normalization parameters enabled.
  - 30 epochs with a learning rate of 0.0002.  
- **Part2: Fine-tuning**  
  - The Batch Normalization parameters freezed in the encoder part of the network.
  - 8 epochs with a learning rate of 0.00005.

## Initial training

In [None]:
model = build_pconv_unet(img_shape=IMG_SHAPE)
opt_train_01 = tf.keras.optimizers.Adam(learning_rate=0.0002)

history_01 = fit(model=model, 
                 input_data=iter(train_dataset), 
                 batch_size=BATCH_SIZE, 
                 epochs=EPOCHS, 
                 steps_per_epoch=400, 
                 validation_data=iter(val_dataset),  
                 validation_steps=250, 
                 vgg16=vgg16, 
                 optimizer=opt_train_01, 
                 save_dir=checkpoints_dir)

## Fine-tuning

In [None]:
model = build_pconv_unet(img_shape=IMG_SHAPE, fine_tuning=True)
model.load_weights(checkpoints_dir + '/epoch-29-2020-07-24-05-25-31.h5')
opt_train_02 = tf.keras.optimizers.Adam(learning_rate=0.00005)

history_02 = fit(model=model, 
                 input_data=iter(train_dataset), 
                 batch_size=BATCH_SIZE, 
                 epochs=15, 
                 steps_per_epoch=400, 
                 validation_data=iter(val_dataset), 
                 validation_steps=250, 
                 vgg16=vgg16,  
                 optimizer=opt_train_02, 
                 save_dir=checkpoints_dir)

# Visualize Results of the Training

In [None]:
# The test images directory
test_dir = '/gdrive/My Drive/PConv-Tensorflow2/dataset/test'
test_dataset = create_input_pipeline_test(test_dir, batch_size=BATCH_SIZE)

In [None]:
model = build_pconv_unet(img_shape=IMG_SHAPE)
model.load_weights(checkpoints_dir + '/epoch-7-2020-07-24-06-42-17.h5')

In [None]:
for test_target_batch in test_dataset.take(1):
  (test_masked_batch, test_mask_batch), test_target_batch = create_input_dataset(test_target_batch, batch_size=BATCH_SIZE)
test_result = model.predict([test_masked_batch, test_mask_batch])

In [None]:
for idx in range(BATCH_SIZE):
  display([test_masked_batch[idx], test_mask_batch[idx], test_target_batch[idx], test_result[idx]])

## Visualize Learninsg Curves

In [None]:
loss_01 = np.array(history_01['loss'])
loss_02 = np.array(history_02['loss'])
loss = np.concatenate([loss_01, loss_02])

val_loss_01 = np.array(history_01['val_loss'])
val_loss_02 = np.array(history_01['val_loss'])
val_loss = np.concatenate([val_loss_01, val_loss_02])

total_epochs = np.arange(0, EPOCHS+EPOCHS_FT)
plt.title('Learning Curves')
plt.xlabel('Epoch')
plt.ylabel('Loss value')
plt. plot(total_epochs, loss, label='train')
plt. plot(total_epochs, val_loss, label='validation')
# The border which represents the end of initial train & start of fine-tuning.
plt.axvline(total_epochs=EPOCHS, linewidth=1, color='gray', linestyle='--')
plt.legend()
plt.show()