# Training the MoDL network 


In [None]:
# -*- coding: utf-8 -*-
"""
This is the training code to train the model as described in the following article:

MoDL: Model-Based Deep Learning Architecture for Inverse Problems
by H.K. Aggarwal, M.P. Mani, M. Jacob from University of Iowa.

Paper dwonload  Link:     https://arxiv.org/abs/1712.02862

This code solves the following optimization problem:

    argmin_x ||Ax-b||_2^2 + ||x-Dw(x)||^2_2

 'A' can be any measurement operator. Here we consider parallel imaging problem in MRI where
 the A operator consists of undersampling mask, FFT, and coil sensitivity maps.

Dw(x): it represents the residual learning CNN.

Here is the description of the parameters that you can modify below.

epochs: how many times to pass through the entire dataset

nLayer: number of layers of the convolutional neural network.
        Each layer will have filters of size 3x3. There will be 64 such filters
        Except at the first and the last layer.

gradientMethod: MG or AG. set MG for 'manual gradient' of conjuagate gradient (CG) block
                as discussed in section 3 of the above paper. Set it to AG if
                you want to rely on the tensorflow to calculate gradient of CG.

K: it represents the number of iterations of the alternating strategy as
    described in Eq. 10 in the paper.  Also please see Fig. 1 in the above paper.
    Higher value will require a lot of GPU memory. Set the maximum value to 20
    for a GPU with 16 GB memory. Higher the value more is the time required in training.

sigma: the standard deviation of Gaussian noise to be added in the k-space

batchSize: You can reduce the batch size to 1 if the model does not fit on GPU.

Output:

After running the code the output model will be saved in the subdirectory 'savedModels'.
You can give the name of the generated ouput directory in the tstDemo.py to
run the newly trained model on the test data.


@primary author: Hemant Kumar Aggarwal
"""

# import some libraries
import time
import numpy as np
import h5py as h5
 
import os,time
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
 
import tensorflow as tf
from datetime import datetime
from tqdm import tqdm
import supportingFunctions as sf
import model as mm

tf.reset_default_graph()
config = tf.ConfigProto()
config.gpu_options.allow_growth=True

#--------------------------------------------------------------
#% SET THESE PARAMETERS CAREFULLY
nLayers=5
epochs=100
batchSize=1
gradientMethod='AG'
K=10
sigma=0.0001
ncoils= 1
nx=128
ny=128
restoreWeights=False

#%% to train the model with higher K values  (K>1) such as K=5 or 10,
#if K>1:
 #   restoreWeights=True
 #   restoreFromModel='26Oct_0117am_5L_1K_100E_AG'
#if restoreWeights:
#    wts=sf.getWeights('savedModels/'+restoreFromModel)
#--------------------------------------------------------------------------
#%%Generate a meaningful filename to save the trainined models for testing
print ('*************************************************')
start_time=time.time()
saveDir='savedModels/'
cwd=os.getcwd()
directory=saveDir+datetime.now().strftime("%d%b_%I%M%P_")+ \
 str(nLayers)+'L_'+str(K)+'K_'+str(epochs)+'E_'+gradientMethod

if not os.path.exists(directory):
    os.makedirs(directory)
sessFileName= directory+'/model'

#%% save test model
tf.reset_default_graph()
csmT = tf.placeholder(tf.complex64,shape=(None,ncoils,nx,ny),name='csm')
maskT= tf.placeholder(tf.complex64,shape=(None,nx,ny),name='mask')
atbT = tf.placeholder(tf.float32,shape=(None,nx,ny,2),name='atb')

out=mm.makeModel(atbT,csmT,maskT,False,nLayers,K,gradientMethod)
predTst=out['dc'+str(K)]
predTst=tf.identity(predTst,name='predTst')
sessFileNameTst=directory+'/modelTst'

saver=tf.train.Saver()
with tf.Session(config=config) as sess:
    sess.run(tf.global_variables_initializer())
    savedFile=saver.save(sess, sessFileNameTst,latest_filename='checkpointTst')
