In [None]:
import albumentations as A
from datetime import date
import numpy as np
import matplotlib.pyplot as plt
import json
import os
import pickle
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [None]:
def load_data(positive=True):
    data_path = '../../data/training_data/segmentation_masks/data'
    label_path = '../../data/training_data/segmentation_masks/masks'
    if positive:
        data_files = [f for f in sorted(os.listdir(data_path)) if '.pkl' in f and 'neg' not in f]
        label_files = [f for f in sorted(os.listdir(label_path)) if '.png' in f and 'neg' not in f]
    else:
        data_files = [f for f in sorted(os.listdir(data_path)) if '.pkl' in f and 'neg' in f]
        label_files = [f for f in sorted(os.listdir(label_path)) if '.png' in f and 'neg' in f]

    # Load the data
    data = []
    for f in data_files:
        with open(os.path.join(data_path, f), 'rb') as frame:
            data.append(pickle.load(frame))
    data = np.array(data)
    img_size = data[0].shape

    # load the labels
    labels = []
    for f in label_files:
        labels.append(plt.imread(os.path.join(label_path, f))[:,:,0])
    labels = np.array(labels)
    return data, labels

def augment_data(x, y, iterations=1000):
    # create albumentations augmentation set
    aug = A.Compose([
        A.RandomSizedCrop(min_max_height=(24, 38), height=48, width=48, p=1),
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.25, scale_limit=0.5, rotate_limit=180, interpolation=1, border_mode=4, p=1),
        A.VerticalFlip(p=0.5),
        A.OneOf(
                [A.ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03), 
                A.GridDistortion(p=0.5),A.OpticalDistortion(distort_limit=1, shift_limit=0.5, p=.1)], 
            p=0.8)
    ])

    # It would be nice to have this as a generator passed to the train function
    # This was challenging though. Instead, I'm pregenerating a large dataset
    # and then training on it
    aug_x = []
    aug_y = []
    for _ in range(iterations):
        for image, mask in zip(x, y):
            augmented = aug(image=image, mask=mask)
            aug_x.append(augmented['image'])
            aug_y.append(augmented['mask'])
    aug_x = np.array(aug_x)
    aug_y = np.array(aug_y)
    
    return aug_x, aug_y

In [None]:
positive_data, positive_labels = load_data(positive=True)
negative_data, negative_labels = load_data(positive=False)
aug_pos_data, aug_pos_labels = augment_data(positive_data, positive_labels, iterations=100)
aug_neg_data, aug_neg_labels = augment_data(negative_data, negative_labels, iterations=10)

In [None]:
# prepare data for training
x = np.concatenate((aug_pos_data, aug_neg_data))
x = np.clip(x / 3000, 0, 1)
y = np.expand_dims(np.copy(np.concatenate((aug_pos_labels, aug_neg_labels))).astype(int), axis=-1)

In [None]:
for i in np.random.choice(range(len(x)), 5, replace=False):
    plt.figure(figsize=(10,5), dpi=100)
    plt.subplot(1,2,1)
    plt.imshow(x[i,:,:,3:0:-1])
    plt.axis('off')
    plt.subplot(1,2,2)
    plt.imshow(y[i,:,:])
    plt.axis('off')
    plt.title(i)
    plt.show()

