In [1]:
from keras.models import Sequential
from keras.layers import Conv2D, Input, BatchNormalization
from keras.callbacks import ModelCheckpoint, Callback
from keras.optimizers import SGD, Adam
from keras.utils.training_utils import multi_gpu_model
from keras.preprocessing.image import ImageDataGenerator
from keras import initializers
import prepare_data as pd
import numpy
import math
import scipy.misc as spm

import matplotlib.image as mpimg
import matplotlib.pyplot as plt

from PIL import Image

from lr_multiplier import LearningRateMultiplier

import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.2
set_session(tf.Session(config=config))

Using TensorFlow backend.


In [2]:
# see example https://github.com/keras-team/keras/issues/8649
class MyCbk(Callback):

    def __init__(self, model):
         self.model_to_save = model

    def on_epoch_end(self, epoch, logs=None):
        self.model_to_save.save('checkpoints/pick-64-9-3-5_128-64-SRCNN_model_at_epoch_%d.h5' % epoch)


In [3]:
def psnr(target, ref):

    target_data = numpy.array(target, dtype=float)
    ref_data = numpy.array(ref, dtype=float)

    diff = ref_data - target_data
    diff = diff.flatten('C')

    rmse = math.sqrt(numpy.mean(diff ** 2.))

    return 20 * math.log10(255. / rmse)

In [4]:
def model(optimizer='SGD', all_layers=True, processing='GPU', lr=0.0001):

    multipliers = {'conv2d_1': 1, 'conv2d_2': 1, 'conv2d_3': 0.1}
    
    optimizer = optimizer.upper()
    processing = processing.upper()
    
    print('Using {}.'.format(optimizer))

    if optimizer == 'SGD' and all_layers:
        opt = LearningRateMultiplier(SGD, lr_multipliers=multipliers, lr=lr)
    elif optimizer == 'SGD':
        opt = SGD(lr=0.0001)
    elif all_layers:
        opt = LearningRateMultiplier(Adam, lr_multipliers=multipliers, lr=lr)
    else:
        opt = Adam(lr=lr)
        
    SRCNN = Sequential()
    SRCNN.add(Conv2D(filters=128, kernel_size=(9, 9), 
                     kernel_initializer='glorot_uniform', bias_initializer='random_uniform',
                     activation='relu', padding='valid', use_bias=True, input_shape=(64, 64, 1), name='conv2d_1'))

    SRCNN.add(Conv2D(filters=64, kernel_size=(3, 3), 
                     kernel_initializer='glorot_uniform', bias_initializer='random_uniform',
                     activation='relu', padding='same', use_bias=True, name='conv2d_2'))
    
    SRCNN.add(Conv2D(filters=1, kernel_size=(5, 5),
                     kernel_initializer='glorot_uniform', bias_initializer='random_uniform',
                     activation='linear', padding='valid', use_bias=True,  name='conv2d_3'))
    SRCNN_MULTI = multi_gpu_model(SRCNN, gpus=8)
    if processing == 'GPU':    
        SRCNN_MULTI.compile(optimizer=opt, loss='mean_squared_error', metrics=['mean_squared_error'])
    else:
        SRCNN.compile(optimizer=opt, loss='mean_squared_error', metrics=['mean_squared_error'])
    return SRCNN, SRCNN_MULTI

In [5]:
def predict_model():
    
    sgd = SGD(lr=0.0001)
    
    SRCNN = Sequential()
    SRCNN.add(Conv2D(filters=128, kernel_size=(9, 9), 
                     kernel_initializer='glorot_uniform', bias_initializer='random_uniform',
                     activation='relu', padding='valid', use_bias=True, input_shape=(None, None, 1), name='conv2d_1'))
        
    SRCNN.add(Conv2D(filters=64, kernel_size=(3, 3), 
                     kernel_initializer='glorot_uniform', bias_initializer='random_uniform',
                     activation='relu', padding='same', use_bias=True, name='conv2d_2'))
        
    # SRCNN.add(BatchNormalization())
    SRCNN.add(Conv2D(filters=1, kernel_size=(5, 5),
                     kernel_initializer='glorot_uniform', bias_initializer='random_uniform',
                     activation='linear', padding='valid', use_bias=True,  name='conv2d_3'))
    adam = Adam(lr=0.0003)
    sgd_last = SGD(lr=0.00001)

    SRCNN.compile(optimizer=adam, loss='mean_squared_error', metrics=['mean_squared_error'])
    return SRCNN