print ('testing model saved:' +savedFile)
#%% read multi-channel dataset
trnOrg,trnAtb,trnCsm,trnMask=sf.getData('training')
trnOrg,trnAtb=sf.c2r(trnOrg),sf.c2r(trnAtb)

#%%
tf.reset_default_graph()
csmP = tf.placeholder(tf.complex64,shape=(None,None,None,None),name='csm')
maskP= tf.placeholder(tf.complex64,shape=(None,None,None),name='mask')
atbP = tf.placeholder(tf.float32,shape=(None,None,None,2),name='atb')
orgP = tf.placeholder(tf.float32,shape=(None,None,None,2),name='org')


#%% creating the dataset
nTrn=trnOrg.shape[0]
nBatch= int(np.floor(np.float32(nTrn)/batchSize))
nSteps= nBatch*epochs

trnData = tf.data.Dataset.from_tensor_slices((orgP,atbP,csmP,maskP))
trnData = trnData.cache()
trnData=trnData.repeat(count=epochs)
trnData = trnData.shuffle(buffer_size=trnOrg.shape[0])
trnData=trnData.batch(batchSize)
trnData=trnData.prefetch(5)
iterator=trnData.make_initializable_iterator()
orgT,atbT,csmT,maskT = iterator.get_next('getNext')

#%% make training model

out=mm.makeModel(atbT,csmT,maskT,True,nLayers,K,gradientMethod)
predT=out['dc'+str(K)]
predT=tf.identity(predT,name='pred')
loss = tf.reduce_mean(tf.reduce_sum(tf.pow(predT-orgT, 2),axis=0))
tf.summary.scalar('loss', loss)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

with tf.name_scope('optimizer'):
    optimizer = tf.train.AdamOptimizer()
    gvs = optimizer.compute_gradients(loss)
    capped_gvs = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gvs]
    opToRun=optimizer.apply_gradients(capped_gvs)


#%% training code


print ('training started at', datetime.now().strftime("%d-%b-%Y %I:%M %P"))
print ('parameters are: Epochs:',epochs,' BS:',batchSize,'nSteps:',nSteps,'nSamples:',nTrn)

saver = tf.train.Saver(max_to_keep=100)
totalLoss,ep=[],0
lossT = tf.placeholder(tf.float32)
lossSumT = tf.summary.scalar("TrnLoss", lossT)

with tf.Session(config=config) as sess:
    sess.run(tf.global_variables_initializer())
    if restoreWeights:
        sess=sf.assignWts(sess,nLayers,wts)

    feedDict={orgP:trnOrg,atbP:trnAtb, maskP:trnMask,csmP:trnCsm}
    sess.run(iterator.initializer,feed_dict=feedDict)
    savedFile=saver.save(sess, sessFileName)
    print("Model meta graph saved in::%s" % savedFile)

    writer = tf.summary.FileWriter(directory, sess.graph)
    for step in tqdm(range(nSteps)):
        try:
            tmp,_,_=sess.run([loss,update_ops,opToRun])
            totalLoss.append(tmp)
            if np.remainder(step+1,nBatch)==0:
                ep=ep+1
                avgTrnLoss=np.mean(totalLoss)
                lossSum=sess.run(lossSumT,feed_dict={lossT:avgTrnLoss})
                writer.add_summary(lossSum,ep)
                totalLoss=[] #after each epoch empty the list of total loos
        except tf.errors.OutOfRangeError:
            break
    savedfile=saver.save(sess, sessFileName,global_step=ep,write_meta_graph=True)
    writer.close()

end_time = time.time()
print ('Trianing completed in minutes ', ((end_time - start_time) / 60))
print ('training completed at', datetime.now().strftime("%d-%b-%Y %I:%M %P"))
print ('*************************************************')

#%%



# Testing the network output

## selecting the saved models and loading the test data

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import supportingFunctions as sf
import scipy.io
import time
import h5py as h5
import mat73


cwd=os.getcwd()
tf.reset_default_graph()

#%% choose a model from savedModels directory
subDirectory= '05Apr_1151pm_5L_10K_50E_AG' # R=4 for sagittal slices Super-resolution training
 
