In [1]:
import os
import random

import numpy as np 
import nibabel as nib
import tensorflow as tf
from sklearn.model_selection import train_test_split

from preprocess.get_subvolume import get_training_sub_volumes
from unet3d import *
from utils import makedirs()

np.set_printoptions(precision=2, suppress=True)

In [2]:
print(tf.__version__)
tf.config.list_physical_devices("GPU") 

2.5.0


[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [3]:
# Crea los directorios si no existen, el unico requisito es que se tenga la DB en una carpeta llamada NFBS_Dataset 
# Dentro del mismo directorio que este codigo
paths = make_dirs()
paths.keys()


dict_keys(['SUBVOLUME_FOLDER', 'SUBVOLUME_MASK_FOLDER', 'DATABASE_DIR', 'SAMPLES'])

In [4]:
train_files, val_files = train_test_split(paths["SAMPLES"], test_size=0.2, random_state=42)
train_images=[]
train_images_mask=[]
''' train_files=train_files[0:2]
val_files=val_files[0:2] '''
for sample in train_files:
    for subvol in sorted(os.listdir(os.path.join(paths["SUBVOLUME_FOLDER"],sample))):
        train_images.append(os.path.join(paths["SUBVOLUME_FOLDER"],sample,subvol))
        

    for subvol in sorted(os.listdir(os.path.join(paths["SUBVOLUME_MASK_FOLDER"],sample))):
        train_images_mask.append(os.path.join(paths["SUBVOLUME_MASK_FOLDER"],sample,subvol))

val_images=[]
val_images_mask=[]
for sample in val_files:
    for subvol in sorted(os.listdir(os.path.join(paths["SUBVOLUME_FOLDER"],sample))):
        val_images.append(os.path.join(paths["SUBVOLUME_FOLDER"],sample,subvol))
        

    for subvol in sorted(os.listdir(os.path.join(paths["SUBVOLUME_MASK_FOLDER"],sample))):
        val_images_mask.append(os.path.join(paths["SUBVOLUME_MASK_FOLDER"],sample,subvol))


In [5]:
len(val_images)

7416

In [6]:
def load_image(file, label):
    nifti = np.asarray(nib.load(file.numpy().decode('utf-8')).get_fdata()).astype(np.int16)
    label = np.asarray(nib.load(label.numpy().decode('utf-8')).get_fdata()).astype(np.int16)
    return nifti, label


@tf.autograph.experimental.do_not_convert
def load_image_wrapper(file, label):
    image, label = tf.py_function(load_image, [file, label], [tf.int16, tf.int16])
    image.set_shape(tf.TensorShape([128, 128, 16]))
    label.set_shape(tf.TensorShape([128, 128, 16]))
    return image, label

train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_images_mask))
train_dataset = train_dataset.map(load_image_wrapper, num_parallel_calls=32)
train_dataset = train_dataset.batch(10, drop_remainder=True)

val_dataset = tf.data.Dataset.from_tensor_slices((val_images, val_images_mask))
val_dataset = val_dataset.map(load_image_wrapper, num_parallel_calls=32)
val_dataset = val_dataset.batch(10, drop_remainder=True)

In [7]:
model_unet=unet_3D(128, 128, 16)#(60,160,16)
model_unet.compile(optimizer='adam', loss = 'binary_crossentropy', #"categorical_crossentropy", 
                    metrics = ['accuracy', dice_coeff]) 

(None, 128, 128, 16, 1) 

(None, 128, 128, 16, 16)
(None, 128, 128, 16, 16)
(None, 64, 64, 8, 16)
(None, 64, 64, 8, 32)
(None, 64, 64, 8, 32)
(None, 32, 32, 4, 32)
(None, 32, 32, 4, 64)
(None, 32, 32, 4, 64)
(None, 16, 16, 2, 64)
(None, 16, 16, 2, 128)
(None, 16, 16, 2, 128)
(None, 8, 8, 1, 128) 

(None, 8, 8, 1, 256)
(None, 8, 8, 1, 256) 

(None, 16, 16, 2, 128)
(None, 16, 16, 2, 256)
(None, 16, 16, 2, 128)
(None, 16, 16, 2, 128)
(None, 32, 32, 4, 64)
(None, 32, 32, 4, 128)
(None, 32, 32, 4, 64)
(None, 32, 32, 4, 64)
(None, 64, 64, 8, 32)
(None, 64, 64, 8, 64)
(None, 64, 64, 8, 32)
(None, 64, 64, 8, 32)
(None, 128, 128, 16, 16)
(None, 128, 128, 16, 32)
(None, 128, 128, 16, 16)
(None, 128, 128, 16, 16) 

(None, 128, 128, 16, 1)


In [8]:
PARENT_DIR = os.getcwd()
callbacks = ModelCheckpoint(PARENT_DIR, 
                            monitor='val_dice_coeff', #accuracy', # val_acc
                            verbose=1, 
                            mode='max',
                            save_best_only=True)

In [9]:
history = model_unet.fit(train_dataset,
                         validation_data=val_dataset,
                         epochs=100,
                         callbacks=[callbacks, tf.keras.callbacks.EarlyStopping(verbose=1, patience=10,min_delta=0.0005, monitor='val_dice_coeff')]) #Guardar la mejor epoca para validación

Epoch 1/100