In [1]:
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
from tqdm import trange
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, metrics
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.callbacks import ReduceLROnPlateau, ModelCheckpoint
import tensorflow.keras.backend as K

import ants

#from preprocess import *
from model import *
from loss import *
#from train import *
#from inference import *

# Set this environment variable to allow ModelCheckpoint to work
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'

# Set this environment variable to only use the first available GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

# For tensorflow 2.x.x allow memory growth on GPU
###################################
gpus = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)
###################################

In [4]:
def get_paths_csv(base_dir, name_dict, output_csv):
    
    try:
        def get_files(path):
            files_list = list()
            for root, _, files in os.walk(path, topdown = False):
                for name in files:
                    files_list.append(os.path.join(root, name))
            return files_list

        cols = ['id'] + list(names_dict.keys())
        df = pd.DataFrame(columns = cols)
        row_dict = dict.fromkeys(cols)

        ids = os.listdir(base_dir)

        for i in ids:
            row_dict['id'] = i
            path = os.path.join(base_dir, i)
            files = get_files(path)

            for file in files:
                for img_type in name_dict.keys():
                    for img_string in name_dict[img_type]:
                        if img_string in file:
                            row_dict[img_type] = file

            df = df.append(row_dict, ignore_index = True)

        df.to_csv(output_csv, index = False)
    except:
        print('ERROR! Returning non-zero exit status.')
        return 1
    
    return 0
    
    ################# End of function #################

In [5]:
names_dict = {'mask': ['UPenn', 'Segm'],
              't1': ['t1_'],
              't2': ['t2_'], 
              'tc': ['t1gd_'], 
              'fl': ['flair_']}
base_dir = '/rsrch1/ip/aecelaya/data/ivygap/IvyGap/'
output_csv = 'ivygap_paths.csv'

get_paths_csv(base_dir, names_dict, output_csv)

0

In [9]:
data = pd.read_csv('ivygap_paths.csv')
train, val, _, _ = train_test_split(data, data, test_size = 0.2, random_state = 42)
train = train.reset_index(drop = True)
val = val.reset_index(drop = True)

In [7]:
# def make_tfrecords(df, filename):
#     def _float_feature(value):
#         return tf.train.Feature(float_list=tf.train.FloatList(value=value))

#     # open the file
#     writer = tf.io.TFRecordWriter(filename)

#     for j in trange(len(df)):
#         patient = df.iloc[j].to_dict()
#         mask_info = ants.image_header_info(patient['mask'])
#         dims = mask_info['dimensions']
#         dims = tuple(int(d) for d in dims)
#         mask_labels = [0, 1, 2, 4]
#         patch_size = 64
#         radius = patch_size // 2

#         mask = ants.image_read(patient['mask'])
#         nz = mask.nonzero()
#         mask = mask.numpy()
#         mask_numpy = np.empty((*dims, len(mask_labels)))
#         for i in range(len(mask_labels)):
#             mask_numpy[..., i] = mask == mask_labels[i]

#         images = list(patient.values())[2:len(patient)]
#         images_numpy = np.empty((*dims, len(images)))
#         for i in range(len(images)):
#             ants_image = ants.image_read(images[i])
#             ants_image = ants_image.numpy()
#             ants_image_nz = ants_image[ants_image != 0]
#             mean = np.mean(ants_image_nz)
#             std = np.std(ants_image_nz)
#             ants_image = (ants_image - mean) / std
#             ants_image = np.multiply(mask, ants_image)
#             images_numpy[..., i] = ants_image

#         idx = np.arange(0, len(nz[0]))

