# Group 23: Fast MRI Notebook
Members: Sam Barrett, Bianca Bunaciu, Katie Potts, Aled, James Chamberlain, Vilmos
Czeroczky

In [2]:
# imports

import torch.nn as nn
from torch.nn import functional as F
from layers.conv_layer import ConvLayer
import torch

# Define Model(s)
1. Super resoution Convolutional Neural Network following the paper "Learning a deep convolutional Network for image super-resolution"
2. UNET, based on the Facebook AI FastMRI paper (reference) 

The first layer of the network is formally expressed as:

$$F_1(\textbf{Y}) = \max(0,W_1 * \textbf{Y} + B_1)$$

where:

- $\textbf{Y}$ is our interpolated (undersampled) image
- $\textbf{X}$ is our ground truth image
- $W_1$ is the filter
    - of size $c\times f_1 \times f_1 \times n_1$
    - $c$ is the number of channels in the input image
    - $f_1$ is the spacial size of the filter 
    - $n_1$ is the number of filters
- $B_1$ is the bias

- We want to recover from $\textbf{Y}$ an image $F(\textbf{Y}) \approx \textbf{X}$ 

In [4]:
class ConvNet(torch.nn.Module):

    def __init__(self) -> None:
        super(ConvNet, self).__init__()

        c_0 = 1  # input channels
        f_1 = 6  # spacial size of kernel
        c_1 = 1  # output channels
        n_1 = 64  # number of convs to apply
        l1_modules: List[nn.Module] = []
        l1_conv_layers = [nn.Conv2d(in_channels=c_0, out_channels=c_1, kernel_size=f_1, padding=1, stride=1) for _ in
                          range(n_1)]
        l1_modules.extend(l1_conv_layers)
        l1_modules.append(nn.ReLU())

        self.hidden1 = nn.Sequential(*l1_modules)

        n_2 = 1  # number of convs to apply
        c_2 = 32  # output channels
        f_2 = 1  # size of kernel
        l2_modules = []
        l2_conv_layers = [nn.Conv2d(in_channels=c_1, out_channels=c_2, kernel_size=f_2, padding=1, stride=1) for _ in
                          range(n_2)]
        l2_modules.extend(l2_conv_layers)
        l2_modules.append(nn.ReLU())

        self.hidden2 = nn.Sequential(*l2_modules)

        l3_modules = []
        c_3 = c_0  # output channels = original input
        n_3 = 1  # number of convolutions to apply
        f_3 = 5  # size of kernel
        l3_conv_layers = [nn.Conv2d(in_channels=c_2, out_channels=c_3, kernel_size=f_3, padding=1, stride=1) for _ in
                          range(n_3)]
        l3_modules.extend(l3_conv_layers)

        self.hidden3 = nn.Sequential(*l3_modules)

    def forward(self, x):
        out = self.hidden1(x)
        print(f"l1: {out.shape}")
        out = self.hidden2(out)
        print(f"l2: {out.shape}")
        out = self.hidden3(out)
        print(f"l3: {out.shape}")
        return out


