In [1]:
#%load_ext memory_profiler

In [2]:
#%memit

In [3]:
import numpy as np
import math
import scipy.io

import os
import matplotlib.pyplot as plt
import time
from skimage.measure import compare_ssim as ssim

from IPython import display
from keras import regularizers
from keras import optimizers
from keras.models import Sequential, Model, load_model
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.layers import Input, Conv2D, Conv3D, Conv3DTranspose, Lambda, Reshape, Add, MaxPooling2D, UpSampling2D, Subtract, Activation
from keras.layers.merge import Concatenate
import keras.backend as K

# dataloading
data_file_path = "./radon_spherique"

# Training parameters
batch_size = 32
num_epochs = 1000
Nk = 32 # sensor
Nr = 256 # time
N = 128
tag= 'Unet'
#tag= 'dualUnet'
#tag= 'tightUnet'

# Randomly pick stddev
np.random.seed(10)

def psnr(y_true,y_pred):      
    y_pred = y_pred - K.min(y_pred)
    y_pred = y_pred / K.max(y_pred)
    y_true = y_true - K.min(y_true)
    y_true = y_true / K.max(y_true)
    return 10.0 * K.log(1.0 / (K.mean(K.square(y_pred - y_true)))) / K.log(10.0)
    #return -10. * np.log10(K.mean(K.square(y_pred - y_true)))
    
def psnr_img(y_true,y_pred):
    y_true = y_true-np.min(y_true)
    y_true = y_true/np.max(y_true)
    y_pred = y_pred-np.min(y_pred)
    y_pred= y_pred/np.max(y_pred)
    mse = np.mean( (y_true - y_pred) ** 2 )
    if mse == 0:
        return 100
    PIXEL_MAX = 1.0
    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))

Using TensorFlow backend.


In [9]:
# save and load best model
MODEL_DIR = "Z:/jj_git/tf-timereversal/ckpt-img2img-Radon-SLtype-overlap-10-4000-Nk%s-Nr%s-N%s.mat" % (Nk, Nr, N)
print(MODEL_DIR)
if not os.path.isdir(MODEL_DIR):
    os.makedirs(MODEL_DIR)
