In [1]:
import sys
from preprocessing import *
import numpy as np
from utils.gpu import set_memory_growth
from model import create_model
from utils import load_points, segment_image, load_itk, world_to_voxel, create_fake_images
from os.path import join
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import LearningRateScheduler
import math

In [27]:
training_dir = "./training"


def get_batch(dataset_idx, vessel_idx, batch_size=32):
    """
    :param dataset_idx: dataset to use
    :param vessel_idx: vessel to use
    :param batch_size: size of the batch gee
    :return: a batch of data from the given dataset and vessel
    """

#     print("-", end="")
    reference_points = load_reference_points("./preprocessing/reference_directions.txt")
    probs, radii, directions, input_data = [], [], [], []

    points_path = join(training_dir, "dataset0%d/vessel%s/reference.txt" % (dataset_idx, str(vessel_idx)))
    points = load_points(points_path)

    image, _, _ = load_itk(join(training_dir, "dataset0%d/image0%d.mhd" % (dataset_idx, dataset_idx)))
    idxs = np.random.randint(300, len(points) - 300, batch_size)
    for idx in idxs:
        radius, direction = create_sample(idx, points, reference_points)

        point = world_to_voxel(points[idx, :3])
        patch = segment_image(image, point).copy()

        if patch.shape == (19, 19, 19):
            input_data.append(patch)
            probs.append(1.)
            radii.append(radius)
            directions.append(direction)

    input_data = np.asarray(input_data).reshape(-1, 19, 19, 19, 1)
    radii = np.asarray(radii).reshape(-1, 1)
    directions = np.asarray(directions).reshape(-1, 500)
    probs = np.asarray(probs).reshape(-1, 1)
    # print(input_data.shape, radii.shape, directions.shape, probs.shape)

    return input_data, [probs, radii, directions]


epochs = 5
batch_size = 32
batches_per_epoch = 20
total_iterations = epochs * batch_size * batches_per_epoch
print("epochs: %s\nbatch size: %d\nbatches per epoch: %d\ntotal iterations: %d" % (
    epochs, batch_size, batches_per_epoch, total_iterations))

epochs: 5
batch size: 32
batches per epoch: 20
total iterations: 3200


In [28]:
def step_decay(epoch):
    initial_lrate = 1e-3
    epochs_drop = 10000 / batch_size / batches_per_epoch
    drop = 0.1
    lrate = initial_lrate * math.pow(drop, math.floor((1 + epoch) / epochs_drop))
    return lrate


lrate_callback = LearningRateScheduler(step_decay)
set_memory_growth()

1 Physical GPUs, 1 Logical GPUs


In [None]:
model = create_model()
dataset_idx, image_idx = 0, 0
for e in range(epochs):
    print("Epoch %d/%d\n[" % (e + 1, epochs), end="")
    for b in range(batches_per_epoch):
        print("-", end="")
        X_batch, y_batch = get_batch(dataset_idx % 8, image_idx % 4, batch_size=batch_size)
        model.fit(X_batch, y_batch, verbose=0, epochs=1, callbacks=[lrate_callback])
        dataset_idx += 1
        image_idx += 1
    print("]")

model.evaluate(X_batch, y_batch)
model.save_weights("./models/model1.h5")

Epoch 1/5
[--------------------]
Epoch 2/5
[--------------------]
Epoch 3/5
[--------------------]
Epoch 4/5
[--------------------]
Epoch 5/5
[----------

In [19]:
from time import time
start = time()
get_batch(0, 0)
end = time()

In [22]:
start = time()


reference_points = load_reference_points("./preprocessing/reference_directions.txt")
probs, radii, directions, input_data = [], [], [], []

points_path = join(training_dir, "dataset0%d/vessel%s/reference.txt" % (0, str(0)))
points = load_points(points_path)



point1 = time()



image, _, _ = load_itk(join(training_dir, "dataset0%d/image0%d.mhd" % (0, 0)))



point2 = time()



idxs = np.random.randint(300, len(points) - 300, batch_size)
for idx in idxs:
    p1 = time()
    
    radius, direction = create_sample(idx, points, reference_points)
    
    p2 = time()
    
    point = world_to_voxel(points[idx, :3])
    patch = segment_image(image, point).copy()
    
    p3 = time()

    if patch.shape == (19, 19, 19):
        input_data.append(patch)
        probs.append(1.)
        radii.append(radius)
        directions.append(direction)
        
    p4 = time()
        
        
point3 = time()


input_data = np.asarray(input_data).reshape(-1, 19, 19, 19, 1)
radii = np.asarray(radii).reshape(-1, 1)
directions = np.asarray(directions).reshape(-1, 500)
probs = np.asarray(probs).reshape(-1, 1)


end = time()

In [26]:
(point2 - point1), (point3 - point2), (end - point3)

(0.41507482528686523, 1.582848310470581, 0.0)