In [None]:
import os
import numpy as np
import matplotlib .pyplot as plt
from sklearn.model_selection import train_test_split
from skimage.io import imread
from skimage.transform import resize
from skimage.util import img_as_float
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint
from keras.preprocessing.image import ImageDataGenerator
from keras import backend as K
from tensorflow.keras.callbacks import Callback

import tensorflow as tf
import tensorflow.keras.layers as L
from tensorflow.keras.models import Model
import time

tic = time.time()

K.set_image_data_format('channels_last')  # TF dimension ordering in this code

# you can select dataset_no from 1 to 5 to train each fold
dataset_no=1
i=str(dataset_no)

main_path = r"/path/to/dataset/"

data_path =main_path+'\Fold'+ i +'\Train'

patch_size=128
rows = patch_size
cols = patch_size

img_rows = patch_size
img_cols = patch_size

SEED = 42
#Setting the seed ensures that, if you run the data augmentation process multiple times, you'll get the same augmented images in each run, which can be important for reproducibility.

smooth =5.
#############################################################################

def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_loss(y_true, y_pred):
    return -K.log(dice_coef(y_true, y_pred))

def load_train_data():
    train_images_path = os.path.join(data_path, 'Images')
    train_masks_path = os.path.join(data_path, 'Masks')
    images = os.listdir(train_images_path)
    masks = os.listdir(train_masks_path)
    total = len(images)

    rimgs = np.empty((total, rows, cols), dtype=np.float32)
    rmsks = np.empty((total, rows, cols), dtype=np.float32)

    i = 0
    print('Convert training images to arrays')
    print('------------------------------------------')
    for image_name in images:
        img = imread(os.path.join(train_images_path, image_name))
        img = img_as_float(img)
        rimg = resize(img, (rows, cols), preserve_range=True)

        rimgs[i] = rimg

        if i % 10 == 0:
            print('Done: {0}/{1} images'.format(i, total))
        i += 1
    print('Done.')
    print('------------------------------------------')

    i = 0
    print('Convert training masks to arrays')
    print('------------------------------------------')
    for mask_name in masks:
        msk = imread(os.path.join(train_masks_path, mask_name))
        img = img_as_float(img)
        rmsk = resize(msk, (rows, cols), preserve_range=True)

        rmsks[i] = rmsk

        if i % 10 == 0:
            print('Done: {0}/{1} masks'.format(i, total))
        i += 1
    print('Done.')
    print('------------------------------------------')

    return rimgs, rmsks

def my_generator(x_train, y_train, batch_size):
    data_generator = ImageDataGenerator(
            horizontal_flip=True,
            vertical_flip=True,
            width_shift_range=0.1,
            height_shift_range=0.1,
            rotation_range=10,
            zoom_range=0.1,
            fill_mode='nearest').flow(x_train, x_train, batch_size, seed=SEED)

    mask_generator = ImageDataGenerator(
            horizontal_flip=True,
            vertical_flip=True,
            width_shift_range=0.1,
            height_shift_range=0.1,
            rotation_range=10,
            zoom_range=0.1,
            fill_mode='nearest').flow(y_train, y_train, batch_size, seed=SEED)
    while True:
        x_batch, _ = data_generator.next()
        y_batch, _ = mask_generator.next()
        yield x_batch, y_batch

def conv_block(x, num_filters):
    x = L.Conv2D(num_filters, 3, padding="same")(x)
    x = L.BatchNormalization()(x)
    x = L.Activation("relu")(x)

    x = L.Conv2D(num_filters, 3, padding="same")(x)
    x = L.BatchNormalization()(x)
    x = L.Activation("relu")(x)

    return x

def encoder_block(x, num_filters):
    x = conv_block(x, num_filters)
    p = L.MaxPool2D((2, 2))(x)
    return x, p

def attention_gate(g, s, num_filters):
    Wg = L.Conv2D(num_filters, 1, padding="same")(g)
    Wg = L.BatchNormalization()(Wg)

    Ws = L.Conv2D(num_filters, 1, padding="same")(s)
    Ws = L.BatchNormalization()(Ws)

    out = L.Activation("relu")(Wg + Ws)
    out = L.Conv2D(num_filters, 1, padding="same")(out)
    out = L.Activation("sigmoid")(out)

    return out * s

