In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from imgaug import augmenters as iaa
from segmentation_models import Unet, Linknet
from segmentation_models import get_preprocessing
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint, TensorBoard, ReduceLROnPlateau
from keras.utils.training_utils import multi_gpu_model
from keras.models import Model
from .utils import *

In [None]:
# Fold 1.
train_df = pd.read_csv("jsrt/jsrt_fold1/train.csv")
val_df = pd.read_csv("jsrt/jsrt_fold1/val.csv")
test_df = pd.read_csv("jsrt/jsrt_fold1/test.csv")

backbone = "resnext50"
preprocessing_fn = get_preprocessing(backbone)
shape = 512

In [None]:
seq = iaa.Sequential([
    iaa.Fliplr(0.5),

    iaa.OneOf([
        iaa.Affine(
            scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
            rotate=(-15, 15),
            shear=(-10, 10),
        ),
        iaa.Affine(
            translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)},
            shear=(-10, 10),
        )
    ])
], random_order=True)

In [None]:
for i in range(10):
    show_augm(i, train_df, seq,preprocessing_fn)

In [None]:
val_images, val_masks = load_val(val_df, shape, preprocessing_fn)

In [None]:
model = Unet(backbone_name="resnext50", encoder_weights="imagenet", classes=6)
model.summary()

In [None]:
optimizer = Adam(lr=0.001, decay=1e-7, beta_1=0.9, beta_2=0.999)
model.compile(optimizer=optimizer, loss=bc_dice_loss, metrics=[dice_coefficient, "binary_accuracy"])

In [None]:
batch_size = 4
epochs = 30
callbacks = [
    ModelCheckpoint("backup/epoch_{epoch:02d}.hdf5", monitor="val_dice_coef", mode="max", save_weights_only=True, save_best_only=False, verbose=1),
    TensorBoard(log_dir="logs", batch_size=batch_size),
    ReduceLROnPlateau(monitor="val_dice_coef", factor=0.4, patience=2, verbose=1, mode="max", min_lr=0.000000001),
]
model.fit_generator(generator(batch_size, shape, train_df, seq, preprocessing_fn), validation_data=(val_images, val_masks), steps_per_epoch=500, epochs=epochs, callbacks=callbacks)

In [None]:
model.save_weights("backup/final.hdf5")

In [None]:
model.load_weights("backup/" + sorted(os.listdir("backup"))[-1])

In [None]:
test_images, test_masks = load_test(test_df)

In [None]:
test_results = model.predict(np.array(test_images))
test_results[test_results >= 0.5] = 1
test_results[test_results < 0.5] = 0

In [None]:
print("Lungs: " + str(hard_dice(test_results[:, :, :, 0], test_masks[:, :, :, 0]))
print("Heart: " + str(hard_dice(test_results[:, :, :, 2], test_masks[:, :, :, 0]))
print("Clavicles: " + str(hard_dice(test_results[:, :, :, 3], test_masks[:, :, :, 0]))

In [None]:
print("Lungs: " + str(iou(test_results[:, :, :, 0], test_masks[:, :, :, 0]))
print("Heart: " + str(iou(test_results[:, :, :, 2], test_masks[:, :, :, 0]))
print("Clavicles: " + str(iou(test_results[:, :, :, 3], test_masks[:, :, :, 0]))

In [None]:
layer_outputs = [layer.output for layer in model.layers[-16:]]
activation_model = Model(inputs=model.input, outputs=layer_outputs)
activations = activation_model.predict(np.expand_dims(np.array(test_images[0]), 0))
 
def display_activation(activations, col_size, row_size, act_index): 
    activation = activations[act_index]
    activation_index=0
    fig, ax = plt.subplots(row_size, col_size, figsize=(row_size*2.5,col_size*1.5))
    for row in range(0,row_size):
        for col in range(0,col_size):
            ax[row][col].imshow(activation[0, :, :, activation_index], cmap='gray')
            activation_index += 1

In [None]:
plt.figure(figsize=(30, 30))
rows = 4
cols = 4

for i in range(16):
    subplot = plt.subplot(rows, cols, i + 1)
    subplot.axis("off")
    subplot.imshow(activations[-3][0,...,i], cmap="hot")

plt.show()

In [None]:
plt.figure(figsize=(30, 30))
rows = 1
cols = 6

for i in range(6):
    subplot = plt.subplot(rows, cols, i + 1)
    subplot.axis("off")
    subplot.imshow(activations[-2][0,...,i], cmap="hot")

plt.show()

In [None]:
for image, mask in zip(test_images, test_results):
    mask[mask >= 0.5] = 255
    mask[mask < 0.5] = 0
    mask = cv2.cvtColor(mask.astype(np.uint8), cv2.COLOR_GRAY2RGB)
    mask[:, :, 1] = 0
    mask[:, :, 2] = 0
    cv2.addWeighted(mask, 0.4, image, 0.6, 0, image)

In [None]:
plt.figure(figsize=(30, 30))
rows = 6
cols = 4

for i in range(len(test_images[:24])):
    subplot = plt.subplot(rows, cols, i + 1)
    subplot.axis("off")
    subplot.imshow(test_images[i])

plt.show()