#%%Read the testing data from dataset.hdf5 file

#tstOrg is the original ground truth
#tstAtb: it is the aliased/noisy image
#tstCsm: this is coil sensitivity maps
#tstMask: it is the undersampling mask

#tstOrg,tstAtb,tstCsm,tstMask=sf.getTestingData()
#you can also read more testing data from dataset.hdf5 (see readme) file using the command

filename= 'modl_dataset_SR_size64_testing.hdf5' # R=4 for sagittal slices supper-resolution testing

with h5.File(filename) as f:
    tstOrg,tstCsm,tstMask=f['tstOrg'][:],f['tstCsm'][:],f['tstMask'][:]

atb=sf.generateUndersampled(tstOrg,tstCsm,tstMask,sigma=0.001)
tstAtb=sf.c2r(atb)

## Loading the model and doing the slice by slice recon

In [None]:
modelDir= cwd+'/savedModels/'+subDirectory #complete path
rec=np.empty(tstAtb.shape,dtype=np.complex64) #rec variable will have output

tf.reset_default_graph()
loadChkPoint=tf.train.latest_checkpoint(modelDir)
config = tf.ConfigProto()
config.gpu_options.allow_growth=True

with tf.Session(config=config) as sess:
    new_saver = tf.train.import_meta_graph(modelDir+'/modelTst.meta')
    new_saver.restore(sess, loadChkPoint)
    graph = tf.get_default_graph()
    predT =graph.get_tensor_by_name('predTst:0')
    maskT =graph.get_tensor_by_name('mask:0')
    atbT=graph.get_tensor_by_name('atb:0')
    csmT   =graph.get_tensor_by_name('csm:0')
    wts=sess.run(tf.global_variables())
    start_time=time.time()
    for i in range(len(tstAtb)):
        atb,csm,mask=tstAtb[i],tstCsm[i],tstMask[i]
        na=np.newaxis
        dataDict={atbT:atb[na],maskT:mask[na] ,csmT:csm[na] }
        rec[i]=sess.run(predT,feed_dict=dataDict)
end_time=time.time()
rec=sf.r2c(rec.squeeze())
print ('Trianing completed in seconds ', ((end_time - start_time)))
print('Reconstruction done')

## Quantitative evaluation w.r.t. ground truth and Visualization for a representative slice

In [None]:
num=17
psnrAtb=sf.myPSNR(normAtb[num],normOrg[num])
psnrRec=sf.myPSNR(normRec[num],normOrg[num])

print ('*****************')
print ('  ' + 'Org-Atb ' + 'Org-Rec')
print ('  {0:.2f}    {1:.2f}'.format(psnrAtb,psnrRec))
#print ('{0:.2f}'.format(psnrRec))
print ('*****************')
plot= lambda x: plt.imshow(x,cmap=plt.cm.gray, clim=(0.0, 0.8))
plt.clf()
plt.figure(figsize = (15,15))
plt.subplot(141)
plot(np.fft.fftshift(tstMask[num]))
#plot(tstMask[0])
plt.axis('off')
plt.title('Mask')
plt.subplot(142)
plot(np.fliplr(normOrg[num]))
#plot(np.fft.ifftshift(normOrg[num], axes=0))
plt.axis('off')
plt.title('ground-truth')
plt.subplot(143)
plot(np.fliplr(normAtb[num]))
#plot(np.fft.ifftshift(normAtb[:,:,num], axes=0))
plt.title('input \n PSNR='+ str(psnrAtb.round(2)) +' dB')
#plt.title('recon from Bart')
plt.axis('off')
plt.subplot(144)
plot(normRec[num])
#plot(np.fft.ifftshift(normRec[:,:,num], axes=0))
#plt.title('MoDL, PSNR='+ str(psnrRec.round(2)) +' dB')
plt.title('output \n PSNR='+ str(psnrRec.round(2)) +' dB')
plt.axis('off')
plt.subplots_adjust(left=0, right=1, top=1, bottom=0,wspace=.01)
plt.show()