In [1]:
%cd ~/Workspace/fastMRI/

/home/chengjiun/Workspace/fastMRI


In [2]:

import logging
import pathlib
import random
import shutil
import time

import numpy as np
import torch
import torchvision
from tensorboardX import SummaryWriter
from torch.nn import functional as F
from torch.utils.data import DataLoader

from common.args import Args
from common.subsample import MaskFunc
from data import transforms
from data.mri_data import SliceData
from models.unet.unet_model import UnetModel

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class DataTransform:
    """
    Data Transformer for training U-Net models.
    """

    def __init__(self, mask_func, resolution, which_challenge, use_seed=True):
        """
        Args:
            mask_func (common.subsample.MaskFunc): A function that can create a mask of
                appropriate shape.
            resolution (int): Resolution of the image.
            which_challenge (str): Either "singlecoil" or "multicoil" denoting the dataset.
            use_seed (bool): If true, this class computes a pseudo random number generator seed
                from the filename. This ensures that the same mask is used for all the slices of
                a given volume every time.
        """
        if which_challenge not in ('singlecoil', 'multicoil'):
            raise ValueError(f'Challenge should either be "singlecoil" or "multicoil"')
        self.mask_func = mask_func
        self.resolution = resolution
        self.which_challenge = which_challenge
        self.use_seed = use_seed

    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        kspace = transforms.to_tensor(kspace)
        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace, mask = transforms.apply_mask(kspace, self.mask_func, seed)
        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image
        image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)

        target = transforms.to_tensor(target)
        # Normalize target
        target = transforms.normalize(target, mean, std, eps=1e-11)
        target = target.clamp(-6, 6)
        return image, target, mean, std, attrs['norm'].astype(np.float32)


def create_datasets(args):
    train_mask = MaskFunc(args.center_fractions, args.accelerations)
    dev_mask = MaskFunc(args.center_fractions, args.accelerations)

    train_data = SliceData(
        root=args.data_path / f'{args.challenge}_train',
        transform=DataTransform(train_mask, args.resolution, args.challenge),
        sample_rate=args.sample_rate,
        challenge=args.challenge
    )
    dev_data = SliceData(
        root=args.data_path / f'{args.challenge}_val',
        transform=DataTransform(dev_mask, args.resolution, args.challenge, use_seed=True),
        sample_rate=args.sample_rate,
        challenge=args.challenge,
    )
    return dev_data, train_data


