In [1]:
import sys
import keras
import cv2
import numpy as np
import matplotlib
import skimage

Using TensorFlow backend.


In [2]:
# Image necessary packages
from keras.models import Sequential
from keras.layers import Conv2D, Input
from keras.optimizers import SGD, Adam
from skimage.measure import compare_ssim as ssim
from matplotlib import pyplot as plt
import cv2
import math
import os

# python magic functions
%matplotlib inline

In [3]:
# define a function for peak signal to noise ratio (PSNR)
def psnr(target, ref):
    
    # assume a RGB/BGR Image 
    target_data = target.astype(float)
    ref_data = ref.astype(float)
    
    diff = ref_data - target_data
    diff = diff.flatten('C')
    
    rmse = math.sqrt(np.mean(diff ** 2.))
    
    # formula
    return 20 * math.log10(255. / rmse)

# Define function for mean squared error(MSE)

def mse(target, ref):
    
    # assume a RGB/BGR Image 
    err = np.sum((target.astype(float)- ref.astype(float)) ** 2)
    err /= float(target.shape[0] * target.shape[1])
    
    return err

# define a function that combines all 3 image quality metrics
def compare_images(target, ref):
    scores = []
    scores.append(psnr(target,ref))
    scores.append(mse(target,ref))
    # Structural similarity index
    scores.append(ssim(target,ref, multichannel = True))
    
    return scores
    

In [4]:
# Prepare degraded images by introducing quality distortion via resizing

def prepare_images(path, factor):
    
    # Loop through the files in directory
    for file in os.listdir(path):
        
        # Open the file
        img = cv2.imread(path + '/' + file)
        
        # find the old and new image dimensions
        h, w, c = img.shape
        new_height = int(h / factor)
        new_width = int(w / factor)
        
        # resize images - down
        img = cv2.resize(img, (new_width, new_height), interpolation = cv2.INTER_LINEAR)
        
        # resize image - up
        img = cv2.resize(img, (w, h), interpolation = cv2.INTER_LINEAR)
        
        # save the image
        print("Saving {}".format(file))
        cv2.imwrite("images/{}".format(file),img)
        

In [5]:
prepare_images("source_images",2)

Saving baboon.bmp
Saving baby_GT.bmp
Saving barbara.bmp
Saving bird_GT.bmp
Saving butterfly_GT.bmp
Saving coastguard.bmp
Saving comic.bmp
Saving face.bmp
Saving flowers.bmp
Saving foreman.bmp
Saving head_GT.bmp
Saving lenna.bmp
Saving monarch.bmp
Saving pepper.bmp
Saving ppt3.bmp
Saving woman_GT.bmp
Saving zebra.bmp


In [6]:
# testing the generated images using image quality metrics

for file in os.listdir("images/"):
    
    # Open the target and ref image
    target = cv2.imread("images/{}".format(file))
    ref = cv2.imread("source_images/{}".format(file))
    
    # calculate the scores
    scores = compare_images(target, ref)
    
    # Print the file names
    print('{}\nPSNR: {}\n MSE: {}\nSSIM: {}\n'.format(file,scores[0],scores[1],scores[2]))


baboon.bmp
PSNR: 22.157084083442548
 MSE: 1187.1161333333334
SSIM: 0.629277587900277

baby_GT.bmp
PSNR: 34.37180640966199
 MSE: 71.28874588012695
SSIM: 0.9356987872724932

barbara.bmp
PSNR: 25.906629837568126
 MSE: 500.65508535879627
SSIM: 0.8098632646406401

bird_GT.bmp
PSNR: 32.896644728720005
 MSE: 100.12375819830247
SSIM: 0.9533644866026473

butterfly_GT.bmp
PSNR: 24.782076560337416
 MSE: 648.6254119873047
SSIM: 0.8791344763843051

coastguard.bmp
PSNR: 27.161600663887082
 MSE: 375.00887784090907
SSIM: 0.756950063354931

comic.bmp
PSNR: 23.799861502225532
 MSE: 813.2338836565096
SSIM: 0.8347335416398209

face.bmp
PSNR: 30.99220650287191
 MSE: 155.23189718546524
SSIM: 0.8008439492289884

flowers.bmp
PSNR: 27.454504805386147
 MSE: 350.55093922651935
SSIM: 0.8697286286974628

foreman.bmp
PSNR: 30.14456532664372
 MSE: 188.6883483270202
