In [1]:
from keras.models import Sequential
from keras.layers import Conv2D, Input, BatchNormalization
from keras.callbacks import ModelCheckpoint, Callback
from keras.optimizers import SGD, Adam
from keras.utils.training_utils import multi_gpu_model
from keras.preprocessing.image import ImageDataGenerator
from keras import initializers
import prepare_data as pd
import numpy
import math
import scipy.misc as spm

import matplotlib.image as mpimg
import matplotlib.pyplot as plt

from PIL import Image

from lr_multiplier import LearningRateMultiplier

import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.2
set_session(tf.Session(config=config))

Using TensorFlow backend.


In [2]:
# see example https://github.com/keras-team/keras/issues/8649
class MyCbk(Callback):

    def __init__(self, model):
         self.model_to_save = model

    def on_epoch_end(self, epoch, logs=None):
        self.model_to_save.save('checkpoints/64-9-3-5_128-64-SRCNN_model_at_epoch_%d.h5' % epoch)


In [3]:
def psnr(target, ref):

    target_data = numpy.array(target, dtype=float)
    ref_data = numpy.array(ref, dtype=float)

    diff = ref_data - target_data
    diff = diff.flatten('C')

    rmse = math.sqrt(numpy.mean(diff ** 2.))

    return 20 * math.log10(255. / rmse)

In [4]:
def model(optimizer='SGD', all_layers=True, processing='GPU', lr=0.0001):

    multipliers = {'conv2d_1': 1, 'conv2d_2': 1, 'conv2d_3': 0.1}
    
    optimizer = optimizer.upper()
    processing = processing.upper()
    
    print('Using {}.'.format(optimizer))

    if optimizer == 'SGD' and all_layers:
        opt = LearningRateMultiplier(SGD, lr_multipliers=multipliers, lr=lr)
    elif optimizer == 'SGD':
        opt = SGD(lr=0.0001)
    elif all_layers:
        opt = LearningRateMultiplier(Adam, lr_multipliers=multipliers, lr=lr)
    else:
        opt = Adam(lr=lr)
    
    SRCNN = Sequential()
    SRCNN.add(Conv2D(filters=128, kernel_size=(9, 9), 
                     kernel_initializer='glorot_uniform', bias_initializer='random_uniform',
                     activation='relu', padding='valid', use_bias=True, input_shape=(64, 64, 1), name='conv2d_1'))

    SRCNN.add(Conv2D(filters=64, kernel_size=(3, 3), 
                     kernel_initializer='glorot_uniform', bias_initializer='random_uniform',
                     activation='relu', padding='same', use_bias=True, name='conv2d_2'))
    
    SRCNN.add(Conv2D(filters=1, kernel_size=(5, 5),
                     kernel_initializer='glorot_uniform', bias_initializer='random_uniform',
                     activation='linear', padding='valid', use_bias=True,  name='conv2d_3'))
    SRCNN_MULTI = multi_gpu_model(SRCNN, gpus=8)
    if processing == 'GPU':    
        SRCNN_MULTI.compile(optimizer=opt, loss='mean_squared_error', metrics=['mean_squared_error'])
    else:
        SRCNN.compile(optimizer=opt, loss='mean_squared_error', metrics=['mean_squared_error'])
    return SRCNN, SRCNN_MULTI

In [5]:
def predict_model():

    sgd = SGD(lr=0.0001)
    
    SRCNN = Sequential()
    SRCNN.add(Conv2D(filters=128, kernel_size=(9, 9), 
                     kernel_initializer='glorot_uniform', bias_initializer='random_uniform',
                     activation='relu', padding='valid', use_bias=True, input_shape=(None, None, 1), name='conv2d_1'))
        
    SRCNN.add(Conv2D(filters=64, kernel_size=(3, 3), 
                     kernel_initializer='glorot_uniform', bias_initializer='random_uniform',
                     activation='relu', padding='same', use_bias=True, name='conv2d_2'))
        
    SRCNN.add(Conv2D(filters=1, kernel_size=(5, 5),
                     kernel_initializer='glorot_uniform', bias_initializer='random_uniform',
                     activation='linear', padding='valid', use_bias=True,  name='conv2d_3'))
    adam = Adam(lr=0.0003)
    sgd_last = SGD(lr=0.00001)

    SRCNN.compile(optimizer=adam, loss='mean_squared_error', metrics=['mean_squared_error'])
    return SRCNN

