# MRI Brain Cancer Segmentation with UNet

In [1]:
import pandas as pd
import cv2
import os
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

In [2]:
# utils
image_types = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")


def list_images(base_path,contains=None):
    # return the set of files that are valid
    return list_files(base_path, valid_exts=image_types, contains=contains)

def list_files(base_path, valid_exts=None, contains=None):
    # loop over the directory structure
    for (root_dir, dir_names, filenames) in os.walk(base_path):
        # loop over the filenames in the current directory
        for filename in filenames:
            # if the contains string is not none and the filename does not contain
            # the supplied string, then ignore the file
            if contains is not None and filename.find(contains) == -1:
                continue

            # determine the file extension of the current file
            ext = filename[filename.rfind("."):].lower()

            # check to see if the file is an image and should be processed
            if valid_exts is None or ext.endswith(valid_exts):
                # construct the path to the image and yield it
                image_path = os.path.join(root_dir, filename)
                yield image_path

In [3]:
DATA_PATH = "/kaggle/input/lgg-mri-segmentation/kaggle_3m/"

TRAIN_SPLIT=0.8
VAL_SPLIT=0.1  # % from training dataset

IMG_SIZE=(256,256)

# Exploring dataset

In [4]:
image_paths=list(list_images(DATA_PATH))

dataset=pd.DataFrame(image_paths, columns=['filepath'])

print(dataset.filepath[0])
print(dataset.filepath[100])

In [5]:
# check if filepath contains "mask"
images=dataset[~dataset["filepath"].str.contains("mask")]
masks=dataset[dataset["filepath"].str.contains("mask")]

# Sorting images
base_len=len('/kaggle/input/lgg-mri-segmentation/kaggle_3m/TCGA_CS_4941_19960909/TCGA_CS_4941_19960909_')
img_len=len('.tif')
mask_len=len('_mask.tif')

images=sorted(images.filepath.values,key=lambda p : int(p[base_len:-img_len]))
masks=sorted(masks.filepath.values,key=lambda p : int(p[base_len:-mask_len]))
print(images[150])
print(masks[150])


In [6]:
patients=[i.split('/')[5] for i in images]
print(patients[2])

In [7]:
mask1=cv2.imread(masks[0])
mask2=cv2.imread(masks[1500])
fig,ax=plt.subplots(1,2)
ax[0].imshow(mask1)
ax[1].imshow(mask2)

In [8]:

def diagnosis(path):
    '''Mask with fill with zeros is negative. Mask with a region of ones is positive.'''
    value = np.max(cv2.imread(path))
    if value > 0 : return 1
    else: return 0
    
dataset=pd.DataFrame({
    "patient": patients,
    "image_path":images,
    "mask_path": masks,
})

dataset["diagnosis"]=dataset["mask_path"].apply(lambda path:diagnosis(path))

In [9]:
print(dataset.head())
print()
print(dataset.info())
print()
print("Number of negative (0) and positive (1) cases:")
print(dataset.diagnosis.value_counts())

In [10]:
# Plot some images and mask
positive=dataset[dataset["diagnosis"]==1].iloc[100]
negative=dataset[dataset["diagnosis"]==0].iloc[100]

fig, (ax1, ax2)=plt.subplots(2,3,figsize=(25., 25.))
img_pos=cv2.imread(positive["image_path"])
mask_pos=cv2.imread(positive["mask_path"])
ax1[0].imshow(img_pos)
ax1[1].imshow(img_pos[:,:,0], cmap='hot')
ax1[2].imshow(mask_pos)

img_neg=cv2.imread(negative["image_path"])
mask_neg=cv2.imread(negative["mask_path"])
ax2[0].imshow(img_neg)
ax2[1].imshow(img_neg[:,:,0], cmap='hot')
ax2[2].imshow(mask_neg)

    
print(img_pos.shape)

In [11]:
#compute the training and testing split
train,test=train_test_split(dataset, train_size=TRAIN_SPLIT,stratify=dataset["diagnosis"])
train = train.reset_index(drop=True)
test = test.reset_index(drop=True)

#split for validation
train,val=train_test_split(train, test_size=VAL_SPLIT,stratify=train["diagnosis"])
train = train.reset_index(drop=True)
val = val.reset_index(drop=True)

