In [0]:
# Initialization
import numpy as np
seed = 256
np.random.seed(seed)
import tensorflow as tf
tf.random.set_seed(seed)

In [0]:
# Imports
import os
from keras.callbacks import ModelCheckpoint, CSVLogger
from sklearn.utils import class_weight
from keras.optimizers import Adam
from keras.utils import to_categorical

In [0]:
# Google Drive mount
from google.colab import drive
drive.mount('/content/drive')

In [0]:
# Links functions
%run /content/drive/My\ Drive/segmentation/functions.ipynb
%run /content/drive/My\ Drive/segmentation/models.ipynb
%run /content/drive/My\ Drive/segmentation/metrics.ipynb

In [0]:
# Config
num_classes = 26
epochs = 40
batch_size = 1
save_dir = '/content/drive/My Drive/segmentation/saved_models/'
model_name = 'UNet3D-vertebrae'
model_dir = os.path.join(save_dir, model_name)
create_data = False

dim_x = 96
dim_y = 96
dim_z = 128

data_path = '/content/drive/My Drive/segmentation/data'
dim = (dim_x, dim_y, dim_z)

In [0]:
# Create data
if create_data:
  create_train_imgs(data_path, dim)
  create_train_masks(data_path, dim)
  create_test_imgs(data_path, dim)
  create_test_masks(data_path, dim)

In [0]:
# Load data
train_imgs, train_masks = load_train_data(data_path)

In [0]:
# Normalization
train_imgs = train_imgs.astype(np.float32) / 255.
train_masks = to_categorical(train_masks.astype(np.float32), num_classes)

In [0]:
# Checkpoints
model_dir = os.path.join(save_dir, model_name)
if not os.path.exists(model_dir):
    os.mkdir(model_dir)

model_checkpoint = ModelCheckpoint(os.path.join(model_dir, 'weights.h5'), monitor='val_loss', save_best_only=True)
csv_logger = CSVLogger(os.path.join(model_dir, 'logs.txt'), separator=',', append=True)

In [0]:
# Load/compile
model = unet_3d_multiclass((dim_x, dim_y, dim_z, 1),num_classes)
model.compile(optimizer=Adam(lr=0.00001), loss=soft_dice_loss, metrics=[dice_coef])

In [0]:
# Fit model
model.fit(train_imgs, train_masks, batch_size=batch_size, epochs=epochs, verbose=1, shuffle=False, validation_split=0.2, callbacks=[model_checkpoint, csv_logger])

In [0]:
# Plot losses
plot_loss(os.path.join(model_dir, 'logs.txt'))

In [0]:
# Save model 
model.save(os.path.join(model_dir, 'model.h5'))

Predict

In [0]:
# Load test data
test_imgs, test_masks = load_test_data(data_path)

In [0]:
# Normalization
test_imgs = train_imgs.astype(np.float32) / 255.
test_masks = to_categorical(train_masks.astype(np.float32), num_classes)

In [0]:
# Prediction
predicted_masks = model.predict(train_imgs[:1], batch_size=1, verbose=1)

In [0]:
predicted_masks = np.argmax(predicted_masks, -1)

In [0]:
# Save predictions
save_predictions(model_dir, predicted_masks)

In [0]:
predicted_masks = np.expand_dims(predicted_masks, -1)

In [0]:
# Visualize predictions
plot_predictions(0, predicted_masks, test_imgs, np.expand_dims(np.argmax(test_masks, -1), -1))