In [None]:
import os
import cv2
import numpy as np
import matplotlib .pyplot as plt
from keras.optimizers import Adam
from skimage.io import imsave, imread
from skimage.transform import resize
from skimage.util import img_as_float
from keras import backend as K
import tensorflow as tf
import tensorflow.keras.layers as L
from tensorflow.keras.models import Model
from timeit import default_timer as timer

K.set_image_data_format('channels_last')  # TF dimension ordering in this code

patch_size=128
rows = patch_size
cols = patch_size

img_rows = patch_size
img_cols = patch_size

ncols = patch_size
nrows = patch_size

# you can select dataset_no from 1 to 5 to predict the outer test partition in each fold
dataset_no=1
i=str(dataset_no)

main_path =  r"/path/to/dataset/"
data_path =main_path+'\Fold'+ i +'\Test'

dsize = (ncols, nrows)

smooth = 5.

dir_path = main_path+'\Fold'+ i + '\Result_attUnet\Predictions'
dir_path2 = main_path+'\Fold'+ i + '\Result_attUnet\Results'
dir_path1 = main_path+'\Fold'+ i + '\Result_attUnet\Masks'

if not os.path.exists(dir_path):
        os.makedirs(dir_path)
        
if not os.path.exists(dir_path1):
        os.makedirs(dir_path1)

if not os.path.exists(dir_path2):
        os.makedirs(dir_path2)

import re
def sorted_alphanumeric(data):
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
    return sorted(data, key=alphanum_key)

def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_loss(y_true, y_pred):
    return -K.log(dice_coef(y_true, y_pred))

def load_train_data():
    train_images_path = os.path.join(data_path, 'Images')
    train_masks_path = os.path.join(data_path, 'Masks')
    images = sorted_alphanumeric(os.listdir(train_images_path))
    masks = sorted_alphanumeric(os.listdir(train_masks_path))
  
    total = len(images)

    rimgs = np.empty((total, rows, cols), dtype=np.float32)
    rmsks = np.empty((total, rows, cols), dtype=np.float32)

    i = 0
    print('Convert training images to arrays')
    print('------------------------------------------')
    for image_name in images:
        img = imread(os.path.join(train_images_path, image_name))
        img = img_as_float(img)
        rimg = resize(img, (rows, cols), preserve_range=True)

        rimgs[i] = rimg

        if i % 10 == 0:
            print('Done: {0}/{1} images'.format(i, total))
        i += 1
    print('Done.')
    print('------------------------------------------')

    i = 0
    print('Convert training masks to arrays')
    print('------------------------------------------')
    for mask_name in masks:
        msk = imread(os.path.join(train_masks_path, mask_name))
        img = img_as_float(img)
        rmsk = resize(msk, (rows, cols), preserve_range=True)

        rmsks[i] = rmsk

        if i % 10 == 0:
            print('Done: {0}/{1} masks'.format(i, total))
        i += 1
    print('Done.')
    print('------------------------------------------')

    return rimgs, rmsks

def conv_block(x, num_filters):
    x = L.Conv2D(num_filters, 3, padding="same")(x)
    x = L.BatchNormalization()(x)
    x = L.Activation("relu")(x)

    x = L.Conv2D(num_filters, 3, padding="same")(x)
    x = L.BatchNormalization()(x)
    x = L.Activation("relu")(x)

    return x

def encoder_block(x, num_filters):
    x = conv_block(x, num_filters)
    p = L.MaxPool2D((2, 2))(x)
    return x, p

def attention_gate(g, s, num_filters):
    Wg = L.Conv2D(num_filters, 1, padding="same")(g)
    Wg = L.BatchNormalization()(Wg)

    Ws = L.Conv2D(num_filters, 1, padding="same")(s)
    Ws = L.BatchNormalization()(Ws)

    out = L.Activation("relu")(Wg + Ws)
    out = L.Conv2D(num_filters, 1, padding="same")(out)
    out = L.Activation("sigmoid")(out)

    return out * s

def decoder_block(x, s, num_filters):
    x = L.UpSampling2D(interpolation="bilinear")(x)
    s = attention_gate(x, s, num_filters)
    x = L.Concatenate()([x, s])
    x = conv_block(x, num_filters)
    return x

def attention_unet():
    """ Inputs """
    inputs = L.Input((img_rows, img_cols, 1), name="input_first")

    """ Encoder """
    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)

    b1 = conv_block(p3, 512)

    """ Decoder """
    d1 = decoder_block(b1, s3, 256)
    d2 = decoder_block(d1, s2, 128)
    d3 = decoder_block(d2, s1, 64)

    """ Outputs """
    outputs = L.Conv2D(1, 1, padding="same", activation="sigmoid")(d3)

    """ Model """
    model = Model(inputs, outputs, name="Attention-UNET")
    model.compile(optimizer=Adam(learning_rate=1e-4), loss=dice_coef_loss, metrics=[dice_coef])
    
    return model

