In [None]:
import os

import numpy as np
import tensorflow as tf

In [None]:
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K

from tensorflow.keras.utils import plot_model

In [None]:
os.chdir('..')

In [None]:
from src.volumetric_unet import create_volumetric_unet

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
tf.config.list_physical_devices('GPU')

In [None]:
# Setting seed for reproducibility
os.environ['PYTHONHASHSEED']=str(123)

np.random.seed(123)
tf.random.set_seed(123)

In [None]:
X_train_combined = np.load('./data/X_train_combined.npy')
y_train_combined = np.load('./data/y_train_combined.npy')

In [None]:
X_val = np.load('./data/X_val.npy')
y_val = np.load('./data/y_val.npy')

X_train_combined.shape, y_train_combined.shape, X_val.shape, y_val.shape

In [None]:
sample_size_train = int(0.9 * X_train_combined.shape[0])
sample_size_val = int(0.9 * X_val.shape[0])

indices_train = np.random.choice(X_train_combined.shape[0], sample_size_train, replace=False)
indices_val = np.random.choice(X_val.shape[0], sample_size_val, replace=False)

X_train_sample = X_train_combined[indices_train]
y_train_sample = y_train_combined[indices_train]
X_val_sample = X_val[indices_val]
y_val_sample = y_val[indices_val]

X_train_sample.shape, y_train_sample.shape, X_val_sample.shape, y_val_sample.shape

In [None]:
# Parameters
Ngrid = 32  # grid size

In [None]:
unet_model = create_volumetric_unet((Ngrid, Ngrid, Ngrid, 1), num_classes=1)

In [None]:
# Compile the model
unet_model.compile(optimizer=Adam(),
                   loss='mse',
                   run_eagerly=True)

In [None]:
# Define callbacks
early_stopping = EarlyStopping(monitor='val_loss', patience=5)

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.001)

model_checkpoint = ModelCheckpoint('models/unet/060324_90p_samp_augment.hdf5', 
                                   monitor='val_loss',
                                   verbose=1, 
                                   save_best_only=True)

In [None]:
# Train the model
history = unet_model.fit(x=X_train_sample,
                    y=y_train_sample,
                    validation_data=(X_val_sample, y_val_sample),
                    batch_size=32, 
                    epochs=50,
                    callbacks=[early_stopping, reduce_lr, model_checkpoint])