# Imports

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sigpy.mri
import torch
import yaml

from dip import evaluate, fft_np, models, mri, plotting
from dip.dataset import MRDDataset, PhantomDataset
from dip.lps import LowRankPlusSparse
from dip.mdip import MDIP

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

# Hyperparameters

In [None]:
# reset parameters
if 'params' not in locals():
    params = {}
for param in params:
    del locals()[param]
_cur_locals = list(locals().keys())

# these parameters are modifiable using papermill
raw_folder = './data'
out_folder = './results'
filename = '<enter_filename_here>.h5'
slice_idx = 0  # this is a more or less random choice, we want a slice that is somewhat in the middle
n_coils = 12
zs_chans = 2  # c
zt_chans = 4  # K
n_bases = 16  # L
p_dropout = 0
noise_reg = 0.05  # sigma_0
lr_max = 1e-3  # eta_f
lr_min = 1e-6
lr_static_factor = 1  # eta_s / eta_f
weight_decay = 0
lambda_flow_spatial = 0.10  # lambda_s
lambda_flow_temporal = 0.05  # lambda_f
lambda_zt = 0
lambda_basis = 0
ksp_scale = 100
n_iter = 10000  # N_iter
save_every = 0
activate_flow_after = 0  # N_def
batch_size = 96
cuda_num = 0
phantom_acceleration = 8  # only used for mrxcat data
phantom_snr = 10  # in dB, only used for mrxcat data

In [None]:
# save parameters to dictionary
# make sure this is the first thing after the parameters cell, but in a different cell so that it works with papermill
params = {k: v for k, v in locals().items() if k not in _cur_locals and not k.startswith('_')}

# create output folder
output_path = Path(out_folder) / f'{Path(filename).stem}' / f'slice_{slice_idx:02d}'
output_path.mkdir(parents=True, exist_ok=True)

# dtype and device
dtype = torch.float32
device = torch.device(f'cuda:{cuda_num}')

# Data loading and preprocessing

In [None]:
if filename.split('.')[-1] == 'h5':
    data = MRDDataset(Path(raw_folder) / filename, apodize=True)
elif filename.split('.')[-1] == 'mat':
    data = PhantomDataset(Path(raw_folder) / filename, apodize=True, acceleration_rate=phantom_acceleration, snr=phantom_snr)
else:
    raise ValueError('Unknown file format')

# crop readout oversampling
print('Cropping readout oversampling...')
data.crop_readout_oversampling()

# whiten k-space
print('Whitening...')
data.whiten()

print(f'Number of slices: {data.n_slices}')
print(f'Number of frames: {data.n_phases}')
print(f'Number of coils:  {data.n_coils}')
print(f'Matrix size:      {data.matrix_size}')

In [None]:
# select slice
data.sl = slice_idx

# undersampled k-space data
k = data.k  # [frame, coil, kx, ky]

# coil compression
if k.shape[1] > n_coils:
    print(f'Coil compression: {k.shape[1]} coils -> {n_coils} coils')
    k = mri.coil_compression(k, n_coils, ch_axis=1)

# sampling mask
m = (np.abs(k) > 0).astype(np.int8)  # [frame, coil, kx, ky]
m = m[:, 0]  # [frame, kx, ky]

# physio data
ecg = data.get_physio(0)  # [time, channel]
ecg_t = data.get_physio_t(0)  # [time,]
resp = data.get_physio(2)  # [time, channel]
resp_t = data.get_physio_t(2)  # [time,]