def predict():

    print('Load data into array')
    print('----------------------')
    
    train_images_path = os.path.join(data_path, 'Images')
    train_masks_path = os.path.join(data_path, 'Masks')
    images = sorted_alphanumeric(os.listdir(train_images_path))
    masks = sorted_alphanumeric(os.listdir(train_masks_path))

    total = len(images)

    rimgs = np.empty((total, rows, cols), dtype=np.float32)
    rmsks = np.empty((total, rows, cols), dtype=np.float32)

    i = 0
    print('Convert training images to arrays')
    print('------------------------------------------')
    for image_name in images:
        img = imread(os.path.join(train_images_path, image_name))
        img = img_as_float(img)
        rimg = resize(img, (rows, cols), preserve_range=True)

        rimgs[i] = rimg

        if i % 10 == 0:
            print('Done: {0}/{1} images'.format(i, total))
        i += 1
    print('Done.')
    print('------------------------------------------')

    i = 0
    print('Convert training masks to arrays')
    print('------------------------------------------')
    for mask_name in masks:
        msk = imread(os.path.join(train_masks_path, mask_name))
        img = img_as_float(img)
        rmsk = resize(msk, (rows, cols), preserve_range=True)

        rmsks[i] = rmsk

        if i % 10 == 0:
            print('Done: {0}/{1} masks'.format(i, total))
        i += 1
    print('Done.')
    print('------------------------------------------')

    imgs_train = rimgs
    imgs_mask_train = rmsks

    imgs_train = imgs_train.astype('float32')
    mean = np.mean(imgs_train)  # mean for data centering
    std = np.std(imgs_train)  # std for data normalization

    imgs_train -= mean
    imgs_train /= std

    imgs_mask_train = imgs_mask_train.astype('float32')
    imgs_mask_train /= 255.  # scale masks to [0, 1]

    imgs_train = imgs_train.reshape(imgs_train.shape[0], img_rows, img_cols, 1)
    imgs_mask_train = imgs_mask_train.reshape(imgs_mask_train.shape[0], img_rows, img_cols, 1)

    print('Done loading')
    print('------------------------')
    print('Load weights')
    print('------------------------')
    
    model=attention_unet()

    model_no=dataset_no
    model_name='RAZ'+str(model_no)+'-2025-JMI'+str(patch_size)+'.h5'
    model.load_weights(model_name)
    
    print('Done loading')
    print('------------------------')

    print('Predict all')
    print('------------------------')


    y_hat = model.predict(imgs_train)

    max=len(os.listdir(os.path.join(data_path, 'Images')))
  
    for i in range(0, max):

        fig, ax = plt.subplots(1, 3, figsize=(12, 6))
        ax[0].imshow(imgs_train[i, :, :, 0], cmap='gray')
        r_n = str(i + 1) + '_Mask.bmp'
        s1 = images[i]
        rr_n = s1.replace('.bmp', '') + '_Input.png'
        rr_p = os.path.join(dir_path2, rr_n)

        ax[0].set_title(rr_n)
        #ax[0].set_title('Test set',i)
        ax[1].imshow(imgs_mask_train[i, :, :, 0], cmap='gray')
        ax[1].set_title('Mask')
        ax[2].imshow(y_hat[i, :, :, 0], cmap='gray')
        ax[2].set_title('Output')

        plt.savefig(rr_p)
        
        r_n = str(i + 1) + '_Mask.bmp'
        s1 = images[i]
        r_n = s1.replace('.bmp', '') + '_Mask.bmp'
        r_p = os.path.join(dir_path1, r_n)
        
        o_i = imgs_mask_train[i].reshape(patch_size, patch_size)
        o = cv2.resize(o_i, dsize)
        ''''''
        if o.dtype == np.float32 or o.dtype == np.float64:
            # Convert to integer (0-255 scale)
            image_data_int = (o * 255).astype(np.uint8) 
      
        o=image_data_int
        imsave(r_p, o)

        result_name = str(i+1) + '_result.bmp'

        s = images[i]
        result_name = s.replace('.bmp', '') + '_result.bmp'
        result_path = os.path.join(dir_path, result_name)
        output_image = y_hat[i].reshape(patch_size, patch_size)
        
        output = cv2.resize(output_image, dsize)

        if output.dtype == np.float32 or output.dtype == np.float64:
            # Convert to integer (0-255 scale)
            image_data_int = (output * 255).astype(np.uint8)

        output=image_data_int
        imsave(result_path, output)

predict()