In [1]:
from keras import layers
from keras import models
from keras.models import load_model, Model
from keras.layers import Input
import keras.optimizers as optimizers
from keras import backend as K
import numpy as np
import cv2
import os
from scipy.misc import imread, imresize, imsave

Using TensorFlow backend.


In [2]:
def PSNRLoss(y_true, y_pred):
    """
    PSNR is Peek Signal to Noise Ratio, which is similar to mean squared error.
    It can be calculated as
    PSNR = 20 * log10(MAXp) - 10 * log10(MSE)
    When providing an unscaled input, MAXp = 255. Therefore 20 * log10(255)== 48.1308036087.
    However, since we are scaling our input, MAXp = 1. Therefore 20 * log10(1) = 0.
    Thus we remove that component completely and only compute the remaining MSE component.
    """
    return -10. * K.log(K.mean(K.square(y_pred - y_true))) / K.log(10.)

def psnr(y_true, y_pred):
    assert y_true.shape == y_pred.shape, "Cannot calculate PSNR. Input shapes not same." \
                                         " y_true shape = %s, y_pred shape = %s" % (str(y_true.shape),
                                                                                   str(y_pred.shape))

    return 20. * np.log10(255.) -10. * np.log10(np.mean(np.square(y_pred - y_true)))


In [28]:
K.set_image_data_format('channels_first')

def create_model(img_dim_2, img_dim_1, weight_path = 'Weights/SRWeights2X.h5'):
    height = img_dim_2
    width = img_dim_1
    channels=3

    shape = (channels, width, height)
    init = Input(shape=shape)

    f1 = 9
    f2 = 1
    f3 = 5

    n1 = 64
    n2 = 32

    x = layers.Convolution2D(n1, (f1, f1), activation='relu', padding='same', name='level1')(init)
    x = layers.Convolution2D(n2, (f2, f2), activation='relu', padding='same', name='level2')(x)

    out = layers.Convolution2D(channels, (f3, f3), padding='same', name='output')(x)

    model = Model(init, out)

    adam = optimizers.Adam(lr=1e-3)
    model.compile(optimizer=adam, loss='mse', metrics=[PSNRLoss])
    model.load_weights(weight_path, reshape=True)
    return model


In [55]:
def upscale(img_path, save_intermediate=True, suffix="scaled", verbose=True):

        #destination path
        path = os.path.splitext(img_path)
        filename = path[0] + "_" + suffix + "(%dx)" % (2) + path[1]

        hd_img = imread(img_path, mode='RGB') #high resolution image
        hd_img_res = (hd_img.shape[0], hd_img.shape[1])
        scale_factor = 2
        lr_img_res = (hd_img.shape[0] // scale_factor, hd_img.shape[1] // scale_factor)
        
        lr_img = imresize(hd_img, lr_img_res) #low resolution image
        intermediate_img = imresize(lr_img, hd_img_res) #after bilinear interpolation

        if save_intermediate:
            if verbose: print("Saving intermediate image.")
            fn = path[0] + "_intermediate_" + path[1]
            imsave(fn, intermediate_img)
        
        print("\npsnr intermediate: ", psnr(hd_img, intermediate_img))
        intermediate_img = np.expand_dims(intermediate_img, axis=0) #for batch feed
            
        # transpose and process images
        if K.image_dim_ordering() == "th":
            img_conv = intermediate_img.transpose((0, 3, 1, 2)).astype(np.float32) / 255.
        else:
            img_conv = intermediate_img.astype(np.float32) / 255.

        model = create_model(hd_img.shape[1], hd_img.shape[0])
        if verbose: print("Model loaded.")

        # create prediction for image patches
        result = model.predict(img_conv, batch_size=128, verbose=verbose)       
        
        if K.image_dim_ordering() == "th":
            result = result.transpose((0, 2, 3, 1)).astype(np.float32) * 255.
        else:
            result = result.astype(np.float32) * 255.

        result = result[0, :, :, :] # access the 3 dimensional image vector

        result = np.clip(result, 0, 255).astype('uint8')
        print("\npsnr result after clip: ", psnr(hd_img, result))

        if verbose: print("Saving image.")
        imsave(filename, result)

In [56]:
img_path = 'Images/monarch.bmp'

upscale(img_path)

`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
  
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.
  del sys.path[0]
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.
  
`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.


Saving intermediate image.

psnr intermediate:  35.19102067757655
Model loaded.

psnr result after clip:  35.62402837416574
Saving image.


`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.
