In [1]:
import argparse
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.tensorboard import SummaryWriter

import fastmri
from fastmri.data import transforms
from fastmri.models.unet import Unet
from fastmri.models.varnet import *

import sigpy as sp
from sigpy import from_pytorch
import sigpy.plot as pl
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (10,10)
import skimage.metrics

from dloader import genDataLoader

import matplotlib.pyplot as plt

In [2]:
# # command line argument parser
# parser = argparse.ArgumentParser(
#     description = 'define parameters and roots for STL training'
# )

# parser.add_argument(
#     '--scarcities', default=[0, 1], type=int, nargs='+',
#     help='number of samples will be decreased by 1/2^N; match with roots'
#     )

# parser.add_argument(
#     '--undersample', default=[6], type=int, nargs='+',
#     help='undersampling factor of k-space'
#     )

# parser.add_argument(
#     '--dataroots', nargs='+',
#     help='paths of data files; match with scarcities',
# #     required = True
# )

# parser.add_argument(
#     '--epochs', default=100, type=int,
#     help='number of epochs to run'
# )

# parser.add_argument(
#     '--lr', default=0.001, type=float,
#     help='learning rate'
# )

# parser.add_argument(
#     '--modelpath', default='models/best-no-name.pt',
#     help='path to save best model'
# )

# parser.add_argument(
#     '--verbose', default=True, type=bool,
#     help='''if true, prints to console and creatues full TensorBoard
#     (if tensorboard is also True)'''
# )

# parser.add_argument(
#     '--tensorboard', default=True, type=bool,
#     help='if true, creates TensorBoard'
# )

# parser.add_argument(
#     '--savefreq', default=5, type=int,
#     help='how many epochs per saved recon image'
# )

# parser.add_argument(
#     '--experimentname', default='unnamed_experiment',
#     help='experiment name'
# )


# opt = parser.parse_args()

In [3]:
# We can make one iteration block like this
class VarNetBlock(nn.Module):
    """
    This model applies a combination of soft data consistency with the input
    model as a regularizer. A series of these blocks can be stacked to form
    the full variational network.
    """

    def __init__(self, model: nn.Module):
        """
        Args:
            model: Module for "regularization" component of variational
                network.
        """
        super().__init__()

        self.model = model
        self.eta = nn.Parameter(torch.ones(1))

    def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
        return fastmri.fft2c(fastmri.complex_mul(x, sens_maps)) # F*S operator

    def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
        x = fastmri.ifft2c(x)
        return fastmri.complex_mul(x, fastmri.complex_conj(sens_maps)).sum(
            dim=1, keepdim=True
        ) # S^H * F^H operator

    def forward(
        self,
        current_kspace: torch.Tensor,
        ref_kspace: torch.Tensor,
        mask: torch.Tensor,
        sens_maps: torch.Tensor,
    ) -> torch.Tensor:
        mask = mask.bool()
        zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace)
        soft_dc = torch.where(mask, current_kspace - ref_kspace, zero) * self.eta
        model_term = self.sens_expand(
            self.model(self.sens_reduce(current_kspace, sens_maps)), sens_maps
        )

        return current_kspace - soft_dc - model_term

In [4]:
# now we can stack VarNetBlocks to make a unrolled VarNet (with 10 blocks)


class VarNet(nn.Module):
    """
    A full variational network model.

    This model applies a combination of soft data consistency with a U-Net
    regularizer. To use non-U-Net regularizers, use VarNetBock.
    """

    def __init__(
        self,
        num_cascades: int = 12,
        chans: int = 18,
        pools: int = 4,
    ):
        super().__init__()

        self.cascades = nn.ModuleList(
            [VarNetBlock(NormUnet(chans, pools)) for _ in range(num_cascades)]
        )
        
    def forward(
        self,
        masked_kspace: torch.Tensor, 
        mask: torch.Tensor,
        sens_maps: torch.Tensor
    ) -> torch.Tensor:
        
        kspace_pred = masked_kspace.clone()

        for cascade in self.cascades:
            kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps)
        
        im_coil = fastmri.ifft2c(kspace_pred)
        im_comb = fastmri.complex_mul(im_coil, fastmri.complex_conj(sens_maps)).sum(
            dim=1, keepdim=True
        )
        
        return kspace_pred, im_comb

In [5]:
def count_parameters(model):
    return sum(
        p.numel() for p in model.parameters() if p.requires_grad
    )