In [6]:
def train(processing='GPU', optimizer='SGD', use_data_augmentation=False, epochs=200, lr=0.0001):

    optimizer = optimizer.upper()
    processing = processing.upper()
    
    srcnn_model, srcnn_multi_model = model(optimizer=optimizer, processing=processing, lr=lr)

    callbacks_list = []
    BS = 1024
    
    if processing == 'GPU' or use_data_augmentation:    
        print(srcnn_multi_model.summary())
        # see example https://github.com/keras-team/keras/issues/8649
        cbk = MyCbk(srcnn_model)
        
    else:
        print(srcnn_model.summary())
        checkpoint = ModelCheckpoint("64SRCNN_check.h5", monitor='val_loss', 
                                     verbose=2, 
                                     save_best_only=True,
                                     save_weights_only=False)
        callbacks_list.append(checkpoint)
        
    data, label = pd.read_training_data("./64crop_train.h5")
    val_data, val_label = pd.read_training_data("./64test.h5")

    max_value = data.max()

    data = data/max_value
    label = label/max_value

    val_data = val_data/max_value
    val_label = val_label/max_value

    
    # construct the training image generator for data augmentation
    datagen = ImageDataGenerator(rotation_range=45,
                                 zoom_range=0.15,
                                 width_shift_range=0.2,
                                 height_shift_range=0.2,
                                 horizontal_flip=True, 
                                 vertical_flip=True)

    if use_data_augmentation and processing == 'GPU':
        history = srcnn_multi_model.fit_generator(datagen.flow(data, label, batch_size=BS),
                                                    validation_data=(val_data, val_label),
                                                    steps_per_epoch=len(data) // BS,
                                                    epochs=epochs, 
                                                    verbose=2, 
                                                    shuffle=True,
                                                    callbacks=[cbk],
                                                    use_multiprocessing=True)
    elif processing == 'GPU':
        srcnn_multi_model.fit(data, label, batch_size=BS,
                              validation_split=0.2,
                              shuffle=True, 
                              epochs=epochs, 
                              verbose=2,
                              callbacks=[cbk])
    elif use_data_augmentation:
        history = srcnn_model.fit_generator(datagen.flow(data, label, batch_size=BS),
                                          validation_data=(val_data, val_label),
                                          steps_per_epoch=len(data) // BS,
                                          epochs=epochs, 
                                          verbose=2, 
                                          shuffle=True,
                                          callbacks=[cbk])
    else:
        srcnn_model.fit(data, label, batch_size=BS,
                        validation_split=0.2,
                        shuffle=True, 
                        epochs=epochs, 
                        verbose=2,
                        callbacks=callbacks_list)

In [7]:
data, label = pd.read_training_data("./64crop_train.h5")
val_data, val_label = pd.read_training_data("./64test.h5")

max_value = data.max()

data = data/max_value
label = label/max_value

val_data = val_data/max_value
val_label = val_label/max_value

In [8]:
def predict():

    srcnn_model = predict_model()
    srcnn_model.load_weights("checkpoints/64-9-3-5_128-64-SRCNN_model_at_epoch_999.h5")
    IMG_NAME = numpy.fromfile('ground_data/pick/recon/slice_256.b', dtype = 'uint8')
    INPUT_NAME = numpy.fromfile("sub_data/pick/recon/slice_256.b", dtype = 'uint8')
    OUTPUT_NAME = "pre2.png"

    dimension_img = int(IMG_NAME.shape[0]**(1/2))
    dimension_input = int(INPUT_NAME.shape[0]**(1/2))
    
    IMG_NAME = IMG_NAME.reshape((dimension_img, dimension_img))
    INPUT_NAME = INPUT_NAME.reshape((dimension_input, dimension_input))
    
    import cv2
    img = INPUT_NAME.copy()
    shape = IMG_NAME.shape
    Y_img = img.copy()
    Y_img = spm.imresize(Y_img, size=shape, interp='bicubic')
    bicubic = Y_img.copy()

    Y_img = Y_img.reshape((dimension_img, dimension_img, 1))
    
    for i in range(Y_img.shape[-1]):
        Y = numpy.zeros((1, shape[0], shape[1], 1), dtype=float)
        Y[0,:,:,0] = Y_img[:,:,i].astype(float) / 255.
        pre = srcnn_model.predict(Y, batch_size=1) * 255.
        
    pre[pre[:] > 255] = 255
    pre[pre[:] < 0] = 0
    pre = pre.astype(float)
    print(Y_img.shape, pre.shape)
    Y_img[6: -6, 6: -6, 0] = pre[0, :, :, 0].copy()
    cv2.imwrite(OUTPUT_NAME, Y_img)

    # psnr calculation:
    im1 = IMG_NAME
    im2 = bicubic
    im3 = cv2.imread(OUTPUT_NAME)[:,:,0]
    
    print(im1.shape, im2.shape, im3.shape)

    print ("bicubic:")
    print (cv2.PSNR(im1, im2))
    print ("SRCNN:")
    print (cv2.PSNR(im1, im3))
    
    import matplotlib.pyplot as plt
    print(im1.shape)
    plt.figure(figsize=[15,30])
    plt.subplot(311)
    
    plt.title('Original')
    plt.imshow(im1)
    plt.subplot(312)

    plt.title('Bicubic')    
    plt.imshow(im2)
    plt.subplot(313)

    plt.title('SRCNN')
    plt.imshow(im3[10:-10,10:-10])
    plt.show()

In [None]:
train(processing='GPU', optimizer='adam', use_data_augmentation=False, epochs=1000, lr=0.0001)

