In [None]:
import os
import numpy as np
from glob import glob
import matplotlib.pyplot as plt
import tensorflow as tf
import os
os.environ["SM_FRAMEWORK"] = "tf.keras"
import segmentation_models as sm
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
import cv2
import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import Dense, BatchNormalization, GlobalAveragePooling2D, Conv2D, UpSampling2D, Input, Dropout, Conv2DTranspose, MaxPooling2D, Flatten, Activation
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split, StratifiedKFold
import albumentations as A
from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix, roc_auc_score
import gc

In [None]:
image_size = (224, 224)
batch_size = 8
n_classes = 1
activation = 'sigmoid'

# Data

In [None]:
all_no_cancer = glob("../data/0/imgs/*")
all_cancer = glob("../data/1/imgs/*")

In [None]:
print(f"Number of cancer: {len(all_cancer)}")
print(f"NUmber of no cancer: {len(all_no_cancer)}")

In [None]:
train_size = 0.8
cancer_train_paths, cancer_test_paths = train_test_split(all_cancer, train_size=train_size, random_state=42)
no_cancer_train_paths, no_cancer_test_paths = train_test_split(all_no_cancer, train_size=train_size, random_state=42)

train_paths = cancer_train_paths + no_cancer_train_paths
test_paths = cancer_test_paths + no_cancer_test_paths

train_labels = [1] * len(cancer_train_paths) + [0] * len(no_cancer_train_paths)
test_labels = [1] * len(cancer_test_paths) + [0] * len(no_cancer_test_paths)

train_paths = np.array(train_paths)
train_labels = np.array(train_labels)
test_paths = np.array(test_paths)
test_labels = np.array(test_labels)

random_state = 1
np.random.seed(random_state)
indices = np.random.permutation(len(train_paths))

train_paths = train_paths[indices]
train_labels = train_labels[indices]

k = 5
skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=42)

folds_train_paths = []
folds_val_paths = []
folds_train_labels = []
folds_val_labels = []

for fold, (train_index, val_index) in enumerate(skf.split(train_paths, train_labels)):
    folds_train_paths.append([train_paths[i] for i in train_index])
    folds_val_paths.append([train_paths[i] for i in val_index])
    
    folds_train_labels.append([train_labels[i] for i in train_index])
    folds_val_labels.append([train_labels[i] for i in val_index])

In [None]:
folds_train_mask_paths = []
folds_val_mask_paths = []
test_mask_paths = []

for fold in folds_train_paths:
    aux = []
    for i in fold:
        aux.append(i.replace('imgs', 'masks'))
    folds_train_mask_paths.append(aux)

for fold in folds_val_paths:
    aux = []
    for i in fold:
        aux.append(i.replace('imgs', 'masks'))
    folds_val_mask_paths.append(aux)


for i in test_paths:
    test_mask_paths.append(i.replace('imgs', 'masks'))

In [None]:
print("Number of samples (train): ", len(folds_train_mask_paths[0]))
print("Number of samples (validation): ", len(folds_val_mask_paths[0]))
print("Number of samples (test): ", len(test_mask_paths))

# Dataloader

In [None]:
def load_image(path, target_size=image_size):
    img = np.load(path)
    img = cv2.resize(img, target_size)
    img = ((img - img.min()) / (img.max() - img.min())) * 255
    img = np.stack((img,) * 3, axis=-1)

    return img.astype(np.uint8)


MASK_CONST_ZEROS = np.zeros((image_size[0], image_size[1], 1), dtype='float32')

def load_mask(path, label, target_size=image_size):
    if label == 0:
        return MASK_CONST_ZEROS
        
    mask = cv2.resize(np.load(path), target_size)
    mask = (mask > 0).astype('int32')
    mask = mask[:,:,0]

    return np.expand_dims(mask, axis=-1)

In [None]:
class Dataset:
    CLASSES = ['no_cancer', 'cancer']
    
    def __init__(
            self, 
            images_fps, 
            masks_fps, 
            class_values, 
            augmentation=None, 
            preprocessing=None,
    ):
        self.images_fps = images_fps
        self.masks_fps = masks_fps
        self.class_values = class_values
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):        
        image = load_image(self.images_fps[i])
        label = self.class_values[i]
        mask = load_mask(self.masks_fps[i], label)
        
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

            
        return image, mask, float(np.max(mask))
        
    def __len__(self):
        return len(self.images_fps)
    
    
