Single Image Super-Resolution

In [0]:
import tensorflow as tf
#from keras.layers import Lambda
from math import ceil, floor
from keras.layers import *
from keras.models import Model, load_model
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adadelta
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
from skimage import io

Some constants and hyperparameters.

In [0]:
train_dir = '/content/train/'
val_dir = '/content/val'

filters = 256
kernel_size = 3
strides = 1
res_blocks = 1
subpix_scale = 2

batch_size = 8
inner_batch = 2
epochs = 15
num_training_samples = 200
num_validation_samples = 100

input_size = 300
crop_length_l = HR_size = 256
crop_length_s = LR_size = 64

scale_fact = 4
img_depth = 3
overlap = 16

img_datagen = ImageDataGenerator(rescale=1./255)

Subpixel convolution layer and crop function.

In [0]:
def SubpixelConv2D(input_shape, scale=4):
    def subpixel_shape(input_shape):
        return input_shape[0], input_shape[1] * scale, input_shape[2] * scale, input_shape[3] // scale**2
    def subpixel(x):
        return tf.depth_to_space(x, scale)
    return Lambda(subpixel, output_shape=subpixel_shape)

def random_crop(img, random_crop_size):
    height, width = img.shape[:2]
    dy, dx = random_crop_size
    x = np.random.randint(0, width - dx + 1)
    y = np.random.randint(0, height - dy + 1)
    return img[y:(y+dy), x:(x+dx)]

Batch generator and model initialization.

In [0]:
def batch_generator(batches, crop_length_l=256, crop_length_s=64):
    while True:
        batch =  next(batches)
        len_batch = batch.shape[0]
        for i in range(len_batch):
            new_batch_x = np.zeros((inner_batch,  crop_length_s, crop_length_s, 3))
            new_batch_y = np.zeros((inner_batch,  crop_length_l, crop_length_l, 3))
            for j in range(inner_batch):
                cropped = random_crop(batch[i], (crop_length_l, crop_length_l))
                resized = cv2.resize(cropped, (crop_length_s, crop_length_s))
                new_batch_x[j] = resized
                new_batch_y[j] = cropped
            yield (new_batch_x, new_batch_y)

def get_model():
    # Head
    input_ = Input(name='input', shape=(crop_length_s, crop_length_s, img_depth))
    conv0 = Conv2D(filters, kernel_size, strides=strides, padding='same')(input_)

    # Body
    res = Conv2D(filters, kernel_size, strides=strides, padding='same')(conv0)
    act = ReLU()(res)
    res = Conv2D(filters, kernel_size, strides=strides, padding='same')(act)
    res_rec = Add()([conv0, res])
    for i in range(res_blocks):
        res1 = Conv2D(filters, kernel_size, strides=strides, padding='same')(res_rec)
        act  = ReLU()(res1)
        res2 = Conv2D(filters, kernel_size, strides=strides, padding='same')(act)
        res_rec = Add()([res_rec, res2])
    conv = Conv2D(filters, kernel_size, strides=strides, padding='same')(res_rec)
    add  = Add()([conv0, conv])

    # Tail
    conv = Conv2D(filters, kernel_size, strides=strides, padding='same')(add)
    act  = ReLU()(conv)
    up   = SubpixelConv2D(input_shape=act.shape, scale=subpix_scale)(act)
    conv = Conv2D(filters, kernel_size, strides=strides, padding='same')(up)
    act  = ReLU()(conv)
    up   = SubpixelConv2D(input_shape=act.shape, scale=subpix_scale)(act)
    output = Conv2D(name='output', filters=3, kernel_size=1, strides=1, padding='same')(up)
    model = Model(inputs=input_, outputs=output)
    optimizer = Adadelta(lr=1.0, rho=0.95, decay=0.0)
    model.compile(optimizer=optimizer, loss='mean_squared_error', metrics = ['accuracy'])
    
    print(model.summary())
    return model

