In [None]:
import scipy.io
import numpy as np
from celluloid import Camera
import matplotlib.pyplot as plt
from IPython.display import HTML
import matplotlib.animation as animation

from supermri.utils.tensorboard import remove_spines, remove_ticks

%matplotlib inline

In [None]:
def rotate(scan: np.ndarray):
    """rotate scan 90 degree to the left"""
    assert len(scan.shape) == 4
    # copy is needed to avoid negative strides error
    return np.rot90(scan, k=1, axes=[1, 3]).copy()


def load_mat(filename):
    data = scipy.io.loadmat(filename)
    data = np.stack([data["FLAIRarray"], data["T1array"], data["T2array"]]).astype(
        np.float32
    )
    return rotate(data)


def animate(scan, dim=0, filename="mri.gif"):
    assert 0 <= dim < 3 and len(scan.shape) == 3
    figure = plt.figure(figsize=(6, 6), dpi=240)
    if dim == 0:
        slices = [scan[i, :, :] for i in range(scan.shape[0])]
    elif dim == 1:
        slices = [scan[:, i, :] for i in range(scan.shape[1])]
    else:
        slices = [scan[:, :, i] for i in range(scan.shape[2])]

    images = [
        [plt.imshow(slices[i], cmap="gray", animated=True)] for i in range(len(slices))
    ]
    ani = animation.ArtistAnimation(
        figure, images, interval=50, blit=True, repeat_delay=1000
    )
    ani.save(filename)


def plot(scan, index=70):
    figure, axes = plt.subplots(
        nrows=1,
        ncols=3,
        figsize=(8, 3),
        gridspec_kw={
            "width_ratios": [1, 1, 1],
            "wspace": 0.25,
        },
        squeeze=False,
        dpi=120,
    )
    figure.patch.set_facecolor("white")
    axes = axes.flatten()
    axes[0].imshow(scan[index, :, :], cmap="gray", interpolation="none")
    axes[0].set_title(f"scan[{index}, : ,:]")
    axes[1].imshow(scan[:, index, :], cmap="gray", interpolation="none")
    axes[1].set_title(f"scan[:, {index} ,:]")
    axes[2].imshow(scan[:, :, index], cmap="gray", interpolation="none")
    axes[2].set_title(f"scan[:, : ,{index}]")
    plt.show()

In [None]:
merge_scan = load_mat("../data/affine/SR_031_BRIC1_V0_affine.mat")
print(merge_scan.shape)
plot(merge_scan[0], index=70)
plot(merge_scan[1], index=70)
plot(merge_scan[2], index=70)
# animate(merge_scan[0], dim=1, filename='t1_dim1.mp4')

In [None]:
scan = load_mat("../data/warp/SR_002_NHSRI_V1.mat")
print(scan.shape)
index = 70
plot(scan[0], index=index)
plot(scan[1], index=index)
plot(scan[2], index=index)

Make side-by-side comparison animation

In [None]:
# select FLAIR
lr, sr, hr = (
    load_mat("../data/affine/SR_002_NHSRI_V0_affine.mat")[0],
    load_mat("../runs/003_agunet_16f_gelu/scans/SR_002_NHSRI.mat")[0],
    load_mat("../data/affine/SR_002_NHSRI_V1_affine.mat")[0],
)

In [None]:
figure, axes = plt.subplots(
    nrows=1,
    ncols=3,
    figsize=(8, 3.5),
    gridspec_kw={
        "wspace": 0.05,
    },
    dpi=240,
    facecolor="white",
)

camera = Camera(figure)

PAD = 4
axes[0].set_title("true low-resolution", pad=PAD)
axes[1].set_title("super-resolution", pad=PAD)
axes[2].set_title("true high-resolution", pad=PAD)
for ax in axes:
    remove_spines(axis=ax)
    remove_ticks(axis=ax)

for s in range(lr.shape[1]):
    axes[0].imshow(lr[:, s, :], cmap="gray", interpolation="none", aspect="equal")
    axes[1].imshow(sr[:, s, :], cmap="gray", interpolation="none", aspect="equal")
    axes[2].imshow(hr[:, s, :], cmap="gray", interpolation="none", aspect="equal")
    camera.snap()

animation = camera.animate()
animation.save(
    "animation.gif",
    fps=8,
    dpi=240,
    savefig_kwargs={
        "pad_inches": 0,
    },
)