# MRI reconstruction
In this notebook we'll show our implementation of the neural network that reconstruct images from undersampled h5 files in order to provide a faster MRI scan processing.

In [4]:
import h5py, os
import pytorch_ssim
from functions import transforms as T
from functions.subsample import MaskFunc
from scipy.io import loadmat
from torch.utils.data import DataLoader
import numpy as np
import torch
from matplotlib import pyplot as plt
from torch.nn import functional as F
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader, random_split, Subset
from torchvision import transforms
from torchvision.utils import save_image
device = 'cuda' if torch.cuda.is_available() else 'cpu'  # check whether a GPU is available
from skimage.measure import compare_ssim
import sys
def ssim(gt, pred):
    """ Compute Structural Similarity Index Metric (SSIM). """
    return compare_ssim(
        gt.transpose(1, 2, 0), pred.transpose(1, 2, 0), multichannel=True, data_range=gt.max()
    )
print(device)

cpu


#### Data Loader methods
provided methods to imports the data.

And also the batch_size const we'll use for retrieving the data loader batches.

In [5]:
batch_size = 8
current_mask = 8
PATH = "une" # path for the model saved paramters

In [6]:
class MRIDataset(DataLoader):
    def __init__(self, data_list, acceleration, center_fraction, use_seed):
        self.data_list = data_list
        self.acceleration = acceleration
        self.center_fraction = center_fraction
        self.use_seed = use_seed

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        subject_id = self.data_list[index]
        return get_epoch_batch(subject_id, self.acceleration, self.center_fraction, self.use_seed)

In [7]:
def get_epoch_batch(subject_id, acc, center_fract, use_seed=True):
    ''' random select a few slices (batch_size) from each volume'''

    fname, rawdata_name, slice = subject_id
    
    with h5py.File(rawdata_name, 'r') as data:
        rawdata = data['kspace'][slice]
                      
    slice_kspace = T.to_tensor(rawdata).unsqueeze(0)
    S, Ny, Nx, ps = slice_kspace.shape
    m = nn.ZeroPad2d(((512-Nx)//2, (512-Nx) // 2, 0, 0))
    # we're adding padding on the width to have it consisted in our data set
    slice_kspace = slice_kspace.permute(0,3,1,2)
    slice_kspace = m(slice_kspace)
    slice_kspace = slice_kspace.permute(0,2,3,1)
    S, Ny, Nx, ps = slice_kspace.shape
    

    # apply random mask
    shape = np.array(slice_kspace.shape)
    mask_func = MaskFunc(center_fractions=[center_fract], accelerations=[acc])
    seed = None if not use_seed else tuple(map(ord, fname))
    mask = mask_func(shape, seed)
      
    # undersample
    masked_kspace = torch.where(mask == 0, torch.Tensor([0]), slice_kspace)
    masks = mask.repeat(S, Ny, 1, ps)

    img_gt, img_und = T.ifft2(slice_kspace), T.ifft2(masked_kspace)

    # perform data normalization which is important for network to learn useful features
    # during inference there is no ground truth image so use the zero-filled recon to normalize
    norm = T.complex_abs(img_und).max()
    if norm < 1e-6: norm = 1e-6
    
    # normalized data
    img_gt, img_und, rawdata_und = img_gt/norm, img_und/norm, masked_kspace/norm
    img_gt = T.complex_center_crop(img_gt.squeeze(0), [320,320])
    img_und = T.complex_center_crop(img_und.squeeze(0), [320,320])
        
    return img_gt, img_und, rawdata_und.squeeze(0), masks.squeeze(0), norm

In [8]:
def load_data_path(train_data_path, val_data_path):
    """ Go through each subset (training, validation) and list all 
    the file names, the file paths and the slices of subjects in the training and validation sets 
    """

    data_list = {}
    train_and_val = ['train', 'val']
    data_path = [train_data_path, val_data_path]
      
    for i in range(len(data_path)): # 0: train_path , 1: val_path
        print("dataset-loader: opening ... ", data_path[i])

        data_list[train_and_val[i]] = []
        
        which_data_path = data_path[i]
    
        for fname in sorted(os.listdir(which_data_path)):
            
            subject_data_path = os.path.join(which_data_path, fname) # fetch one h5 file from the path
            if not os.path.isfile(subject_data_path): continue 
            
            with h5py.File(subject_data_path, 'r') as data:
                if 'kspace' in data:
                    num_slice = data['kspace'].shape[0]
                else:
                    num_slice = data['kspace_4af'].shape[0]  if current_mask == 4 else data['kspace_8af'].shape[0]
                
            # the first 5 slices are mostly noise so it is better to exlude them
            data_list[train_and_val[i]] += [(fname, subject_data_path, slice) for slice in range(5, num_slice)]
    
    return data_list

---

### Initialize the DataLoader

In [9]:

data_path_train = '/data/local/NC2019MRI/train'
data_path_val = '/data/local/NC2019MRI/test'
data_list = load_data_path(data_path_train, data_path_val)

mask4 = { 'acc': 4, 'cen_fract': 0.08 }
mask8 = { 'acc': 8, 'cen_fract': 0.04 }

mask = mask4 if current_mask == 4 else mask8
acc = mask['acc']
cen_fract = mask['cen_fract']
seed = False # random masks for each slice 
num_workers = 12 # data loading is faster using a bigger number for num_workers. 0 means using one cpu to load data

# create data loader for training set. It applies same to validation set as well
dataset = MRIDataset(data_list['train'], acceleration=acc, center_fraction=cen_fract, use_seed=seed)
len_dataset = len(dataset)
indx = np.arange(len_dataset)
train_indx = indx[:int(len_dataset*0.8)]

val_indx = indx[-(int(len_dataset*0.2)):]

train_dataset = Subset(dataset, train_indx)
val_dataset = Subset(dataset, val_indx)
#train_dataset, val_dataset = random_split(dataset, (int(len(dataset)*0.8)+1, int(len(dataset)*0.2)))

train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers) 
val_loader = DataLoader(val_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)


dataset-loader: opening ...  /data/local/NC2019MRI/train


FileNotFoundError: [WinError 3] The system cannot find the path specified: '/data/local/NC2019MRI/train'

## Our model

Here we'll construct our model:

#TODO explanation


In [None]:
## HELPER nested model
class ConvBlock(nn.Module):
    """
    A Convolutional Block that consists of two convolution layers each followed by
    instance normalization, relu activation and dropout.
    """

    def __init__(self, in_chans, out_chans, stride=1):
        """
        Args:
            in_chans (int): Number of channels in the input
            out_chans (int): Number of channels in the output 
        """
        super().__init__()

        self.in_chans = in_chans
        self.out_chans = out_chans
        self.stride = stride

        self.layers = nn.Sequential(
            nn.Conv2d(in_chans, out_chans, kernel_size=5, padding=2, stride=stride, bias=True),
            nn.InstanceNorm2d(out_chans),
            nn.LeakyReLU(),

            nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, stride=1, bias=True),
            nn.InstanceNorm2d(out_chans),
            nn.LeakyReLU()
        )

    def forward(self, input):
        return self.layers(input)