class Dataloder(tf.keras.utils.Sequence):
    def __init__(self, dataset, batch_size=1, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indexes = np.arange(len(dataset))

        self.on_epoch_end()

    def __getitem__(self, i):        
        start = i * self.batch_size
        stop = (i + 1) * self.batch_size
        data = []
        for j in range(start, stop):
            data.append(self.dataset[j])
        
        batch = [np.stack(samples, axis=0) for samples in zip(*data)]
        
        return {'input_image': batch[0]}, {'classification_output': batch[2], 'sigmoid': batch[1]}
    
    def __len__(self):
        return len(self.indexes) // self.batch_size
    
    def on_epoch_end(self):
        if self.shuffle:
            self.indexes = np.random.permutation(self.indexes)   

# Augmentations

In [None]:
def round_clip_0_1(x, **kwargs):
    return x.round().clip(0, 1)

def get_image_pattern_augmentation():
    pattern_transform = [
        A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.3),
        A.GaussianBlur(blur_limit=(3, 5), p=0.3),
        A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.3),
        A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
    ]
    return A.Compose(pattern_transform)

def get_geometric_augmentation():
    pattern_transform = [
        A.HorizontalFlip(p=0.3),
        A.RandomSizedCrop(min_max_height=(int(image_size[0] * 0.5), int(image_size[0])), height=int(image_size[0]), width=int(image_size[0]), p=0.3),
        A.Rotate(limit=30, p=0.3),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.0, rotate_limit=0, p=0.3),
    ]
    return A.Compose(pattern_transform)

def get_training_augmentation():
    pattern_augmentations = get_image_pattern_augmentation()
    geometric_augmentations = get_geometric_augmentation()

    return A.Compose([
        pattern_augmentations,
        geometric_augmentations,
    ])

def get_preprocessing(preprocessing_fn):
    _transform = [
        A.Lambda(image=preprocessing_fn),
    ]
    return A.Compose(_transform)

def custom_preprocessing(image, mask):
    return {
        image: image/255,
        mask: mask
    }

# Config Sets

In [None]:
def get_sets(backbone_name, FOLD_I = 0, preprocessing_function=None):
    train_paths = folds_train_paths[FOLD_I]
    train_mask_paths = folds_train_mask_paths[FOLD_I]
    train_labels = folds_train_labels[FOLD_I]

    val_paths = folds_val_paths[FOLD_I]
    val_mask_paths = folds_val_mask_paths[FOLD_I]
    val_labels = folds_val_labels[FOLD_I]
    
    if (preprocessing_function == None):
        preprocessing_function = get_preprocessing(sm.get_preprocessing(backbone_name))

    train_dataset = Dataset(train_paths, 
                            train_mask_paths, 
                            train_labels,
        augmentation=get_training_augmentation(),
        preprocessing=preprocessing_function,
    )

    valid_dataset = Dataset(val_paths, 
                            val_mask_paths, 
                            val_labels, 
        augmentation=None,
        preprocessing=preprocessing_function,
    )

    test_dataset = Dataset(test_paths, 
                            test_mask_paths, 
                            test_labels, 
        augmentation=None,
        preprocessing=preprocessing_function,
    )
    
    train_dataloader = Dataloder(train_dataset, batch_size=batch_size, shuffle=True)
    valid_dataloader = Dataloder(valid_dataset, batch_size=batch_size, shuffle=False)

    return train_dataset, train_dataloader, valid_dataset, valid_dataloader, test_dataset

# Metrics

In [None]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()
    
# helper function for data visualization    
def denormalize(x):
    """Scale image to range 0..1 for correct plot"""
    x_max = np.percentile(x, 98)
    x_min = np.percentile(x, 2)    
    x = (x - x_min) / (x_max - x_min)
    x = x.clip(0, 1)
    return x

In [None]:
def plot_some_images(model, test_dataset):
    ids = [0, 50, 150, 250]

    for i in ids:
        image, gt_mask, label = test_dataset[i]
        image = np.expand_dims(image, axis=0)
        predicted = model.predict(image)
        pr_mask = predicted[1][0]

        print("label: " + str(label))
        print("predicted: " + str(predicted[0]))

        fig, axs = plt.subplots(1, 3, figsize=(18, 6))

        axs[0].imshow(denormalize(image.squeeze()), cmap='gray')
        axs[0].set_title('Image')
        axs[0].axis('off')  

        axs[1].imshow(gt_mask[..., 0].squeeze(), cmap='gray')
        axs[1].set_title('Ground Truth Mask')
        axs[1].axis('off')

        cax = axs[2].imshow(pr_mask[..., 0].squeeze(), cmap='hot', vmin=0, vmax=1)
        axs[2].set_title('Predicted Mask')
        axs[2].axis('off')
        fig.colorbar(cax, ax=axs[2])

        bbox = axs[2].get_position()
        axs[2].set_position([bbox.x0, bbox.y0, bbox.width * 1.5, bbox.height * 1.5])

        plt.tight_layout()
        plt.show()