Using ADAM.
Instructions for updating:
Colocations handled automatically by placer.
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
conv2d_1_input (InputLayer)     (None, 64, 64, 1)    0                                            
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 64, 64, 1)    0           conv2d_1_input[0][0]             
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, 64, 64, 1)    0           conv2d_1_input[0][0]             
__________________________________________________________________________________________________
lambda_3 (Lambda)               (None, 64, 64, 1)    0           conv2d_1_input[0][0]             
_________________________

Epoch 33/1000
 - 18s - loss: 3.1749e-04 - mean_squared_error: 3.1749e-04 - val_loss: 2.5462e-05 - val_mean_squared_error: 2.5462e-05
Epoch 34/1000
 - 16s - loss: 3.1741e-04 - mean_squared_error: 3.1741e-04 - val_loss: 2.5796e-05 - val_mean_squared_error: 2.5796e-05
Epoch 35/1000
 - 18s - loss: 3.1751e-04 - mean_squared_error: 3.1751e-04 - val_loss: 2.5611e-05 - val_mean_squared_error: 2.5611e-05
Epoch 36/1000
 - 17s - loss: 3.1741e-04 - mean_squared_error: 3.1741e-04 - val_loss: 2.6029e-05 - val_mean_squared_error: 2.6029e-05
Epoch 37/1000
 - 18s - loss: 3.1756e-04 - mean_squared_error: 3.1756e-04 - val_loss: 2.5467e-05 - val_mean_squared_error: 2.5467e-05
Epoch 38/1000
 - 17s - loss: 3.1749e-04 - mean_squared_error: 3.1749e-04 - val_loss: 2.5495e-05 - val_mean_squared_error: 2.5495e-05
Epoch 39/1000
 - 17s - loss: 3.1757e-04 - mean_squared_error: 3.1757e-04 - val_loss: 2.5675e-05 - val_mean_squared_error: 2.5675e-05
Epoch 40/1000
 - 16s - loss: 3.1748e-04 - mean_squared_error: 3.1748e

Epoch 95/1000
 - 18s - loss: 3.1714e-04 - mean_squared_error: 3.1714e-04 - val_loss: 2.6590e-05 - val_mean_squared_error: 2.6590e-05
Epoch 96/1000
 - 16s - loss: 3.1732e-04 - mean_squared_error: 3.1732e-04 - val_loss: 2.6061e-05 - val_mean_squared_error: 2.6061e-05
Epoch 97/1000
 - 18s - loss: 3.1708e-04 - mean_squared_error: 3.1708e-04 - val_loss: 2.5743e-05 - val_mean_squared_error: 2.5743e-05
Epoch 98/1000
 - 17s - loss: 3.1730e-04 - mean_squared_error: 3.1730e-04 - val_loss: 2.9439e-05 - val_mean_squared_error: 2.9439e-05
Epoch 99/1000
 - 18s - loss: 3.1731e-04 - mean_squared_error: 3.1731e-04 - val_loss: 2.5711e-05 - val_mean_squared_error: 2.5711e-05
Epoch 100/1000
 - 17s - loss: 3.1664e-04 - mean_squared_error: 3.1664e-04 - val_loss: 2.5791e-05 - val_mean_squared_error: 2.5791e-05
Epoch 101/1000
 - 17s - loss: 3.1713e-04 - mean_squared_error: 3.1713e-04 - val_loss: 2.6255e-05 - val_mean_squared_error: 2.6255e-05
Epoch 102/1000
 - 17s - loss: 3.1716e-04 - mean_squared_error: 3.17

Epoch 157/1000
 - 18s - loss: 3.1547e-04 - mean_squared_error: 3.1547e-04 - val_loss: 2.7127e-05 - val_mean_squared_error: 2.7127e-05
Epoch 158/1000
 - 17s - loss: 3.1542e-04 - mean_squared_error: 3.1542e-04 - val_loss: 2.6639e-05 - val_mean_squared_error: 2.6639e-05
Epoch 159/1000
 - 18s - loss: 3.1519e-04 - mean_squared_error: 3.1519e-04 - val_loss: 2.5603e-05 - val_mean_squared_error: 2.5603e-05
Epoch 160/1000
 - 17s - loss: 3.1531e-04 - mean_squared_error: 3.1531e-04 - val_loss: 2.6019e-05 - val_mean_squared_error: 2.6019e-05
Epoch 161/1000
 - 18s - loss: 3.1523e-04 - mean_squared_error: 3.1523e-04 - val_loss: 2.5988e-05 - val_mean_squared_error: 2.5988e-05
Epoch 162/1000
 - 17s - loss: 3.1530e-04 - mean_squared_error: 3.1530e-04 - val_loss: 2.6570e-05 - val_mean_squared_error: 2.6570e-05
Epoch 163/1000
 - 18s - loss: 3.1523e-04 - mean_squared_error: 3.1523e-04 - val_loss: 2.5459e-05 - val_mean_squared_error: 2.5459e-05
Epoch 164/1000
 - 17s - loss: 3.1528e-04 - mean_squared_error:

In [None]:
predict()