In [None]:
import nobrainer
from nobrainer.io import _is_gzipped
from nobrainer.volume import to_blocks

import tensorflow as tf
import glob
import numpy as np

AUTOTUNE = tf.data.experimental.AUTOTUNE

    
# function to apply augmentations to tf dataset
def apply_augmentations(features, labels):

    """ Apply <TYPE_OF> augmentation to the dataset
    
    """
#     iaa.SomeOf(
#             (0, 3),
#             [
#                 iaa.Fliplr(0.5),
#                 iaa.Flipud(0.5),
#                 iaa.Noop(),
#                 iaa.OneOf(
#                     [
#                         iaa.Affine(rotate=90),
#                         iaa.Affine(rotate=180),
#                         iaa.Affine(rotate=270),
#                     ]
#                 ),
#                 # iaa.GaussianBlur(sigma=(0.0, 0.2)),
#             ],
#         )
    
    return


def get_dataset(file_pattern,
                n_classes,
                batch_size,
                volume_shape,
                plane,
                block_shape=None,
                n_epochs=None,
                mapping=None,
                augment=False,
                shuffle_buffer_size=None,
                num_parallel_calls=AUTOTUNE):
    
    """ Returns tf.data.Dataset after preprocessing from 
    tfrecords for training and validation
    
    Parameters
    ----------
    file_pattern:
    
    n_classes:
    
    """
    
    files = glob.glob(file_pattern)
    
    if not files:
        raise ValueError("no files found for pattern '{}'".format(file_pattern))
    
    
    compressed = _is_gzipped(files[0])
    shuffle = bool(shuffle_buffer_size)
    
    
    ds = nobrainer.dataset.tfrecord_dataset(
        file_pattern=file_pattern,
        volume_shape=volume_shape,
        shuffle=shuffle,
        scalar_label=True,
        compressed=compressed,
        num_parallel_calls=num_parallel_calls,
    )
    
    if augment:
        ds = ds.map(
            lambda x, y: tf.cond(
                tf.random.uniform((1,)) > 0.5,
                    true_fn=lambda: apply_augmentations(x, y),
                    false_fn=lambda: (x, y),
            ),
            num_parallel_calls=num_parallel_calls,
        )
    
    def _ss(x, y):
        x, y = structural_slice(x, y, plane)
        return (x, y)
    
    ds = ds.map(_ss, num_parallel_calls)
    
    
#     def _f(x, y):
#         x = to_blocks(x, block_shape)
#         n_blocks = x.shape[0]
#         y = tf.repeat(y, n_blocks)
#         return (x, y)
#     ds = ds.map(_f, num_parallel_calls=num_parallel_calls)
    
    # This step is necessary because it reduces the extra dimension.
    ds = ds.unbatch()
    
    # add a single dimension at the end
    ds = ds.map(lambda x, y: (tf.expand_dims(x, -1), y))
    
    ds = ds.prefetch(buffer_size=batch_size)
    
    if batch_size is not None:
        ds = ds.batch(batch_size=batch_size, drop_remainder=True)
        
    if shuffle_buffer_size:
        ds = ds.shuffle(buffer_size=shuffle_buffer_size)
    
    # Repeat the dataset n_epochs times
    ds = ds.repeat(n_epochs)
    
    return ds


def structural_slice(x, y, plane):
    
    """ Transpose dataset based on the plane
    
    Parameters
    ----------
    x:
    
    y:
    
    plane:
    
    """
    
    options = ['axial', 'coronal', 'sagittal']
    
    x = tf.convert_to_tensor(x)
    volume_shape = np.array(x.shape)
    
    
    if isinstance(plane, str) and plane in options:
        if plane == 'axial':
            x = x
            y = tf.repeat(y, volume_shape[0])
        
        if plane == 'coronal':
            x = tf.transpose(x, perm=[1,0,2])
            y = tf.repeat(y, volume_shape[1])
            
        if plane == 'sagittal':
            x = tf.transpose(x, perm=[2,0,1])
            y = tf.repeat(y, volume_shape[2])
        return x, y
    else:
        raise ValueError("expected plane to be one of ['axial', 'coronal', 'sagittal']")
    
    