In [6]:
def train(processing='GPU', optimizer='SGD', use_data_augmentation=False, epochs=200, lr=0.0001):

    optimizer = optimizer.upper()
    processing = processing.upper()
    
    srcnn_model, srcnn_multi_model = model(optimizer=optimizer, processing=processing, lr=lr)

    callbacks_list = []
    BS = 1024
    
    if processing == 'GPU' or use_data_augmentation:    
        print(srcnn_multi_model.summary())
        # see example https://github.com/keras-team/keras/issues/8649
        cbk = MyCbk(srcnn_model)
        
    else:
        print(srcnn_model.summary())
        checkpoint = ModelCheckpoint("64SRCNN_check.h5", monitor='val_loss', 
                                     verbose=2, 
                                     save_best_only=True,
                                     save_weights_only=False)
        callbacks_list.append(checkpoint)
        
    data, label = pd.read_training_data("./64_52_pick_crop_train.h5")
    val_data, val_label = pd.read_training_data("./64test.h5")
    
    # construct the training image generator for data augmentation
    datagen = ImageDataGenerator(rotation_range=45,
                                 zoom_range=0.15,
                                 width_shift_range=0.2,
                                 height_shift_range=0.2,
                                 horizontal_flip=True, 
                                 vertical_flip=True,
                                 validation_split=0.2)

    if use_data_augmentation and processing == 'GPU':
        history = srcnn_multi_model.fit_generator(datagen.flow(data, label, batch_size=BS),
                                                  steps_per_epoch=len(data) // BS,
                                                  epochs=epochs, 
                                                  verbose=2, 
                                                  shuffle=True,
                                                  callbacks=[cbk],
                                                  use_multiprocessing=True)
    elif processing == 'GPU':
        srcnn_multi_model.fit(data, label, batch_size=BS,
                              validation_split=0.2,
                              shuffle=True, 
                              epochs=epochs, 
                              verbose=2,
                              callbacks=[cbk])
    elif use_data_augmentation:
        history = srcnn_model.fit_generator(datagen.flow(data, label, batch_size=BS),
                                          steps_per_epoch=len(data) // BS,
                                          epochs=epochs, 
                                          verbose=2, 
                                          shuffle=True,
                                          callbacks=[cbk])
    else:
        srcnn_model.fit(data, label, batch_size=BS,
                        validation_split=0.2,
                        shuffle=True, 
                        epochs=epochs, 
                        verbose=2,
                        callbacks=callbacks_list)

In [7]:
def predict():

    srcnn_model = predict_model()
    srcnn_model.load_weights("checkpoints/64-9-3-5_128-64-SRCNN_model_at_epoch_999.h5")
    IMG_NAME = numpy.fromfile('ground_data/pick/recon/slice_256.b', dtype = 'uint8')
    INPUT_NAME = numpy.fromfile("sub_data/pick/recon/slice_256.b", dtype = 'uint8')
    OUTPUT_NAME = "pre2.png"

    dimension_img = int(IMG_NAME.shape[0]**(1/2))
    dimension_input = int(INPUT_NAME.shape[0]**(1/2))
    
    IMG_NAME = IMG_NAME.reshape((dimension_img, dimension_img))
    INPUT_NAME = INPUT_NAME.reshape((dimension_input, dimension_input))
    
    import cv2
    img = INPUT_NAME.copy()
    shape = IMG_NAME.shape
    Y_img = img.copy()
    Y_img = spm.imresize(Y_img, size=shape, interp='bicubic')
    bicubic = Y_img.copy()

    Y_img = Y_img.reshape((dimension_img, dimension_img, 1))
    
    for i in range(Y_img.shape[-1]):
        Y = numpy.zeros((1, shape[0], shape[1], 1), dtype=float)
        Y[0,:,:,0] = Y_img[:,:,i].astype(float) / 255.
        pre = srcnn_model.predict(Y, batch_size=1) * 255.
        
    pre[pre[:] > 255] = 255
    pre[pre[:] < 0] = 0
    pre = pre.astype(float)
    print(Y_img.shape, pre.shape)
    Y_img[6: -6, 6: -6, 0] = pre[0, :, :, 0].copy()
    cv2.imwrite(OUTPUT_NAME, Y_img)

    # psnr calculation:
    im1 = IMG_NAME
    im2 = bicubic
    im3 = cv2.imread(OUTPUT_NAME)[:,:,0]
    
    print(im1.shape, im2.shape, im3.shape)

    print ("bicubic:")
    print (cv2.PSNR(im1, im2))
    print ("SRCNN:")
    print (cv2.PSNR(im1, im3))
    
    import matplotlib.pyplot as plt
    print(im1.shape)
    plt.figure(figsize=[15,30])
    plt.subplot(311)

    plt.title('Original')
    plt.imshow(im1)
    plt.subplot(312)

    plt.title('Bicubic')    
    plt.imshow(im2)
    plt.subplot(313)

    plt.title('SRCNN')
    plt.imshow(im3[10:-10,10:-10])
    plt.show()

