### Define Environment Variables and Hyper Parameters

In [110]:
import os
import time
import h5py
import torch
import numpy as np
from functions import transforms as T
from functions.subsample import MaskFunc
from torch import nn
from torch.nn import Conv2d, Sequential, InstanceNorm2d, ReLU, Dropout2d, Module, ModuleList, functional as F
from torch.utils.data import DataLoader
from torch.optim import RMSprop
from torch.optim.lr_scheduler import StepLR
from torchsummary import summary
from scipy.io import loadmat
from skimage.measure import compare_ssim 
from matplotlib import pyplot as plt

In [111]:
train_data_path = '/data/local/NC2019MRI/train'
val_data_path = '/data/local/NC2019MRI/train'
test_data_path = '/data/local/NC2019MRI/test'

# for mask 4AF - acc = 4, cen = 0.08
# for mask 8AF - acc = 8, cen = 0.04
acc = 8
cen_fract = 0.04
seed = True # random masks for each slice 
num_workers = 12 # data loading is faster using a bigger number for num_workers. 0 means using one cpu to load data

# Model parameters
in_chans = 1
out_chans = 1
chans = 8
# This needs to be (1,1) for the model to run...why...
kernel_size=(1, 1)

# Hyperparameters
epochs = 10
dropout_prob = 0.001
learning_rate = 0.001
weight_decay = 0.0
step_size = 15
lr_gamma = 0.1 # change in learning rate
num_pool_layers = 3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

### Data Visualisation

In [112]:
def show_slices(data, slice_nums, cmap=None): # visualisation
    fig = plt.figure(figsize=(15,10))
    for i, num in enumerate(slice_nums):
        plt.subplot(1, len(slice_nums), i + 1)
        plt.imshow(data[num], cmap=cmap)
        plt.axis('off')

### Data Loading and Processing

In [113]:
def load_data_path(train_data_path, val_data_path):
    """ Go through each subset (training, validation) and list all 
    the file names, the file paths and the slices of subjects in the training and validation sets 
    """
    data_list = {}
    train_and_val = ['train', 'val']
    data_path = [train_data_path, val_data_path]
      
    for i in range(len(data_path)):
        data_list[train_and_val[i]] = []
        which_data_path = data_path[i]
    
        for fname in sorted(os.listdir(which_data_path)):
            subject_data_path = os.path.join(which_data_path, fname)
            if not os.path.isfile(subject_data_path): continue 
        
            with h5py.File(subject_data_path, 'r') as data:
                num_slice = data['kspace'].shape[0]        
            # the first 5 slices are mostly noise so it is better to exlude them
            data_list[train_and_val[i]] += [(fname, subject_data_path, slice) for slice in range(5, num_slice)]
    
    return data_list

In [114]:
def get_epoch_batch(subject_id, acc, center_fract, use_seed):
    ''' random select a few slices (batch_size) from each volume'''
    fname, rawdata_name, slice = subject_id  
    
    with h5py.File(rawdata_name, 'r') as data:
        rawdata = data['kspace'][slice]             
    slice_kspace = T.to_tensor(rawdata).unsqueeze(0)
#     print(slice_kspace.shape)
    S, Ny, Nx, ps = slice_kspace.shape
    # apply random mask
    shape = np.array(slice_kspace.shape)
    mask_func = MaskFunc(center_fractions=[center_fract], accelerations=[acc])
    seed = None if not use_seed else tuple(map(ord, fname))
    mask = mask_func(shape, seed)
      
    # undersample
    masked_kspace = torch.where(mask == 0, torch.Tensor([0]), slice_kspace)
    masks = mask.repeat(S, Ny, 1, ps)

    img_gt, img_und = T.ifft2(slice_kspace), T.ifft2(masked_kspace)
    # perform data normalization which is important for network to learn useful features
    # during inference there is no ground truth image so use the zero-filled recon to normalize
    norm = T.complex_abs(img_und).max()
    if norm < 1e-6: 
        norm = 1e-6
    # normalized data
    img_gt, img_und, rawdata_und = img_gt/norm, img_und/norm, masked_kspace/norm
        
    return img_gt.squeeze(0), img_und.squeeze(0), rawdata_und.squeeze(0), masks.squeeze(0), norm