TRAIN_SIZE=len(train)
VAL_SIZE=len(val)
TEST_SIZE=len(test)
print("Training dataset size: {}".format(TRAIN_SIZE))
print("Validation dataset size: {}".format(VAL_SIZE))
print("Testing dataset size: {}".format(TEST_SIZE))

In [12]:
def dataloader(dataframe, batch_size, target_size, aug_params):
    SEED=5
    
    img_datagenerator=ImageDataGenerator(**aug_params)
    mask_datagenerator=ImageDataGenerator(**aug_params)
      
    image_generator = img_datagenerator.flow_from_dataframe(
        dataframe,
        x_col = "image_path",
        class_mode = None,
        color_mode = 'rgb',
        target_size = target_size,
        batch_size = batch_size,
        save_to_dir = None,
        seed = SEED)

    mask_generator = mask_datagenerator.flow_from_dataframe(
        dataframe,
        x_col = "mask_path",
        class_mode = None,
        color_mode = 'grayscale',
        target_size = target_size,
        batch_size = batch_size,
        save_to_dir = None,
        seed = SEED)

    gen = zip(image_generator, mask_generator)
    
    for (img, mask) in gen:
        img = img / 255.
        mask = mask / 255.
        mask[mask > 0.5] = 1
        mask[mask <= 0.5] = 0
        yield (img,mask)
        


In [13]:
BATCH_SIZE=32
EPOCHS=80

In [14]:
aug_params=dict(
    rotation_range=0.1,
    width_shift_range=0.05,
    height_shift_range=0.05,
    shear_range=0.05,
    zoom_range=0.05,
    horizontal_flip=True,
    vertical_flip=True,
    fill_mode='nearest'
)

train_dataloader=dataloader(train,BATCH_SIZE,IMG_SIZE,aug_params)
val_dataloader=dataloader(val,BATCH_SIZE,IMG_SIZE,dict())
test_dataloader=dataloader(test,BATCH_SIZE,IMG_SIZE,dict())

# Modeling

## UNet Model

In [15]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Activation, Conv2DTranspose
from tensorflow.keras.layers import concatenate, BatchNormalization
from tensorflow.keras import backend as K