In [8]:
train(processing='GPU', optimizer='adam', use_data_augmentation=True, epochs=1000, lr=0.0001)

Using ADAM.
Instructions for updating:
Colocations handled automatically by placer.
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
conv2d_1_input (InputLayer)     (None, 64, 64, 1)    0                                            
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 64, 64, 1)    0           conv2d_1_input[0][0]             
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, 64, 64, 1)    0           conv2d_1_input[0][0]             
__________________________________________________________________________________________________
lambda_3 (Lambda)               (None, 64, 64, 1)    0           conv2d_1_input[0][0]             
_________________________

Epoch 54/1000
 - 33s - loss: 19.5386 - mean_squared_error: 19.5386
Epoch 55/1000
 - 33s - loss: 19.5333 - mean_squared_error: 19.5333
Epoch 56/1000
 - 34s - loss: 19.5352 - mean_squared_error: 19.5352
Epoch 57/1000
 - 34s - loss: 19.5249 - mean_squared_error: 19.5249
Epoch 58/1000
 - 34s - loss: 19.5242 - mean_squared_error: 19.5242
Epoch 59/1000
 - 33s - loss: 19.5225 - mean_squared_error: 19.5225
Epoch 60/1000
 - 33s - loss: 19.5164 - mean_squared_error: 19.5164
Epoch 61/1000
 - 33s - loss: 19.5105 - mean_squared_error: 19.5105
Epoch 62/1000
 - 33s - loss: 19.5080 - mean_squared_error: 19.5080
Epoch 63/1000
 - 33s - loss: 19.5062 - mean_squared_error: 19.5062
Epoch 64/1000
 - 33s - loss: 19.5016 - mean_squared_error: 19.5016
Epoch 65/1000
 - 33s - loss: 19.4970 - mean_squared_error: 19.4970
Epoch 66/1000
 - 33s - loss: 19.4927 - mean_squared_error: 19.4927
Epoch 67/1000
 - 33s - loss: 19.4913 - mean_squared_error: 19.4913
Epoch 68/1000
 - 33s - loss: 19.4826 - mean_squared_error: 19.

 - 33s - loss: 19.3384 - mean_squared_error: 19.3384
Epoch 176/1000
 - 34s - loss: 19.3627 - mean_squared_error: 19.3627
Epoch 177/1000
 - 34s - loss: 19.3441 - mean_squared_error: 19.3441
Epoch 178/1000
 - 34s - loss: 19.3621 - mean_squared_error: 19.3621
Epoch 179/1000
 - 33s - loss: 19.3494 - mean_squared_error: 19.3494
Epoch 180/1000
 - 33s - loss: 19.3488 - mean_squared_error: 19.3488
Epoch 181/1000
 - 33s - loss: 19.3633 - mean_squared_error: 19.3633
Epoch 182/1000
 - 33s - loss: 19.3371 - mean_squared_error: 19.3371
Epoch 183/1000
 - 33s - loss: 19.3522 - mean_squared_error: 19.3522
Epoch 184/1000
 - 33s - loss: 19.3410 - mean_squared_error: 19.3410
Epoch 185/1000
 - 33s - loss: 19.3360 - mean_squared_error: 19.3360
Epoch 186/1000
 - 33s - loss: 19.3632 - mean_squared_error: 19.3632
Epoch 187/1000
 - 33s - loss: 19.3333 - mean_squared_error: 19.3333
Epoch 188/1000
 - 33s - loss: 19.3429 - mean_squared_error: 19.3429
Epoch 189/1000
 - 34s - loss: 19.3536 - mean_squared_error: 19.

Epoch 296/1000
 - 33s - loss: 19.3525 - mean_squared_error: 19.3525
Epoch 297/1000
 - 33s - loss: 19.3272 - mean_squared_error: 19.3272
