Purpose: To load and view either slices or full volumes of data

Usage: To load a dataset as
1. Slices: **2, 'full'**, i.e. (N_SLICES, 384, 384, 1)
2. Subjects: **3, 'full'**, i.e. (N_SUBJECTS, 384, 384, 256, 1)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pdb
import sys
sys.path.append('/home/quahb/caipi_denoising/src')

from utils.vizualization_tools import plot2, plot4, plot_slices, plot_patches, add_center_window, center_window_std
from preparation.data_io import load_dataset
from preparation.preprocessing_pipeline import rescale_magnitude, fourier_transform as ft, inverse_fourier_transform as ift

%load_ext autoreload
%autoreload 2

# Load Dataset

In [None]:
DIMS = 3
FULL_OR_PATCH = 'patches'
dataset_name = '3D_patches_complexfreq'

DATA_PATH = f'/home/quahb/caipi_denoising/data/datasets/unaccelerated/{dataset_name}/'

images, names = load_dataset(os.path.join(DATA_PATH, 'images'), DIMS, FULL_OR_PATCH, return_names=True, subset=10)
labels, names = load_dataset(os.path.join(DATA_PATH, 'labels'), DIMS, FULL_OR_PATCH, return_names=True, subset=10)
print(images.shape, images.dtype)
print(labels.shape, labels.dtype)

# View Dataset

In [None]:
# Just run cell, Dont need to set params

# 1. Init Viewing Params
b_COMPLEX = False
b_FREQ = False
b_PATCHES = False

if np.iscomplexobj(images): b_COMPLEX = True
if 'freq' in dataset_name: b_FREQ = True
if FULL_OR_PATCH == 'patches': b_PATCHES = True

# 2. Draw plots
if DIMS == 2:
    # Full: [N_SLICES, 384, 384, 1]
    # Patches: [N_PATCHES, 256, 256, 1]
    SLICES_TO_SHOW = 5
    INDICES = range(0, images.shape[0], 50)[:SLICES_TO_SHOW]
    imgs, lbls = [], []
    
    for i in INDICES:
        imgs.append(images[i])
        lbls.append(labels[i])

    if b_COMPLEX:  # complex data, full or patch
        if b_FREQ:
            imgs = [ ift(i) for i in imgs ]
            lbls = [ ift(i) for i in lbls ]

        for img, lbl in zip(imgs, lbls):
            if b_FREQ:
                plot2(np.abs(np.log(ft(img))), np.abs(np.log(ft(lbl))))
            plot2(np.abs(img), np.abs(lbl))
            plot2(np.angle(img), np.angle(lbl))
    else:  # magnitude data, full or patch
        for img, lbl in zip(imgs, lbls):
            plot2(img, lbl)
            
elif DIMS == 3:
    # Full: [N_SUBJECTS, 384, 384, 256, 1]
    # Patches: [N_PATCHES, 256, 256, 256, 1]
    
    INDEX = 5 # Works for full or patches
    img, lbl = images[INDEX], labels[INDEX] 

    if b_COMPLEX:
        if b_FREQ:
            ift_img, ift_lbl = ift(img), ift(lbl)

        for i in range(0, img.shape[-2], 50):
            if b_FREQ:
                plot2(np.abs(np.log(img[:,:,i])), np.abs(np.log(lbl[:,:,i])))
            plot2(np.abs(ift_img)[:,:,i], np.abs(ift_lbl)[:,:,i])
            plot2(np.angle(ift_img)[:,:,i], np.angle(ift_lbl)[:,:,i])
    else:
        for i in range(0, img.shape[-2], 50):
            plot2(img[:,:,i], lbl[:,:,i])