Initialize training and validation batches and the neural network.

In [0]:
train_generator = img_datagen.flow_from_directory(train_dir, target_size=(input_size, input_size), batch_size=batch_size, class_mode=None)
val_generator = img_datagen.flow_from_directory(val_dir, target_size=(input_size, input_size), batch_size=batch_size, class_mode=None)

train_batches = batch_generator(train_generator) 
val_batches = batch_generator(val_generator)

model = get_model()

Train, save or load the model.

In [0]:
history = model.fit_generator(train_batches, steps_per_epoch=ceil(inner_batch * num_training_samples / batch_size), epochs=epochs, validation_data=val_batches, validation_steps=ceil(inner_batch * num_validation_samples / batch_size),)

#model.load_weights('weights.h5')

model.save_weights('weights.h5')

Predict images.

In [0]:
def predict(image_name):
  image = io.imread(image_name)[:, :]
  SR, crops = predict_crops(image)
  reconstructed_image = reconstruct(SR, crops)
  height_pad = (reconstructed_image.shape[0]-image.shape[0])/2
  width_pad = (reconstructed_image.shape[1]-image.shape[1])/2
  reconstructed_image = reconstructed_image[ceil(height_pad): reconstructed_image.shape[0]-floor(height_pad), ceil(width_pad): reconstructed_image.shape[1]-floor(width_pad)]
  return image, reconstructed_image, cv2.resize(image, (image.shape[1]//scale_fact, image.shape[0]//scale_fact))
  
def predict_crops(image):
    height, width = image.shape[:2]
    height_pad = HR_size-(height%(HR_size-overlap))
    width_pad = HR_size-(width%(HR_size-overlap))
    pad_width = ((ceil(height_pad/2), floor(height_pad/2)), (ceil(width_pad/2), floor(width_pad/2)), (0, 0))
    padded_image = np.pad(image, pad_width, 'constant')
    crops = seq_crop(padded_image)
    SR= []
    for crop_row in crops:
        for crop in tqdm(crop_row):
            LR_image = cv2.resize(crop, (LR_size, LR_size))
            SR_image = model.predict(np.expand_dims(LR_image, 0))[0]
            SR_image = SR_image[overlap//2:HR_size-overlap//2, overlap//2:HR_size-overlap//2]
            SR.append(SR_image)
    return SR, crops

def seq_crop(img):
    sub_images = []
    j, shifted_height = 0, 0
    while shifted_height < (img.shape[0] - HR_size):
        horizontal = []
        shifted_height = j * (HR_size - overlap)
        i, shifted_width = 0, 0
        while shifted_width < (img.shape[1] - HR_size):
            shifted_width = i * (HR_size - overlap)
            horizontal.append(crop_precise(img, shifted_width, shifted_height, HR_size, HR_size))
            i += 1
        sub_images.append(horizontal)
        j += 1
    return sub_images

def crop_precise(img, coord_x, coord_y, width_length, height_length):
    tmp_img = img[coord_y:coord_y + height_length, coord_x:coord_x + width_length]
    return float_im(tmp_img)

def reconstruct(predictions, crops):
    def nest(data, template):
        data = iter(data)
        return [[next(data) for _ in row] for row in template]
    predictions = nest(predictions, crops)
    H = np.cumsum([x[0].shape[0] for x in predictions])
    W = np.cumsum([x.shape[1] for x in predictions[0]])
    D = predictions[0][0]
    recon = np.empty((H[-1], W[-1], D.shape[2]), D.dtype)
    for rd, rs in zip(np.split(recon, H[:-1], 0), predictions):
        for d, s in zip(np.split(rd, W[:-1], 1), rs):
            d[...] = s
    tmp_overlap = overlap * (scale_fact - 1)
    return recon[tmp_overlap:recon.shape[0]-tmp_overlap, tmp_overlap:recon.shape[1]-tmp_overlap]
  
def float_im(img):
    return np.divide(img, 255.)