In [3]:
class UNet(nn.Module):

    def __init__(self, c_in: int, c_out: int, c: int, n_pool_layers: int, drop_prob: float) -> None:
        super().__init__()
        self.c_in = c_in
        self.c_out = c_out
        self.c = c
        self.n_pool_layers = n_pool_layers
        self.drop_prob = drop_prob
        self.down_sample_layers = nn.ModuleList([ConvLayer(self.c_in, self.c, self.drop_prob)])
        channels = self.c
        for _ in range(self.n_pool_layers - 1):
            self.down_sample_layers += [ConvLayer(channels, channels * 2, self.drop_prob)]
            channels *= 2

        self.conv = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.InstanceNorm2d(channels),
            nn.ReLU(),
            nn.Dropout2d(self.drop_prob)
        )

        self.up_sample_layers = nn.ModuleList()
        for _ in range(self.n_pool_layers - 1):
            self.up_sample_layers += [ConvLayer(channels * 2, channels // 2, self.drop_prob)]
            channels //= 2
        self.up_sample_layers += [ConvLayer(channels * 2, channels, self.drop_prob)]

        self.conv2 = nn.Sequential(
            nn.Conv2d(channels, channels // 2, kernel_size=1, ),
            nn.Conv2d(channels // 2, self.c_out, kernel_size=1),
            nn.Conv2d(self.c_out, self.c_out, kernel_size=1)
        )

    def forward(self, x):
        stack = []
        y = x
        for layer in self.down_sample_layers:
            y = layer(y)
            stack.append(y)
            y = F.max_pool2d(y, kernel_size=2)

        y = self.conv(y)
        for layer in self.up_sample_layers:
            y = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=False)
            y = torch.cat([y, stack.pop()], dim=1)
            y = layer(y)

        return self.conv2(y)


In [5]:
class ConvLayer(nn.Module):
    def __init__(self, c_in: int, c_out: int, drop_probability: float) -> None:
        super().__init__()

        self.c_in = c_in
        self.c_out = c_out
        self.drop_probability = drop_probability

        self.torch_modules = nn.Sequential(
            nn.Conv2d(c_in, c_out, kernel_size=3, padding=1),
            nn.InstanceNorm2d(c_out),
            nn.ReLU(),
            nn.Dropout2d(self.drop_probability),
            nn.Conv2d(c_out, c_out, kernel_size=3, padding=1),
            nn.InstanceNorm2d(c_out),
            nn.ReLU(),
            nn.Dropout2d(self.drop_probability)
        )

    def forward(self, x):
        return self.torch_modules(x)

# Training

In [6]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader

import pytorch_ssim
from UNET import UNet
from utils.data_loader import collate_batches, MRIDataset, load_data_path
import matplotlib.pyplot as plt

RED = '\033[0;31m'
GREEN = '\033[0;32m'
YELLOW = '\033[0;33m'
ENDC = '\033[0m'
info = lambda x: print(x)
warn = lambda x: print(YELLOW + x + ENDC)
success = lambda x: print(GREEN + x + ENDC)
error = lambda x: print(RED + x + ENDC)


def advance_epoch(model, data_loader, optimizer):
    model.train()
    losses = []
    avg_loss = 0.

    for iter, data in enumerate(data_loader):
        img_gt, img_und, rawdata_und, masks, norm = data

        img_in = Variable(torch.FloatTensor(img_und)).cuda()

        ground_truth = Variable(torch.FloatTensor(img_gt)).cuda()
        # print(img_in.shape)
        # print(ground_truth.shape)
        output = model(img_in)
        # print(output.shape)
        criterion = pytorch_ssim.SSIM()

        loss = - criterion(output, ground_truth)
        # loss = np.sum(loss)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(-loss.item())
        avg_loss = 0.99 * avg_loss + 0.01 * loss.item() if iter > 0 else loss.item()

    return np.average(losses)


def evaluate(device, model, data_loader):
    model.eval()
    losses = []

    with torch.no_grad():
        for iter, data in enumerate(data_loader):
            img_gt, img_und, rawdata_und, masks, norm = data
            img_und = img_und.to(device)
            img_gt = img_gt.to(device)
            output = model(img_und)

            import fastMRI.functions.transforms as T

            # target = T.normalize() target
            # output = output
            # print(norm.shape)

            loss = - pytorch_ssim.ssim(output, img_gt)
            losses.append(loss.item())
    return np.mean(losses)


CENTRE_FRACTION = 0.08
ACCELERATION = 4
EPSILON = 0.0001
GAMMA = 0.1
STEP_SIZE = 10
BATCH_SIZE = 14
NUMBER_EPOCHS = 30
NUMBER_POOL_LAYERS = 4
DROP_PROB = 0


def main():
    warn("Data loading...")
    try:
        data_path_train = '/data/local/NC2019MRI/train'
        data_list = load_data_path(data_path_train, data_path_train)
    except:
        data_path_train = '/home/sam/datasets/FastMRI/NC2019MRI/train'
        data_list = load_data_path(data_path_train, data_path_train)

    
    # Randomly splitting indices:


    data_list_train = data_list['train']
    data_list_val = data_list['val']

    seed = False  # random masks for each slice

    num_workers = 8
    # create data loader for training set. It applies same to validation set as well
    train_dataset = MRIDataset(data_list_train, acceleration=ACCELERATION, center_fraction=CENTRE_FRACTION,
                               use_seed=seed)
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=num_workers,
                              collate_fn=collate_batches)

    val_dataset = MRIDataset(data_list_val, acceleration=ACCELERATION,
                             center_fraction=CENTRE_FRACTION, use_seed=seed)
    val_loader = DataLoader(val_dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=num_workers,
                            collate_fn=collate_batches)

    success("Data loaded")

    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

    warn("Constructing model")
    model = UNet(1, 1, 32, NUMBER_POOL_LAYERS, DROP_PROB).to(device)
    success("Constructed model")

    criterion = pytorch_ssim.SSIM()

    optimiser = optim.Adam(model.parameters(), lr=EPSILON)

    warn("Starting training")


    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimiser, step_size=STEP_SIZE, gamma=GAMMA)
    val_losses = []
    train_losses = []
    for epoch in range(0, NUMBER_EPOCHS):
        success(f"EPOCH: {epoch}")
        error("-" * 10)
        train_loss = advance_epoch(model, train_loader, optimiser)
        train_losses.append(train_loss)
        # scheduler.step(epoch)
        dev_loss = evaluate(device, model, val_loader)
        val_losses.append(dev_loss)
        info(
            f'Epoch = [{epoch:4d}/{NUMBER_EPOCHS:4d}] TrainLoss = {train_loss:.4g} '
            f'ValLoss = {dev_loss:.4g}',
        )
    torch.save(model.state_dict(),
               f"./models/UNET-B{BATCH_SIZE}e-{NUMBER_EPOCHS}-ssim-adam")
    plt.plot(range(NUMBER_EPOCHS), train_losses)
    plt.plot(range(NUMBER_EPOCHS), val_losses)


if __name__ == "__main__":
    main()


[0;33mData loading...[0m
[0;32mData loaded[0m
[0;33mConstructing model[0m
[0;32mConstructed model[0m
[0;33mStarting training[0m
[0;32mEPOCH: 0[0m
[0;31m----------[0m


RuntimeError: CUDA out of memory. Tried to allocate 88.00 MiB (GPU 0; 3.95 GiB total capacity; 1.27 GiB already allocated; 56.00 MiB free; 2.53 MiB cached)

# Baseline Loss

In [None]:
print("Data loading...")
# data_path_train = '/data/local/NC2019MRI/train'
data_path_train = '/home/sam/datasets/FastMRI/NC2019MRI/train'
data_list = load_data_path(data_path_train, data_path_train)

# Split dataset into train-validate
validation_split = 0.1
dataset_len = len(data_list['train'])
indices = list(range(dataset_len))

# Randomly splitting indices:
val_len = int(np.floor(validation_split * dataset_len))
validation_idx = np.random.choice(indices, size=val_len, replace=False)
train_idx = list(set(indices) - set(validation_idx))

data_list_train = [data_list['train'][i] for i in train_idx]
data_list_val = [data_list['val'][i] for i in validation_idx]

criterion = pytorch_ssim.SSIM()

acc = 8
cen_fract = 0.04
seed = False  # random masks for each slice
num_workers = 8
val_loss = []
# create data loader for training set. It applies same to validation set as well
train_dataset = MRIDataset(data_list_train, acceleration=acc, center_fraction=cen_fract, use_seed=seed)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=10, num_workers=num_workers)

val_dataset = MRIDataset(data_list_val, acceleration=acc, center_fraction=cen_fract, use_seed=seed)
val_loader = DataLoader(val_dataset, shuffle=True, batch_size=10, num_workers=num_workers)

data_loaders = {"train": train_loader, "val": val_loader}
data_lengths = {"train": len(train_idx), "val": val_len}
train_loss = []
print("loaded data")
for phase in ['train', 'val']:

    for i, sample in enumerate(data_loaders[phase]):
        img_gt, img_und, rawdata_und, masks, norm = sample
        img_in = T.center_crop(T.complex_abs(img_und).unsqueeze(0), [320, 320]).transpose_(0, 1)

        loss = - criterion(img_in,
                           T.center_crop(T.complex_abs(img_gt).unsqueeze(0), [320, 320]).transpose_(0, 1))

        # backward + optimise only if in training phase

        # calculate loss
        # tra.append(- loss.item())
        # running_loss += - loss.item() * img_in.size(0)

    # epoch_loss = running_loss / data_lengths[phase]
    # print('{} Loss: {:.4f}'.format(phase, epoch_loss))

    if phase == 'train':
        train_loss.append(-loss.item())
    else:
        val_loss.append(-loss.item())

print(train_loss)

print("\n\n")

print(val_loss)

# Model Assessment

In [None]:
data_path_train = '/home/sam/datasets/FastMRI/NC2019MRI/train'
# data_path_train = '/data/local/NC2019MRI/train'

data_path_val = data_path_train
data_list = load_data_path(data_path_train, data_path_val)

acc = 4
cen_fract = 0.08
seed = False  # random masks for each slice
num_workers = 8

val_dataset = MRIDataset(data_list['val'], acceleration=acc, center_fraction=cen_fract, use_seed=seed)
val_loader = DataLoader(val_dataset, shuffle=True, batch_size=1, num_workers=num_workers, collate_fn=collate_batches)

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
optimiser = 'SGD'
model = UNET.UNet(1,1,32,4,0).to(device)
model.load_state_dict(torch.load(f"./vary-optim/models/UNET-lr0.0001-{optimiser}.pkl"))
model.eval()
fig = plt.figure()
counter= 0
ssims = []
for i, sample in enumerate(val_loader):
    img_gt, img_und, rawdata_und, masks, norm = sample
    img_in = img_und.to(device)
    # img_in = T.center_crop(T.complex_abs(img_und).unsqueeze(0), [320, 320]).transpose(0,1).to(device)

    # input
    A = img_und.squeeze()
    # print(A.shape)
    
    # output
    output = model(img_in)
    # print(output.shape)
    B = output.squeeze().cpu()
    # print(B.shape)
    ssim = pytorch_ssim.ssim(output,img_gt.to(device))
    
    ssims.append(ssim.item())
    # print(f"SSIM of this image is {ssim}")
    # real
    C = img_gt.squeeze()
    # print(C.shape)
    all_imgs = torch.stack([A.detach(), B.detach(), C.detach()], dim=0)

    # from left to right: mask, masked kspace, undersampled image, ground truth
    if ssim > 0.9 and counter <=3:
        show_slices(all_imgs, [0, 1, 2], cmap='gray')
        # plt.savefig(f"./vary-optim/reconstructions/{optimiser}-ssim/{ssim:.2f}-{optimiser}30.png")
        plt.pause(1)
        counter += 1

    # if counter >= 3: break  # show 4 random slices
print(f"Max ssim : {max(ssims)}")
print(f"Average ssim is : {np.average(ssims)}")
print(f"Variance in ssim is {np.var(ssims)}")
print(f"Standard Deviation in ssim is {np.std(ssims)}")