#         num_points = 20
#         cnt = 0
#         while cnt < num_points:
#             idx_sample = np.random.choice(idx, size = 1)[0]
#             point = (nz[0][idx_sample], nz[1][idx_sample], nz[2][idx_sample])
#             point_upper = [point[i] + radius in range(0, dims[i] + 1) for i in range(len(point))]
#             point_lower = [point[i] - radius in range(0, dims[i] + 1) for i in range(len(point))]
#             if False in point_upper or False in point_lower:
#                 continue
#             else:
#                 image_patch = images_numpy[point[0] - radius:point[0] + radius, 
#                                            point[1] - radius:point[1] + radius, 
#                                            point[2] - radius:point[2] + radius, 
#                                            ...]
#                 mask_patch = mask_numpy[point[0] - radius:point[0] + radius, 
#                                         point[1] - radius:point[1] + radius, 
#                                         point[2] - radius:point[2] + radius, 
#                                         ...]

#                 # Create a feature
#                 feature = {'image': _float_feature(image_patch.ravel()),
#                             'mask': _float_feature(mask_patch.ravel())}

#                 # Create an example protocol buffer
#                 example = tf.train.Example(features=tf.train.Features(feature=feature))

#                 # Serialize to string and write on the file
#                 writer.write(example.SerializeToString())

#                 cnt += 1

#     writer.close()

In [11]:
def get_points(mask, num_points):
    # Return list of randomly sampled non-zero points from image mask
    
    # Get indicies of non-zero elements of mask
    nonzeros = np.nonzero(mask)
    
    # Randomly sample non-zero indicies
    idx = np.arange(0, len(nonzeros[0]))
    idx_sample = np.random.choice(idx, size = num_points)
    
    # Get list of points 
    points = list()
    for i in idx_sample:
        points.append((nonzeros[0][i], nonzeros[1][i], nonzeros[2][i]))
        
    return points

def normalize(image, brainmask):
    nonzeros = image[image != 0]
    mean = np.mean(nonzeros)
    std = np.std(nonzeros)
    image = (image - mean) / std
    image = np.multiply(brainmask, image)
    return image

In [9]:
def make_tfrecords(df, filename):
    def _float_feature(value):
        return tf.train.Feature(float_list=tf.train.FloatList(value = value))
    
    # open the file
    writer = tf.io.TFRecordWriter(filename)

    for i in trange(len(df)):
        patient = df.iloc[i].to_dict()
        mask_info = ants.image_header_info(patient['mask'])
        dims = mask_info['dimensions']
        dims = tuple(int(d) for d in dims)
        mask_labels = [0, 1, 2, 4]
        patch_size = 64
        radius = patch_size // 2

        mask_numpy = ants.image_read(patient['mask']).numpy()
        mask_numpy = np.pad(mask_numpy, radius)
        mask = np.empty((*(dim + patch_size for dim in dims), len(mask_labels)))
        for j in range(len(mask_labels)):
            mask[..., j] = mask_numpy == mask_labels[j]

        image_list = list(patient.values())[2:len(patient)]
        image = np.empty((*(dim + patch_size for dim in dims), len(image_list)))
        for j in range(len(image_list)):
            image_ants = ants.image_read(image_list[j])
            brainmask = ants.get_mask(image_ants, cleanup = 0).numpy()
            image_numpy = image_ants.numpy()
            image_numpy = normalize(image_numpy, brainmask)
            
            # Pad each image with radius to ensure each point in mask can be picked
            image[..., j] = np.pad(image_numpy, radius) 
            
        points = get_points(mask_numpy, 20)
        for point in points:
            image_patch = image[point[0] - radius:point[0] + radius, 
                                point[1] - radius:point[1] + radius, 
                                point[2] - radius:point[2] + radius, 
                                ...]
            mask_patch = mask[point[0] - radius:point[0] + radius, 
                              point[1] - radius:point[1] + radius, 
                              point[2] - radius:point[2] + radius, 
                              ...]
            
            # Create a feature
            feature = {'image': _float_feature(image_patch.ravel()),
                        'mask': _float_feature(mask_patch.ravel())}

            # Create an example protocol buffer
            example = tf.train.Example(features = tf.train.Features(feature = feature))

            # Serialize to string and write on the file
            writer.write(example.SerializeToString())

    writer.close()