if tag is "tightUnet":
    print(tag)
    # Build model - tightUnet
    input_shape = (N,N,1)
    model_input = Input(shape=input_shape)
    conv1 = Conv2D(64, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block1_conv1')(model_input)
    conv1 = Conv2D(64, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block1_conv2')(conv1)
    conv1_1 = Reshape((K.int_shape(conv1)[1],K.int_shape(conv1)[2],K.int_shape(conv1)[3],1))(conv1)
    wavdec1 = Conv3D(4, (2, 2, 1), 
               use_bias=False, padding="valid",activation="relu",
               strides=(2, 2, 1),kernel_initializer='glorot_uniform',
               name='wavdec1',trainable=False)(conv1_1)
    wavdec1_LL = Lambda(lambda x : x[:,:,:,:,0])(wavdec1)
    wavdec1_LH = Lambda(lambda x : x[:,:,:,:,1])(wavdec1)
    wavdec1_HL = Lambda(lambda x : x[:,:,:,:,2])(wavdec1)
    wavdec1_HH = Lambda(lambda x : x[:,:,:,:,3])(wavdec1)
    conv2 = Conv2D(128, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block2_conv1')(wavdec1_LL)
    conv2 = Conv2D(128, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block2_conv2')(conv2)
    conv2_1 = Reshape((K.int_shape(conv2)[1],K.int_shape(conv2)[2],K.int_shape(conv2)[3],1))(conv2)
    wavdec2 = Conv3D(4, (2, 2, 1), 
               use_bias=False, padding="same",activation="relu",
               strides=(2, 2, 1),kernel_initializer='glorot_uniform',
               name='wavdec2',trainable=False)(conv2_1)        
    wavdec2_LL = Lambda(lambda x : x[:,:,:,:,0])(wavdec2)
    wavdec2_LH = Lambda(lambda x : x[:,:,:,:,1])(wavdec2)
    wavdec2_HL = Lambda(lambda x : x[:,:,:,:,2])(wavdec2)
    wavdec2_HH = Lambda(lambda x : x[:,:,:,:,3])(wavdec2)

    conv3 = Conv2D(256, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block3_conv1')(wavdec2_LL)
    conv3 = Conv2D(256, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block3_conv2')(conv3)
    conv3_1 = Reshape((K.int_shape(conv3)[1],K.int_shape(conv3)[2],K.int_shape(conv3)[3],1))(conv3)
    wavdec3 = Conv3D(4, (2, 2, 1), 
               use_bias=False, padding="same",activation="relu",
               strides=(2, 2, 1),kernel_initializer='glorot_uniform',
               name='wavdec3',trainable=False)(conv3_1)
    wavdec3_LL = Lambda(lambda x : x[:,:,:,:,0])(wavdec3)
    wavdec3_LH = Lambda(lambda x : x[:,:,:,:,1])(wavdec3)
    wavdec3_HL = Lambda(lambda x : x[:,:,:,:,2])(wavdec3)
    wavdec3_HH = Lambda(lambda x : x[:,:,:,:,3])(wavdec3)
   

    conv4 = Conv2D(512, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block4_conv1')(wavdec3_HH)
    conv4 = Conv2D(512, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block4_conv2')(conv4)
    conv4_1 = Reshape((K.int_shape(conv4)[1],K.int_shape(conv4)[2],K.int_shape(conv4)[3],1))(conv4)
    wavdec4 = Conv3D(4, (2, 2, 1), 
               use_bias=False, padding="same",activation="relu",
               strides=(2, 2, 1),kernel_initializer='glorot_uniform',
               name='wavdec4',trainable=False)(conv4_1)       
    wavdec4_LL = Lambda(lambda x : x[:,:,:,:,0])(wavdec4)
    wavdec4_LH = Lambda(lambda x : x[:,:,:,:,1])(wavdec4)
    wavdec4_HL = Lambda(lambda x : x[:,:,:,:,2])(wavdec4)
    wavdec4_HH = Lambda(lambda x : x[:,:,:,:,3])(wavdec4)

    conv5 = Conv2D(1024, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block5_conv1')(wavdec4_HH)
    conv5 = Conv2D(512, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block5_conv2')(conv5)    
    conv5 = Reshape((K.int_shape(conv5)[1],K.int_shape(conv5)[2],K.int_shape(conv5)[3],1))(conv5)
    wavrec4 = Conv3DTranspose(4, (2, 2, 1), 
               use_bias=False, padding="same",activation="relu",
               strides=(2, 2, 1),kernel_initializer='glorot_uniform',
               name='wavrec4',trainable=False)(conv5)
    wavrec4 = Reshape((K.int_shape(wavrec4)[1],K.int_shape(wavrec4)[2],K.int_shape(wavrec4)[3]*K.int_shape(wavrec4)[4]))(wavrec4)        
    merge1 = Concatenate()([conv4,wavrec4])
    
    conv6 = Conv2D(512, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block6_conv1')(merge1)
    conv6 = Conv2D(256, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block6_conv2')(conv6)
    conv6 = Reshape((K.int_shape(conv6)[1],K.int_shape(conv6)[2],K.int_shape(conv6)[3],1))(conv6)
    wavrec3 = Conv3DTranspose(4, (2, 2, 1), 
               use_bias=False, padding="same",activation="relu",
               strides=(2, 2, 1),kernel_initializer='glorot_uniform',
               name='wavrec3',trainable=False)(conv6)
    wavrec3 = Reshape((K.int_shape(wavrec3)[1],K.int_shape(wavrec3)[2],K.int_shape(wavrec3)[3]*K.int_shape(wavrec3)[4]))(wavrec3)   
    merge2 = Concatenate()([conv3,wavrec3])

    conv7 = Conv2D(256, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block7_conv1')(merge2)
    conv7 = Conv2D(128, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block7_conv2')(conv7)
    conv7 = Reshape((K.int_shape(conv7)[1],K.int_shape(conv7)[2],K.int_shape(conv7)[3],1))(conv7)
    wavrec2 = Conv3DTranspose(4, (2, 2, 1), 
               use_bias=False, padding="same",activation="relu",
               strides=(2, 2, 1),kernel_initializer='glorot_uniform',
               name='wavrec2',trainable=False)(conv7)
    wavrec2 = Reshape((K.int_shape(wavrec2)[1],K.int_shape(wavrec2)[2],K.int_shape(wavrec2)[3]*K.int_shape(wavrec2)[4]))(wavrec2) 
    merge3 = Concatenate()([conv2,wavrec2])

    conv8 = Conv2D(128, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block8_conv1')(merge3)
    conv8 = Conv2D(64, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block8_conv2')(conv8)
    conv8 = Reshape((K.int_shape(conv8)[1],K.int_shape(conv8)[2],K.int_shape(conv8)[3],1))(conv8)
    wavrec1 = Conv3DTranspose(4, (2, 2, 1), 
               use_bias=False, padding="same",activation="relu",
               strides=(2, 2, 1),kernel_initializer='glorot_uniform',
               name='wavrec1',trainable=False)(conv8)
    wavrec1 = Reshape((K.int_shape(wavrec1)[1],K.int_shape(wavrec1)[2],K.int_shape(wavrec1)[3]*K.int_shape(wavrec1)[4]))(wavrec1)
    
    merge4 = Concatenate()([conv1,wavrec1])

    conv9 = Conv2D(64, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block9_conv1')(merge4)
    conv9 = Conv2D(64, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block9_conv2')(conv9)

    model_output = Conv2D(1, (1, 1), 
               use_bias=False, padding="same",activation="linear",
               strides=1,kernel_initializer='glorot_uniform',
               name='block9_conv3')(conv9)
if tag is "dualUnet":
    print(tag)
    # Build model - dual Unet
    input_shape = (N,N,1)
    model_input = Input(shape=input_shape)
    conv1 = Conv2D(64, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block1_conv1')(model_input)
    conv1 = Conv2D(64, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block1_conv2')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = Conv2D(128, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block2_conv1')(pool1)
    conv2 = Conv2D(128, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block2_conv2')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = Conv2D(256, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block3_conv1')(pool2)
    conv3 = Conv2D(256, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block3_conv2')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(512, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block4_conv1')(pool3)
    conv4 = Conv2D(512, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block4_conv2')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(1024, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block5_conv1')(pool4)
    conv5 = Conv2D(512, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block5_conv2')(conv5)
    resid1 = Subtract()([conv5,pool4])
    up1 = UpSampling2D(size = (2,2))(resid1)
    merge1 = Concatenate()([conv4,up1])
    conv6 = Conv2D(512, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block6_conv1')(merge1)
    conv6 = Conv2D(256, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block6_conv2')(conv6)
    resid2 = Subtract()([conv6,pool3])
    up2 = UpSampling2D(size = (2,2))(resid2)
    merge2 = Concatenate()([conv3,up2])

    conv7 = Conv2D(256, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block7_conv1')(merge2)
    conv7 = Conv2D(128, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block7_conv2')(conv7)
    resid3 = Subtract()([conv7,pool2])
    up3 = UpSampling2D(size = (2,2))(resid3)
    merge3 = Concatenate()([conv2,up3])

    conv8 = Conv2D(128, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block8_conv1')(merge3)
    conv8 = Conv2D(64, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block8_conv2')(conv8)
    resid4 = Subtract()([conv8,pool1])
    up4 = UpSampling2D(size = (2,2))(resid4)
    merge4 = Concatenate()([conv1,up4])

    conv9 = Conv2D(64, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block9_conv1')(merge4)
    conv9 = Conv2D(64, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block9_conv2')(conv9)

    model_output = Conv2D(1, (1, 1), 
               use_bias=False, padding="same",activation="linear",
               strides=1,kernel_initializer='glorot_uniform',
               name='block9_conv3')(conv9)
if tag is "Unet":
    print(tag)
    # Build model - Unet
    input_shape = (N,N,1)
    model_input = Input(shape=input_shape)
    conv1 = Conv2D(64, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block1_conv1')(model_input)
    conv1 = Conv2D(64, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block1_conv2')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = Conv2D(128, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block2_conv1')(pool1)
    conv2 = Conv2D(128, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block2_conv2')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = Conv2D(256, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block3_conv1')(pool2)
    conv3 = Conv2D(256, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block3_conv2')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(512, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block4_conv1')(pool3)
    conv4 = Conv2D(512, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block4_conv2')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(1024, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block5_conv1')(pool4)
    conv5 = Conv2D(512, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block5_conv2')(conv5)
    up1 = UpSampling2D(size = (2,2))(conv5)
    merge1 = Concatenate()([conv4,up1])
    conv6 = Conv2D(512, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block6_conv1')(merge1)
    conv6 = Conv2D(256, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block6_conv2')(conv6)
    up2 = UpSampling2D(size = (2,2))(conv6)
    merge2 = Concatenate()([conv3,up2])

    conv7 = Conv2D(256, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block7_conv1')(merge2)
    conv7 = Conv2D(128, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block7_conv2')(conv7)
    up3 = UpSampling2D(size = (2,2))(conv7)
    merge3 = Concatenate()([conv2,up3])

    conv8 = Conv2D(128, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block8_conv1')(merge3)
    conv8 = Conv2D(64, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block8_conv2')(conv8)
    up4 = UpSampling2D(size = (2,2))(conv8)
    merge4 = Concatenate()([conv1,up4])

    conv9 = Conv2D(64, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block9_conv1')(merge4)
    conv9 = Conv2D(64, (3, 3), 
               use_bias=False, padding="same",activation="relu",
               strides=1,kernel_initializer='glorot_uniform',
               name='block9_conv2')(conv9)

    model_output = Conv2D(1, (1, 1), 
               use_bias=False, padding="same",activation="linear",
               strides=1,kernel_initializer='glorot_uniform',
               name='block9_conv3')(conv9)

model = Model(model_input, model_output)
model.summary()

Z:/jj_git/tf-timereversal/ckpt-img2img-Radon-SLtype-overlap-10-4000-Nk32-Nr256-N128.mat
Unet
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 128, 128, 1)  0                                            
__________________________________________________________________________________________________
block1_conv1 (Conv2D)           (None, 128, 128, 64) 576         input_1[0][0]                    
__________________________________________________________________________________________________
block1_conv2 (Conv2D)           (None, 128, 128, 64) 36864       block1_conv1[0][0]               
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 64, 64, 64)   0           block1_conv2[0][0]               
________________

In [None]:
if tag is 'tightUnet':
    Harr_wav_L = 1/np.sqrt(2) * np.ones((1,2))
    Harr_wav_H = 1/np.sqrt(2) * np.ones((1,2))
    Harr_wav_H[0,0] = Harr_wav_H[0,0]*-1

    Harr_wav_LL = np.transpose(Harr_wav_L)*Harr_wav_L
    Harr_wav_LH = np.transpose(Harr_wav_L)*Harr_wav_H
    Harr_wav_HL = np.transpose(Harr_wav_H)*Harr_wav_L
    Harr_wav_HH = np.transpose(Harr_wav_H)*Harr_wav_H
    print(Harr_wav_LL,'\n',Harr_wav_LH,'\n',Harr_wav_HL,'\n',Harr_wav_HH)

    Harr_wav_filters = np.stack((Harr_wav_LL,Harr_wav_LH, Harr_wav_HL,Harr_wav_HH),axis = 2)
    Harr_wav_filters = Harr_wav_filters.reshape((1,2,2,1,1,4))

    layer_dict = dict([(layer.name, layer) for layer in model.layers])
    layer_dict['wavdec1'].set_weights(Harr_wav_filters)
    layer_dict['wavdec2'].set_weights(Harr_wav_filters)
    layer_dict['wavdec3'].set_weights(Harr_wav_filters)
    layer_dict['wavdec4'].set_weights(Harr_wav_filters)
    Harr_wav_filters = Harr_wav_filters.reshape((1,2,2,1,4,1))
    layer_dict['wavrec1'].set_weights(Harr_wav_filters)
    layer_dict['wavrec2'].set_weights(Harr_wav_filters)
    layer_dict['wavrec3'].set_weights(Harr_wav_filters)
    layer_dict['wavrec4'].set_weights(Harr_wav_filters)
    np.array(layer_dict['wavdec1'].get_weights())[0,:,:,0,0,0]

In [None]:
# Compile
model.compile(loss='mean_squared_error', optimizer="adam", metrics=['acc', psnr])
checkpoint1 = ModelCheckpoint(filepath = os.path.join(MODEL_DIR, "model-best-%s-psnr.h5" % tag),
                             monitor='val_psnr', verbose=0, save_best_only=True, mode ='max')
checkpoint2 = ModelCheckpoint(filepath = os.path.join(MODEL_DIR, "model-best-%s-acc.h5" % tag),
                             monitor='val_acc', verbose=0, save_best_only=True, mode ='auto')
earlystop = EarlyStopping(monitor='val_acc', min_delta=0, patience=20, verbose=1, mode='auto')

# Train the model
history = model.fit(x_train, y_train, batch_size=batch_size, epochs=num_epochs,
          verbose=1, validation_data=(x_test, y_test,), callbacks = [checkpoint1,checkpoint2,earlystop])

Wall time: 0 ns
Train on 3000 samples, validate on 1000 samples
Epoch 1/1000
Epoch 2/1000

In [None]:
# Load model  
model = load_model(MODEL_DIR+'/model-best-%s-psnr.h5' % tag, custom_objects={"psnr": psnr})
score = model.evaluate(x_test,y_test, batch_size=batch_size, verbose = 1)
print(model.metrics_names)
print("Test loss: ", score[0])
print("Test accuracy: ", score[1])
print("Test PSNR: ", score[2])