# Imports

In [None]:
import gc
import os
import random as rn
import time
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sigpy.mri
import tensorflow as tf
import yaml
from IPython.display import clear_output
from tensorflow.python.keras.layers import Conv1D, Conv2D, Dropout, InputSpec, Layer, ReLU, UpSampling2D
from tensorflow.python.keras.models import Model

from dip import fft_np as fft
from dip import evaluate, mri, plotting
from dip.dataset import MRDDataset, PhantomDataset

# Make sure tensorflow is loaded properly
print(tf.__version__)
tf.keras.backend.clear_session()

# Hyperparameters

In [None]:
# reset parameters
if 'params' not in locals():
    params = {}
for param in params:
    del locals()[param]
_cur_locals = list(locals().keys())

# these parameters are modifiable using papermill
raw_folder = './data'
out_folder = './results'
filename = '<enter_filename_here>.h5'
slice_idx = 0
n_coils = 12
rankk = 64                # subspace rank
inputDepth = 32           # feature maps in the input to the DIP
batch = 8                 # number of frames in each minibatch
expw = 0.95               # takes exponential moving average of DIP output every epoch
dropoutPerc = 0.05        # dropout level (number between 0-1) important for reducing noise (0.1 for 0.55T and 0.05 for 1.5T)
gradClip = 0.01           # gradient clipping
numEpochs = 3000          # total number of epochs
saveFreq = 3000           # set this equal to numEpochs to save only the final result; otherwise set to a smaller number to save results every X epochs
showFreq = 50             # display images every 50 epochs
lr = 0.001                # learning rate
levelsV = 5               # number of downsampling steps in temporal basis network
cuda_num = 0              # GPU number
phantom_acceleration = 8  # only used for mrxcat data
phantom_snr = 10          # in dB, only used for mrxcat data

In [None]:
# save parameters to dictionary
# make sure this is the first thing after the parameters cell, but in a different cell so that it works with papermill
params = {k: v for k, v in locals().items() if k not in _cur_locals and not k.startswith('_')}

# create output folder
output_path = Path(out_folder) / f'{Path(filename).stem}' / f'slice_{slice_idx:02d}'
output_path.mkdir(parents=True, exist_ok=True)

# set device
os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_num)

In [None]:
# dump parameterization
with open(output_path / 'params_lrdip.yaml', 'w') as f:
    yaml.dump_all(
        [{'params': params},],
        f,
        explicit_start=True,
        default_flow_style=False,
    )

# Data loading and preprocessing

In [None]:
if filename.split('.')[-1] == 'h5':
    data = MRDDataset(Path(raw_folder) / filename, apodize=True)
elif filename.split('.')[-1] == 'mat':
    data = PhantomDataset(Path(raw_folder) / filename, apodize=True, acceleration_rate=phantom_acceleration, snr=phantom_snr)
else:
    raise ValueError('Unknown file format')

# crop readout oversampling
print('Cropping readout oversampling...')
data.crop_readout_oversampling()

# whiten k-space
print('Whitening...')
data.whiten()

print(f'Number of slices: {data.n_slices}')
print(f'Number of frames: {data.n_phases}')
print(f'Number of coils:  {data.n_coils}')
print(f'Matrix size:      {data.matrix_size}')

Nx = data.matrix_size[0]
Ny = data.matrix_size[1]
Nex = data.n_phases

In [None]:
# select slice
data.sl = slice_idx

# undersampled k-space data
k = data.k  # [frame, coil, kx, ky]

# sampling mask
m = (np.abs(k) > 0).astype(np.int8)  # [frame, coil, kx, ky]
m = m[:, 0]  # [frame, kx, ky]