# acceleration rate
m_tmp = m[:, m.shape[1]//2, :]
r = m_tmp.size / m_tmp.sum()
print(f'Acceleration rate: {r}')

In [None]:
# plot kspace and mask
plt.subplot(1, 3, 1)
plt.imshow(np.abs(k[0, 0, :, :])**0.2, cmap='gray')
plt.subplot(1, 3, 2)
plt.imshow(m[0, :, :], cmap='gray')
plt.subplot(1, 3, 3)
plt.imshow(m[:, m.shape[1]//2, :], cmap='gray')
plt.savefig(output_path / 'kspace.png')
plt.show()

In [None]:
# plot average images for compressed coils
kspace_avg = mri.average_data(k, 0)
img_avg = fft_np.ifftnc(kspace_avg, axes=[1, 2])
plotting.plot_multichannel(img_avg, channel_axis=0, columns=6, figheight_per_row=3, figsize=(10, None),  # type: ignore
                           complex='abs', save_path=output_path / 'coils_compressed.png', show=True, cmap='gray')

In [None]:
# plot coil-combined image
img_combined = mri.rss(img_avg, 0)  # type: ignore
plt.imshow(img_combined, cmap='gray')
plt.savefig(output_path / 'averaged_image.png')
plt.show()
plt.close()

In [None]:
# coil sensitivity estimation
sens_maps = sigpy.mri.app.EspiritCalib(kspace_avg, calib_width=24, thresh=0.02, crop=0, show_pbar=False).run()
assert isinstance(sens_maps, np.ndarray), 'ESPIRiT failed'
plotting.plot_multichannel(sens_maps, channel_axis=0, columns=6, figheight_per_row=3, figsize=(10, None), complex='abs',
                           save_path=output_path / 'sens_maps.png', show=True, cmap='gray')

# DIP

In [None]:
# spatial basis generator
basis_gen = models.UNet(
    enc_channels=[zs_chans, 32, 64, 64, 64],
    dec_channels=[64, 64, 64, 32, 16],
    out_channels=2 * n_bases,
    kernel_size=3,
    n_convs_per_block=2,
    p_dropout=p_dropout,
    interpolation_mode='bilinear',
).to(dtype=dtype)

# basis coefficients generator
coeff_gen = None
if n_bases > 1:
    coeff_gen = models.MLP(
        feature_lengths=[zt_chans, 32, 64, 128, 256, 128, 64, 2 * n_bases],
        last_activation=False,
        p_dropout=p_dropout,
    ).to(dtype=dtype)

# required code vector size
code_vector_size = basis_gen.required_input_size(data.matrix_size, 2)

# flow generator
unet_bottleneck_size = basis_gen.get_bottleneck_size(code_vector_size)
flow_gen = models.FlowGenerator(
    mlp_features=[zt_chans, 32, 64, 64, 64],
    conv_input_size=unet_bottleneck_size,
    conv_channels=[64, 64, 64, 64, 64],
    n_convs_per_block=3,
    p_dropout=p_dropout,
    interpolation_mode='nearest',
).to(dtype=dtype)

# spatial transformer
output_size = basis_gen.get_output_size(code_vector_size)
transformer = models.SpatialTransformer(output_size).to(dtype=dtype)

nParam = sum(p.numel() for p in basis_gen.parameters())
print('Number of params in basis generator: %d' % nParam)
if coeff_gen is not None:
    nParam = sum(p.numel() for p in coeff_gen.parameters())
    print('Number of params in coefficient generator: %d' % nParam)
nParam = sum(p.numel() for p in flow_gen.parameters())
print('Number of params in flow generator: %d' % nParam)

In [None]:
# prepare k-space data
k_tor = torch.from_numpy(k) # [frame, coil, kx, ky]
k_max = torch.max(torch.abs(k_tor)).item()  # keep as separate variable for scaling of ground truth in quantitative evaluation of phantom data
k_tor = k_tor * ksp_scale / k_max
k_tor = k_tor.to(dtype=torch.promote_types(dtype, torch.complex32))

# prepare mask
m_tor = torch.from_numpy(m)[:, None]  # [frame, coil=1, kx, ky]
m_tor = m_tor.to(dtype=dtype)

# prepare sensitivity maps
sen_tor = torch.from_numpy(sens_maps) # [coil, x, y]
sen_tor = sen_tor.to(dtype=torch.promote_types(dtype, torch.complex32))

# prepare static code vector
zs = torch.empty(1, zs_chans, *code_vector_size, dtype=dtype).uniform_(0, 0.1)  # [batch=1, channel, x, y]

# prepare temporal code vector
zt = torch.zeros(data.n_phases, zt_chans, dtype=dtype)  # [time, channel]

In [None]:
# dump parameterization
with open(output_path / 'params.yaml', 'w') as f:
    yaml.dump_all(
        [{'params': params},
         {'basis_gen': basis_gen.config},
         {'coeff_gen': coeff_gen.config if coeff_gen is not None else None},
         {'flow_gen': flow_gen.config},],
        f,
        explicit_start=True,
        default_flow_style=False,
    )

In [None]:
mdip = MDIP(
    zs=zs,
    zt=zt,
    basis_gen=basis_gen,
    coeff_gen=coeff_gen,
    flow_gen=flow_gen,
    transformer=transformer,
    matrix_size=data.matrix_size,
    n_frames=data.n_phases,
    imaging_fs=1000 / data.tres,
    lambda_flow_spatial=lambda_flow_spatial,
    lambda_flow_temporal=lambda_flow_temporal,
    lambda_zt=lambda_zt,
    lambda_basis=lambda_basis,
    noise_reg=noise_reg,
    lr_max=lr_max,
    lr_min=lr_min,
    lr_static_factor=lr_static_factor,
    weight_decay=weight_decay,
    output_path=output_path,
).to_device(device)

In [None]:
mdip.optimize(
    k=k_tor.to(device=device),
    sens=sen_tor.to(device=device),
    mask=m_tor.to(device=device),
    n_iter=n_iter,
    save_every=save_every,
    activate_flow_after=activate_flow_after,
    batch_size=batch_size,
    monitor_every=50 if isinstance(data, PhantomDataset) else -1,
    monitor_gt=data.ground_truth[slice_idx] * ksp_scale / k_max if isinstance(data, PhantomDataset) else None,
)
mdip.save()

In [None]:
plt.figure(figsize=(12, 5))
plt.subplot(121)
plt.semilogy(mdip.metrics['total_loss'], linewidth=0.6, label='$L$')
plt.semilogy(mdip.metrics['kspace_loss'], linewidth=0.6, label='$L_k$')
if lambda_flow_spatial > 0:
    plt.semilogy(mdip.metrics['flow_loss_spatial'], linewidth=0.6, label='$L_{def,s}$')
if lambda_flow_temporal > 0:
    plt.semilogy(mdip.metrics['flow_loss_temporal'], linewidth=0.6, label='$L_{def,t}$')
if lambda_zt > 0:
    plt.semilogy(mdip.metrics['zt_loss'], linewidth=0.6, label='$L_{zt}$')
if lambda_basis > 0:
    plt.semilogy(mdip.metrics['basis_loss'], linewidth=0.6, label='$L_{b}$')
plt.ylim(bottom=1e-5, top=2e0)
plt.legend()
plt.title('Loss')
plt.subplot(122)
plt.semilogy(mdip.metrics['residual'], linewidth=0.6)
plt.ylim(bottom=6e-3, top=2e0)
plt.tight_layout()
plt.title('Residual')
plt.savefig(output_path / 'loss.png')
plt.show()

In [None]:
# inference
with mdip.no_grad_and_eval():
    cine, basis, coeffs, flow = mdip.forward_()
    cine = cine.cpu().numpy()
    basis = basis.cpu().numpy()
    coeffs = coeffs.cpu().numpy()
    flow = flow.cpu().numpy()

    # scale back to original k-space
    cine = cine / ksp_scale * k_max

    # plot predicted k-space
    kpred = fft_np.fftnc(cine[:, None] * sens_maps[None], [2, 3])  # type: ignore
    plt.subplot(1, 2, 1)
    plt.imshow(np.abs(kpred[0, 0]) ** 0.2, cmap='gray')
    plt.subplot(1, 2, 2)
    plt.imshow(np.abs(cine[0]), cmap='gray')
    plt.savefig(output_path / 'kpred.png')
    plt.show()

    # crop to recon size
    cine = mri.center_crop(cine, data.recon_size, (1, 2))

    mdip.save_cine(cine, equalize_histogram=True)
    mdip.save_basis(basis, show=True)
    if n_bases > 1:
        mdip.save_coeffs(coeffs, show=True)
    mdip.save_flow(flow)
    ecg_t_ = (ecg_t - ecg_t.min()) / 1000
    resp_t_ = (resp_t - resp_t.min()) / 1000
    mdip.save_static_code_vector()
    mdip.save_temporal_code_vector()
    mdip.save_temporal_code_vector((ecg[:, 0], ecg_t_), (resp[:, 0], resp_t_), name='zt_ecg_resp', show=True)

# L+S reconstruction

In [None]:
lps = LowRankPlusSparse(k_tor, sen_tor, m_tor).to_device(device)
l, s = lps.run(max_iter=500, lambda_l=0.5, lambda_s=0.05, tol=1e-5)
l = l.cpu().numpy()
s = s.cpu().numpy()
cine_lps = l + s
cine_lps = cine_lps / ksp_scale * k_max  # scale back to original k-space
cine_lps = mri.center_crop(cine_lps, data.recon_size, (1, 2))  # crop to recon size
mdip.save_cine(cine_lps, name='cine_lps', equalize_histogram=True)

# Only for phantom data: Quantitative evaluation

In [None]:
if isinstance(data, PhantomDataset):
    # save ground truth vs noisy image
    plt.subplot(121)
    plt.imshow(np.abs(data.ground_truth[0, 0]), cmap='gray')
    plt.title('Ground truth')
    plt.axis('off')
    plt.subplot(122)
    plt.imshow(np.abs(data.noisy_img[0, 0]), cmap='gray')
    plt.title(f'Noisy image (SNR={data.snr_db} dB)')
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(output_path / 'ground_truth_vs_noisy.png')
    plt.close()

    # save ground truth cine
    cine_gt = data.ground_truth[slice_idx]
    mdip.save_cine(cine_gt, name='cine_gt', equalize_histogram=True)

    # save noisy cine
    cine_noisy = data.noisy_img[slice_idx]
    mdip.save_cine(cine_noisy, name='cine_noisy', equalize_histogram=True)

    # get magnitude reconstructions
    cine_gt = np.abs(cine_gt)
    cine_noisy = np.abs(cine_noisy)
    cine_mdip = np.abs(cine)
    cine_lps_ = np.abs(cine_lps)

    # quantitative evaluation
    with open('mrxcat_annotations.yaml', 'r') as f:
        annotations = yaml.safe_load(f)[output_path.parent.name]
        bbox = annotations['bbox']
        center = annotations['center']
    preds_tuples = [('M-DIP', cine_mdip), ('LPS', cine_lps_), ('Noisy', cine_noisy)]
    metrics_cine = evaluate.get_metrics(cine_gt, *preds_tuples)
    metrics_roi = evaluate.get_metrics(cine_gt, *preds_tuples, bbox=bbox)
    metrics_profiles = evaluate.get_metrics(cine_gt, *preds_tuples, center=center)
    with pd.option_context('display.float_format', '{:.4f}'.format):
        print('Cine:')
        print(metrics_cine)
        print('\nROI:')
        print(metrics_roi)
        print('\nTemporal Profiles:')
        print(metrics_profiles)

    # save metrics to csv
    evaluate.update_metrics_csv(
        output_path / 'metrics.csv', ('cine', metrics_cine), ('roi', metrics_roi), ('profiles', metrics_profiles),
    )

    # plot monitored metrics
    mdip.save_metrics(show=True)

    # save ROIs
    evaluate.save_cine_roi(cine_gt, bbox, output_path / 'ROI_cine_gt.gif', data.tres)
    evaluate.save_cine_roi(cine_mdip, bbox, output_path / 'ROI_cine_mdip.gif', data.tres)
    evaluate.save_cine_roi(cine_lps_, bbox, output_path / 'ROI_cine_lps.gif', data.tres)
    evaluate.save_cine_roi(cine_noisy, bbox, output_path / 'ROI_cine_noisy.gif', data.tres)

    # save temporal profiles
    evaluate.save_temporal_profiles(cine_gt, center, output_path / 'profile_cine_gt.png')
    evaluate.save_temporal_profiles(cine_mdip, center, output_path / 'profile_cine_mdip.png')
    evaluate.save_temporal_profiles(cine_lps_, center, output_path / 'profile_cine_lps.png')
    evaluate.save_temporal_profiles(cine_noisy, center, output_path / 'profile_cine_noisy.png')

    # save error images
    evaluate.save_error_map(cine_gt, cine_mdip, output_path / 'error10x_mdip.gif', data.tres, scale=10)
    evaluate.save_error_map(cine_gt, cine_lps_, output_path / 'error10x_lps.gif', data.tres, scale=10)
    evaluate.save_error_map(cine_gt, cine_noisy, output_path / 'error10x_noisy.gif', data.tres, scale=10)

    # save overview image
    evaluate.save_overview_image(cine_gt[0], bbox, center, output_path / 'overview.png')

# Cleanup

In [None]:
# move everything back to CPU
mdip = mdip.to_device(torch.device('cpu'))
lps = lps.to_device(torch.device('cpu'))
torch.cuda.empty_cache()

# END