# End to end Model - Training

### Pulkit Mathur, Vaibhav Saxena, Sumeet Ranka, Duc Thong

In [2]:
%matplotlib inline
import matplotlib.pylab as plt
import numpy as np
import sys
from keras.optimizers import  Adam

MY_UTILS_PATH = "/home/sumeet_ranka47_gmail_com/Hybrid-CS-Model-MRI/Modules"
if not MY_UTILS_PATH in sys.path:
    sys.path.append(MY_UTILS_PATH)
import frequency_spatial_network as fsnet

# Importing callbacks and data augmentation utils
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.preprocessing.image import ImageDataGenerator

## Loading Data

In [None]:
def imagetokspace(image):
    s = image.shape
    kpace = np.zeros((s[0],s[1],2))
    temp = np.fft.fft2(image)
    kpace[:,:,0] = temp.real
    kpace[:,:,1] = temp.imag
    return kspace

#validation data
mi_val = pickle.load()
oi_val = pickle.load()
mk_val = imagetokspace(mi_val)
mi_val = None
ok_val = imagetokspace(mk_val)

# training data
x_train = pickle.load()
y_train = pickle.load()

## Loading stats

In [None]:
stats = np.load("/home/sumeet_ranka47_gmail_com/git/fastMRI/training_data_stats20.npy")
# stats[0], stats[1] are mean and std of masked k-space in training data respectively
# stats[2], stats[3] are mean and std of actual image in training data respectively

## Loading Model

In [6]:
epochs = 50
batch_size= 16
model = fsnet.wnet(stats[0],stats[1],stats[2],stats[3],\
                                   kshape = (5,5),kshape2=(3,3))
opt = Adam(lr=1e-3,decay = 1e-7)
model.compile(loss = [fsnet.nrmse,fsnet.nrmse],optimizer=opt, loss_weights=[0.01, 0.99])

model_name = "/home/sumeet_ranka47_gmail_com/Hybrid-CS-Model-MRI/Models/wnet_" + under_rate + ".hdf5"
if os.path.isfile(model_name):
    print("weights loaded")
    model.load_weights(model_name)

print(model.summary())

# Early stopping callback to shut down training after
#10 epochs with no improvement
earlyStopping = EarlyStopping(monitor='val_loss',
                                       patience=20, 
                                       verbose=0, mode='min')

# Checkpoint callback to save model  along the epochs
checkpoint = ModelCheckpoint(model_name, mode = 'min', \
                             monitor='val_loss',verbose=0,\
                             save_best_only=True, save_weights_only = True)

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 256, 256, 2)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 256, 256, 48) 2448        input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 256, 256, 48) 57648       conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 256, 256, 48) 57648       conv2d_2[0][0]                   
__________________________________________________________________________________________________
max_poolin

In [None]:
def generator(blurred_images, actual_images, batch_size):
    batch_masked_kspace = np.zeros((batch_size, 320, 320, 2))
    batch_original_kspace = np.zeros((batch_size, 320, 320, 2))
    batch_actual_images = np.zeros((batch_size,320,320))
    while True:
        for i in range(batch_size):
            index= random.choice(len(blurred_images),1)
            batch_masked_kspace[i] = imagetokspace(blurred_images[index])
            batch_original_kspace[i] = imagetokspace(actual_images[index])
            batch_actual_images[i] = actual_images[index]
    yield batch_masked_kspace, [batch_original_kspace,batch_actual_images]
    
# creating generator
gen = generator(x_train, y_train, batch_size)


# sample data from generator
for ii in gen:
    print(ii[1][1].shape)
    plt.figure()
    plt.subplot(121)
    plt.imshow(ii[1][1][10,:,:,0],cmap = 'gray')
    plt.axis("off")
    plt.subplot(122)
    plt.imshow(np.log(1+np.abs(ii[1][0][10,:,:,0] + 1j*ii[1][0][8,:,:,1])),cmap = 'gray')
    plt.axis("off")
    plt.show()
    break

## Train model

In [None]:
hist = model.fit_generator(gen,
                           steps_per_epoch=x_train.shape[0] / batch_size,
                           epochs=epochs,
                           verbose=1,
                           validation_data = (mk_val,[ok_val, oi_val]) 
                           callbacks=[checkpoint,earlyStopping])