SSIM: 0.933268417388899

head_GT.bmp
PSNR: 31.020502848237534
 MSE: 154.2237755102041
SSIM: 0.8011121330733371

lenna.bmp
PSNR: 31.47349297867539
 MSE: 1

In [7]:
# define the srcnn model

def model():
    
    # define model type
    SRCNN = Sequential()
    
    # add model layers
    SRCNN.add(Conv2D(filters=128, kernel_size = (9,9), kernel_initializer='glorot_uniform',
                    activation='relu', padding = 'valid', use_bias=True, input_shape = (None,None,1)))
    
    SRCNN.add(Conv2D(filters=64, kernel_size = (3,3), kernel_initializer='glorot_uniform',
                    activation='relu', padding = 'same', use_bias=True))
    
    SRCNN.add(Conv2D(filters=1, kernel_size = (5,5), kernel_initializer='glorot_uniform',
                    activation='linear', padding = 'valid', use_bias=True, input_shape = (None,None,1)))
    
    # Define optimizer
    adam = Adam(lr = 0.0003)
    
    # Compile model
    SRCNN.compile(optimizer=adam, loss='mean_squared_error', metrics=['mean_squared_error'])
    
    return SRCNN



In [8]:
# Define necessary images processing functions

def modcrop(img, scale):
    
    tmpsz = img.shape
    sz = tmpsz[0:2]
    sz = sz - np.mod(sz, scale)
    img = img[0:sz[0],1:sz[1]]
    
    return img

def shave(image, border):
    img = image[border: -border, border: -border]
    return img

    
    

In [18]:
# defining the main prediction function

def predict(image_path):
    
    # load the SRCNN model with weights
    srcnn = model()
    srcnn.load_weights('3051crop_weight_200.h5')
    
    # load in our degraded and reference image
    path, file = os.path.split(image_path)
    degraded = cv2.imread(image_path)
    ref = cv2.imread('source_images/{}'.format(file))
    
    # Preprocess the images with modcrop
    ref = modcrop(ref, 3)
    degraded = modcrop(degraded, 3)
    
    # Convert the image to YCrCb (SRCNN is trained on Y channel)
    temp = cv2.cvtColor(degraded, cv2.COLOR_BGR2YCrCb)  
    
    # create image size and normalize 
    Y = np.zeros((1, temp.shape[0], temp.shape[1], 1), dtype=float)
    
    Y[0, :, :, 0] = temp[:, :, 0].astype(float) / 255;
    
    # Perform super resolution with SRCNN network
    pre = srcnn.predict(Y, batch_size = 1)
    
    # Post process output
    pre *= 255
    pre[pre[:] > 255] = 255
    pre[pre[:] < 0] = 0
    pre = pre.astype(np.uint8)
    
    # Copy Y channel back to image and convert o BGR
    temp = shave(temp, 6)
    temp[:, :, 0] = pre[0, : , :, 0]
    output = cv2.cvtColor(temp, cv2.COLOR_YCrCb2BGR)
    
    # Remove border from ref and degraded image
    ref = shave(ref.astype(np.uint8),6)
    degraded = shave(ref.astype(np.uint8),6)
    
    # Image quality calcuations
    scores = []
    scores.append(compare_images(degraded, ref))
    scores.append(compare_images(output, ref))
    
    # return images and scores
    
    return ref, degraded, output, scores    

In [19]:
ref, degraded, output , scores = predict('source_images/flowers.bmp')

# Print all scores
print('Degraded images \nPSNR: {}\n MSE: {}\nSSIM: {}\n'.format(scores[0][0],scores[0][1],scores[0][2]))
print('Reconstructed images \nPSNR: {}\n MSE: {}\nSSIM: {}\n'.format(scores[1][0],scores[1][1],scores[1][2]))

# display images as subplots
fig, axs = plt.subplots(1, 3, figsize = (20,8))
axs[0].imshow(cv2.cvColor(ref, cv2.COLOR_BGR2RGB))
axs[0].set_title('Original')
axs[1].imshow(cv2.cvtColor(degraded, cv2.COLOR_BGR2RGB))
axs[1].set_title('Degraded')
axs[2].imshow(cv2.cvColor(output, cv2.COLOR_BGR2RGB))

#Remove the x and y tick marks
for ax in axs:
    ax.set_xticks([])
    ax.set_yticks([])
    

ValueError: operands could not be broadcast together with shapes (348,485,3) (336,473,3) 