# CAIPI Denoising

TODO

- Add tensorboard functionality
- Design more models
- Add performance metrics - PSNR, SSIM
- Add more data augmentation methods

In [None]:
# Notebook imports

import warnings
warnings.filterwarnings('ignore')

import sys
import logging
import os
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np

# 0 = all messages are logged (default behavior)
# 1 = INFO messages are not printed
# 2 = INFO and WARNING messages are not printed
# 3 = INFO, WARNING, and ERROR messages are not printed
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# Refresh before importing python scripts
%load_ext autoreload
%autoreload 2

import matplotlib as mpl
mpl.rc('image', cmap='gray')

# LOGGING WHEN NOTEBOOK DISCONNECTS
nblog = open("nb.log", "a+")
sys.stdout.echo = nblog
sys.stderr.echo = nblog

get_ipython().log.handlers[0].stream = nblog
get_ipython().log.setLevel(logging.INFO)

%autosave 5

print(tf.__version__)
gpus = tf.config.list_physical_devices('GPU')
print(gpus)

In [None]:
PREPROCESSED_SAVE_PATH = '/home/quahb/caipi_denoising/data/preprocessed'

PATCH_SIZE = (32, 32)
EXTRACT_STEP = (32, 32)
PAD_VALUE = -0.5
train_preprocessing_params = {
    'normalize': {},
    
    'pad_square': {
        'value': PAD_VALUE
    },
    
    'white_noise': {
        'mu': 0.0,
        'sigma': 0.15
    },
    
    'extract_patches': {
        'patch_size': PATCH_SIZE,
        'extract_step': EXTRACT_STEP,
        'pad_before_ext': False,
        'pad_value': PAD_VALUE,
    },
    
}

# extract patches before white noise if pad_before_ext==True
# we can add white noise to loaded patches
X_train_ops = [
    'normalize', 
    'pad_square', 
    'extract_patches', 
    'white_noise',
]

y_train_ops = [
    'normalize', 
    'pad_square', 
    'extract_patches',
]

test_preprocessing_params = {
    'normalize': {},
    
    'pad_square': {},
    
    'white_noise': {},
    
    'extract_patches': {
        'patch_size': PATCH_SIZE,
        'extract_step': (1, 1),  # should always be 1, 1 for reconstruction
        'pad_before_ext': True
    },
    
}

X_test_ops = ['extract_patches']


GPUS_TO_USE = ["/GPU:0", "/GPU:1", "/GPU:2", "/GPU:3"]
MODEL_INPUT_SHAPE = (None, ) + PATCH_SIZE + (1, ) if 'extract_patches' in X_train_ops else (None, 384, 384, 1)
VALID_SPLIT = 0.8
MODEL_TYPE = 0
PATIENCE = 5
MODEL_PATH='/home/quahb/caipi_denoising/models'
MODEL_FILENAME='denoiser_ep{epoch:02d}.h5'

BATCH_SIZE = 20
N_EPOCHS = 20
INIT_EPOCH = 0

# Load Training Data

In [None]:
from src.preparation.gen_data import get_data_dict, get_train_data, get_median_slices

In [None]:
X_train, y_train = get_train_data(median_slices=False)

# Plot median slices for each subject

In [None]:
#print(dicoms_dict.keys(), dicoms_dict['1_01_016-V1'].keys(), len(dicoms_dict['1_01_016-V1']['3D_T2STAR_segEPI']))

# Plot median slice for each subject for single modality

dicoms_dict = get_data_dict('/home/quahb/caipi_denoising/data/data.json')

def plot_subj_slices(modality=0):
    MODALITY = modality 
    N_SLICES = X_train.shape[0]
    LEFT_i, MEDIAN_i, RIGHT_i = 35, 128, 221
    subj_list = list(dicoms_dict.keys())

    columns = 4
    rows    = len(subj_list) // columns + 1

    plt.figure(figsize=(23,120))
    for i in range(len(subj_list)):    
        plt.subplot(rows, columns, i + 1)
        plt.title(subj_list[i])

        plt.imshow(X_train[i * 256 + MEDIAN_i], cmap=plt.cm.gray)
        
#plot_subj_slices(0) # ['3D_T2STAR_segEPI', 'CAIPI1x2', 'CAIPI1x3', 'CAIPI2x2']

# Preprocess Training Data

In [None]:
from src.preparation.np_preprocessing_pipeline import preprocess_data

In [None]:
X_train.shape, y_train.shape