In [10]:
make_tfrecords(train, 'train.tfrecords')
make_tfrecords(val, 'val.tfrecords')

100%|██████████| 27/27 [04:30<00:00, 10.00s/it]
100%|██████████| 7/7 [01:03<00:00,  9.06s/it]


In [2]:
def decode(serialized_example):
    # Decode examples stored in TFRecord
    # NOTE: make sure to specify the correct dimensions for the images
    features = tf.io.parse_single_example(
        serialized_example,
        features={'image': tf.io.FixedLenFeature([64, 64, 64, 4], tf.float32),
                  'mask': tf.io.FixedLenFeature([64, 64, 64, 4], tf.float32)})

    # NOTE: No need to cast these features, as they are already `tf.float32` values.
    return features['image'], features['mask']

In [14]:
def run_model(batch_size, pocket):
    
    def get_dataset(tfrecords_file, batch_size):
        ds = tf.data.TFRecordDataset(tfrecords_file).map(decode, num_parallel_calls = tf.data.AUTOTUNE)
        ds = ds.shuffle(buffer_size = 25)
        ds = ds.batch(batch_size = batch_size, drop_remainder = True)
        ds = ds.repeat()
        ds = ds.prefetch(tf.data.AUTOTUNE)
        return ds
    
    train = get_dataset('train.tfrecords', batch_size)    
    val = get_dataset('val.tfrecords', batch_size)
    
    # Create logs directory based on architecture and batch size
    if pocket:
        print('Running pocket u-net with batch size ' + str(int(batch_size)))
        #logs = 'logs/' + 'pocket_unet_batchsize_' + str(batchSize)
    else:
        print('Running full u-net with batch size ' + str(int(batch_size)))
        #logs = 'logs/' + 'full_unet_batchsize_' + str(batchSize)

    # Create model
#     model = PocketNet(inputShape = (64, 64, 64, 4), 
#                       numClasses = 4, 
#                       mode = 'seg', 
#                       net = 'densenet', 
#                       pocket = pocket, 
#                       initFilters = 16, 
#                       depth = 4)

    model = DenseNet((64, 64, 64, 4), 16, 4, 4, True).get_model()
    
    # Compile model with Dice loss
    model.compile(optimizer = 'adam', loss = [dice_loss_l2_weighted], metrics = [dice_loss_l2])
    
    # Reduce learning rate by 0.5 if validation dice coefficient does not improve after 5 epochs
    reduceLr = ReduceLROnPlateau(monitor = 'val_loss', 
                                 mode = 'min',
                                 factor = 0.5, 
                                 patience = 5, 
                                 min_lr = 0.000001, 
                                 verbose = 1)

    if pocket:
        modelName = 'pocket.h5'
    else:
        modelName = 'full.h5'

    saveBestModel = ModelCheckpoint(filepath = modelName, 
                                    monitor = 'val_loss', 
                                    verbose = 1, 
                                    save_best_only = True)
    
    # Train model
    model.fit(train, 
              epochs = 35, 
              steps_per_epoch = 135, 
              validation_data = val, 
              validation_steps = 35, 
              callbacks = [reduceLr, saveBestModel], 
              verbose = 1) 
    
    ##### END OF FUNCTION #####

In [15]:
run_model(batch_size = 4, pocket = True)

Running pocket u-net with batch size 4
Epoch 1/35

Epoch 00001: val_loss improved from inf to 0.26263, saving model to pocket.h5
Epoch 2/35

Epoch 00002: val_loss did not improve from 0.26263
Epoch 3/35

Epoch 00003: val_loss did not improve from 0.26263
Epoch 4/35

Epoch 00004: val_loss improved from 0.26263 to 0.17074, saving model to pocket.h5
Epoch 5/35

Epoch 00005: val_loss did not improve from 0.17074
Epoch 6/35

Epoch 00006: val_loss did not improve from 0.17074
Epoch 7/35

Epoch 00007: val_loss improved from 0.17074 to 0.15623, saving model to pocket.h5
Epoch 8/35