class UNet():
    @staticmethod
    def build(width, height, depth):
        input_shape=(height,width,depth)
        channels_dim=-1
        if K.image_data_format=='channels_first':
            input_shape=(depth,height,width)
            channels_dim=1
        
        inputs=Input(input_shape)
        conv1 = Conv2D(64, (3, 3), padding='same')(inputs)
        a1 = Activation('relu')(conv1)
        bn1 = BatchNormalization(axis=channels_dim)(a1)
        conv1 = Conv2D(64, (3, 3), padding='same')(bn1)
        a1 = Activation('relu')(conv1)
        bn1 = BatchNormalization(axis=channels_dim)(a1)
        pool1 = MaxPooling2D(pool_size=(2, 2))(bn1)

        conv2 = Conv2D(128, (3, 3), padding='same')(pool1)
        a2 = Activation('relu')(conv2)
        bn2 = BatchNormalization(axis=channels_dim)(a2)
        conv2 = Conv2D(128, (3, 3), padding='same')(bn2)
        a2 = Activation('relu')(conv2)
        bn2 = BatchNormalization(axis=channels_dim)(a2)
        pool2 = MaxPooling2D(pool_size=(2, 2))(bn2)

        conv3 = Conv2D(256, (3, 3), padding='same')(pool2)
        a3 = Activation('relu')(conv3)
        bn3 = BatchNormalization(axis=channels_dim)(a3)
        conv3 = Conv2D(256, (3, 3), padding='same')(bn3)
        a3 = Activation('relu')(conv3)
        bn3 = BatchNormalization(axis=channels_dim)(a3)
        pool3 = MaxPooling2D(pool_size=(2, 2))(bn3)

        conv4 = Conv2D(512, (3, 3), padding='same')(pool3)
        a4 = Activation('relu')(conv4)
        bn4 = BatchNormalization(axis=channels_dim)(a4)
        conv4 = Conv2D(512, (3, 3), padding='same')(bn4)
        a4 = Activation('relu')(conv4)
        bn4 = BatchNormalization(axis=channels_dim)(a4)
        pool4 = MaxPooling2D(pool_size=(2, 2))(bn4)

        conv5 = Conv2D(1024, (3, 3), padding='same')(pool4)
        a5 = Activation('relu')(conv5)
        bn5 = BatchNormalization(axis=channels_dim)(a5)
        conv5 = Conv2D(1024, (3, 3), padding='same')(bn5)
        a5 = Activation('relu')(conv5)
        bn5 = BatchNormalization(axis=channels_dim)(a5)
        
        up6 = concatenate([Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(bn5), conv4], axis=channels_dim)
        conv6 = Conv2D(512, (3, 3), padding='same')(up6)
        a6 = Activation('relu')(conv6)
        bn6 = BatchNormalization(axis=channels_dim)(a6)
        conv6 = Conv2D(512, (3, 3), padding='same')(bn6)
        a6 = Activation('relu')(conv6)
        bn6 = BatchNormalization(axis=channels_dim)(a6)

        up7 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(bn6), conv3], axis=channels_dim)
        conv7 = Conv2D(256, (3, 3), padding='same')(up7)
        a7 = Activation('relu')(conv7)
        bn7 = BatchNormalization(axis=channels_dim)(a7)
        conv7 = Conv2D(256, (3, 3), padding='same')(bn7)
        a7 = Activation('relu')(conv7)
        bn7 = BatchNormalization(axis=channels_dim)(a7)
        
        up8 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(bn7), conv2], axis=channels_dim)
        conv8 = Conv2D(128, (3, 3), padding='same')(up8)
        a8 = Activation('relu')(conv8)
        bn8 = BatchNormalization(axis=channels_dim)(a8)
        conv8 = Conv2D(128, (3, 3), padding='same')(bn8)
        a8 = Activation('relu')(conv8)
        bn8 = BatchNormalization(axis=channels_dim)(a8)
        
        up9 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(bn8), conv1], axis=channels_dim)
        conv9 = Conv2D(64, (3, 3), padding='same')(up9)
        a9 = Activation('relu')(conv9)
        bn9 = BatchNormalization(axis=channels_dim)(a9)
        conv9 = Conv2D(64, (3, 3), padding='same')(bn9)
        a9 = Activation('relu')(conv9)
        bn9 = BatchNormalization(axis=channels_dim)(a9)  

        conv10 = Conv2D(1, (1, 1), activation='sigmoid')(bn9)

        return Model(inputs=[inputs], outputs=[conv10])


In [16]:
epsilon=100

def dice_coeff(y_true, y_pred, epsilon=epsilon): 
    #flatten label and prediction tensors
    y_pred = K.flatten(y_pred)
    y_true = K.flatten(y_true)
    
    intersection = K.sum(y_true*y_pred)
    dice = (2*intersection + epsilon) / (K.sum(y_true) + K.sum(y_pred) + epsilon)
    return dice

def soft_dice_loss(y_true, y_pred):
    return 1 - dice_coeff(y_true, y_pred)

def iou(y_true, y_pred, epsilon=epsilon):
    #flatten label and prediction tensors
    y_pred = K.flatten(y_pred)
    y_true = K.flatten(y_true)
    
    intersection = K.sum(y_true*y_pred)
    total = K.sum(y_true) + K.sum(y_pred)
    union = total - intersection
    
    iou = (intersection + epsilon) / (union + epsilon)
    return iou

def iou_loss(y_true, y_pred):
    return 1 - iou(y_true, y_pred)

In [17]:
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

model = UNet.build(IMG_SIZE[0],IMG_SIZE[1],3)

LEARNING_RATE=1e-3

opt = Adam(learning_rate=LEARNING_RATE, decay=LEARNING_RATE / EPOCHS)
model.compile(optimizer=opt, loss=iou_loss, metrics=[iou,dice_coeff])


checkpoint=ModelCheckpoint(
        filepath='/kaggle/working/weights.unet_best.hdf5', 
        monitor='val_loss',
        verbose=1, 
        save_best_only=True
)
stopper=EarlyStopping(monitor='val_loss', restore_best_weights=True, patience=15)
reducer=ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, verbose=1, min_lr=1e-11)


In [18]:
history=model.fit(
    train_dataloader,
    steps_per_epoch=TRAIN_SIZE//BATCH_SIZE,
    validation_data=val_dataloader,
    validation_steps=VAL_SIZE//BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=[checkpoint,stopper,reducer]
)