def create_data_loaders(args):
    dev_data, train_data = create_datasets(args)
    display_data = [dev_data[i] for i in range(0, len(dev_data), len(dev_data) // 16)]

    train_loader = DataLoader(
        dataset=train_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
    )
    dev_loader = DataLoader(
        dataset=dev_data,
        batch_size=args.batch_size,
        num_workers=8,
        pin_memory=True,
    )
    display_loader = DataLoader(
        dataset=display_data,
        batch_size=16,
        num_workers=8,
        pin_memory=True,
    )
    return train_loader, dev_loader, display_loader


def train_epoch(args, epoch, model, data_loader, optimizer, writer):
    model.train()
    avg_loss = 0.
    start_epoch = start_iter = time.perf_counter()
    global_step = epoch * len(data_loader)
    for iter, data in enumerate(data_loader):
        input, target, mean, std, norm = data
        input = input.unsqueeze(1).to(args.device)
        target = target.to(args.device)

        output = model(input).squeeze(1)
        loss = F.l1_loss(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        avg_loss = 0.99 * avg_loss + 0.01 * loss.item() if iter > 0 else loss.item()
        writer.add_scalar('TrainLoss', loss.item(), global_step + iter)

        if iter % args.report_interval == 0:
            logging.info(
                f'Epoch = [{epoch:3d}/{args.num_epochs:3d}] '
                f'Iter = [{iter:4d}/{len(data_loader):4d}] '
                f'Loss = {loss.item():.4g} Avg Loss = {avg_loss:.4g} '
                f'Time = {time.perf_counter() - start_iter:.4f}s',
            )
        start_iter = time.perf_counter()
    return avg_loss, time.perf_counter() - start_epoch


def evaluate(args, epoch, model, data_loader, writer):
    model.eval()
    losses = []
    start = time.perf_counter()
    with torch.no_grad():
        for iter, data in enumerate(data_loader):
            input, target, mean, std, norm = data
            input = input.unsqueeze(1).to(args.device)
            target = target.to(args.device)
            output = model(input).squeeze(1)

            mean = mean.unsqueeze(1).unsqueeze(2).to(args.device)
            std = std.unsqueeze(1).unsqueeze(2).to(args.device)
            target = target * std + mean
            output = output * std + mean

            norm = norm.unsqueeze(1).unsqueeze(2).to(args.device)
            loss = F.mse_loss(output / norm, target / norm, size_average=False)
            losses.append(loss.item())
        writer.add_scalar('Dev_Loss', np.mean(losses), epoch)
    return np.mean(losses), time.perf_counter() - start


def visualize(args, epoch, model, data_loader, writer):
    def save_image(image, tag):
        image -= image.min()
        image /= image.max()
        grid = torchvision.utils.make_grid(image, nrow=4, pad_value=1)
        writer.add_image(tag, grid, epoch)

    model.eval()
    with torch.no_grad():
        for iter, data in enumerate(data_loader):
            input, target, mean, std, norm = data
            input = input.unsqueeze(1).to(args.device)
            target = target.unsqueeze(1).to(args.device)
            output = model(input)
            save_image(target, 'Target')
            save_image(output, 'Reconstruction')
            save_image(torch.abs(target - output), 'Error')
            break


def save_model(args, exp_dir, epoch, model, optimizer, best_dev_loss, is_new_best):
    torch.save(
        {
            'epoch': epoch,
            'args': args,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_dev_loss': best_dev_loss,
            'exp_dir': exp_dir
        },
        f=exp_dir / 'model.pt'
    )
    if is_new_best:
        shutil.copyfile(exp_dir / 'model.pt', exp_dir / 'best_model.pt')


def build_model(args):
    model = UnetModel(
        in_chans=1,
        out_chans=1,
        chans=args.num_chans,
        num_pool_layers=args.num_pools,
        drop_prob=args.drop_prob
    ).to(args.device)
    return model


def load_model(checkpoint_file):
    checkpoint = torch.load(checkpoint_file)
    args = checkpoint['args']
    model = build_model(args)
    if args.data_parallel:
        model = torch.nn.DataParallel(model)
    model.load_state_dict(checkpoint['model'])

    optimizer = build_optim(args, model.parameters())
    optimizer.load_state_dict(checkpoint['optimizer'])
    return checkpoint, model, optimizer


def build_optim(args, params):
    optimizer = torch.optim.RMSprop(params, args.lr, weight_decay=args.weight_decay)
    return optimizer


In [3]:
torch.__version__

'0.4.1.post2'

In [4]:
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
seed_everything(42)

In [18]:
class args():
    num_pools = 4 # , help='Number of U-Net pooling layers')
    drop_prob = 0.0 # , help='Dropout probability')
    num_chans = 32 # , help='Number of U-Net channels')
    batch_size = 16 # , type=int, help='Mini batch size')
    num_epochs = 15 # , help='Number of training epochs')
    lr = 0.001# , help='Learning rate')
    lr_step_size = 40 # , help='Period of learning rate decay')
    lr_gamma = 0.1 # , help='Multiplicative factor of learning rate decay')
    weight_decay = 0. # , help='Strength of weight decay regularization')
    report_interval = 100 # , help='Period of loss reporting')
    data_parallel = True # , help='If set, use multiple GPUs using data parallelism')
    device = 'cuda' # , help='Which device to train on. Set to "cuda" to use the GPU')
    exp_dir = pathlib.Path('input/checkpoints') # , help='Path where model and results should be saved')
    resume = False # , help='If set, resume the training from a previous model checkpoint. ''"--checkpoint" should be set with this')
    checkpoint = '' # , type=str, help='Path to an existing checkpoint. Used along with "--resume"')
    
    
    seed = 42 #, type=int, help='Seed for random number generators')
    resolution=320# , type=int, help='Resolution of images'
    challenge='singlecoil' # , 'multicoil'], required=True, help='Which challenge')
    data_path = pathlib.Path('./input') # , required=True, help='Path to the dataset')
    sample_rate=1.# ,help='Fraction of total volumes to include')
    accelerations = [4, 8] #[4, 8], type=int,
    # help='Ratio of k-space columns to be sampled. If multiple values are '
    # provided, then one of those is chosen uniformly at random for each volume.')
    center_fractions = [0.08, 0.04] #, type=float,
    # help='Fraction of low-frequency k-space columns to be sampled. Should '
    #   'have the same length as accelerations')

    
   
    

# train model

In [16]:
args.exp_dir.mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(log_dir=args.exp_dir / 'summary')

if args.resume:
    checkpoint, model, optimizer = load_model(args.checkpoint)
    args = checkpoint['args']
    best_dev_loss = checkpoint['best_dev_loss']
    start_epoch = checkpoint['epoch']
    del checkpoint
else:
    model = build_model(args)
    if args.data_parallel:
        model = torch.nn.DataParallel(model)
    optimizer = build_optim(args, model.parameters())
    best_dev_loss = 1e9
    start_epoch = 0
logging.info(args)
logging.info(model)

train_loader, dev_loader, display_loader = create_data_loaders(args)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_step_size, args.lr_gamma)

INFO:root:<class '__main__.args'>
INFO:root:DataParallel(
  (module): UnetModel(
    (down_sample_layers): ModuleList(
      (0): ConvBlock(in_chans=1, out_chans=32, drop_prob=0.0)
      (1): ConvBlock(in_chans=32, out_chans=64, drop_prob=0.0)
      (2): ConvBlock(in_chans=64, out_chans=128, drop_prob=0.0)
      (3): ConvBlock(in_chans=128, out_chans=256, drop_prob=0.0)
    )
    (conv): ConvBlock(in_chans=256, out_chans=256, drop_prob=0.0)
    (up_sample_layers): ModuleList(
      (0): ConvBlock(in_chans=512, out_chans=128, drop_prob=0.0)
      (1): ConvBlock(in_chans=256, out_chans=64, drop_prob=0.0)
      (2): ConvBlock(in_chans=128, out_chans=32, drop_prob=0.0)
      (3): ConvBlock(in_chans=64, out_chans=32, drop_prob=0.0)
    )
    (conv2): Sequential(
      (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
      (1): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1))
      (2): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
    )
  )
)


In [17]:
args.num_epochs

50

In [19]:
for epoch in range(start_epoch, args.num_epochs):
    scheduler.step(epoch)
    train_loss, train_time = train_epoch(args, epoch, model, train_loader, optimizer, writer)
    dev_loss, dev_time = evaluate(args, epoch, model, dev_loader, writer)
    visualize(args, epoch, model, display_loader, writer)

    is_new_best = dev_loss < best_dev_loss
    best_dev_loss = min(best_dev_loss, dev_loss)
    save_model(args, args.exp_dir, epoch, model, optimizer, best_dev_loss, is_new_best)
    logging.info(
        f'Epoch = [{epoch:4d}/{args.num_epochs:4d}] TrainLoss = {train_loss:.4g} '
        f'DevLoss = {dev_loss:.4g} TrainTime = {train_time:.4f}s DevTime = {dev_time:.4f}s',
    )
writer.close()

INFO:root:Epoch = [  0/ 15] Iter = [   0/2172] Loss = 1.12 Avg Loss = 1.12 Time = 8.2612s
INFO:root:Epoch = [  0/ 15] Iter = [ 100/2172] Loss = 0.297 Avg Loss = 0.6849 Time = 0.2725s
INFO:root:Epoch = [  0/ 15] Iter = [ 200/2172] Loss = 0.4155 Avg Loss = 0.4972 Time = 5.1924s
INFO:root:Epoch = [  0/ 15] Iter = [ 300/2172] Loss = 0.4079 Avg Loss = 0.4277 Time = 0.2727s
INFO:root:Epoch = [  0/ 15] Iter = [ 400/2172] Loss = 0.3599 Avg Loss = 0.3958 Time = 5.1033s
INFO:root:Epoch = [  0/ 15] Iter = [ 500/2172] Loss = 0.2685 Avg Loss = 0.3834 Time = 0.2735s
INFO:root:Epoch = [  0/ 15] Iter = [ 600/2172] Loss = 0.356 Avg Loss = 0.3719 Time = 5.0398s
INFO:root:Epoch = [  0/ 15] Iter = [ 700/2172] Loss = 0.2949 Avg Loss = 0.3623 Time = 0.2772s
INFO:root:Epoch = [  0/ 15] Iter = [ 800/2172] Loss = 0.436 Avg Loss = 0.3619 Time = 4.6569s
INFO:root:Epoch = [  0/ 15] Iter = [ 900/2172] Loss = 0.3895 Avg Loss = 0.3592 Time = 0.2739s
INFO:root:Epoch = [  0/ 15] Iter = [1000/2172] Loss = 0.3067 Avg Lo

INFO:root:Epoch = [  3/ 15] Iter = [1600/2172] Loss = 0.3294 Avg Loss = 0.3168 Time = 4.5278s
INFO:root:Epoch = [  3/ 15] Iter = [1700/2172] Loss = 0.351 Avg Loss = 0.3149 Time = 0.2784s
INFO:root:Epoch = [  3/ 15] Iter = [1800/2172] Loss = 0.2959 Avg Loss = 0.3155 Time = 4.5735s
INFO:root:Epoch = [  3/ 15] Iter = [1900/2172] Loss = 0.3089 Avg Loss = 0.3147 Time = 0.2761s
INFO:root:Epoch = [  3/ 15] Iter = [2000/2172] Loss = 0.26 Avg Loss = 0.3149 Time = 4.9215s
INFO:root:Epoch = [  3/ 15] Iter = [2100/2172] Loss = 0.2428 Avg Loss = 0.3146 Time = 0.2769s
INFO:root:Epoch = [   3/  15] TrainLoss = 0.3224 DevLoss = 0.0203 TrainTime = 1818.4685s DevTime = 164.3877s
INFO:root:Epoch = [  4/ 15] Iter = [   0/2172] Loss = 0.3544 Avg Loss = 0.3544 Time = 7.4274s
INFO:root:Epoch = [  4/ 15] Iter = [ 100/2172] Loss = 0.442 Avg Loss = 0.3326 Time = 0.2756s
INFO:root:Epoch = [  4/ 15] Iter = [ 200/2172] Loss = 0.2338 Avg Loss = 0.3188 Time = 4.0543s
INFO:root:Epoch = [  4/ 15] Iter = [ 300/2172] Lo

INFO:root:Epoch = [  7/ 15] Iter = [1100/2172] Loss = 0.3165 Avg Loss = 0.3132 Time = 0.2736s
INFO:root:Epoch = [  7/ 15] Iter = [1200/2172] Loss = 0.2641 Avg Loss = 0.314 Time = 0.2752s
INFO:root:Epoch = [  7/ 15] Iter = [1300/2172] Loss = 0.2566 Avg Loss = 0.3108 Time = 0.2743s
INFO:root:Epoch = [  7/ 15] Iter = [1400/2172] Loss = 0.3083 Avg Loss = 0.3109 Time = 0.2761s
INFO:root:Epoch = [  7/ 15] Iter = [1500/2172] Loss = 0.3138 Avg Loss = 0.3148 Time = 0.2756s
INFO:root:Epoch = [  7/ 15] Iter = [1600/2172] Loss = 0.2916 Avg Loss = 0.3052 Time = 0.2739s
INFO:root:Epoch = [  7/ 15] Iter = [1700/2172] Loss = 0.2625 Avg Loss = 0.308 Time = 0.2739s
INFO:root:Epoch = [  7/ 15] Iter = [1800/2172] Loss = 0.317 Avg Loss = 0.3061 Time = 0.8798s
INFO:root:Epoch = [  7/ 15] Iter = [1900/2172] Loss = 0.2699 Avg Loss = 0.2994 Time = 0.2785s
INFO:root:Epoch = [  7/ 15] Iter = [2000/2172] Loss = 0.2576 Avg Loss = 0.3075 Time = 1.5299s
INFO:root:Epoch = [  7/ 15] Iter = [2100/2172] Loss = 0.3192 Av

INFO:root:Epoch = [ 11/ 15] Iter = [ 600/2172] Loss = 0.2655 Avg Loss = 0.3092 Time = 3.2179s
INFO:root:Epoch = [ 11/ 15] Iter = [ 700/2172] Loss = 0.5342 Avg Loss = 0.3132 Time = 0.2789s
INFO:root:Epoch = [ 11/ 15] Iter = [ 800/2172] Loss = 0.3022 Avg Loss = 0.3138 Time = 4.9413s
INFO:root:Epoch = [ 11/ 15] Iter = [ 900/2172] Loss = 0.2289 Avg Loss = 0.31 Time = 0.2756s
INFO:root:Epoch = [ 11/ 15] Iter = [1000/2172] Loss = 0.2751 Avg Loss = 0.3105 Time = 4.6572s
INFO:root:Epoch = [ 11/ 15] Iter = [1100/2172] Loss = 0.2308 Avg Loss = 0.3067 Time = 0.2778s
INFO:root:Epoch = [ 11/ 15] Iter = [1200/2172] Loss = 0.3462 Avg Loss = 0.3084 Time = 5.1777s
INFO:root:Epoch = [ 11/ 15] Iter = [1300/2172] Loss = 0.3089 Avg Loss = 0.3037 Time = 0.2778s
INFO:root:Epoch = [ 11/ 15] Iter = [1400/2172] Loss = 0.2985 Avg Loss = 0.3044 Time = 4.7949s
INFO:root:Epoch = [ 11/ 15] Iter = [1500/2172] Loss = 0.2446 Avg Loss = 0.304 Time = 0.2768s
INFO:root:Epoch = [ 11/ 15] Iter = [1600/2172] Loss = 0.4377 Av

# evaluate validation model

In [23]:
args.mask_kspace = True # ', action='store_true', help='Whether to apply a mask (set to True for val data and False for test data')
args.data_split = 'val' # , choices=['val', 'test'], required=True, help='Which data partition to run on: "val" or "test"')
args.checkpoint = pathlib.Path('input/checkpoints/best_model.pt') # ', type=pathlib.Path, required=True, help='Path to the U-Net model')
args.out_dir = pathlib.Path('input/reconstruction_val/') # ', type=pathlib.Path, required=True, help='Path to save the reconstructions to')


In [24]:
from models.unet.run_unet import run_unet, save_reconstructions, load_model, create_data_loaders


In [25]:
data_loader = create_data_loaders(args)
model = load_model(args.checkpoint)
reconstructions = run_unet(args, model, data_loader)
save_reconstructions(reconstructions, args.out_dir)

# evaluate validation

In [28]:
from common.evaluate import evaluate

In [32]:

args.target_path = pathlib.Path('input/singlecoil_val/') # ', type=pathlib.Path, required=True, help='Path to the ground truth data')
args.predictions_path = pathlib.Path('input/reconstruction_val/') # ', type=pathlib.Path, required=True, help='Path to reconstructions')
args.acquisition = None# ', choices=['CORPD_FBK', 'CORPDFS_FBK'], default=None, help='If set, only volumes of the specified acquisition type are used for evaluation. By default, all volumes are included.')

recons_key = 'reconstruction_rss' if args.challenge == 'multicoil' else 'reconstruction_esc'
metrics = evaluate(args, recons_key)
print(metrics)

  cropped = ar[slices]


MSE = 1.488e-10 +/- 3.53e-10 NMSE = 0.04256 +/- 0.05487 PSNR = 30.67 +/- 5.924 SSIM = 0.6822 +/- 0.2848


# run on test set

In [33]:
args.mask_kspace = False # ', action='store_true', help='Whether to apply a mask (set to True for val data and False for test data')
args.data_split = 'test_v2' # , choices=['val', 'test'], required=True, help='Which data partition to run on: "val" or "test"')
args.checkpoint = pathlib.Path('input/checkpoints/best_model.pt') # ', type=pathlib.Path, required=True, help='Path to the U-Net model')
args.out_dir = pathlib.Path('input/reconstruction_test/') # ', type=pathlib.Path, required=True, help='Path to save the reconstructions to')


In [34]:
data_loader = create_data_loaders(args)
model = load_model(args.checkpoint)
reconstructions = run_unet(args, model, data_loader)
save_reconstructions(reconstructions, args.out_dir)