Epoch 00008: val_loss did not improve from 0.15623
Epoch 9/35

Epoch 00009: val_loss did not improve from 0.15623
Epoch 10/35

Epoch 00010: val_loss improved from 0.15623 to 0.15350, saving model to pocket.h5
Epoch 11/35

Epoch 00011: val_loss improved from 0.15350 to 0.15153, saving model to pocket.h5
Epoch 12/35

Epoch 00012: val_loss improved from 0.15153 to 0.14414, saving model to pocket.h5
Epoch 13/35

Epoch 0

In [5]:
model = load_model('pocket.h5', custom_objects = {'dice_loss_l2': dice_loss_l2, 'dice_loss_l2_weighted': dice_loss_l2_weighted})

In [6]:
def get_strides(dims, patch_size):
    def get_factors(n):    
            return [i for i in range(1, n + 1) if n % i == 0]

    strides = list()
    for dim in dims:
        factors = get_factors(dim - patch_size)
        factors.sort()
        strides.append(np.max([factor for factor in factors if factor < patch_size]))

    return strides

In [12]:
def inference(model, df):
    for i in trange(len(df)):
        patient = df.iloc[i].to_dict()
        mask_info = ants.image_header_info(patient['mask'])
        dims = mask_info['dimensions']
        dims = tuple(int(d) for d in dims)
        mask_labels = [0, 1, 2, 4]
        patch_size = 64
        radius = patch_size // 2

        image_list = list(patient.values())[2:len(patient)]
        image = np.empty((*(dim for dim in dims), len(image_list)))
        for j in range(len(image_list)):
            image_ants = ants.image_read(image_list[j])
            brainmask = ants.get_mask(image_ants, cleanup = 0).numpy()
            image_numpy = image_ants.numpy()
            image_numpy = normalize(image_numpy, brainmask)
            image[..., j] = image_numpy
            
        strides = get_strides(dims, patch_size)
        pred = np.empty((*dims, len(mask_labels)))
        for i in range(0, dims[0] - patch_size + 1, strides[0]):
            for j in range(0, dims[1] - patch_size + 1, strides[1]):
                for k in range(0, dims[2] - patch_size + 1, strides[2]):
                    patch = image[i:(i + patch_size), j:(j + patch_size), k:(k + patch_size), ...]
                    patch = patch.reshape((1, patch_size, patch_size, patch_size, len(mask_labels)))
                    pred_patch = model.predict(patch)
                    pred[i:(i + patch_size), j:(j + patch_size), k:(k + patch_size), ...] = pred_patch
                    
        pred = pred.argmax(axis = -1)
        pred[pred == 3] = 4
        
        original = ants.image_read(patient['mask'])
        pred = original.new_image_like(data = pred.astype(np.float32))
        pred_name = patient['id'] + '_pred_W.nii.gz'
        ants.image_write(pred, pred_name)
                    
    return pred

In [13]:
pred = inference(model, val)

100%|██████████| 7/7 [01:34<00:00, 13.56s/it]


In [18]:
val

Unnamed: 0,id,mask,t1,t2,tc,fl
0,W8,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W8/W8_1...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W8/W8_1...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W8/W8_1...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W8/W8_1...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W8/W8_1...
1,W6,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W6/W6_1...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W6/W6_1...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W6/W6_1...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W6/W6_1...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W6/W6_1...
2,W2,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W2/W2_1...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W2/W2_1...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W2/W2_1...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W2/W2_1...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W2/W2_1...
3,W18,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W18/W18...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W18/W18...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W18/W18...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W18/W18...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W18/W18...
4,W7,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W7/W7_1...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W7/W7_1...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W7/W7_1...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W7/W7_1...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W7/W7_1...
5,W55,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W55/W55...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W55/W55...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W55/W55...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W55/W55...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W55/W55...
6,W30,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W30/W30...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W30/W30...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W30/W30...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W30/W30...,/rsrch1/ip/aecelaya/data/ivygap/IvyGap/W30/W30...
