In [None]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
import supportingFunctions as sf

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
cwd = os.getcwd()
torch.backends.cudnn.benchmark = True

# Reset the default graph
torch.cuda.empty_cache()

# Choose a model from savedModels directory
subDirectory='04Jun_0356pm_5L_10K_50E_AG'

# Read the testing data from dataset.hdf5 file
# tstOrg is the original ground truth
# tstAtb: it is the aliased/noisy image
# tstCsm: this is coil sensitivity maps
# tstMask: it is the undersampling mask

tstOrg,tstAtb,tstCsm,tstMask = sf.getTestingData()

# You can also read more testing data from dataset.hdf5 (see readme) file using the command
# tstOrg,tstAtb,tstCsm,tstMask=sf.getData('testing',num=100)

# Load existing model. Then do the reconstruction
print ('Now loading the model ...')

modelDir= cwd+'/savedModels/'+subDirectory #complete path
rec = np.empty(tstAtb.shape, dtype=np.complex64) # rec variable will have output

# Load model from checkpoint
loadChkPoint = torch.load(modelDir+'/modelTst.pth')

# Create the network and load the weights
net = sf.Net()
net.load_state_dict(loadChkPoint['state_dict'])

net.cuda()
net.eval()

# Get the predictions
atbT = torch.from_numpy(tstAtb).cuda()
maskT = torch.from_numpy(tstMask).cuda()
csmT = torch.from_numpy(tstCsm).cuda()

with torch.no_grad():
    dataDict = {'atb': atbT, 'mask': maskT, 'csm': csmT}
    recT = net(dataDict)

rec = recT.cpu().numpy().squeeze()
rec = sf.r2c(rec)

print('Reconstruction done')

# Normalize the data for calculating PSNR
print('Now calculating the PSNR (dB) values')

normOrg = sf.normalize01(np.abs(tstOrg))
normAtb = sf.normalize01(np.abs(sf.r2c(tstAtb)))
normRec = sf.normalize01(np.abs(rec))

psnrAtb = sf.myPSNR(normOrg, normAtb)
psnrRec = sf.myPSNR(normOrg, normRec)

print ('*****************')
print ('  ' + 'Noisy ' + 'Recon')
print ('  {0:.2f} {1:.2f}'.format(psnrAtb, psnrRec))
print ('*****************')

# Display the output images
plot = lambda x: plt.imshow(x, cmap=plt.cm.gray, clim=(0.0, .8))
plt.clf()
plt.subplot(141)
plot(np.fft.fftshift(tstMask[0]))
plt.axis('off')
plt.title('Mask')
plt.subplot(142)
plot(normOrg)
plt.axis('off')
plt.title('Original')
plt.subplot(143)
plot(normAtb)
plt.title('Input, PSNR='+str(psnrAtb.round(2))+' dB' )
plt.axis('off')
plt.subplot(144)
plot(normRec)
plt.title('Output, PSNR='+ str(psnrRec.round(2)) +' dB')
plt.axis('off')
plt.subplots_adjust(left=0, right=1, top=1, bottom=0,wspace=.01)
plt.show()