def test_result(im_fs: torch.Tensor, im_us: torch.Tensor) -> np.ndarray:

    with torch.no_grad():
        im_us = from_pytorch(im_us.cpu().detach(),iscomplex = True)
        im_fs = from_pytorch(im_fs.cpu().detach(), iscomplex = True)
        im_us = np.abs(im_us).squeeze()
        im_fs = np.abs(im_fs).squeeze()
        
        im_us = sp.resize(im_us, [360, 320])
        im_fs = sp.resize(im_fs, [360, 320])
        
        out_cat = np.concatenate((im_us, im_fs),1)
        error_cat = np.concatenate((im_us, im_us),1)
        error_cat = np.abs(error_cat - out_cat) * 5
        
        out_cat = np.concatenate((error_cat, out_cat,), axis=0)
        out_cat = out_cat * 20
        
    return np.flip(out_cat)


def plot_quadrant(im_fs, im_us):
    fig = plt.figure()
    plt.imshow(test_result(im_fs, im_us), cmap = 'gray')
    plt.close(fig)
    return fig


def param_dict(lr, epochs, undersampling, scarcities, center_fractions):
    params = {}
    params['lr'] = lr
    params['epochs'] = epochs
    
    for i in range(len(undersampling)):
        params[f'accerlation_{i}'] = undersampling[i]
    
    for i in range(len(scarcities)):
        params[f'scarcity_{i}'] = scarcities[i]
        
    for i in range(len(center_fractions)):
        params[f'center_fraction_{i}'] = center_fractions[i]
    return params


def write_tensorboard(avg_cost, model, total_epochs, opt):            
    #write to tensorboard ###opt###
    contrast_1, contrast_2, _ = avg_cost.keys()
    for epoch in range(total_epochs):
        
        writer.add_scalars(
            'losses/l1', {
                f'train/{contrast_1}' : avg_cost[contrast_1][epoch, 0],
                f'val/{contrast_1}' : avg_cost[contrast_1][epoch, 4],
                f'train/{contrast_2}' : avg_cost[contrast_2][epoch, 0],
                f'val/{contrast_2}' : avg_cost[contrast_2][epoch, 4],
            }, 
            epoch
        )

        writer.add_scalars(
            'metrics/ssim', {
                f'train/{contrast_1}' : avg_cost[contrast_1][epoch, 1],
                f'val/{contrast_1}' : avg_cost[contrast_1][epoch, 5],
                f'train/{contrast_2}' : avg_cost[contrast_2][epoch, 1],
                f'val/{contrast_2}' : avg_cost[contrast_2][epoch, 5],
            }, 
            epoch
        )

        writer.add_scalars(
            'metrics/psnr', {
                f'train/{contrast_1}' : avg_cost[contrast_1][epoch, 2],
                f'val/{contrast_1}' : avg_cost[contrast_1][epoch, 6],
                f'train/{contrast_2}' : avg_cost[contrast_2][epoch, 2],
                f'val/{contrast_2}' : avg_cost[contrast_2][epoch, 6],
            }, 
            epoch
        )

        writer.add_scalars(
            'metrics/nrmse', {
                f'train/{contrast_1}' : avg_cost[contrast_1][epoch, 3],
                f'val/{contrast_1}' : avg_cost[contrast_1][epoch, 7],
                f'train/{contrast_2}' : avg_cost[contrast_2][epoch, 3],
                f'val/{contrast_2}' : avg_cost[contrast_2][epoch, 7],
            }, 
            epoch
        )
        
        if True: ###opt###
            writer.add_scalars(
                'overall/metrics/ssim', {
                    f'train' : avg_cost['overall'][epoch, 1],
                    f'val' : avg_cost['overall'][epoch, 5],
                }, 
                epoch
            )

            writer.add_scalars(
                'overall/metrics/psnr', {
                    f'train' : avg_cost['overall'][epoch, 2],
                    f'val' : avg_cost['overall'][epoch, 6],
                }, 
                epoch
            )

            writer.add_scalars(
                'overall/metrics/nrmse', {
                    f'train' : avg_cost['overall'][epoch, 3],
                    f'val' : avg_cost['overall'][epoch, 7],
                }, 
                epoch
            )
            
            writer.add_scalars(
                'overall/losses/l1', {
                    f'train' : avg_cost['overall'][epoch, 0],
                    f'val' : avg_cost['overall'][epoch, 4],
                }, 
                epoch
            )
            
    writer.add_text(
        'parameters', 
        f'{count_parameters(model)} parameters'
    )
    
    ###opts###
    writer.add_hparams(
        param_dict(0.001, 2, [6], [0], [0.06]), 
        {'overall/losses/l1':0}
    )