def print_test_metrics(model, test_dataset):
    test_dataloader = Dataloder(test_dataset, batch_size=len(test_dataset), shuffle=False)

    data_generator = test_dataloader
    test_preds = model.predict(data_generator[0][0]['input_image'])
    test_true_labels = data_generator[0][1]['classification_output']

    test_pred_labels = (test_preds[0] > 0.5).astype(int)

    accuracy = accuracy_score(test_true_labels, test_pred_labels)
    matriz_confusao = confusion_matrix(test_true_labels, test_pred_labels)
    auc = roc_auc_score(test_true_labels, test_preds[0])

    print("accuracy: "+str(accuracy))
    print("auc: "+str(auc))
    print("cm: ")
    print(matriz_confusao)

    test_true_masks = data_generator[0][1]['sigmoid']
    test_pred_masks = test_preds[1]

    print(sm.metrics.IOUScore(threshold=0.5)(test_true_masks, test_pred_masks))

# Model

In [None]:
BACKBONES_TO_TEST = [
    {
        'name': 'resnet152',
        'layer': 'relu1'
    },
    {
        'name': 'seresnext50',
        'layer': 'activation_80'
    },
    {
        'name': 'seresnet152',
        'layer': 'activation_250'
    },
    {
        'name': 'resnext101',
        'layer': 'stage4_unit3_relu'
    },
    {
        'name': 'seresnext101',
        'layer': 'activation_165'
    },
    {
        'name': 'senet154',
        'layer': 'activation_252'
    },
    {
        'name': 'densenet201',
        'layer': 'relu'
    },
    {
        'name': 'inceptionresnetv2',
        'layer': 'conv_7b_ac'
    },
    {
        'name': 'mobilenetv2',
        'layer': 'out_relu'
    },
    {
        'name': 'efficientnetb7',
        'layer': 'top_activation'
    },
    {
        'name': 'vgg19',
        'layer': 'center_block2_relu'
    }
]

def get_backbone_by_name(name):
    for backbone in BACKBONES_TO_TEST:
        if (backbone['name'] == name):
            return backbone

def create_model(backbone):
    model_segmentation = sm.Unet(backbone['name'], classes=n_classes, activation=activation, input_shape=(224, 224, 3), encoder_freeze=False)
    model_segmentation_partial = Model(inputs=model_segmentation.input, outputs=model_segmentation.get_layer(backbone['layer']).output) 

    x = GlobalAveragePooling2D(name="branch_class_1")(model_segmentation_partial.output)    
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.2)(x)
    classification_output = Dense(1, activation='sigmoid', name='classification_output')(x)

    multitask_model = Model(inputs=model_segmentation_partial.input, outputs=[classification_output, model_segmentation.output])
    
    multitask_model.layers[0]._name = 'input_image'
    
    return multitask_model

# Train

In [None]:
historys = []
for fold_i in range(5):
    print(f"============ INIT FOLD {fold_i} ===============")
    backbone = get_backbone_by_name('seresnext101')
    train_dataset, train_dataloader, valid_dataset, valid_dataloader, test_dataset = get_sets(backbone['name'], fold_i)

    print(f"Range: [{np.min(test_dataset[0][0])} , {np.max(test_dataset[0][0])}]")

    try:
        del multitask_model
    except:
        pass
    tf.keras.backend.clear_session()
    gc.collect()

    multitask_model = create_model(backbone)

    import os 
    try:
        os.remove('/kaggle/working/multi_task_all.h5')
    except:
        pass

    optim_zero = tf.keras.optimizers.Adam()
    dice_loss = sm.losses.DiceLoss()
    focal_loss = sm.losses.BinaryFocalLoss()
    total_loss = dice_loss + (1 * focal_loss)
    classification_loss = sm.losses.BinaryFocalLoss(alpha=0.6, gamma=2.0)
    
    metrics = {'classification_output': ['accuracy'], 'sigmoid': [sm.metrics.IOUScore(threshold=0.5)]}
    loss_weights = {'classification_output': 1, 'sigmoid': 1}
    loss = {'classification_output': classification_loss, 'sigmoid': total_loss}

    multitask_model.compile(optim_zero, loss_weights=loss_weights, loss=loss, metrics=metrics)
    
    callbacks = [
        tf.keras.callbacks.ModelCheckpoint('./best_model.h5', save_weights_only=True, save_best_only=True, mode='min'),
        tf.keras.callbacks.ReduceLROnPlateau(),
        tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, verbose=1, mode='min')
    ]

    history = multitask_model.fit(
            train_dataloader,
            steps_per_epoch=len(train_dataloader),
            epochs=500,
            validation_data=valid_dataloader,
            validation_steps=len(valid_dataloader),
            callbacks=callbacks
    )
    historys.append(history)
    

# Test

In [None]:
multitask_model.load_weights('/kaggle/working/best_model.h5')

In [None]:
plot_some_images(multitask_model, test_dataset)
print_test_metrics(multitask_model, test_dataset)