if __name__=="__main__":
    
    n_classes=2
    global_batch_size=4
    volume_shape=(64,64,64)
    dataset_train_axial = get_dataset("tfrecords/tfrecords_fold_1/data-train_*",
                            n_classes=n_classes,
                            batch_size=global_batch_size,
                            volume_shape=volume_shape,
                            plane='axial',
                            shuffle_buffer_size=3)
    
    print(dataset_train_axial)
# dataset_train_coronal = get_dataset("tfrecords/tfrecords_fold_1/data-train_*",
#                             n_classes=n_classes,
#                             batch_size=global_batch_size,
#                             volume_shape=volume_shape,
#                             block_shape=block_shape,
#                             plane='coronal',
#                             shuffle_buffer_size=3)

# dataset_train_sagittal = get_dataset("tfrecords/tfrecords_fold_1/data-train_*",
#                             n_classes=n_classes,
#                             batch_size=global_batch_size,
#                             volume_shape=volume_shape,
#                             block_shape=block_shape,
#                             plane='sagittal',
#                             shuffle_buffer_size=3)


In [None]:
import tensorflow as tf

import nobrainer
from nobrainer import dataset, volume

dir_path = os.path.abspath("csv/faced_defaced/train_test_fold_1/csv/")
csv_path = os.path.join(dir_path, "training.csv")
# print(dir_path, csv_path)


labels = pd.read_csv(csv_path)["Y"].values
paths = pd.read_csv(csv_path)["X"].values


# print(labels)

n_classes=2
volume_shape = (256, 256, 256)
block_shape = (128, 128, 128)


training_paths = zip(paths, labels)

print(training_paths)

In [None]:
import os, sys
sys.path.append("..")
import binascii
from helpers.utils import load_vol, save_vol
from preprocessing.normalization import standardize_volume, normalize_volume
from preprocessing.conform import conform_data
import numpy as np
import nibabel as nb
from glob import glob
from pathlib import Path
from shutil import *
import subprocess


orig_data_face = "/work/01329/poldrack/data/mriqc-net/data/face/T1w"
orig_data_deface = "/work/01329/poldrack/data/mriqc-net/data/defaced"

save_data_face = "/work/06850/sbansal6/maverick2/mriqc-shared/face"
save_data_deface = "/work/06850/sbansal6/maverick2/mriqc-shared/deface"

os.makedirs(save_data_face, exist_ok=True)
os.makedirs(save_data_deface, exist_ok=True)


conform_size = (64, 64, 64)

def is_gz_file(filepath):
    if os.path.splitext(filepath)[1] == '.gz':
        with open(filepath, 'rb') as test_f:
            return binascii.hexlify(test_f.read(2)) == b'1f8b'
    return False

def preprocess(pth, conform_size, save_data_path):
    """
    """
    filename = pth.split("/")[-1]
    print('Confirmation step')
    volume = conform_data(pth, out_size=conform_size)
    
    print("Normalize/Standardize step")
    volume = normalize_volume(standardize_volume(volume))
    save_path = os.path.join(save_data_path, filename)

    newaffine = np.eye(4)
    newaffine[:3, 3] = -0.5 * (np.array(conform_size) - 1)
    nii = nb.Nifti1Image(volume, newaffine, None)
    
    print("Save new affine")
    nii.to_filename(save_path)
    return save_path

        
# print(list.count(deface_orig))

#     if tempfile not in deface_C:
#         print(tempfile)
#         deface_NC.append(tempfile)
        
#     if not is_gz_file(path):
#         tempname = path.split("/")[-1]
#         rename_file = os.path.splitext(tempname)[0]
#         dst = os.path.join(save_data_face, rename_file)
#         print(dst)
#         subprocess.call(['cp', path, dst])
#         print(preprocess(dst, conform_size))
#     else:
#         print(preprocess(path, conform_size))