In [115]:
class MRIDataset(DataLoader):
    def __init__(self, data_list, acceleration, center_fraction, use_seed):
        self.data_list = data_list
        self.acceleration = acceleration
        self.center_fraction = center_fraction
        self.use_seed = use_seed

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        subject_id = self.data_list[idx]
        return get_epoch_batch(subject_id, self.acceleration, self.center_fraction, self.use_seed)

## Model

- Unet: Neural networks with downsampling and upsampling. ref: https://github.com/facebookresearch/fastMRI/blob/master/models/unet/

In [116]:
class UnetModel(nn.Module):
    """
    PyTorch implementation of a U-Net model.
    This is based on:
        Olaf Ronneberger, Philipp Fischer, and Thomas Brox. U-net: Convolutional networks
        for biomedical image segmentation. In International Conference on Medical image
        computing and computer-assisted intervention, pages 234–241. Springer, 2015.
    """

    def __init__(self, in_chans, out_chans, chans, num_pool_layers, drop_prob):
        """
        Args:
            in_chans (int): Number of channels in the input to the U-Net model.
            out_chans (int): Number of channels in the output to the U-Net model.
            chans (int): Number of output channels of the first convolution layer.
            num_pool_layers (int): Number of down-sampling and up-sampling layers.
            drop_prob (float): Dropout probability.
        """
        super().__init__()

        self.in_chans = in_chans
        self.out_chans = out_chans
        self.chans = chans
        self.num_pool_layers = num_pool_layers
        self.drop_prob = drop_prob

        self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)])
        ch = chans
        for i in range(num_pool_layers - 1):
            self.down_sample_layers += [ConvBlock(ch, ch * 2, drop_prob)]
            ch *= 2
        self.conv = ConvBlock(ch, ch, drop_prob)

        self.up_sample_layers = nn.ModuleList()
        for i in range(num_pool_layers - 1):
            self.up_sample_layers += [ConvBlock(ch * 2, ch // 2, drop_prob)]
            ch //= 2
        self.up_sample_layers += [ConvBlock(ch * 2, ch, drop_prob)]
        self.conv2 = nn.Sequential(
            nn.Conv2d(ch, ch // 2, kernel_size=1),
            nn.Conv2d(ch // 2, out_chans, kernel_size=1),
            nn.Conv2d(out_chans, out_chans, kernel_size=1),
        )

    def forward(self, input):
        """
        Args:
            input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]
        Returns:
            (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
        """
        stack = []
        output = input
        # Apply down-sampling layers
        for layer in self.down_sample_layers:
            output = layer(output)
            stack.append(output)
            output = F.max_pool2d(output, kernel_size=2)

        output = self.conv(output)

        # Apply up-sampling layers
        for layer in self.up_sample_layers:
            output = F.interpolate(output, scale_factor=2, mode='bilinear', align_corners=False)
            output = torch.cat([output, stack.pop()], dim=1)
            output = layer(output)
        return self.conv2(output)

In [117]:
class ConvBlock(Module):
    """
    A Convolutional Block that consists of two convolution layers each followed by
    instance normalization, relu activation and dropout.
    """

    def __init__(self, in_chans, out_chans, drop_prob):
        """
        Args:
            in_chans (int): Number of channels in the input.
            out_chans (int): Number of channels in the output.
            drop_prob (float): Dropout probability.
        """
        super().__init__()

        self.in_chans = in_chans
        self.out_chans = out_chans
        self.drop_prob = drop_prob

        self.layers = Sequential(
            Conv2d(in_chans, out_chans, kernel_size=kernel_size),
            InstanceNorm2d(out_chans),
            ReLU(),
            Dropout2d(drop_prob),
            Conv2d(out_chans, out_chans, kernel_size=kernel_size),
            InstanceNorm2d(out_chans),
            ReLU(),
            Dropout2d(drop_prob)
        )

    def forward(self, input):
        """
        Args: input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]
        Returns: (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
        """
        return self.layers(input)

    def __repr__(self):
        return f'ConvBlock(in_chans={self.in_chans}, out_chans={self.out_chans}, ' \
            f'drop_prob={self.drop_prob})'

## Main Methods

In [119]:
def training_epoch(epoch, model, data_loader, optimizer):
    model.train()
    avg_loss = 0
    start_epoch = start_iter = time.perf_counter()
    global_step = epoch * len(data_loader)
    
    for iter, data_sample in enumerate(data_loader):
        # img ground truth, img undersampled, raw data understampled, masks, norm
        img_gt, img_und, rawdata_und, masks, norm = data_sample
        input = T.complex_abs(img_und)
        input = T.center_crop(input, [320, 320])
        input = input[None, ...].to(device, dtype=torch.float)
        
        target = T.complex_abs(img_gt)
        target = T.center_crop(target, [320, 320])
        target = target[None, ...].to(device, dtype=torch.float)

        output = model(input)
        loss = F.l1_loss(output, target)
        optimizer.zero_grad()
        loss.backward()
#         lr_scheduler.step()
        optimizer.step()

        avg_loss = 0.99 * avg_loss + 0.01 * loss.item() if iter > 0 else loss.item()
        
        if iter % report_interval == 0:
            print('Epoch: ' + str(epoch) + "/" + str(epochs)  + " \n Iteration" + str(iter/len(data_loader)) +
                  " \n Loss: " + str(loss.item()) + "Avg Loss: " + str(avg_loss) + 
                  " \n Time: " + str(time.perf_counter() - start_iter)
            )
        start_iter = time.perf_counter()
        
    return avg_loss, time.perf_counter() - start_epoch
    

In [120]:
if __name__ == '__main__':
    
    data_path_train = '/data/local/NC2019MRI/train'
    data_path_val = '/data/local/NC2019MRI/train'
    data_list = load_data_path(data_path_train, data_path_val) # first load all file names, paths and slices.
    train_data = data_list['train']
    val_data = data_list['val']
    
    # create data loader for training and validation sets
    train_dataset = MRIDataset(train_data, acceleration=acc, center_fraction=cen_fract, use_seed=seed)
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=1, num_workers=num_workers) 

    val_dataset = MRIDataset(val_data, acceleration=acc, center_fraction=cen_fract, use_seed=seed)
    val_loader = DataLoader(val_dataset, shuffle=True, batch_size=1, num_workers=num_workers) 
    
    # create model object
    model = UnetModel(in_chans=in_chans, out_chans=out_chans, chans=chans, num_pool_layers=4, drop_prob=dropout_prob).to(device)
    # use RMSprop as optimizer
    optimizer = RMSprop(model.parameters(), learning_rate, weight_decay=weight_decay)
    
    print(model)  

UnetModel(
  (down_sample_layers): ModuleList(
    (0): ConvBlock(in_chans=1, out_chans=8, drop_prob=0.001)
    (1): ConvBlock(in_chans=8, out_chans=16, drop_prob=0.001)
    (2): ConvBlock(in_chans=16, out_chans=32, drop_prob=0.001)
    (3): ConvBlock(in_chans=32, out_chans=64, drop_prob=0.001)
  )
  (conv): ConvBlock(in_chans=64, out_chans=64, drop_prob=0.001)
  (up_sample_layers): ModuleList(
    (0): ConvBlock(in_chans=128, out_chans=32, drop_prob=0.001)
    (1): ConvBlock(in_chans=64, out_chans=16, drop_prob=0.001)
    (2): ConvBlock(in_chans=32, out_chans=8, drop_prob=0.001)
    (3): ConvBlock(in_chans=16, out_chans=8, drop_prob=0.001)
  )
  (conv2): Sequential(
    (0): Conv2d(8, 4, kernel_size=(1, 1), stride=(1, 1))
    (1): Conv2d(4, 1, kernel_size=(1, 1), stride=(1, 1))
    (2): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
  )
)


In [121]:
# input_size=(channels, H, W)
summary(model, input_size=(1, 320, 320), batch_size=1, device=str(device))  

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [1, 8, 320, 320]              16
    InstanceNorm2d-2           [1, 8, 320, 320]               0
              ReLU-3           [1, 8, 320, 320]               0
         Dropout2d-4           [1, 8, 320, 320]               0
            Conv2d-5           [1, 8, 320, 320]              72
    InstanceNorm2d-6           [1, 8, 320, 320]               0
              ReLU-7           [1, 8, 320, 320]               0
         Dropout2d-8           [1, 8, 320, 320]               0
         ConvBlock-9           [1, 8, 320, 320]               0
           Conv2d-10          [1, 16, 160, 160]             144
   InstanceNorm2d-11          [1, 16, 160, 160]               0
             ReLU-12          [1, 16, 160, 160]               0
        Dropout2d-13          [1, 16, 160, 160]               0
           Conv2d-14          [1, 16, 1

In [123]:
# Epochs
scheduler = StepLR(optimizer, step_size, lr_gamma)
current_epoch = 0

# run model epochs
report_interval = 100
for epoch in range(current_epoch, 1):
    scheduler.step(epoch)
    print(epoch)
    train_loss, train_time = training_epoch(epoch, model, train_loader, optimizer)

    print("Epoch: " + str(epoch) + "/" + str(epochs) + "\n TrainLoss: " + str(train_loss) +
          "\n TrainTime: " + str(train_time))

0
Epoch: 0/10 
 Iteration0.0 
 Loss: 0.0928804948925972Avg Loss: 0.0928804948925972 
 Time: 0.4140533310128376
Epoch: 0/10 
 Iteration0.046860356138706656 
 Loss: 0.09897296875715256Avg Loss: 0.08218212777058857 
 Time: 0.025186471990309656
Epoch: 0/10 
 Iteration0.09372071227741331 
 Loss: 0.05659743398427963Avg Loss: 0.07842177511376608 
 Time: 0.025159848970361054
Epoch: 0/10 
 Iteration0.14058106841611998 
 Loss: 0.09576474875211716Avg Loss: 0.0784915899757033 
 Time: 0.02494194102473557
Epoch: 0/10 
 Iteration0.18744142455482662 
 Loss: 0.0495634451508522Avg Loss: 0.07728860304948956 
 Time: 0.02533127903006971
Epoch: 0/10 
 Iteration0.23430178069353327 
 Loss: 0.08508935570716858Avg Loss: 0.0790906653573928 
 Time: 0.025148739921860397
Epoch: 0/10 
 Iteration0.28116213683223995 
 Loss: 0.09547880291938782Avg Loss: 0.07702456993270267 
 Time: 0.02466635894961655
Epoch: 0/10 
 Iteration0.3280224929709466 
 Loss: 0.07885479182004929Avg Loss: 0.07732173751324974 
 Time: 0.02506514801

## Evaluation
We can evaluate SSIM on the whole volume in the region of interset (320x320 central region) with respect to ground truth. As can be seen, the more aggressive sampling we have, the lower SSIM value we get. 

In [28]:
def ssim(gt, pred):
    """ Compute Structural Similarity Index Metric (SSIM). """
    return compare_ssim(
        gt.transpose(1, 2, 0), pred.transpose(1, 2, 0), multichannel=True, data_range=gt.max()
    )