In [6]:
def criterion(im_fs: torch.Tensor, im_us: torch.Tensor):
    '''
    @parameter im_us: undersampled image (2D)
    @parameter im_fs: fully sampled image (2D)
    should be on GPU device for fast computation
    '''
    
    # use l1 loss between two images
    criterion = nn.L1Loss()
    
    # can add more fancy loss functions here later
    
    return criterion(im_us, im_fs)

def metrics(im_fs: torch.Tensor, im_us: torch.Tensor):
    '''
    @parameter im_us: undersampled image (2D)
    @parameter im_fs: fully sampled image (2D)
    should be on GPU device for fast computation
    '''

    # change to ndarray
    im_us = transforms.tensor_to_complex_np(im_us.cpu().detach())
    im_fs = transforms.tensor_to_complex_np(im_fs.cpu().detach())
    
    # convert complex nums to magnitude
    im_us = np.absolute(im_us)
    im_fs = np.absolute(im_fs)
    
    im_us = im_us.reshape(
        (im_us.shape[2], im_us.shape[3])
    )
    
    im_fs = im_fs.reshape(
        (im_fs.shape[2], im_fs.shape[3])
    )
    
    # psnr
    psnr = skimage.metrics.peak_signal_noise_ratio(
        im_fs, 
        im_us, 
        data_range = np.max(im_fs) - np.min(im_fs)
    )
    
    #nrmse
    nrmse = skimage.metrics.normalized_root_mse(im_fs, im_us)
    
    # ssim
    # normalize 0 to 1
    im_fs -= np.min(im_fs)
    im_fs /= np.max(im_fs)
    im_us -= np.min(im_us)
    im_us /= np.max(im_us)
    
    ssim = skimage.metrics.structural_similarity(im_fs, im_us, data_range = 1)
    
    return ssim, psnr, nrmse

In [7]:
"""
=========== Universal Single-task Trainer =========== 
code modified from https://github.com/lorenmt/mtan/blob/master/im2im_pred/utils.py
"""


def single_task_trainer(
    train_loader, val_loader,
    train_ratios, val_ratios,
    single_task_model, device, 
    optimizer, scheduler,
    writer,
    opt = 0, total_epochs=2 ###opt###
):
    
    train_batch = len(train_loader)
    val_batch = len(val_loader)
    
    best_val_loss = np.infty
    
    # contains info for all epochs and contrasts
    avg_cost = {
        contrast : np.zeros([total_epochs, 8])
        for contrast in train_ratios.keys()
    }
    avg_cost['overall'] = np.zeros([total_epochs, 8])
    
    for epoch in range(total_epochs):
        # contains info for single batch of a single epoch
        cost = np.zeros(8, dtype = np.float32)

        # train data
        single_task_model.train()
        train_dataset = iter(train_loader)
    
        for _ in range(train_batch):
            kspace, mask, sens, im_fs, contrast = next(train_dataset)
            contrast = contrast[0] # torch dataset loader returns as tuple
            kspace, mask = kspace.to(device), mask.to(device)
            sens, im_fs = sens.to(device), im_fs.to(device)

            optimizer.zero_grad()
            _, im_us = single_task_model(kspace, mask, sens) # forward pass
            loss = criterion(im_fs, im_us)
            loss.backward()
            optimizer.step()
            
            # L1 loss for now
            cost[0] = loss.item() 
            # ssim, psnr, nrmse
            cost[1], cost[2], cost[3] = metrics(im_fs, im_us)
            
            # update overall
            avg_cost[contrast][epoch, :4] += cost[:4] / train_ratios[contrast]
            avg_cost['overall'][epoch, :4] += cost[:4] / train_batch

            
        # validation data
        single_task_model.eval()
        with torch.no_grad():
            val_dataset = iter(val_loader)

            for val_idx in range(val_batch):
                kspace, mask, sens, im_fs, contrast = next(val_dataset)
                contrast = contrast[0]
                kspace, mask = kspace.to(device), mask.to(device)
                sens, im_fs = sens.to(device), im_fs.to(device)

                _, im_us = single_task_model(kspace, mask, sens) # forward pass
                loss = criterion(im_fs, im_us)
                
                # L1 loss for now
                cost[4] = loss.item()
                # ssim, psnr, nrmse
                cost[5], cost[6], cost[7] = metrics(im_fs, im_us)
                
                # update overall
                avg_cost[contrast][epoch, 4:] += cost[4:] / val_ratios[contrast]
                avg_cost['overall'][epoch, 4:] += cost[4:] / val_batch
                
               # visualize reconstruction every few epochs
                if val_idx == 18 and epoch % 1 == 0: ###opt###
                    writer.add_figure(
                        'recons', 
                        plot_quadrant(im_fs, im_us),
                        epoch, close = True,
                    )
                
                
        # early stopping        
        if avg_cost['overall'][epoch, 4] < best_val_loss:
            best_val_loss = avg_cost['overall'][epoch, 4]
            torch.save(single_task_model.state_dict(), 'models/best-val.pt') ###opt###
            

        scheduler.step()
        
        print(f'''
        >Epoch: {epoch + 1:04d}
        TRAIN: loss {avg_cost['overall'][epoch, 0]:.4f} | ssim {avg_cost['overall'][epoch, 1]:.4f} | psnr {avg_cost['overall'][epoch, 2]:.4f} | nrmse {avg_cost['overall'][epoch, 3]:.4f} 
        VAL: loss {avg_cost['overall'][epoch, 4]:.4f} | ssim {avg_cost['overall'][epoch, 5]:.4f} | psnr {avg_cost['overall'][epoch, 6]:.4f} | nrmse {avg_cost['overall'][epoch, 7]:.4f}
        
        ''')
    
        
    # write to tensorboard
    ###opt###
    if True:
        write_tensorboard(avg_cost, single_task_model, total_epochs, opt)   
        