In [None]:

class MRIModel(nn.Module):
    """
    PyTorch implementation of a U-Net mode with dense deep middle layer
    """
    def __init__(self, in_chans, out_chans, chans, num_pool_layers=4, num_depth_blocks=3):
        super().__init__()
        # test up sampling after down sampling
        self.chans = chans
        self.in_chans = in_chans
        self.out_chans = out_chans
        self.num_pool_layers = num_pool_layers
        self.num_depth_blocks = num_depth_blocks


        # First block should have no reduction in feature map size.
        # turns the inputs (2 since complex) to 32
        self.phase_head = ConvBlock(in_chans=in_chans, out_chans=chans, stride=1)
        self.down_sample_layers = nn.ModuleList([self.phase_head])

        ch = chans
        """
        First we're down sample the image while increasing the number of channels.
        Meaning smaller parts of the image across more neurons.
        Thus, extracting the important features of the image
        """
        for _ in range(num_pool_layers - 1):
            conv = ConvBlock(in_chans=ch, out_chans=ch * 2, stride=2)
            self.down_sample_layers.append(conv)
            ch *= 2

        # Size reduction happens at the beginning of a block, hence the need for stride here.
        self.mid_conv = ConvBlock(in_chans=ch, out_chans=ch, stride=2)
        self.middle_layers = nn.ModuleList()
        """
        Then we're passing the data through deep middle layers of convolutional2D.
        Adding more paramters to the network
        """
        for _ in range(num_depth_blocks - 1):
            self.middle_layers.append(ConvBlock(in_chans=ch, out_chans=ch, stride=1))

        """
        Lastly we're upsampled the image while concatinating it with previously features extracted
        by the down sampler. then passing each through layers of convolutional scan.
        Essentially emphasizing the features picked up by the down sampled version of the image.
        """
        self.up_sample_layers = nn.ModuleList()
        for _ in range(num_pool_layers - 1):
            conv = ConvBlock(in_chans=ch * 2, out_chans=ch // 2, stride=1)
            self.up_sample_layers.append(conv)
            ch //= 2
        else:  # Last block of up-sampling.
            conv = ConvBlock(in_chans=ch * 2, out_chans=ch, stride=1)
            self.up_sample_layers.append(conv)
            assert chans == ch, 'Channel indexing error!'


        # passing the resulted image through finalization process with 3 convolutional layers
        # This is to try smooth the image a bit.
        self.final_layers = nn.Sequential(
            nn.Conv2d(in_channels=ch, out_channels=ch, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=ch, out_channels=out_chans, kernel_size=1)
        )

    def forward(self, tensor):
        """
        Args:
            input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]
        Returns:
            (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
        """
        stack = list()
        output = tensor

        # Down-Sampling
        for layer in self.down_sample_layers:
            output = layer(output)
            stack.append(output)

        # Middle blocks
        output = self.mid_conv(output)
        for layer in self.middle_layers:
            output = output + layer(output)  # Residual layers in the middle.
        # Up-Sampling.
        for layer in self.up_sample_layers:
            output = F.interpolate(output, scale_factor=2, mode='bilinear', align_corners=False)
            ds_output = stack.pop()
            # concatinating with the down sample input of the same size.
            output = torch.cat([output, ds_output], dim=1)
            output = layer(output)

        final_output = self.final_layers(output)
        return final_output
        # return (tensor + output) if self.use_residual else output

*General purpose function to construct a train step function*

In [None]:
def generate_train_step_call(model, loss_fn, optimiser):

    # define a function inside another function
    """
        img_inputs = complex valued undersampled image
        img_target = complex valued ground truth image
    """
    def train_step(img_inputs, img_target): 
        # img_target = T.complex_abs(img_target)
        """
         permutate the image (1, w, h, 2) -> (1, 2, w, h)
         considering the complex numbers as two channel input
        """
        inputs_perm = img_und.permute(0, 3, 1, 2)

        ### foreword ###
        output_raw_pred = model(inputs_perm) # feed forward the inputs (complex image)
        
        img_pred_complex = output_raw_pred.permute(0, 2, 3, 1) # permutate the image back to its origianl shape
        
        ### backward ###
        optimiser.zero_grad()
        # compute the loss using SSIM score between prediction and ground truth
        
        loss = loss_fn(img_target, img_pred_complex) 
        loss.backward()                   # autograd = provide gradient to update the params
        optimiser.step()                  # update parameters
        return loss.item()                # return the loss

    # return the newly defined function
    return train_step

### Train the model

__Details of training:__
* we use here 10 epoches (/iterations) to adjust the paramters.
* For the loss function we choose to use SSIM
* Adam method for the backpropogration optimization step. (weight_decay=)




**Using generate_train_step_call to create generic step caller using model, criterion and optimizer**


In [None]:
# init essential constants
epoches = 2
lr = 1e-4 # learning rate
weight_decay = 0

# create the main components to generate a train step
model = MRIModel(
    in_chans=2,
    out_chans=2,
    chans=32,
    num_pool_layers=4
).to(device)
criterion = nn.MSELoss(reduction='mean')
#loss_fn = pytorch_ssim.SSIM(window_size=11)
# optimiser = torch.optim.RMSprop(model.parameters(), lr, weight_decay=weight_decay)
optimiser = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

train_step = generate_train_step_call(model, criterion, optimiser)

# init data loaders
train_loader, val_loader = build_data_loaders()

### Iterate the epoches

In [None]:
print("Training started")
losses = list()
val_losses = list()
running_loss = 0
running_valid_loss = 0
min_valid_loss = 1000
for epoch in range(epoches):
    model.train()
    for iteration, sample in enumerate(train_loader):
        img_gt, img_und, rawdata_und, masks, norm = sample
        
        # send to GPU
        img_und = img_und.to(device)
        img_gt = img_gt.to(device)
        
        loss_value = train_step(img_und, img_gt)
        # accumelators
        running_loss += loss_value
#         if (-loss_value) >= 0.81: break
    
    # performing evalutation every epoch
    with torch.no_grad():
        model.eval()
        for iteration, sample in enumerate(val_loader):
            img_gt, img_und, rawdata_und, masks, norm = sample

            # send to GPU
            image_und = img_und.to(device)
            img_gt = img_gt.to(device)
            
            inputs = img_und.permute(0, 3, 1, 2) # passing the same way like in training

            # feed forward the inputs
            output_raw_pred = model(inputs.to(device))
            output_raw_pred = output_raw_pred.permute(0, 2, 3, 1)
            val_loss_value = criterion(img_gt, output_raw_pred)

            running_valid_loss += val_loss_value
            
    # Log the epoch loss value
    avg_loss = running_loss/len(train_loader)
    losses.append(avg_loss)
    avg_valid_loss = running_valid_loss/len(val_loader)
    
    val_losses.append(avg_valid_loss)
    print('epoch [{}/{}], loss:{:.4f}, val_loss:{:.4f}'.format(epoch+1, epoches, avg_loss, avg_valid_loss))
    running_loss = 0
    running_valid_loss = 0
    if avg_valid_loss < min_valid_loss:
        min_valid_loss = avg_valid_loss
        torch.save(model.state_dict(), PATH + ".pth")
    print("Best valid loss")
    print("------------")
    # implement early stop
#     else:
        
#         continue
#     break


### Printing the loss over time
We can see here the loss value on each slices batch over the total epoches
And the value represented as inverted loss of SSIM model.

In [None]:
# plot the loss
plt.plot(range(1,epoches+1), losses, label="Training loss")
plt.plot(range(1,epoches+1), val_losses, label="Validation loss")
plt.xlabel = "Epoches"
plt.ylabel = "Loss"

plt.legend()
plt.savefig('loss' + PATH + '.png')

#### Save the results to output folder

In [None]:

model.load_state_dict(torch.load('unet.pth'))
model.eval()


In [None]:
def show_slices(data, slice_nums, cmap=None): # visualisation
    fig = plt.figure(figsize=(15,10))
    for i, num in enumerate(slice_nums):
        plt.subplot(1, len(slice_nums), i + 1)
        plt.imshow(data[num], cmap=cmap)
        plt.axis('off')

In [None]:
file_path = '/data/local/NC2019MRI/test/file1000817.h5'

with h5py.File(file_path,  "r") as hf:
    volume_kspace_4af = hf['kspace_4af'][()]
    volume_kspace_8af = hf['kspace_8af'][()]
    mask_4af = hf['mask_4af'][()]
    mask_8af = hf['mask_8af'][()]
    print(volume_kspace_4af.shape)
    print(volume_kspace_4af.dtype)
    print(mask_4af.shape)
    print(mask_4af.dtype)

volume_kspace2 = T.to_tensor(volume_kspace_8af)      # Convert from numpy array to pytorch tensor
volume_image = T.ifft2(volume_kspace2)            # Apply Inverse Fourier Transform to get the complex image
cropped_volume_image_8abs = T.complex_center_crop(volume_image.squeeze(0), [320,320])
print(cropped_volume_image_8abs.shape)

In [None]:
val_loader_test = DataLoader(val_dataset, shuffle=True, batch_size=1, num_workers=num_workers)
sample = next(iter(val_loader_test))


img_gt, img_und, _, masks, _ = sample
img_gt = T.complex_abs(img_gt.to(device))
img_undersampled = T.complex_abs(img_und.to(device))

inputs = img_und.to(device).permute(0, 3, 1, 2) # passing the same way like in training

# feed forward the inputs
output_raw_pred = model(inputs.to(device))
img_pred = output_raw_pred.permute(0, 2, 3, 1)
img_pred = T.complex_abs(img_pred)

allimgs = torch.stack([img_undersampled.squeeze(0).cpu().detach(),
                       img_pred.squeeze(0).cpu().detach(),
                       img_gt.squeeze(0).cpu().detach()
                      ], dim=0)
show_slices(allimgs, [0,1,2],  cmap='gray')
with torch.no_grad():
    ssimloss1 = ssim(img_undersampled.cpu().detach().numpy(), 
                    img_gt.cpu().detach().numpy())
    print("Random sample SSIM score (undersampled):", ssimloss1)
    ssimloss2 = ssim(img_pred.cpu().detach().numpy(), 
                    img_gt.cpu().detach().numpy())
    print("Random sample SSIM score (predicted):", ssimloss2)

In [None]:
#sample = next(iter(test_loader))


img_gt, img_und, _, masks, _ = sample
img_gt = T.complex_abs(img_gt.to(device))
img_undersampled = T.complex_abs(img_und.to(device))

#inputs = img_und.to(device).permute(0, 3, 1, 2) # passing the same way like in training
inputs = cropped_volume_image_8abs[-1].unsqueeze(0).permute(0, 3, 1, 2)
img_undersampled = T.complex_abs(inputs.permute(0, 2, 3, 1))
print(inputs.shape)
print(img_undersampled.shape)
# feed forward the inputs
output_raw_pred = model(inputs.to(device))
img_pred = output_raw_pred.permute(0, 2, 3, 1)
img_pred = T.complex_abs(img_pred)

allimgs = torch.stack([img_undersampled.squeeze(0).cpu().detach(),
                       img_pred.squeeze(0).cpu().detach(),
                       ], dim=0)
show_slices(allimgs, [0,1],  cmap='gray')
#with torch.no_grad():
    #ssimloss = ssim(img_pred.squeeze(0).cpu().detach().numpy(), 
    #                img_gt.squeeze(0).cpu().detach().numpy())
    #print("Random sample SSIM score:", ssimloss)

In [None]:
def save_reconstructions(reconstructions, out_dir):
    """
    Saves the reconstructions from a model into h5 files that is appropriate for submission
    to the leaderboard.
    Args:
        reconstructions (dict[str, np.array]): A dictionary mapping input filenames to
            corresponding reconstructions (of shape num_slices x height x width).
        out_dir (pathlib.Path): Path to the output directory where the reconstructions
            should be saved.
    """
    for fname, recons in reconstructions.items():
        subject_path = os.path.join(out_dir, fname)
        print(subject_path)
        with h5py.File(subject_path, 'a') as f:
            f.create_dataset('recon_4af', data=recons)

In [None]:
def predict(cropped_img):
    inputs = cropped_img.unsqueeze(0).permute(0, 3, 1, 2)
    img_undersampled = T.complex_abs(inputs.permute(0, 2, 3, 1))
    # feed forward the inputs
    output_raw_pred = model(inputs.to(device))
    img_pred = output_raw_pred.permute(0, 2, 3, 1)
    img_pred = T.complex_abs(img_pred)
    return img_pred

In [None]:
file_path = '/data/local/NC2019MRI/test/'

for fname in sorted(os.listdir(file_path)):
    subject_path = os.path.join(file_path, fname)
    with h5py.File(subject_path,  "r") as hf:
        print(f'file {fname} key is {list(hf.keys())}')

        volume_kspace_4af = hf['kspace_4af'][()]
        volume_kspace_8af = hf['kspace_8af'][()]
        mask_4af = hf['mask_4af'][()]
        mask_8af = hf['mask_8af'][()]
        print(volume_kspace_4af.shape)
        print(volume_kspace_4af.dtype)
        print(mask_4af.shape)
        print(mask_4af.dtype)

        volume_kspace2 = T.to_tensor(volume_kspace_8af)      # Convert from numpy array to pytorch tensor
        volume_image = T.ifft2(volume_kspace2)            # Apply Inverse Fourier Transform to get the complex image
        cropped_volume_image_8abs = T.complex_center_crop(volume_image.squeeze(0), [320,320])
        result = torch.Tensor()
        for i in range(0, cropped_volume_image_8abs.shape[0]):
            pred = predict(cropped_volume_image_8abs[i,:,:,:]).cpu().detach()
            result = torch.cat((result,pred), 0)
            
        reconstructions = {fname: result.numpy()}
        out_dir = '../saved/' # where you want to save your result. 
        if not (os.path.exists(out_dir)): os.makedirs(out_dir)
        save_reconstructions(reconstructions, out_dir)

            
            
            
        

In [None]:

# evaluate
ssim_vals = list()

for iteration, sample in enumerate(val_loader):
    img_gt, img_und, _, masks, _ = sample
    img_target = crop320x_image(T.complex_abs(img_gt.to(device)))
    inputs = crop320x_image(T.complex_abs(img_und.to(device)))

    img_pred = model(inputs, masks)
    with torch.no_grad():
        ssim_vals.append(ssim(img_pred.squeeze(0).cpu().detach().numpy(), 
                              img_target.squeeze(0).cpu().detach().numpy()))
print("Average SSIM score: ", np.average(ssim_vals))