for path in glob(orig_data_deface + "/*/*.nii*"):
#     try:
    print("Orig Path: ", path)
    if not is_gz_file(path) and os.path.splitext(path)[1] == '.gz':
        tempname = path.split("/")[-1]
        ds = path.split("/")[-2]
        rename_file = os.path.splitext(tempname)[0]
        dst = os.path.join(save_data_deface, rename_file)
        print(dst)
        subprocess.call(['cp', path, dst])
        ds_save_path = os.path.join(save_data_deface, ds)
        if not os.path.exists(ds_save_path):
            os.makedirs(ds_save_path)   
        print(preprocess(dst, conform_size, save_data_path=ds_save_path)) 
    else:
        ds = path.split("/")[-2]
        ds_save_path = os.path.join(save_data_deface, ds)
        if not os.path.exists(ds_save_path):
            os.makedirs(ds_save_path)
        print(preprocess(path, conform_size, save_data_path=ds_save_path))
#     except:
#         print("Preprocessing incomplete. Exception occurred.")
#         pass

In [None]:
import nibabel as nb

in_file = '/work/06850/sbansal6/maverick2/mriqc-shared/deface/ds001912_anat/sub-01_ses-01_run-01_T1w.nii.gz'
# dst_path = '/work/06850/sbansal6/maverick2/mriqc-shared/deface'

# print(is_gz_file(in_file))

# if not is_gz_file(in_file):
#     filename = in_file.split("/")[-1]
#     print(filename)
#     rename_file = os.path.splitext(filename)[0]
#     dst = os.path.join(dst_path, rename_file)
    
#     subprocess.call(['cp', in_file, dst])
    
    
# if isinstance(in_file, (str, Path)):
i_file = nb.load(in_file)
print(i_file)



In [None]:
face_C = []
face_O = []

for path in glob(save_data_face + "/*/*.nii*"):
    tempname = path.split("/")[-1]
    ds = path.split("/")[-2]
    face_C.append(ds + '/' + tempname)

print(len(face_C))
# print(face_C)


for path in glob(orig_data_face + "/*/*.nii*"):
    tempname = path.split("/")[-1]
    ds = path.split("/")[-2]
    face_O.append(ds + '/' + tempname)

print(len(face_O))
# print(face_O)

count = 0
for f in face_O:
    exists = False
    for fc in face_C:
        if fc in f:
            exists = True
    if not exists:
        count += 1
        print(f)
print(count)

In [None]:
import os, sys
sys.path.append("..")
import binascii
from helpers.utils import load_vol, save_vol
from preprocessing.normalization import standardize_volume, normalize_volume
from preprocessing.conform import conform_data
import numpy as np
import nibabel as nb
from glob import glob
from pathlib import Path
from shutil import *
import subprocess

face_path = "/work/06850/sbansal6/maverick2/mriqc-shared/face"
deface_path = "/work/06850/sbansal6/maverick2/mriqc-shared/deface"

paths = []
labels = []

for path in glob(deface_path + "/*/*.nii*"):
    paths.append(path)
    labels.append(0)

for path in glob(face_path + "/*/*.nii*"):
    paths.append(path)
    labels.append(1)
    
print(len(paths))
print(len(labels))

In [None]:
from operator import itemgetter
from sklearn.model_selection import KFold
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split
import pandas as pd

save_path = "./csv/"

os.makedirs(save_path, exist_ok=True)

df = pd.DataFrame()
df["X"] = paths
df["Y"] = labels
df.to_csv(os.path.join(save_path, "all.csv"))

SPLITS = 10
skf = StratifiedKFold(n_splits=SPLITS)
fold_no = 1