In [None]:
# Load the model with the best weights
''''
model = UNet.build(IMG_SIZE[0],IMG_SIZE[1],3)
model.load_weights('/kaggle/working/weights.unet_best.hdf5')
opt = Adam(learning_rate=LEARNING_RATE, decay=LEARNING_RATE / EPOCHS)
model.compile(optimizer=opt, loss=iou_loss, metrics=[iou,dice_coeff])
'''

In [48]:
from math import floor

def plot_model_performance(history, loss, metrics):
    plt.style.use("ggplot")
    n_epochs_trained=len(history.history['loss'])
    metrics=[(m,'val_'+m) for m in metrics]
    metrics.insert(0, ('loss','val_loss'))
    row=1
    col=2
    number_of_plots=len(metrics)
    if number_of_plots%2==0:
        row=number_of_plots/2
    else:
        row=floor(number_of_plots/2)+1
    
    figure, axis = plt.subplots(row,col,figsize=(16., 10.))
    if number_of_plots%2!=0:
        figure.delaxes(axis[row-1,col-1])
        
    for r in range(row):
        for c in range(col):
            axis[r,c].plot(np.arange(0, n_epochs_trained), history.history[metrics[c+2*r][0]], label="train_"+metrics[c+2*r][0])
            axis[r,c].plot(np.arange(0, n_epochs_trained), history.history[metrics[c+2*r][1]], label=metrics[c+2*r][1])
            if metrics[c+2*r][0]=='loss':
                axis[r,c].set_title("Train vs Validation {} loss".format(loss), fontsize = 15)
            else:
                axis[r,c].set_title("Train vs Validation {}".format(metrics[c+2*r][0]), fontsize = 15)

            axis[r,c].set_xlabel("Epoch #")
            axis[r,c].set_ylabel(metrics[c+2*r][0])
            axis[r,c].legend(loc="lower left")
            if 1+c+r*2>=number_of_plots:
                break
    
    figure.tight_layout(pad=3.0)
    plt.show()

In [51]:
%matplotlib inline

results = model.evaluate(test_dataloader, steps=TEST_SIZE // BATCH_SIZE)
print()
print("Mean loss: {}".format(results[0]))
print("Mean IOU score: {}".format(results[1]))
print("Mean dice coeff: {}".format(results[2]))
print()
plot_model_performance(history,'iou',["iou","dice_coeff"])

# Visualizing Results

In [50]:
NUMBER_OF_IMAGES=10

for i in range(NUMBER_OF_IMAGES):
    index=np.random.randint(1,len(test.index))
    img = cv2.imread(test['image_path'].iloc[index])
    img = cv2.resize(img ,IMG_SIZE)
    img = img / 255
    img = img[np.newaxis, :, :, :]
    pred=model.predict(img)

    plt.figure(figsize=(12,12))
    plt.subplot(1,3,1)
    plt.imshow(np.squeeze(img))
    plt.title('Original Image')
    plt.subplot(1,3,2)
    plt.imshow(np.squeeze(cv2.imread(test['mask_path'].iloc[index])))
    plt.title('Original Mask')
    plt.subplot(1,3,3)
    plt.imshow(np.squeeze(pred) > .5)
    plt.title('Prediction')
    plt.show()

# Using the model as a tool

In [91]:
index=np.random.randint(1,len(test.index))
img=cv2.imread(test["image_path"].iloc[index])
img=cv2.resize(img,IMG_SIZE)
img=img/255
img = img[np.newaxis, :, :, :]
pred=model.predict(img)
pred=np.squeeze(pred) > .5
mask=pred!=0
img_mask=np.array(img)
img_mask[0,mask,:]=(1,0,0)

figure=plt.figure(figsize=(12,12))
plt.subplot(1,4,1)
plt.imshow(np.squeeze(img))
plt.title('Original Image')
plt.subplot(1,4,2)
plt.imshow(np.squeeze(cv2.imread(test['mask_path'].iloc[index])))
plt.title('Original Mask')
plt.subplot(1,4,3)
plt.imshow(pred)
plt.title('Prediction')
plt.subplot(1,4,4)
plt.imshow(np.squeeze(img_mask))
plt.title('Original with predicted mask')

plt.show()