Epoch 298/1000
 - 33s - loss: 19.3239 - mean_squared_error: 19.3239
Epoch 299/1000
 - 33s - loss: 19.3209 - mean_squared_error: 19.3209
Epoch 300/1000
 - 33s - loss: 19.3358 - mean_squared_error: 19.3358
Epoch 301/1000
 - 33s - loss: 19.3174 - mean_squared_error: 19.3174
Epoch 302/1000
 - 35s - loss: 19.3410 - mean_squared_error: 19.3410
Epoch 303/1000
 - 35s - loss: 19.3295 - mean_squared_error: 19.3295
Epoch 304/1000
 - 35s - loss: 19.3299 - mean_squared_error: 19.3299
Epoch 305/1000
 - 35s - loss: 19.3309 - mean_squared_error: 19.3309
Epoch 306/1000
 - 35s - loss: 19.3223 - mean_squared_error: 19.3223
Epoch 307/1000
 - 35s - loss: 19.3237 - mean_squared_error: 19.3237
Epoch 308/1000
 - 35s - loss: 19.3343 - mean_squared_error: 19.3343
Epoch 309/1000
 - 35s - loss: 19.3203 - mean_squared_error: 19.3203
Epoch 310/1000
 - 35s - loss: 19.3302 - mean_squ

Epoch 417/1000
 - 33s - loss: 19.3093 - mean_squared_error: 19.3093
Epoch 418/1000
 - 33s - loss: 19.2977 - mean_squared_error: 19.2977
Epoch 419/1000
 - 33s - loss: 19.3101 - mean_squared_error: 19.3101
Epoch 420/1000
 - 33s - loss: 19.3005 - mean_squared_error: 19.3005
Epoch 421/1000
 - 33s - loss: 19.2976 - mean_squared_error: 19.2976
Epoch 422/1000
 - 33s - loss: 19.3004 - mean_squared_error: 19.3004
Epoch 423/1000
 - 33s - loss: 19.3034 - mean_squared_error: 19.3034
Epoch 424/1000
 - 33s - loss: 19.2966 - mean_squared_error: 19.2966
Epoch 425/1000
 - 33s - loss: 19.3002 - mean_squared_error: 19.3002
Epoch 426/1000
 - 33s - loss: 19.2981 - mean_squared_error: 19.2981
Epoch 427/1000
 - 33s - loss: 19.3022 - mean_squared_error: 19.3022
Epoch 428/1000
 - 33s - loss: 19.3017 - mean_squared_error: 19.3017
Epoch 429/1000
 - 33s - loss: 19.3045 - mean_squared_error: 19.3045
Epoch 430/1000
 - 33s - loss: 19.2983 - mean_squared_error: 19.2983
Epoch 431/1000
 - 33s - loss: 19.3044 - mean_squ

Epoch 538/1000
 - 34s - loss: 19.3047 - mean_squared_error: 19.3047
Epoch 539/1000
 - 33s - loss: 19.3072 - mean_squared_error: 19.3072
Epoch 540/1000
 - 33s - loss: 19.3092 - mean_squared_error: 19.3092
Epoch 541/1000
 - 33s - loss: 19.3053 - mean_squared_error: 19.3053
Epoch 542/1000
 - 35s - loss: 19.3093 - mean_squared_error: 19.3093
Epoch 543/1000
 - 35s - loss: 19.3146 - mean_squared_error: 19.3146
Epoch 544/1000
 - 35s - loss: 19.3014 - mean_squared_error: 19.3014
Epoch 545/1000
 - 35s - loss: 19.3043 - mean_squared_error: 19.3043
Epoch 546/1000
 - 35s - loss: 19.3151 - mean_squared_error: 19.3151
Epoch 547/1000
 - 35s - loss: 19.3119 - mean_squared_error: 19.3119
Epoch 548/1000
 - 35s - loss: 19.3078 - mean_squared_error: 19.3078
Epoch 549/1000
 - 35s - loss: 19.2998 - mean_squared_error: 19.2998
Epoch 550/1000
 - 35s - loss: 19.3097 - mean_squared_error: 19.3097
Epoch 551/1000
 - 35s - loss: 19.3083 - mean_squared_error: 19.3083
Epoch 552/1000
 - 35s - loss: 19.3093 - mean_squ

Epoch 659/1000
 - 33s - loss: 19.2964 - mean_squared_error: 19.2964