for train_index, test_index in skf.split(paths, labels):
    out_path = save_path + "/train_test_fold_{}/csv/".format(fold_no)

    if not os.path.exists(out_path):
        os.makedirs(out_path)

    image_train, image_test = (
        itemgetter(*train_index)(paths),
        itemgetter(*test_index)(paths),
    )
    label_train, label_test = (
        itemgetter(*train_index)(labels),
        itemgetter(*test_index)(labels),
    )

    # image_train = [os.path.join(data_path, 'sub-' + str(pth) + '_T1w.nii.gz') for pth in image_train]
    train_data = {"X": image_train, "Y": label_train}
    df_train = pd.DataFrame(train_data)
    df_train.to_csv(os.path.join(out_path, "training.csv"), index=False)

    # image_test = [os.path.join(data_path, 'sub-' + str(pth) + '_T1w.nii.gz') for pth in image_test]
    validation_data = {"X": image_test, "Y": label_test}
    df_validation = pd.DataFrame(validation_data)
    df_validation.to_csv(os.path.join(out_path, "validation.csv"), index=False)

    fold_no += 1

In [None]:
import random
import nobrainer

for fold in range(1, SPLITS+1):
    
    dir_path = "./csv/train_test_fold_{}/csv/".format(fold)
    
    tf_records_dir = "./tfrecords/tfrecords_fold_{}/".format(fold)
    os.makedirs(tf_records_dir, exist_ok=True)
    
    train_csv_path = os.path.join(dir_path, "training.csv")
    valid_csv_path = os.path.join(dir_path, "validation.csv")
    
    train_paths = pd.read_csv(train_csv_path)["X"].values
    train_labels = pd.read_csv(train_csv_path)["Y"].values
    train_D = list(zip(train_paths, train_labels))
    random.shuffle(train_D)
#     print(train_D[0])
    
    valid_paths = pd.read_csv(valid_csv_path)["X"].values
    valid_labels = pd.read_csv(valid_csv_path)["Y"].values
    valid_D = list(zip(valid_paths, valid_labels))
    random.shuffle(valid_D)
    
    train_write_path = os.path.join(tf_records_dir, 'data-train_shard-{shard:03d}.tfrec')
    valid_write_path = os.path.join(tf_records_dir, 'data-valid_shard-{shard:03d}.tfrec')
    
    nobrainer.tfrecord.write(
        features_labels=train_D,
        filename_template=train_write_path,
        examples_per_shard=3)
    
    nobrainer.tfrecord.write(
        features_labels=valid_D,
        filename_template=valid_write_path,
        examples_per_shard=1)

In [None]:
from nobrainer import dataset, volume
from nobrainer.io import _is_gzipped
from nobrainer.volume import to_blocks

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler, TensorBoard
from tensorflow.keras import metrics
from tensorflow.keras import losses
from tensorflow.keras.models import load_model

import matplotlib.pyplot as plt
import glob

n_classes = 1
batch_size = 4
volume_shape = (64, 64, 64)
block_shape = (32, 32, 32)
n_epochs = 20
num_parallel_calls = 4

strategy = tf.distribute.MirroredStrategy()
BATCH_SIZE_PER_REPLICA = batch_size
global_batch_size = BATCH_SIZE_PER_REPLICA*strategy.num_replicas_in_sync

AUTOTUNE = tf.data.experimental.AUTOTUNE

# function to apply augmentations to tf dataset
def apply_augmentations(x, y):
    
#     iaa.SomeOf(
#             (0, 3),
#             [
#                 iaa.Fliplr(0.5),
#                 iaa.Flipud(0.5),
#                 iaa.Noop(),
#                 iaa.OneOf(
#                     [
#                         iaa.Affine(rotate=90),
#                         iaa.Affine(rotate=180),
#                         iaa.Affine(rotate=270),
#                     ]
#                 ),
#                 # iaa.GaussianBlur(sigma=(0.0, 0.2)),
#             ],
#         )
    
    return
    