In [None]:
# Create the model. Code from keras example
def get_model(img_size, num_classes):
    inputs = keras.Input(shape=img_size + (12,))
    # max_conv is a variable that sets the max number of filters in a layer
    max_conv = 256
    ### [First half of the network: downsampling inputs] ###

    # Entry block
    x = layers.Conv2D(32, 12, strides=2, padding="same")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    # Blocks 1, 2, 3 are identical apart from the feature depth.
    for filters in [max_conv // 4, max_conv // 2, max_conv]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(filters, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    ### [Second half of the network: upsampling inputs] ###

    for filters in [max_conv, max_conv // 2, max_conv // 4, max_conv // 8]:
        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.UpSampling2D(2)(x)

        # Project residual
        residual = layers.UpSampling2D(2)(previous_block_activation)
        residual = layers.Conv2D(filters, 1, padding="same")(residual)
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    # Add a per-pixel classification layer
    outputs = layers.Conv2D(num_classes, 3, activation="softmax", padding="same")(x)

    # Define the model
    model = keras.Model(inputs, outputs)
    return model


# Free up RAM in case the model definition cells were run multiple times
keras.backend.clear_session()

# Build model
resolution = 48
model = get_model((resolution,resolution), 2)
model.summary()

In [None]:
class DiceLoss(tf.keras.losses.Loss):
    def __init__(self, smooth=1e-6, gamma=2):
        super(DiceLoss, self).__init__()
        self.name = 'NDL'
        self.smooth = smooth
        self.gamma = gamma

    def call(self, y_true, y_pred):
        y_true, y_pred = tf.cast(
            y_true, dtype=tf.float32), tf.cast(y_pred, tf.float32)
        nominator = 2 * \
            tf.reduce_sum(tf.multiply(y_pred, y_true)) + self.smooth
        denominator = tf.reduce_sum(
            y_pred ** self.gamma) + tf.reduce_sum(y_true ** self.gamma) + self.smooth
        result = 1 - tf.divide(nominator, denominator)
        return result

#loss = DiceLoss()
loss = "sparse_categorical_crossentropy"
model.compile(loss=loss, optimizer=tf.keras.optimizers.Adam())
train_loss = []

In [None]:
# Train the model, doing validation at the end of each epoch.
epochs = 3
batch_size = 64
model.fit(x, y, batch_size=batch_size, epochs=epochs, initial_epoch=len(train_loss))
train_loss += model.history.history['loss']

In [None]:
plt.figure(figsize=(8,5), dpi=100, facecolor=(1,1,1))
plt.plot(train_loss, label='Train Acc')
#plt.plot(test_accuracy, c='r', label='Val Acc')
plt.grid()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Network Train Loss')
plt.show()

In [None]:
# save model
version_number = '0.1'
current_date = date.today()
model_name = f"unet_{model.input_shape[1]}px_v{version_number}_{current_date.isoformat()}"

#assert not os.path.exists('../../models/' + model_name + '.h5'), f"Model of name {model_name} already exists"
data_path = '../../data/training_data/segmentation_masks/data'
data_files = [f for f in sorted(os.listdir(data_path)) if '.pkl' in f and 'neg' not in f]
with open('../../models/' + model_name + '_config.txt', 'w') as f:
    f.write('Input Data:\n')
    [f.write('\t' + file + '\n') for file in data_files]
    f.write('\n\nAugmentation Parameters:\n')
    # TODO: aug is only defined within the augmentation function
    for elem in aug._to_dict()['transforms']:
        f.write(f"{json.dumps(elem, indent=True)}")
    #for k, v in zip(aug._to_dict().keys(), aug._to_dict().values()):
    #    f.write(f"\t{k}: {v}\n")
    
    f.write(f"\nBatch Size: {batch_size}")
    f.write(f"\nLoss Function: {loss}")
    f.write(f"\nTraining Epochs: {len(train_loss)}")
    f.write(f"\nFinal Loss: {train_loss[-1]:.3f}")
        

model.save(f'../../models/{model_name}.h5')
print(f"Saved to ../../models/{model_name}.h5'")

In [None]:
# Visualize predictions on test data
samples = np.random.randint(0, len(x), size=10)
for sample_num in samples:
    num_plots = 4
    pred =  model.predict(x[sample_num:sample_num+1])[0,:,:,1]
    plt.figure(figsize=(9,3), dpi=75)
    plt.subplot(1,4,1)
    rgb = x[sample_num, :, :, 3:0:-1]
    plt.axis('off')
    plt.imshow(rgb)
    plt.subplot(1,4,3)
    plt.imshow(y[sample_num], cmap='RdBu_r')
    plt.axis('off')
    plt.subplot(1,4,2)
    combo = np.copy(rgb)
    combo[:,:,0] = combo[:,:,0] + pred
    plt.axis('off')
    plt.imshow(np.clip(combo, 0, 1))
    plt.subplot(1,4,4)
    plt.imshow(pred, cmap='RdBu_r', vmin=0, vmax=1)
    plt.axis('off')
    plt.show()