def decoder_block(x, s, num_filters):
    x = L.UpSampling2D(interpolation="bilinear")(x)
    s = attention_gate(x, s, num_filters)
    x = L.Concatenate()([x, s])
    x = conv_block(x, num_filters)
    return x


def attention_unet():
    """ Inputs """
    inputs = L.Input((img_rows, img_cols, 1), name="input_first")

    """ Encoder """
    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)

    b1 = conv_block(p3, 512)

    """ Decoder """
    d1 = decoder_block(b1, s3, 256)
    d2 = decoder_block(d1, s2, 128)
    d3 = decoder_block(d2, s1, 64)

    """ Outputs """
    outputs = L.Conv2D(1, 1, padding="same", activation="sigmoid")(d3)

    """ Model """
    model = Model(inputs, outputs, name="Attention-UNET")
    model.compile(optimizer=Adam(learning_rate=1e-4), loss=dice_coef_loss, metrics=[dice_coef])
    
    return model

class StopOnDiceChange(Callback):
    def __init__(self, min_delta=0.001, patience=5):
        """
        Stops training if the change in val_dice_coef is below min_delta for 'patience' epochs.
        
        :param min_delta: Minimum change in val_dice_coef to consider as an improvement.
        :param patience: Number of consecutive epochs with small change before stopping.
        """
        super(StopOnDiceChange, self).__init__()
        self.min_delta = min_delta
        self.patience = patience
        self.wait = 0
        self.prev_val_dice = None

    def on_epoch_end(self, epoch, logs=None):
        val_dice = logs.get('val_dice_coef')

        if val_dice is not None:
            if self.prev_val_dice is not None:
                dice_change = abs(val_dice - self.prev_val_dice)

                if dice_change < self.min_delta:
                    self.wait += 1
                    print(f"Epoch {epoch+1}: val_dice_coef change {dice_change:.6f} < min_delta {self.min_delta} ({self.wait}/{self.patience})")
                    
                    if self.wait >= self.patience:
                        print(f"\nStopping training: val_dice_coef change below {self.min_delta} for {self.patience} consecutive epochs.")
                        self.model.stop_training = True
                else:
                    self.wait = 0  # Reset counter if there is improvement

            self.prev_val_dice = val_dice  # Store previous val_dice_coef
#############################################################################

print('Load and process train data')
print('------------------------------')
imgs_train, imgs_mask_train = load_train_data()

imgs_train = imgs_train.astype('float32')
mean = np.mean(imgs_train)  # mean for data centering
std = np.std(imgs_train)  # std for data normalization

imgs_train -= mean
imgs_train /= std

imgs_mask_train = imgs_mask_train.astype('float32')
imgs_mask_train /= 255.  # scale masks to [0, 1]

imgs_train = imgs_train.reshape(imgs_train.shape[0], rows, cols, 1)
imgs_mask_train = imgs_mask_train.reshape(imgs_mask_train.shape[0], rows, cols, 1)

fix, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(imgs_train[0, :, :, 0], cmap='gray')
ax[1].imshow(imgs_mask_train[0, :, :, 0], cmap='gray')
plt.show()

x_train, x_val, y_train, y_val = train_test_split(imgs_train, imgs_mask_train, test_size=0.2)

print('Create and compile model')
print('------------------------')
model = attention_unet()

model_no=dataset_no
model_name='RAZ'+str(model_no)+'-2025-JMI'+str(patch_size)+'.h5'
model_checkpoint = ModelCheckpoint(model_name, monitor='val_loss', save_best_only=True)

model.summary()

print('------------------------')
#calculate steps_per_epoch
s_p_e=round(len(os.listdir(os.path.join(data_path, 'Images')))/10)
history = model.fit(my_generator(x_train, y_train, 10), steps_per_epoch=s_p_e,  validation_data=(x_val, y_val), epochs=999999, verbose=2, callbacks=[StopOnDiceChange(min_delta=0.001, patience=5)])

model.save(model_name)

##########################################################

print(history.history.keys())

toc = time.time()

print ((toc-tic)/60, 'Processing Time in Minutes: ' )

# summarize history for accuracy
plt.plot(history.history['dice_coef'])
plt.plot(history.history['val_dice_coef'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()