# acceleration rate
m_tmp = m[:, m.shape[1]//2, :]
r = m_tmp.size / m_tmp.sum()
print(f'Acceleration rate: {r}')

In [None]:
# coil compression
if k.shape[1] > n_coils:
    print(f'Coil compression: {k.shape[1]} coils -> {n_coils} coils')
    k = mri.coil_compression(k, n_coils, ch_axis=1)

# plot coil-combined image
kspace_avg = mri.average_data(k, 0)
img_avg = fft.ifftnc(kspace_avg, axes=[1, 2])
img_combined = mri.rss(img_avg, 0)  # type: ignore
plt.imshow(img_combined, cmap='gray')
plt.show()

# ESPIRiT
sens_maps = sigpy.mri.app.EspiritCalib(kspace_avg, calib_width=24, thresh=0.02, crop=0, show_pbar=False).run()
assert isinstance(sens_maps, np.ndarray), 'ESPIRiT failed'

In [None]:
DATA = k.copy().transpose(1, 2, 3, 0)  # [coils, x, y, frames]
coilmaps = sens_maps.copy()

imunder = np.sum(fft.ifftnc(k, axes=[2, 3]) * coilmaps[None].conj(), axis=1)

print("Normalizing data")
maxValue = np.max(abs(imunder))
DATA /= maxValue # scale k-space data so that the max signal in the images is 1.
del imunder
gc.collect()

print("Converting data to Tensorflow format")
DATA = tf.convert_to_tensor(DATA,tf.complex64) # [coils, x, y, TRs]
DATA = tf.transpose(DATA,[0,3,1,2])  #[coils, TRs, x, y]
coilmaps = tf.convert_to_tensor(coilmaps,tf.complex64)
MASK = tf.convert_to_tensor(m,tf.complex64)  # [TRs, x, y]

# LR-DIP Reconstruction

In [None]:
learnRates = np.ones((numEpochs+1,)) * lr
learnRates[0:10] = np.linspace(1e-4,learnRates[-1],10) # use a smaller learning rate for the first few epochs

# The temporal u-net has 5 encoding/decoding layers.
# Each encoding layer downsamples by a factor of 2.
# Here we make sure that the number of frames is divisible by 2^5 = 32.
# If not, just pad it with zeros before inputting to the u-net, and then remove
# the padded frames afterwards.
upper_frames = int( np.ceil(Nex/32) * 32 )
difference_frames = upper_frames - Nex
left_pad = difference_frames//2
right_pad = difference_frames - left_pad

# and same for the spatial u-net
upper_x = int( np.ceil(Nx/32) * 32 )
upper_y = int( np.ceil(Ny/32) * 32 )
difference_x = upper_x - Nx
difference_y = upper_y - Ny
left_pad_x = difference_x//2
right_pad_x = difference_x - left_pad_x
left_pad_y = difference_y//2
right_pad_y = difference_y - left_pad_y

nsteps = Nex // batch    # number of iterations per epoch (1 epoch = 1 pass over all time frames)

## Initialize the networks

In [None]:
class ReflectionPadding1D(Layer):
    def __init__(self, padding=(1,), data_format='channels_first', **kwargs):
        self.padding = tuple(padding)
        self.input_spec = [InputSpec(ndim=3)]
        self.data_format = data_format
        super(ReflectionPadding1D, self).__init__(**kwargs)

    def compute_output_shape(self, s):
        """ If you are using "channels_first" configuration"""
        if self.data_format == 'channels_first':
            return (s[0], s[1], s[1] + 2 * self.padding[0])
        elif self.data_format == 'channels_last':
            return (s[0], s[1] + 2 * self.padding[0], s[2])

    def call(self, x, mask=None):
        h_pad = self.padding[0]
        if self.data_format == 'channels_first':
            return tf.pad(x, [[0,0], [0,0], [h_pad,h_pad] ], 'REFLECT')
        elif self.data_format == 'channels_last':
            return tf.pad(x, [[0,0], [h_pad,h_pad], [0,0] ], 'REFLECT')

class ReflectionPadding2D(Layer):
    def __init__(self, padding=(1, 1), data_format='channels_first', **kwargs):
        self.padding = tuple(padding)
        self.input_spec = [InputSpec(ndim=4)]
        self.data_format = data_format
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def compute_output_shape(self, s):
        """ If you are using "channels_first" configuration"""
        if self.data_format == 'channels_first':
            return (s[0], s[1], s[2] + 2 * self.padding[0], s[3] + 2 * self.padding[1])
        elif self.data_format == 'channels_last':
            return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])

    def call(self, x, mask=None):
        w_pad,h_pad = self.padding
        if self.data_format == 'channels_first':
            return tf.pad(x, [[0,0], [0,0], [h_pad,h_pad], [w_pad,w_pad] ], 'REFLECT')
        elif self.data_format == 'channels_last':
            return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad],[0,0] ], 'REFLECT')