In [None]:
shuffle_idx = np.random.permutation(len(X_train))

In [None]:
X_train_pp = preprocess_data(X_train,
                             train_preprocessing_params,
                             ops=X_train_ops,
                             shuffle=shuffle_idx,
                             save_path=PREPROCESSED_SAVE_PATH)

In [None]:
y_train_pp = preprocess_data(y_train,
                             train_preprocessing_params,
                             ops=y_train_ops,
                             shuffle=shuffle_idx,
                             save_path=PREPROCESSED_SAVE_PATH)

## Plot Preprocessed X_train vs y_train

In [None]:
def plot_pp_io():
    cols = 2
    rows = 4

    plt.figure(figsize=(27,40))
    plt_i = 0
    for i in range(16, 20):  # range of slices to view
        plt.subplot(rows, cols, plt_i + 1)
        plt.title('X_train_processed')
        plt.imshow(X_train_pp[i], cmap=plt.cm.gray)

        plt.subplot(rows, cols, plt_i + 2)
        plt.title('y_train_processed')
        plt.imshow(y_train_pp[i], cmap=plt.cm.gray)

        plt_i += 2
    
plot_pp_io()

In [None]:
valid_i = int(len(X_train_pp) * VALID_SPLIT)

X_train_f, y_train_f = X_train_pp[:valid_i], y_train_pp[:valid_i]
X_valid_f, y_valid_f = X_train_pp[valid_i:], y_train_pp[valid_i:]

In [None]:
X_train_f.shape, y_train_f.shape, X_valid_f.shape, y_valid_f.shape

# Model Instantiation

In [None]:
import sys
sys.path.insert(1, '/home/quahb/caipi_denoising/src')

import tensorflow as tf
from modeling.get_model import get_model
from modeling.callbacks import get_training_cb

In [None]:
strategy = tf.distribute.MirroredStrategy(devices=GPUS_TO_USE)

with strategy.scope():
    model = get_model(model_type=MODEL_TYPE, 
                      input_shape=MODEL_INPUT_SHAPE)

# Model Training

In [None]:
cb_list = get_training_cb(patience=PATIENCE,
                          save_path=MODEL_PATH, 
                          save_filename=MODEL_FILENAME)

In [None]:
model.summary()

In [None]:
%%time

history = model.fit(X_train_f, y_train_f,
                    validation_data=(X_valid_f, y_valid_f),
                    batch_size=BATCH_SIZE,
                    epochs=N_EPOCHS,
                    initial_epoch=INIT_EPOCH,
                    callbacks=cb_list,
                    shuffle=True)
print(history)

In [None]:
#model.load_weights(r'/home/quahb/caipi_denoising/models/2022-06-10/denoiser_ep06.h5')

# Preprocess Testing Data

In [None]:
from src.preparation.gen_data import get_test_data

In [None]:
X_test = get_test_data()

In [None]:
X_test = preprocess_data(X_test, test_preprocessing_params, 
                         X_ops=X_test_ops,
                         shuffle=True)

# Run model inference

In [None]:
X_test.shape

In [None]:
SHOW_N_SLICES = 50
y_test = model.predict(X_test[:SHOW_N_SLICES],
                       verbose=1,
                       batch_size=30)

In [None]:
print(np.min(X_test[0]), np.max(X_test[0]), np.mean(X_test[0]))
print(np.min(y_test[0]), np.max(y_test[0]), np.mean(y_test[0]))

In [None]:
np.random.shuffle(y_test)
y_mini = y_test[:50]

In [None]:
plt.figure(figsize=(28, 55))
columns = 2
rows = 25
SHOW_N_SLICES = 50

for i in range(46, y_test.shape[0], 2):
    input_slc = X_test[i]
    output_slc = y_test[i]
    
    plt.subplot(rows, columns, i + 1)
    plt.imshow(input_slc, cmap='gray')
    plt.subplot(rows, columns, i + 2)
    plt.imshow(output_slc, cmap='gray')

In [None]:
plt.figure(figsize=(30,70))
columns = 2
SHOW_N_SLICES = 50

for i in range(46, y_test.shape[0], 2):
    input_slc = X_test[i]
    output_slc = y_test[i]
    
    plt.subplot(y_test.shape[0] // columns + 1, columns, i + 1)
    plt.imshow(input_slc, cmap='gray')
    plt.subplot(y_test.shape[0] // columns + 1, columns, i + 2)
    plt.imshow(output_slc, cmap='gray')