In [13]:
import os
import numpy as np
import SimpleITK as sitk
import datetime

In [None]:
x_train = np.load('data/train/x_train.npy')
x_test = np.load('data/train/x_test.npy')
y_train = np.load('data/train/y_train.npy')
y_test = np.load('data/train/y_test.npy')

In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, Dropout, concatenate, UpSampling3D

def unet_3d_model(input_size=(256, 256, 256, 1)):
    # Input layer
    inputs = Input(input_size)

    # Contracting Path
    conv1 = Conv3D(64, 3, activation='relu', padding='same')(inputs)
    conv1 = Conv3D(64, 3, activation='relu', padding='same')(conv1)
    pool1 = MaxPooling3D(pool_size=(2, 2, 2))(conv1)

    conv2 = Conv3D(128, 3, activation='relu', padding='same')(pool1)
    conv2 = Conv3D(128, 3, activation='relu', padding='same')(conv2)
    pool2 = MaxPooling3D(pool_size=(2, 2, 2))(conv2)

    conv3 = Conv3D(256, 3, activation='relu', padding='same')(pool2)
    conv3 = Conv3D(256, 3, activation='relu', padding='same')(conv3)
    pool3 = MaxPooling3D(pool_size=(2, 2, 2))(conv3)

    # Bottom of the U
    conv4 = Conv3D(512, 3, activation='relu', padding='same')(pool3)
    conv4 = Conv3D(512, 3, activation='relu', padding='same')(conv4)

    # Expansive Path
    up5 = concatenate([UpSampling3D(size=(2, 2, 2))(conv4), conv3], axis=-1)
    conv5 = Conv3D(256, 3, activation='relu', padding='same')(up5)
    conv5 = Conv3D(256, 3, activation='relu', padding='same')(conv5)

    up6 = concatenate([UpSampling3D(size=(2, 2, 2))(conv5), conv2], axis=-1)
    conv6 = Conv3D(128, 3, activation='relu', padding='same')(up6)
    conv6 = Conv3D(128, 3, activation='relu', padding='same')(conv6)

    up7 = concatenate([UpSampling3D(size=(2, 2, 2))(conv6), conv1], axis=-1)
    conv7 = Conv3D(64, 3, activation='relu', padding='same')(up7)
    conv7 = Conv3D(64, 3, activation='relu', padding='same')(conv7)

    # Output layer
    outputs = Conv3D(1, 1, activation='sigmoid')(conv7)

    # Create the model
    model = Model(inputs=inputs, outputs=outputs)

    return model

# Create the 3D U-Net model
model = unet_3d_model(input_size=(256, 256, 256, 1))

model.compile(optimizer = 'adam',
              loss = 'mean_squared_error', 
              metrics = ['mae'])

trainstarttime = datetime.datetime.today()
print("")
print("Train start time: " + str(trainstarttime))

history = model.fit(x_train, y_train,
            batch_size = 1,
            epochs = 75,
            verbose = 1,
            validation_data = (x_test, y_test))

trainfinishtime = datetime.datetime.today()
print("Train finish time: " + str(trainfinishtime))
ttime = str(trainfinishtime - trainstarttime)
print("Training time: " + ttime)
print("")

In [None]:
file = 'data/train/models/simple_unet_model'

json_string = model.to_json()
fj = file + '.json'
fh = file + '.h5'
open(fj, 'w').write(json_string)
model.save_weights(fh)

print("Model saved.")