def structural_slice(x, y, plane):
    
    options = ['axial', 'coronal', 'sagittal']
    
    x = tf.convert_to_tensor(x)
    volume_shape = np.array(x.shape)
    
    
    if isinstance(plane, str) and plane in options:
        if plane == 'axial':
            x = x
            y = tf.repeat(y, volume_shape[0])
        
        if plane == 'coronal':
            x = tf.transpose(x, perm=[1,0,2])
            y = tf.repeat(y, volume_shape[1])
            
        if plane == 'sagittal':
            x = tf.transpose(x, perm=[2,0,1])
            y = tf.repeat(y, volume_shape[2])
            
        return x, y
    else:
        raise ValueError("expected plane to be one of ['axial', 'coronal', 'sagittal']")
    
    

def get_dataset(file_pattern,
                n_classes,
                batch_size,
                volume_shape,
                plane,
                block_shape=None,
                n_epochs=None,
                mapping=None,
                augment=False,
                shuffle_buffer_size=None,
                num_parallel_calls=AUTOTUNE):
    
    files = glob.glob(file_pattern)
    
    if not files:
        raise ValueError("no files found for pattern '{}'".format(file_pattern))
    
    compressed = _is_gzipped(files[0])
    shuffle = bool(shuffle_buffer_size)
    
    
    ds = dataset.tfrecord_dataset(
        file_pattern=file_pattern,
        volume_shape=volume_shape,
        shuffle=shuffle,
        scalar_label=True,
        compressed=compressed,
        num_parallel_calls=num_parallel_calls,
    )
    
    if augment:
        ds = ds.map(
            lambda x, y: tf.cond(
                tf.random.uniform((1,)) > 0.5,
                    true_fn=lambda: apply_augmentations(x, y),
                    false_fn=lambda: (x, y),
            ),
            num_parallel_calls=num_parallel_calls,
        )
    
    def _sp(x, y):
        x, y = structural_slice(x, y, plane)
        return (x, y)
    
    ds = ds.map(_sp, num_parallel_calls)
    
#     print(ds)
#     temp = list(ds.as_numpy_iterator())

#     def _f(x, y):
#         x = to_blocks(x, block_shape)
#         n_blocks = x.shape[0]
#         y = tf.repeat(y, n_blocks)
#         return (x, y)
#     ds = ds.map(_f, num_parallel_calls=num_parallel_calls)
    
    # This step is necessary because separating into blocks adds a dimension.
    ds = ds.unbatch()
    
    # add a single dimension at the end
    ds = ds.map(lambda x, y: (tf.expand_dims(x, -1), y))
    
    ds = ds.prefetch(buffer_size=batch_size)
    
    if batch_size is not None:
        ds = ds.batch(batch_size=batch_size, drop_remainder=True)
        
#     ds = ds.repeat(n_epochs)
    
    return ds


# dataset_train_axial = get_dataset("tfrecords/tfrecords_fold_1/data-train_*",
#                             n_classes=n_classes,
#                             batch_size=global_batch_size,
#                             volume_shape=volume_shape,
#                             block_shape=block_shape,
#                             plane='axial',
#                             shuffle_buffer_size=3)

# dataset_train_coronal = get_dataset("tfrecords/tfrecords_fold_1/data-train_*",
#                             n_classes=n_classes,
#                             batch_size=global_batch_size,
#                             volume_shape=volume_shape,
#                             block_shape=block_shape,
#                             plane='coronal',
#                             shuffle_buffer_size=3)

# dataset_train_sagittal = get_dataset("tfrecords/tfrecords_fold_1/data-train_*",
#                             n_classes=n_classes,
#                             batch_size=global_batch_size,
#                             volume_shape=volume_shape,
#                             block_shape=block_shape,
#                             plane='sagittal',
#                             shuffle_buffer_size=3)


In [None]:
print(dataset_train_axial)
print(dataset_train_coronal)
print(dataset_train_sagittal)

In [None]:
import glob 

tpaths = glob.glob('tfrecords/tfrecords_fold_1/data-train_*')
vpaths = glob.glob('tfrecords/tfrecords_fold_1/data-valid_*')

