The model:

In [31]:
from models.unet.unet_model import UnetModel

Building the model:

In [2]:
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [3]:
import numpy as np
from data import transforms

class DataTransform:
    """
    Data Transformer for training U-Net models.
    """

    def __init__(self, mask_func, resolution, which_challenge, use_seed=True):
        """
        Args:
            mask_func (common.subsample.MaskFunc): A function that can create a mask of
                appropriate shape.
            resolution (int): Resolution of the image.
            which_challenge (str): Either "singlecoil" or "multicoil" denoting the dataset.
            use_seed (bool): If true, this class computes a pseudo random number generator seed
                from the filename. This ensures that the same mask is used for all the slices of
                a given volume every time.
        """
        if which_challenge not in ('singlecoil', 'multicoil'):
            raise ValueError(f'Challenge should either be "singlecoil" or "multicoil"')
        self.mask_func = mask_func
        self.resolution = resolution
        self.which_challenge = which_challenge
        self.use_seed = use_seed

    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        target = transforms.to_tensor(target)
        kspace = transforms.to_tensor(kspace)
        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace, mask = transforms.apply_mask(kspace, self.mask_func, seed)
        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image to given resolution if larger
        smallest_width = min(min(self.resolution, image.shape[-2]), target.shape[-1])
        smallest_height = min(min(self.resolution, image.shape[-3]), target.shape[-2])
        crop_size = (smallest_height, smallest_width)
        image = transforms.complex_center_crop(image, crop_size)
        target = transforms.center_crop(target, crop_size)

        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)

        # Normalize target
        target = transforms.normalize(target, mean, std, eps=1e-11)
        target = target.clamp(-6, 6)
        return image, target, mean, std, attrs['norm'].astype(np.float32)

In [4]:
from torch.utils.data import DataLoader
from common.subsample import create_mask_for_mask_type
from data.mri_data import SliceData

def create_datasets(data_path,
                    mask_type='random',
                    center_fractions=[0.08, 0.04],
                    accelerations=[4, 8],
                    resolution=320,
                    challenge='singlecoil',
                    sample_rate=1.):
    # mask_type: 'random' or 'equispaced'
    # center_fractions: Fraction of low-frequency k-space columns to be sampled.
    #                   Should have the same length as accelerations
    # accelerations: Ratio of k-space columns to be sampled.
    #                If multiple values are provided, then one of those is chosen
    #                uniformly at random for each volume.
    # resolution: Resolution of images
    # challenge: 'singlecoil' or 'multicoil'
    # sample_rate: Fraction of total volumes to include
    train_mask = create_mask_for_mask_type(mask_type, center_fractions, accelerations)
    dev_mask = create_mask_for_mask_type(mask_type, center_fractions, accelerations)

    train_data = SliceData(
        root=data_path / 'singlecoil_train',
        transform=DataTransform(train_mask, resolution, challenge),
        sample_rate=sample_rate,
        challenge=challenge
    )
    dev_data = SliceData(
        root=data_path / 'singlecoil_val',
        transform=DataTransform(dev_mask, resolution, challenge, use_seed=True),
        sample_rate=sample_rate,
        challenge=challenge,
    )
    return dev_data, train_data