class MyConv1D(Model):
    def __init__(self, myFilterSize=128, myKernelSize=(3,), myKernelStride=(1,), \
                 myDropoutType='Dropout', myDropoutFraction=0.05):
        super(MyConv1D, self).__init__()

        self.conv = Conv1D(filters=myFilterSize, kernel_size=myKernelSize, strides=myKernelStride, \
                           padding='valid', data_format='channels_first')
        self.reflect = ReflectionPadding1D(padding=(myKernelSize[0]//2,))
        self.drop = Dropout(myDropoutFraction) if myDropoutType == 'Dropout' else SpatialDropout2D(myDropoutFraction)
        self.act = ReLU()

    def call(self, x, trainingDrop=True):
            x = self.reflect(x)
            x = self.drop(x, training=trainingDrop)
            x = self.conv(x)
            x = self.act(x)
            return x

class MyConv2D(Model):
    def __init__(self, myFilterSize=128, myKernelSize=(3, 3), myKernelStride=(1, 1), \
                 myDropoutType='Dropout', myDropoutFraction=0.05):
        super(MyConv2D, self).__init__()

        self.conv = Conv2D(filters=myFilterSize, kernel_size=myKernelSize, strides=myKernelStride, \
                           padding='valid', data_format='channels_first')
        self.reflect = ReflectionPadding2D(padding=(myKernelSize[0]//2,myKernelSize[1]//2))
        self.drop = Dropout(myDropoutFraction) if myDropoutType == 'Dropout' else SpatialDropout2D(myDropoutFraction)
        self.act = ReLU()

    def call(self, x, trainingDrop=True):
            x = self.reflect(x)
            x = self.drop(x, training=trainingDrop)
            x = self.conv(x)
            x = self.act(x)
            return x

class SpatialBasisNetwork(Model):
    def __init__(self, ups=5*[2], filters=5*[128], skips=5*[4]):
        super(SpatialBasisNetwork, self).__init__()

        # The output has 2 channels for the real and imaginary parts of the K subspace images.
        self.convOut = Conv2D(2*rankk, (1, 1), (1, 1),data_format='channels_first')

        self.ups = ups
        self.filters = filters
        self.skips = skips

        self.nlayers   = len(ups)
        self.encoder1  = {}
        self.encoder2  = {}
        self.decoder1  = {}
        self.decoder2  = {}
        self.skips     = {}
        self.upsample  = {}

        for i in range(self.nlayers):
            self.encoder1[i] = MyConv2D(filters[i],(3,3),(ups[i],ups[i]),'Dropout',dropoutPerc)
            self.encoder2[i] = MyConv2D(filters[i],(3,3),(1,1),'Dropout',dropoutPerc)
            self.skips[i]    = MyConv2D(skips[i],(1,1),(1,1),'Dropout',dropoutPerc)
            self.decoder1[i] = MyConv2D(filters[i],(3,3),(1,1),'Dropout',dropoutPerc)
            self.decoder2[i] = MyConv2D(filters[i],(3,3),(1,1),'Dropout',dropoutPerc)
            self.upsample[i]  = UpSampling2D( size=(ups[i],ups[i]),interpolation='nearest',data_format='channels_first')

    def call(self, x, trainingDrop=True):

        # Encoder
        xskips = {}
        for i in range(self.nlayers):
            xskips[i] = self.skips[i](x, trainingDrop=trainingDrop)
            x = self.encoder1[i](x, trainingDrop=trainingDrop)
            x = self.encoder2[i](x, trainingDrop=trainingDrop)

        # Decoder
        for i in range(self.nlayers-1,-1,-1):
            x = self.upsample[i](x)
            x = tf.concat([x,xskips[i]], axis=1)
            x = self.decoder1[i](x, trainingDrop=trainingDrop)
            x = self.decoder2[i](x, trainingDrop=trainingDrop)

        x = self.convOut(x) # [batch, 2*K, X, Y]
        x = tf.transpose(x,[0,2,3,1]) # [batch, X, Y, 2*K]
        return x

class TemporalBasisNetwork(Model):
    def __init__(self, ups=5*[2], filters=5*[128], skips=5*[4]):
        super(TemporalBasisNetwork, self).__init__()

        # The output has 2 channels for the real and imaginary parts of the K subspace images.
        self.convOut = Conv1D(2*rankk, (1,), (1,),data_format='channels_first')

        self.ups = ups
        self.filters = filters
        self.skips = skips

        self.nlayers   = len(ups)
        self.encoder1  = {}
        self.encoder2  = {}
        self.decoder1  = {}
        self.decoder2  = {}
        self.skips     = {}
        self.upsample  = {}

        for i in range(self.nlayers):
            self.encoder1[i] = MyConv1D(filters[i],(3,),(ups[i],),'Dropout',dropoutPerc)
            self.encoder2[i] = MyConv1D(filters[i],(3,),(1,),'Dropout',dropoutPerc)
            self.skips[i]    = MyConv1D(skips[i],(1,),(1,),'Dropout',dropoutPerc)
            self.decoder1[i] = MyConv1D(filters[i],(3,),(1,),'Dropout',dropoutPerc)
            self.decoder2[i] = MyConv1D(filters[i],(3,),(1,),'Dropout',dropoutPerc)
            self.upsample[i]  = UpSampling2D( size=(1,ups[i]),interpolation='nearest',data_format='channels_first')

    def call(self, x, trainingDrop=True):

        # Encoder
        xskips = {}
        for i in range(self.nlayers):
            xskips[i] = self.skips[i](x, trainingDrop=trainingDrop)
            x = self.encoder1[i](x, trainingDrop=trainingDrop)
            x = self.encoder2[i](x, trainingDrop=trainingDrop)

        # Decoder
        for i in range(self.nlayers-1,-1,-1):
            x = self.upsample[i](x[:,:,None,:])[:,:,0,:]
            x = tf.concat([x,xskips[i]], axis=1)
            x = self.decoder1[i](x, trainingDrop=trainingDrop)
            x = self.decoder2[i](x, trainingDrop=trainingDrop)

        x = self.convOut(x)
        x = tf.transpose(x,[0,2,1])
        return x

modelU = SpatialBasisNetwork() # u-net that generates spatial basis images
modelV = TemporalBasisNetwork(ups=levelsV*[2], filters=levelsV*[128], skips=levelsV*[4])

# initialize the network inputs.
# these are initialized with random values and remain fixed during training
input_spatial = tf.random.uniform([1,inputDepth,upper_x,upper_y],minval = -1,maxval = 1)
input_temporal = tf.random.uniform([1,inputDepth,upper_frames],minval = -1,maxval = 1)


## Network Inference
Here we define a function that keeps track of the DIP output after every iteration.

In [None]:
Uavg = tf.zeros((Nx,Ny,rankk),dtype=tf.complex64) # the final spatial basis functions
Vavg = tf.zeros((rankk,Nex),dtype=tf.complex64) # the final temporal basis function

@tf.function
def TestNetwork(Uavg,Vavg):
    # spatial basis functions
    U = modelU(input_spatial,trainingDrop=False)
    U = tf.cast(tf.complex(U[0,left_pad_x:left_pad_x+Nx,left_pad_y:left_pad_y+Ny,0:rankk],\
                           U[0,left_pad_x:left_pad_x+Nx,left_pad_y:left_pad_y+Ny,rankk:]),tf.complex64)  # [X, Y, rank]

    # temporal basis functions
    V = modelV(input_temporal,trainingDrop=False)
    V = tf.cast(tf.complex(V[0,left_pad:left_pad+Nex,0:rankk],V[0,left_pad:left_pad+Nex,rankk:]),tf.complex64)
    V = tf.transpose(V)

    # take the exponential weighted average of the previous runs
    return U*(1-expw)+Uavg*expw, V*(1-expw)+Vavg*expw

# Let's time how long it takes to apply the network. The first run takes longer, so time it on the 2nd run
Uavg,Vavg = TestNetwork(Uavg,Vavg)
startTime = time.time()
Uavg,Vavg = TestNetwork(Uavg,Vavg)
inferenceTime = time.time() - startTime
print("Inference time = {:.1f} ms".format(inferenceTime*1000))


## Define the network training function

In [None]:
optimizer = tf.keras.optimizers.Adam( learnRates[0] )  # optimizer for the image reconstruction network

@tf.function
def trainDeepImagePrior(TRIndex_Batch, loss_scale=1.):

    # acquired (undersampled) spiral k-space data (gather the time frames for the current minibatch)
    acquiredData = tf.gather(DATA, TRIndex_Batch, axis=1)  # [coils, timepoints, x, y]
    batchMask = tf.gather(MASK, TRIndex_Batch, axis=0)[None]  # [1, timepoints, x, y]

    with tf.GradientTape(persistent=True) as tape:

        # compute the spatial basis functions
        U = modelU(input_spatial,trainingDrop=True)[0,:,:,:] # [x-pixels, y-pixels, 2*K]. Note: K is the subspace rank
        U = tf.cast(tf.complex(U[left_pad_x:left_pad_x+Nx,left_pad_y:left_pad_y+Ny,0:rankk],
                               U[left_pad_x:left_pad_x+Nx,left_pad_y:left_pad_y+Ny,rankk:]),tf.complex64) # [x-pixels, y-pixels, K]

        # compute the temporal basis functions
        V = modelV(input_temporal,trainingDrop=True)[0,:,:] # [frames, 2*K]
        V = tf.cast(tf.complex(V[left_pad:left_pad+Nex,0:rankk],V[left_pad:left_pad+Nex,rankk:]),tf.complex64)
        V = tf.transpose(V) # [k, frames]
        V = tf.gather(V,TRIndex_Batch,axis=1) # k, batch

        # multiply U*V to get the dynamic images
        y = tf.transpose(tf.matmul(U,V),[2,0,1]) # TRs, X, Y

        # multiply by coil sensitivities
        y = y[None,:,:,:] * coilmaps[:,None,:,:] # image size: [coils, TRs, x-pixels, y-pixels]

        # forward FFT (image to k-space)
        y = tf.signal.ifftshift(y,axes=[2,3])
        y = tf.signal.fft2d(y) / tf.sqrt(tf.cast(tf.size(y[0, 0]), tf.complex64))
        y = tf.signal.fftshift(y,axes=[2,3])

        # compute MSE loss in kspace
        y = batchMask * y
        loss = tf.reduce_sum(abs(acquiredData-y)**2) * loss_scale

    gradients = tape.gradient(loss, modelU.trainable_variables + modelV.trainable_variables)
    gradients, _ = tf.clip_by_global_norm(gradients, 0.01)
    optimizer.apply_gradients(zip(gradients, modelU.trainable_variables + modelV.trainable_variables))

    return float(loss)


# It helps to scale the loss value to a reasonable range...
timerange = np.arange(0,Nex,1).astype('int') # list of time frames
lstep=0
t = timerange[lstep*batch:(lstep+1)*batch]
timeIndex = tf.cast(tf.convert_to_tensor(t),tf.int32)
currLoss = trainDeepImagePrior(timeIndex)
loss_scale = 100. / currLoss

# DIP Training

In [None]:
startTime = time.time()
losses = []        # keep running list of the loss function value

for e in range(numEpochs + 1):

    rn.shuffle(timerange) # shuffle the time frames
    runningLoss = 0

    optimizer.lr = learnRates[e]

    for lstep in range(nsteps): # Step through all image frames for training

        # Select time frames for the current minibatch
        t = timerange[lstep*batch:(lstep+1)*batch]
        timeIndex = tf.cast(tf.convert_to_tensor(t),tf.int32)

        # Compute the loss and update the DIP weights
        currLoss = trainDeepImagePrior(timeIndex, loss_scale)
        runningLoss += currLoss/nsteps # update the running loss

    losses.append(runningLoss)

    # update our final reconstruction
    Uavg,Vavg = TestNetwork(Uavg,Vavg)
    imNet = tf.matmul(Uavg,Vavg)  # Multiply U*V
    imNet *= maxValue  # scale the image back to the original scale

    # Display progress so far....
    elapsedTime = (time.time()-startTime) / 60.0 # elapsed time, in minutes

    if e % 5 == 0:
        print("Epoch {} / {}: Loss {:.6g}, {:.1f} minutes".format(e,numEpochs,losses[-1], elapsedTime))

    # Display images so far....
    if e % showFreq == 0:
        clear_output(wait=True)

        framesToShow = np.linspace(0,Nex-1,5).astype('int') # display images at a few time points while the recon is running

        fs = 5
        plt.figure(figsize=(fs*5,fs*1))
        for lt in range(len(framesToShow)):
            t = framesToShow[lt]
            plt.subplot(1,len(framesToShow),lt+1)
            plt.imshow(abs(imNet[:,:,t]),cmap='gray',vmin=0,vmax=0.85*np.max(abs(imNet[:,:,t])))
            plt.axis('off')
        plt.show()

    # Save results periodically...
    is_last_epoch = (e == numEpochs)
    if (e % saveFreq == 0 and e>0) or is_last_epoch:
        print("Saving...")
        suffix = '' if is_last_epoch else '_epoch_{:04d}'.format(e)
        imNet_npy = imNet.numpy().transpose(2, 0, 1)
        if is_last_epoch:
            imNet_npy = mri.center_crop(imNet_npy, data.recon_size, (1, 2))
            np.save(output_path / f'cine_lrdip.npy', imNet_npy)
        plotting.save_gif(np.abs(imNet_npy), output_path / f'cine_lrdip{suffix}.gif', normalize=True,
                          equalize_histogram=True, duration=min(round(data.tres), 200))

        Uavg_np = Uavg.numpy()
        plotting.plot_multichannel(Uavg_np, channel_axis=2, columns=5, complex='abs', figheight_per_row=4,
                                   save_path=output_path / f'lrdip_U{suffix}', show=is_last_epoch, cmap='gray')

        Vavg_np = Vavg.numpy()
        plotting.plot_stacked(Vavg_np, channel_axis=0, save_path=output_path / f'lrdip_V{suffix}', show=is_last_epoch)

In [None]:
losses_tmp = [float(x) for x in losses]
plt.semilogy(losses_tmp, linewidth=0.6)
plt.ylim(bottom=5e-1, top=1e2)
plt.title('Loss')
plt.savefig(output_path / 'loss_lrdip.png')
plt.show()

# Only for phantom data: Quantitative evaluation

In [None]:
if isinstance(data, PhantomDataset):
    # get magnitude reconstructions
    cine_gt = np.abs(data.ground_truth[slice_idx])
    cine_lrdip = np.abs(imNet_npy)

    # quantitative evaluation
    with open('mrxcat_annotations.yaml', 'r') as f:
        annotations = yaml.safe_load(f)[output_path.parent.name]
        bbox = annotations['bbox']
        center = annotations['center']
    metrics_cine = evaluate.get_metrics(cine_gt, ('LR-DIP', cine_lrdip))
    metrics_roi = evaluate.get_metrics(cine_gt, ('LR-DIP', cine_lrdip), bbox=bbox)
    metrics_profiles = evaluate.get_metrics(cine_gt, ('LR-DIP', cine_lrdip), center=center)
    with pd.option_context('display.float_format', '{:.4f}'.format):
        print('Cine:')
        print(metrics_cine)
        print('\nROI:')
        print(metrics_roi)
        print('\nTemporal Profiles:')
        print(metrics_profiles)

    # save metrics to csv
    evaluate.update_metrics_csv(
        output_path / 'metrics.csv', ('cine', metrics_cine), ('roi', metrics_roi), ('profiles', metrics_profiles),
    )

    # save ROI
    evaluate.save_cine_roi(cine_lrdip, bbox, output_path / 'ROI_cine_lrdip.gif', min(data.tres, 200))

    # save temporal profiles
    evaluate.save_temporal_profiles(cine_lrdip, center, output_path / 'profile_cine_lrdip.png')

    # save error image
    evaluate.save_error_map(cine_gt, cine_lrdip, output_path / 'error10x_lrdip.gif', min(data.tres, 200), scale=10)

# END