Epoch 660/1000
 - 33s - loss: 19.3017 - mean_squared_error: 19.3017
Epoch 661/1000
 - 33s - loss: 19.3061 - mean_squared_error: 19.3061
Epoch 662/1000
 - 34s - loss: 19.4716 - mean_squared_error: 19.4716
Epoch 663/1000
 - 33s - loss: 19.2955 - mean_squared_error: 19.2955
Epoch 664/1000
 - 34s - loss: 19.2946 - mean_squared_error: 19.2946
Epoch 665/1000
 - 33s - loss: 19.2928 - mean_squared_error: 19.2928
Epoch 666/1000
 - 33s - loss: 19.2931 - mean_squared_error: 19.2931
Epoch 667/1000
 - 33s - loss: 19.2899 - mean_squared_error: 19.2899
Epoch 668/1000
 - 33s - loss: 19.2946 - mean_squared_error: 19.2946
Epoch 669/1000
 - 33s - loss: 19.2926 - mean_squared_error: 19.2926
Epoch 670/1000
 - 33s - loss: 19.2953 - mean_squared_error: 19.2953
Epoch 671/1000
 - 33s - loss: 19.3110 - mean_squared_error: 19.3110
Epoch 672/1000
 - 33s - loss: 19.2983 - mean_squared_error: 19.2983
Epoch 673/1000
 - 34s - loss: 19.3013 - mean_squ

Epoch 780/1000
 - 33s - loss: 19.2913 - mean_squared_error: 19.2913
Epoch 781/1000
 - 33s - loss: 19.2988 - mean_squared_error: 19.2988
Epoch 782/1000
 - 35s - loss: 19.3047 - mean_squared_error: 19.3047
Epoch 783/1000
 - 35s - loss: 19.2995 - mean_squared_error: 19.2995
Epoch 784/1000
 - 35s - loss: 19.2960 - mean_squared_error: 19.2960
Epoch 785/1000
 - 35s - loss: 19.2896 - mean_squared_error: 19.2896
Epoch 786/1000
 - 35s - loss: 19.3039 - mean_squared_error: 19.3039
Epoch 787/1000
 - 35s - loss: 19.3090 - mean_squared_error: 19.3090
Epoch 788/1000
 - 35s - loss: 19.2984 - mean_squared_error: 19.2984
Epoch 789/1000
 - 35s - loss: 19.2972 - mean_squared_error: 19.2972
Epoch 790/1000
 - 35s - loss: 19.3019 - mean_squared_error: 19.3019
Epoch 791/1000
 - 35s - loss: 19.2921 - mean_squared_error: 19.2921
Epoch 792/1000
 - 35s - loss: 19.2991 - mean_squared_error: 19.2991
Epoch 793/1000
 - 33s - loss: 19.2960 - mean_squared_error: 19.2960
Epoch 794/1000
 - 33s - loss: 19.2977 - mean_squ

Epoch 901/1000
 - 33s - loss: 19.3012 - mean_squared_error: 19.3012
Epoch 902/1000
 - 33s - loss: 19.2999 - mean_squared_error: 19.2999
Epoch 903/1000
 - 33s - loss: 19.2938 - mean_squared_error: 19.2938
Epoch 904/1000
 - 33s - loss: 19.2936 - mean_squared_error: 19.2936
Epoch 905/1000
 - 34s - loss: 19.2979 - mean_squared_error: 19.2979
Epoch 906/1000
 - 34s - loss: 19.2939 - mean_squared_error: 19.2939
Epoch 907/1000
 - 33s - loss: 19.2964 - mean_squared_error: 19.2964
Epoch 908/1000
 - 33s - loss: 19.2960 - mean_squared_error: 19.2960
Epoch 909/1000
 - 34s - loss: 19.3048 - mean_squared_error: 19.3048
Epoch 910/1000
 - 33s - loss: 19.2940 - mean_squared_error: 19.2940
Epoch 911/1000
 - 34s - loss: 19.2937 - mean_squared_error: 19.2937
Epoch 912/1000
 - 33s - loss: 19.2918 - mean_squared_error: 19.2918
Epoch 913/1000
 - 33s - loss: 19.3000 - mean_squared_error: 19.3000
Epoch 914/1000
 - 33s - loss: 19.2946 - mean_squared_error: 19.2946
Epoch 915/1000
 - 33s - loss: 19.2998 - mean_squ

In [None]:
predict()