In [8]:
# datasets
dataset_names = [
    'div_coronal_pd',
    'div_coronal_pd_fs',
]

basedirs = [
    f'/mnt/dense/vliu/summer_dset/{dataset_name}'
    for dataset_name in dataset_names
]
basedirs

['/mnt/dense/vliu/summer_dset/div_coronal_pd',
 '/mnt/dense/vliu/summer_dset/div_coronal_pd_fs']

In [9]:
train_dloader = genDataLoader(
    [f'{basedir}/Train' for basedir in basedirs], # choose randomly
    [4, 4] # downsample
)

val_dloader = genDataLoader(
    [f'{basedir}/Val' for basedir in basedirs], # choose randomly
    [3, 3], # downsample
    shuffle = False,
)

# other inputs to STL wrapper
writer = SummaryWriter()
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
varnet = VarNet().to(device)

optimizer = torch.optim.Adam(varnet.parameters(),lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)

In [None]:
single_task_trainer(
    train_dloader[0], val_dloader[0], 
    train_dloader[1], val_dloader[1], # ratios dicts
    varnet, device, 
    optimizer, scheduler,
    writer,
    opt = 0, total_epochs=3
)
writer.flush()
writer.close()


        >Epoch: 0001
        TRAIN: loss 0.0526 | ssim 0.6879 | psnr 28.4583 | nrmse 0.2727 
        VAL: loss 0.0463 | ssim 0.6956 | psnr 30.1173 | nrmse 0.2269
        
        


In [None]:
###############################################
# trying to get mr images onto tensorboard, unsuccessful
###############################################

dloader = iter(val_dloader[2])
varnet = VarNet() 
varnet_gpu = varnet.to('cuda:3')
Nepoch = 1
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(varnet_gpu.parameters(),lr=0.0002)



for epoch in range(Nepoch):
    loss_epoch = 0
    idx = 0
    for kspace, mask, sens, im_true, contrast in dloader:
        kspace_gpu, mask_gpu, sens_gpu, im_true_gpu = kspace.to('cuda:3'), mask.to('cuda:3'), sens.to('cuda:3'), im_true.to('cuda:3')
        optimizer.zero_grad() 
        _, im_est_gpu = varnet_gpu(kspace_gpu,mask_gpu,sens_gpu)
        loss = criterion(im_true_gpu, im_est_gpu)
        loss.backward() # this performs the backprop
        optimizer.step() # this performs the gradient update
        loss_epoch += loss.item()
        idx += 1
        if idx == 17:
            break
    print('epoch:{}/{} Mean Loss: {}'.format(epoch, Nepoch, loss_epoch / len(dloader))) # report loss for end of the epoch

In [None]:
###############################################
# trying to get mr images onto tensorboard, unsuccessful
###############################################

%matplotlib inline
# Let's look at the final produced image
im_est = transforms.tensor_to_complex_np(im_est_gpu.cpu().detach())
im_true = transforms.tensor_to_complex_np(im_true_gpu.cpu().detach())

pl.ImagePlot(im_est) # this is est image
pl.ImagePlot(im_true) # this is true image