def create_data_loaders(data_path, batch_size=16):
    dev_data, train_data = create_datasets(data_path)
    display_data = [dev_data[i] for i in range(0, len(dev_data), len(dev_data) // 16)]

    train_loader = DataLoader(
        dataset=train_data,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
    )
    dev_loader = DataLoader(
        dataset=dev_data,
        batch_size=batch_size,
        num_workers=8,
        pin_memory=True,
    )
    display_loader = DataLoader(
        dataset=display_data,
        batch_size=16,
        num_workers=8,
        pin_memory=True,
    )
    return train_loader, dev_loader, display_loader

In [5]:
def build_model(device, num_channels=32, num_pools=4, dropout_prob=0.0):
    model = UnetModel(
        in_chans=1,
        out_chans=1,
        chans=num_channels,
        num_pool_layers=num_pools,
        drop_prob=dropout_prob
    ).to(device)
    return model

def build_optim(params, lr=0.001, weight_decay=0.):
    optimizer = torch.optim.RMSprop(params, lr, weight_decay=weight_decay)
    return optimizer

In [6]:
import time
import torchvision
import shutil



def train_epoch(epoch, num_epochs, model, device, data_loader, optimizer, writer, report_interval=100):
    model.train()
    avg_loss = 0.
    start_epoch = start_iter = time.perf_counter()
    global_step = epoch * len(data_loader)
    for iter, data in enumerate(data_loader):
        input, target, mean, std, norm = data
        input = input.unsqueeze(1).to(device)
        target = target.to(device)

        output = model(input).squeeze(1)
        loss = F.l1_loss(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        avg_loss = 0.99 * avg_loss + 0.01 * loss.item() if iter > 0 else loss.item()
        writer.add_scalar('TrainLoss', loss.item(), global_step + iter)

        if iter % report_interval == 0:
            logging.info(
                f'Epoch = [{epoch:3d}/{num_epochs:3d}] '
                f'Iter = [{iter:4d}/{len(data_loader):4d}] '
                f'Loss = {loss.item():.4g} Avg Loss = {avg_loss:.4g} '
                f'Time = {time.perf_counter() - start_iter:.4f}s',
            )
        start_iter = time.perf_counter()
    return avg_loss, time.perf_counter() - start_epoch


def evaluate(epoch, model, device, data_loader, writer):
    model.eval()
    losses = []
    start = time.perf_counter()
    with torch.no_grad():
        for iter, data in enumerate(data_loader):
            input, target, mean, std, norm = data
            input = input.unsqueeze(1).to(device)
            target = target.to(device)
            output = model(input).squeeze(1)

            mean = mean.unsqueeze(1).unsqueeze(2).to(device)
            std = std.unsqueeze(1).unsqueeze(2).to(device)
            target = target * std + mean
            output = output * std + mean

            norm = norm.unsqueeze(1).unsqueeze(2).to(device)
            loss = F.mse_loss(output / norm, target / norm, size_average=False)
            losses.append(loss.item())
        writer.add_scalar('Dev_Loss', np.mean(losses), epoch)
    return np.mean(losses), time.perf_counter() - start


def visualize(epoch, model, device, data_loader, writer):
    def save_image(image, tag):
        image -= image.min()
        image /= image.max()
        grid = torchvision.utils.make_grid(image, nrow=4, pad_value=1)
        writer.add_image(tag, grid, epoch)

    model.eval()
    with torch.no_grad():
        for iter, data in enumerate(data_loader):
            input, target, mean, std, norm = data
            input = input.unsqueeze(1).to(device)
            target = target.unsqueeze(1).to(device)
            output = model(input)
            save_image(target, 'Target')
            save_image(output, 'Reconstruction')
            save_image(torch.abs(target - output), 'Error')
            break

            
def save_model(exp_dir, epoch, model, optimizer, best_dev_loss, is_new_best):
    torch.save(
        {
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_dev_loss': best_dev_loss,
            'exp_dir': exp_dir
        },
        f=exp_dir / 'model.pt'
    )
    if is_new_best:
        shutil.copyfile(exp_dir / 'model.pt', exp_dir / 'best_model.pt')

In [8]:
! mkdir unet/summary
! tensorboard --logdir=unet/summary

mkdir: cannot create directory ‘unet/summary’: File exists
TensorFlow installation not found - running with reduced feature set.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.1.1 at http://localhost:6006/ (Press CTRL+C to quit)
W0307 00:16:00.550040 139813137123072 plugin_event_accumulator.py:588] Detected out of order event.step likely caused by a TensorFlow restart. Purging 1000 expired tensor events from Tensorboard display between the previous step: 51381 (timestamp: 1583430061.852064) and current step: 0 (timestamp: 1583453025.966422).
W0307 00:16:00.572877 139813137123072 plugin_event_accumulator.py:588] Detected out of order event.step likely caused by a TensorFlow restart. Purging 174 expired tensor events from Tensorboard display between the previous step: 173 (timestamp: 1583453205.461077) and current step: 0 (timestamp: 1583453996.433022).
W0307 00:16:00.687378 139813137123072 plugin_event_accumulator.py:588] Detecte

In [7]:
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter

def main(num_epochs=50):
    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device('cuda')        
        
    outdir = Path('unet')
    outdir.mkdir(parents=True, exist_ok=True)
    log_dir = outdir / 'summary'
    writer = SummaryWriter(log_dir=log_dir)

    model = build_model(device)
    optimizer = build_optim(model.parameters())
    best_dev_loss = 1e9
    start_epoch = 0
    logging.info(model)
    logging.info(optimizer)

    train_loader, dev_loader, display_loader = create_data_loaders(Path('../data'))
    
    lr_step_size = 40
    lr_gamma = 0.1
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, lr_step_size, lr_gamma)

    for epoch in range(start_epoch, num_epochs):
        train_loss, train_time = train_epoch(epoch, num_epochs, model, device, train_loader, optimizer, writer)
        dev_loss, dev_time = evaluate(epoch, model, device, dev_loader, writer)
        visualize(epoch, model, device, display_loader, writer)

        is_new_best = dev_loss < best_dev_loss
        best_dev_loss = min(best_dev_loss, dev_loss)
        save_model(outdir, epoch, model, optimizer, best_dev_loss, is_new_best)
        logging.info(
            f'Epoch = [{epoch:4d}/{num_epochs:4d}] TrainLoss = {train_loss:.4g} '
            f'DevLoss = {dev_loss:.4g} TrainTime = {train_time:.4f}s DevTime = {dev_time:.4f}s',
        )
        scheduler.step(epoch)

    writer.close()


    
main(num_epochs=2)

INFO:root:UnetModel(
  (down_sample_layers): ModuleList(
    (0): ConvBlock(in_chans=1, out_chans=32, drop_prob=0.0)
    (1): ConvBlock(in_chans=32, out_chans=64, drop_prob=0.0)
    (2): ConvBlock(in_chans=64, out_chans=128, drop_prob=0.0)
    (3): ConvBlock(in_chans=128, out_chans=256, drop_prob=0.0)
  )
  (conv): ConvBlock(in_chans=256, out_chans=256, drop_prob=0.0)
  (up_sample_layers): ModuleList(
    (0): ConvBlock(in_chans=512, out_chans=128, drop_prob=0.0)
    (1): ConvBlock(in_chans=256, out_chans=64, drop_prob=0.0)
    (2): ConvBlock(in_chans=128, out_chans=32, drop_prob=0.0)
    (3): ConvBlock(in_chans=64, out_chans=32, drop_prob=0.0)
  )
  (conv2): Sequential(
    (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
    (1): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1))
    (2): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
  )
)
INFO:root:RMSprop (
Parameter Group 0
    alpha: 0.99
    centered: False
    eps: 1e-08
    lr: 0.001
    momentum: 0
    weight_decay: 0.0