In [None]:
import sys, os
sys.path.append("..")
from models import modelN

from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler, TensorBoard
from tensorflow.keras import metrics
from tensorflow.keras import losses
from tensorflow.keras.models import load_model


def train(
    image_size=(64, 64),
    dropout=0.4,
    batch_size=8,
    n_classes=2,
    n_epochs=30
):
    
    planes = ['axial', 'coronal', 'sagittal']
    
    strategy = tf.distribute.MirroredStrategy()
    BATCH_SIZE_PER_REPLICA = batch_size
    global_batch_size = BATCH_SIZE_PER_REPLICA*strategy.num_replicas_in_sync
    
    model_save_path = './model_save_dir'
    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)
        
    cp_save_path = os.path.join(model_save_path, 'weights')
    
    logdir_path = os.path.join(model_save_path, "tb_logs")
    if not os.path.exists(logdir_path):
        os.makedirs(logdir_path)
        
        
    for plane in planes:
        
        logdir = os.path.join(logdir_path, plane)
        os.makedirs(logdir, exist_ok=True)
        
        tbCallback = TensorBoard(
            log_dir=logdir,
            histogram_freq=0,
            write_graph=True,
            write_images=False,
        )
        
        
        os.makedirs(os.path.join(cp_save_path, plane), exist_ok=True)
        
        model_checkpoint = ModelCheckpoint(
            os.path.join(cp_save_path, plane, "best-wts.h5"),
            monitor="val_loss",
            save_best_only=True,
            save_weights_only=True,
            mode="min",
        )
        
        
        with strategy.scope():
            
            lr = 1e-4
            model = modelN.Submodel(
                input_shape=image_size,
                dropout=dropout,
                name=plane,
                include_top=True,
                weights=None,
            )
            
            print("Submodel: ", plane)
            print(model.summary())
    
            METRICS = [
                metrics.TruePositives(name='tp'),
                metrics.FalsePositives(name='fp'),
                metrics.TrueNegatives(name='tn'),
                metrics.FalseNegatives(name='fn'),
                metrics.BinaryAccuracy(name='accuracy'),
                metrics.Precision(name='precision'),
                metrics.Recall(name='recall'),
                metrics.AUC(name='auc'),
            ]

    
            model.compile(
                loss=tf.keras.losses.binary_crossentropy,
                optimizer="adam",
                metrics=METRICS
            )
        
        
        print("GLOBAL BATCH SIZE: ", global_batch_size)
        dataset_train = get_dataset("tfrecords/tfrecords_fold_1/data-train_*",
                                    n_classes=n_classes,
                                    batch_size=global_batch_size,
                                    volume_shape=volume_shape,
                                    block_shape=block_shape,
                                    plane=plane,
                                    shuffle_buffer_size=global_batch_size)
        
        dataset_valid = get_dataset("tfrecords/tfrecords_fold_1/data-valid_*",
                                    n_classes=n_classes,
                                    batch_size=global_batch_size,
                                    volume_shape=volume_shape,
                                    block_shape=block_shape,
                                    plane=plane,
                                    shuffle_buffer_size=global_batch_size)
        
        steps_per_epoch = dataset.get_steps_per_epoch(
            n_volumes = len(tpaths),
            volume_shape = volume_shape,
            block_shape = block_shape,
            batch_size = global_batch_size)

        validation_steps = dataset.get_steps_per_epoch(
            n_volumes = len(vpaths),
            volume_shape = volume_shape,
            block_shape = block_shape,
            batch_size = global_batch_size)

        model.fit(
            dataset_train,
            epochs=n_epochs,
            steps_per_epoch = steps_per_epoch,
            validation_data = dataset_valid,
            validation_steps = validation_steps,
            callbacks=[tbCallback,
                       model_checkpoint]
        )
        
        del model
        K.clear_session()
    
    lr = 5e-5
    model = modelN.CombinedClassifier(input_shape=image_size, dropout=dropout, wts_root=cp_save_path)
    
train()