In [None]:
#Installation of segmentation-models library for the loss function
!pip install segmentation-models

In [9]:
#Standard imports
import os
import pathlib
import shutil
import matplotlib.pyplot as plt
import cv2
import numpy as np
import re
import random
import segmentation_models as sm

#Tensorflow imports
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.utils import image_dataset_from_directory, img_to_array, load_img, array_to_img, save_img, to_categorical

Segmentation Models: using `keras` framework.


## Experiment 1

This experiment will train the 4 different U-Nets on the retinal lesion data. The experiment will proceed in the following steps:

<ol>
<li>Downloading the images and annotations : since the networks will be trained on cloud platforms, the data will need to be downloaded from {my} computer</li>
<li>Seperate the data into train-val-test set with proportions 0.7:0.15:0.15</li>
<li>Create tf datasets for train, val, test data</li>
<li>Train a model
<ul>
<li>For 25-35 epochs (depending on the time taken for one epoch)</li>
<li>With the dice loss</li>
<li>With model checkpoint callback</li>
<li>With training log callback</li>
</ul>
</li>
<li>Save two versions of the model : 
<ul>
<li>The best version</li>
<li>Version after full epoch training</li>
</ul>
</li>
<li>Once all models are trained, produce evaluation output for all models.</li>
</ol>

### File Seperator

In [45]:
class SegmentationFileSeperator:

    """
    This object copies image files and their corresponding segmentation maps into train, test and validation directories
    target_image_path : The path to the directory containing the train, test, and validation directories for images
    original_image_dir : The path to the directory containing the original images
    target_segmap_path : The path to the directory containing the train, test, and validation directories for for the segmentation
    original_segmap_dir : The path to the directory containing the original segmentation maps

    """
    
    def __init__(self, target_image_path, original_image_dir, target_segmap_path, original_segmap_dir):
        self.target_image_path = target_image_path
        self.original_image_dir = original_image_dir 
        self.target_segmap_path = target_segmap_path
        self.original_segmap_dir = original_segmap_dir
        
    def dataset_segregate(self):
        
        #This function is used to create empty train, test, and validation directories within the main directory
        for typ in ["Train", "Test", "Valid"]:
            path_1 = os.path.join(self.target_image_path, typ)
            os.makedirs(path_1)

            path_2 = os.path.join(self.target_segmap_path, typ)
            os.makedirs(path_2)
    
    def class_maker(self):
        
        #This function makes a directories within the train, test, and validation directories for each class in the datset.
        for typ in os.listdir(self.target_image_path):
            for section in os.listdir(self.original_image_dir):
                path_1 = os.path.join(self.target_image_path, typ, section)
                path_2 = os.path.join(self.target_segmap_path, typ, section)
                os.makedirs(path_1)
                os.makedirs(path_2)
    
    def shuffle_together(self, x,  y):
        z = list(zip(x, y))
        random.Random(9).shuffle(z)
        a, b = zip(*z)
        return a, b

                
    def file_mover(self, train_pr, valid_pr):
        
        #This function is the most important function. This moves all the files from the origianal directory into the target directories with the proportions for train, test and validation data.
        #Section represents class
        #dir represents train,test or val
        for dir in os.listdir(self.target_image_path):
            print(f"Moving to dir: {dir}")

            d_path = self.original_image_dir
            m_path = self.original_segmap_dir
            if dir == "Train":
                start_point = 0
                cutoff = len(os.listdir(d_path)) - 1
                end_point = int(train_pr * cutoff)
            elif dir == "Valid":
                start_off = len(os.listdir(d_path)) -1 
                start_point = int(train_pr * start_off)

                cutoff = len(os.listdir(d_path)) - 1
                end_point = start_point + int(valid_pr * cutoff)
            else:
                test_pr = train_pr + valid_pr
                start_off = len(os.listdir(d_path)) -1
                start_point = int(test_pr * start_off)

                   
                end_point = len(os.listdir(d_path)) -1
        
            moveables_1 = sorted(os.listdir(d_path))
            moveables_2 = sorted(os.listdir(m_path))
            mv_1, mv_2 = self.shuffle_together(moveables_1, moveables_2)
            for i in range(start_point, end_point):
                try:
                    i_src_path = os.path.join(self.original_image_dir, mv_1[i])
                    i_des_path = os.path.join(self.target_image_path, dir, mv_1[i])

                    shutil.copy(i_src_path, i_des_path)

                    #Moving segmaps
                    m_src_path = os.path.join(self.original_segmap_dir, mv_2[i])
                    m_des_path = os.path.join(self.target_segmap_path, dir, mv_2[i])

                    shutil.copy(m_src_path, m_des_path)
                except:
                    continue
                    
    def print_statistics(self):
        
        for dir in os.listdir(self.target_image_path):
    
            num = len(os.listdir(f"{self.target_image_path}/{dir}"))
            print(f"Files in {dir} Directory -> : {num}")
            print(" ")

        
        for dir in os.listdir(self.target_segmap_path):
    

            num = len(os.listdir(f"{self.target_segmap_path}/{dir}"))
            print(f"Files in {dir} Directory -> : {num}")
            print(" ")
            
    
    def run(self, train_pr: float, valid_pr: float):
        
        self.dataset_segregate()
        #self.class_maker()

        if train_pr < 0 or train_pr > 1:
            print("Train proportion value not valid. Please enter a value greater than 0 and less than 1.")
        elif valid_pr < 0 or valid_pr > 1:
            print("The validation proportion is not valid. Please enter a value greater than 0 and less than 1.")
        elif valid_pr > train_pr:
            print("Validation proportion is greater than training data proportion, please enter a value less than the training proportion.")
        elif valid_pr + train_pr >= 1:
            print("The sum of the validation and training proportion is greater than or equal to one, this is not valid. Please enter values such that their sum is strictly less than 1.")

        
        self.file_mover(train_pr=train_pr, valid_pr=valid_pr)
    def test_proportions(self, train_pr: float, valid_pr: float):

        if train_pr < 0 or train_pr > 1:
            print("Train proportion value not valid. Please enter a value greater than 0 and less than 1.")
        elif valid_pr < 0 or valid_pr > 1:
            print("The validation proportion is not valid. Please enter a value greater than 0 and less than 1.")
        elif valid_pr > train_pr:
            print("Validation proportion is greater than training data proportion, please enter a value less than the training proportion.")
        elif valid_pr + train_pr >= 1:
            print("The sum of the validation and training proportion is greater than or equal to one, this is not valid. Please enter values such that their sum is strictly less than 1.")
        else:
            print("The proportions enterred are valid.")
        
        
        
        



### TF Dataset creator

In [29]:
class SegmentationDataSet:

  """
  This object returns a zipped tf Dataset containing images and their corresponding segmentation maps
  img_dir : The directory containing the images
  maps_dir : The directory containing the segmentation maps
  input_size : A tuple which specifies the height and width of the images and segmentation maps (h,w)
  n_classes : The number of classes (including the background)
  """

  def __init__(self, img_dir:str, maps_dir:str, input_size:tuple, n_classes:int):

    self.img_root = pathlib.Path(img_dir)
    self.map_root = pathlib.Path(maps_dir)
    self.h , self.w = input_size
    self.n = n_classes
    self.img_root = str(self.img_root/'*/*')
    self.map_root = str(self.map_root/'*/*')

  def process_image(self,file_path):
    file_name = tf.io.read_file(file_path)
    img = tf.io.decode_jpeg(file_name, channels=3)
    img = tf.image.resize(img, [self.h, self.w])
    
    return img

  def process_annotation(self,file_path):

    file_name = tf.io.read_file(file_path)
    img = tf.io.decode_png(file_name, dtype="uint8")
    #img = img - 1
    img = tf.image.resize(img, [self.h, self.w])
    img = tf.cast(img, dtype=tf.int32)
    img = tf.one_hot(img, self.n, axis=2)
    return img

  def create_segmentation_dataset(self):

    seed = 5

    image_ds = tf.data.Dataset.list_files(self.img_root, seed=seed)
    ann_ds = tf.data.Dataset.list_files(self.map_root, seed=seed)

    image_ds = image_ds.map(self.process_image)
    ann_ds = ann_ds.map(self.process_annotation)
    ds = tf.data.Dataset.zip((image_ds, ann_ds))
    ds = ds.shuffle(1000, seed=seed).batch(32)

    return ds
  











### Callback which plots the training loss curve after training

In [None]:
class LossPlot(tf.keras.callbacks.Callback):

    def __init__(self, path):
        
        super(LossPlot, self).__init__()
        self.path = path
        _, self.ax = plt.subplots(1,1, figsize=(10,10))

    
    def on_train_begin(self, logs=None):

        self.epochs = 0
    
    def on_epoch_end(self, epoch,logs=None):

        self.epochs = self.epochs + 1

    def build_plots(self, epochs, ax, y, y2):
        ax.plot(range(epochs), y, "b--", label="Training Loss")
        ax.plot(range(epochs), y2, color="darkturquoise", label="Validation Loss")
        ax.set_ylabel("loss")
        ax.set_xlabel("epochs")
        ax.legend()
        plt.savefig(self.path, dpi=300)
    
    def on_train_end(self, logs):

        y = logs["loss"]
        y2 = logs["val_loss"]

        self.build_plots(self.epochs, self.ax,y, y2)  
        

### Evaluation Pipeline

In [None]:
import json
class EvalPipeline:
    
    def __init__(self, gt, img, n,model_dict, class_dict):
        
        self.gt = gt
        self.img = img
        self.model_dict = model_dict
        self.n = n
        self.class_dict = class_dict

        self.prediction_gen()
        
    def prediction_gen(self):
        
        predictions = {}
        
        for m in self.model_dict.keys():
            model = self.model_dict[m]
   
            pred = model.predict(self.img)
            pred = np.array(pred)
            pred = np.argmax(pred, axis=3)
            pred = pred.reshape(pred.shape[0], pred.shape[1], pred.shape[2], 1)
            
            predictions[m] = pred
            
        self.pred = predictions
            
    
    def stage_one(self, metrics=["SENS", "SPEC", "IoU", "DSC"], path="stage_1.csv"):
        
        scores = {}
        
        #First evaluate the same metric for all sets of predictions
        print(metrics)
        for metric in metrics:
            
            scores[metric] = []
            for p in self.pred.keys():
                
                
                current_pred = self.pred[p]
                #Evaluate for the current prediction
                score = evaluate(self.gt, current_pred, metric=metric, multi_class=True, n_classes=self.n)
                score = np.mean(score)
                scores[metric].append(score)
                
        with open(path, "w") as f:
            for m in metrics:
                f.write(m)
                f.write(",")
            f.write("\n") 
            for i in range(len(scores[m])):
                for m in metrics:
                
                    f.write(str(scores[m][i]))
                    f.write(",")
                f.write("\n")
            
            
        return scores
    
    def stage_two(self, metrics=["SENS", "SPEC", "IoU", "DSC"], path="stage_2.json"):
        
        scores = {}
        
        # Have to change all of done_pred instances into self.pred
        for p in self.pred.keys():
            scores[p] = []
            print(f"Working on : {p}...")
            for metric in metrics:
                
                current_pred = self.pred[p]
                #Evaluate for the current prediction
                score = evaluate(self.gt, current_pred, metric=metric, multi_class=True, n_classes=self.n)
                scores[p].append(score)
            
            scores[p] = np.array(scores[p]).T
        
        scores_2 = {}
        
        print("Creating final dict")
        for p in self.pred.keys():
            scores_2[p] = {}
            for i, c in enumerate(self.class_dict):
                scores_2[p][c] = {}
                for n, m in enumerate(metrics):
                    scores_2[p][c][m] = scores[p][i][n]
        
        

        with open(path, "w") as json_file:
            json.dump(scores_2, json_file)
        
        
        return scores_2

    
    def stage_four(self, img_dir, gt_dir, img_files, gt_files, path="stage_4.png"):
        
        #Assumes the images aren't 4 dimensional tensors
        
        #Number of images should be the same as G.T
        print(img_dir)
        print(gt_dir)
        print(img_files)
        print(gt_files)
        
        n_cols = len(img_files)
        n_rows = len(list(self.model_dict.keys())) +2
        
        #Defining the figure
        fig, ax = plt.subplots(n_rows, n_cols, figsize=(n_cols*3,n_rows*3))
        
        ax[0, 0].set_ylabel("images")
        ax[1, 0].set_ylabel("ground_truth")
        
        #Plotting the images first
        for i, img in enumerate(img_files):
            
            img_path = os.path.join(img_dir, img)
            image = load_img(img_path)
            ax[0, i].imshow(image)
            
            ax[0,i].set_xticks([])
            ax[0, i].set_yticks([])
        
        #Plotting the ground truth
        for i, ann in enumerate(gt_files):
            
            img_path = os.path.join(gt_dir, ann)
            image = tf.io.read_file(img_path)
            image = tf.io.decode_png(image)

            ax[1, i].imshow(image)
            
            ax[1,i].set_xticks([])
            ax[1, i].set_yticks([])
        
        for i, model_name in enumerate(self.model_dict.keys()):
            
            i = i+2
            model = self.model_dict[model_name]
            ax[i, 0].set_ylabel(model_name)
            
            for n, img in enumerate(img_files):
                
                img_path = os.path.join(img_dir, img)
                image = tf.io.read_file(img_path)
                image = tf.io.decode_jpeg(image, 3)
                image = tf.image.resize(image, [224, 224])
                image = np.expand_dims(image, axis=0)
                pred = model.predict(image)
                pred = np.squeeze(pred)
                
                ax[i, n].imshow(pred)
                ax[i,n].set_xticks([])
                ax[i,n].set_yticks([])
            
        plt.subplots_adjust(left=0.2,
                    right=0.9,
                    wspace=0.4,
                    hspace=0.4)
            
        plt.savefig(path, dpi=100)
        
    
    ### Stage 5 : F1 score plots per model per class
    
    def stage_five(self, path="stage_5.png"):
        

        
        markers = ["x", "+", ".", "1", "*", "d"]
        colors = ["lime", "fuchsia", "darkorange", "gold", "salmon", "indigo"]
        
        model_names = list(self.pred.keys())
        
        fig, ax = plt.subplots(1,1, figsize=(6,self.n))
        ys = range(self.n)
        
        for i, model_name in enumerate(model_names):
            
            pred = self.pred[model_name]
            
            prec = evaluate(self.gt, pred, metric="PREC", multi_class=True, n_classes=self.n)
            recall = evaluate(self.gt, pred, metric="Recall", multi_class=True, n_classes=self.n)
            
            f1 = (2 * prec * recall) / (prec + recall)
            
            some_num = np.random.uniform(0.1, 0.3, 1)
            
            ax.scatter(x=f1, y=ys, color=colors[i], marker=markers[i], label=model_name)
            ax.set_yticks(ticks=[0,1,2],labels=self.class_dict)
            ax.legend(loc="best")
            ax.grid()
        
        plt.savefig(path, dpi=100)
        
        
        
    

## Models : Encoders and decoders

The following section of the code contains the different types of U-Nets that will be trained. The networks will be trained across different cloud platforms, where each platform will only train one network

In [10]:
#Weighted dice_loss -> background class is given a smaller weight
dice_loss = sm.losses.DiceLoss(class_weights=[0.25, 1.5, 1, 1, 1])

### Decoder 

All the U-nets use the same decoder

In [13]:


class SpatialAttention(tf.keras.layers.Layer):

    def __init__(self):
        super(SpatialAttention, self).__init__()
        self.alpha = tf.Variable(initial_value=0.0, trainable=True)
    
    def build(self, input_shape):
        
        self.C = input_shape[-1]
        self.H = input_shape[1]
        self.W = input_shape[2]

        #Defining the convolutions
        self.conv1 = tf.keras.layers.Conv2D(self.C, 1)
        self.conv2 = tf.keras.layers.Conv2D(self.C, 1)
        self.conv3 = tf.keras.layers.Conv2D(self.C, 1)

    def call(self, inputs):

        n_shape = self.H * self.W

        a = inputs
        b = self.conv1(inputs)
        c = self.conv2(inputs)
        d = self.conv3(inputs)

        b = tf.transpose(tf.keras.layers.Reshape((n_shape, self.C))(b), perm=[0,2,1])
        c = tf.keras.layers.Reshape((n_shape, self.C))(c)
        d = tf.keras.layers.Reshape((n_shape, self.C))(d)

        c = tf.linalg.matmul(c, b)
        S = tf.keras.layers.Softmax()(c)
        S = tf.transpose(S, perm=[0,2,1])

        d = self.alpha * tf.linalg.matmul(S, d)
        d = tf.keras.layers.Reshape((self.H, self.W, self.C))(d)
        E = tf.keras.layers.Add()([a, d])        

        return E

class ChannelAttention(tf.keras.layers.Layer):

    def __init__(self):
        super(ChannelAttention, self).__init__()
        self.beta = tf.Variable(initial_value=0.0, name="beta", trainable=True)
    
    def build(self, input_shape):
        self.C = input_shape[-1]
        self.H = input_shape[1]
        self.W = input_shape[2]
    
    def call(self, inputs):

        a1=a2=a3=a4= inputs
        n_shape = self.H * self.W
        a2 = tf.keras.layers.Reshape((n_shape, self.C))(a2)
        a3 = tf.keras.layers.Reshape((n_shape, self.C))(a3)
        a4 = tf.transpose(tf.keras.layers.Reshape((n_shape, self.C))(a4), perm=[0,2,1])


        #Creating X, the softmax on the matrix product of A_T_A
        a_T_a = tf.linalg.matmul(a4, a3)
        x = tf.keras.layers.Softmax()(a_T_a)
        x = tf.transpose(x, perm=[0,2,1])

        a2_pass = self.beta * tf.linalg.matmul(a2, x)
        a2_pass = tf.keras.layers.Reshape((self.H,self.W,self.C))(a2_pass)

        E = tf.keras.layers.Add()([a1, a2_pass])

        return E
            

class DualAttention(tf.keras.layers.Layer):

    def __init__(self):
        super(DualAttention, self).__init__()
    
    def build(self, input_shape):
        self.C = input_shape[-1]
        self.conv1 = tf.keras.layers.Conv2D(self.C, 1)
        self.conv2 = tf.keras.layers.Conv2D(self.C, 1)
        self.sam = SpatialAttention()
        self.cam = ChannelAttention()
    
    def call(self, inputs):

        e1 = self.sam(inputs)
        e2 = self.cam(inputs)

        e1 = self.conv1(e1)
        e2 = self.conv2(e2)

        F = tf.keras.layers.Add()([e1, e2])
        return F



def decoder_block(a, x, f, attention=False):

    if attention:
        a = DualAttention()(a)

    x = tf.keras.layers.Conv2DTranspose(filters=f, kernel_size=2, strides=2, padding="same", activation="relu")(x)
    if a is not  None:
        x = tf.concat([a, x], axis=-1)
    x = tf.keras.layers.Conv2D(f, 3, padding="same", activation="relu")(x)
    x = tf.keras.layers.Conv2D(f, 3, padding="same", activation="relu")(x) 

    return x

def decoder_full(activations, x, filters, num_classes, attention_indices):

    #Looping over the activations and filters from bottom to top
    #Activation are reversed for this effect
 
    
    ai = None
    #Flag to indicate whether the point of attention is found
    found = True
    #Flag to pass to the decoder block, whether dual attention should be applied
    att = False
    for i,(a,f) in enumerate(zip(activations[::-1],filters)):

        #Flag to indicate whether there is no need for attention
        there = len(attention_indices)
        if found and there:
            ai = attention_indices.pop()
        
        #Check if the current activation needs attention

        att=found = (i+1 == ai)
        print(att)
        x = decoder_block(a, x, f, att)
    
    output = tf.keras.layers.Conv2D(num_classes, 1, padding="same", activation="softmax")(x)

    return output



### Base U-Net

This the Base U-Net model

In [14]:
def encoder_block(inp,f):

    x = inp

    x = tf.keras.layers.Conv2D(f, 3, 1, activation="relu",padding="same")(x)
    x = tf.keras.layers.Conv2D(f, 3, 1, activation="relu",padding="same")(x)
    a = x
    x = tf.keras.layers.MaxPool2D(2, 2)(x)

    return x,a

def last_encoder(inp, f):

    x = inp

    x = tf.keras.layers.Conv2D(f, 3, 1, activation="relu", padding="same")(x)
    x = tf.keras.layers.Conv2D(f, 3, 1, activation="relu", padding="same")(x)

    return x

def encoder_unet(inp, filters):

    activations = []

    x = inp
    for f in filters[:-1]:
        x, a = encoder_block(x, f)
        activations.append(a)
    
    x = last_encoder(x, filters[-1])
    return x, activations

def unet(num_classes, input_size, input_dim, att_indices=[], last_attention=False):

    inp = tf.keras.layers.Input((input_size, input_size, input_dim))

    filters = [64, 128, 256, 512, 1024]

    x, a = encoder_unet(inp, filters=filters)

    if last_attention:
        x = DualAttention()(x)

    output = decoder_full(a, x, filters[:-1][::-1], num_classes, att_indices)

    model = tf.keras.Model(inp, output)

    return model

In [15]:
unet_model = unet(5, 512, 3)
unet_model.compile(optimizer="adam", loss=dice_loss, metrics="accuracy")

False
False
False
False


### VGG U-Net

In [19]:
from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input
def vgg_encoder_block(x, layers):
    """
    This function passes an input through a set of conv layers from VGG19, returning the downsampled and convolved activation
    """
    for layer in layers:
        x = layer(x)
    
    addition = x
    x = tf.keras.layers.MaxPooling2D((2,2), strides = 2)(x)
    return (x, addition)

def last_vgg_block(x, layers):

    for layer in layers:
        x = layer(x)
    
    return x

def vgg_encoder_full(input, layer_dict):

    """
    This function creates the full encoder given a dictionary of layers from the VGG network, it returns the final activation 
    and a list of intermediate activations
    """

    activations = []
    x = input
    for layer_name in list(layer_dict.keys())[:-1]:
        x, a = vgg_encoder_block(x, layer_dict[layer_name])
        activations.append(a)
    
    x = last_vgg_block(x, layer_dict[list(layer_dict.keys())[-1]])
    
    return x, activations



def vgg_unet(num_classes, input_size, input_dim, att_indices=[], last_attention=False):

    #Downloading the VGG network
    vgg19 = VGG19(weights="imagenet", include_top=False, input_shape=(input_size, input_size,input_dim))
    vgg19.trainable = False
    #Getting all the blocks from the VGG network
    vgg_blocks = {
        f"block{n}" : [layer for layer in vgg19.layers if f"block{n}_conv" in layer.name] for n in range(1, 6)
    }
    
    #Filters for the Decoder
    filters = [512, 256, 128, 64]

    l = len(att_indices)
    assert l >= 0, "Attention indices should be 0 or greater"
    assert l <= len(filters) + 1, "Number of layers for attetention can not exceed 5"

    #assert len(att_indices[att_indices > 5]) == 0, "Attention indices must be from 1 to 5"
    vgg_input = vgg19.input
    #Defining the encoder
    #First Preprocess the input
    x = preprocess_input(x=vgg_input)
    
    x, a = vgg_encoder_full(x, vgg_blocks)

    if last_attention:
        x = DualAttention()(x)

    output = decoder_full(a, x, filters, num_classes, att_indices)

    vgg_unet_model = tf.keras.Model(vgg_input, output)

    return vgg_unet_model

In [20]:
vgg_unet = vgg_unet(5, 512, 3)
vgg_unet.compile(optimizer ="adam", loss=dice_loss, metrics ="accuracy")

False
False
False
False


### Resnet U-Net

In [21]:
from tensorflow.keras.applications.resnet import ResNet50, preprocess_input

def resblock_stem(input, layers):

    x = input
    for layer in layers:

        x = layer(x)
    
    a = x
    x = tf.keras.layers.MaxPool2D(2, 2)(x)

    return x, a


def resblock_enc(inp, layers, res):

    enc = tf.keras.Model(inputs=res.get_layer(layers[0].name).input, outputs=res.get_layer(layers[-1].name).output)
    x = enc(inp)
    return x

def resnet_encoder(inp, layer_dict, res):

    keys = list(layer_dict.keys())

    x, a = resblock_stem(inp, layer_dict[keys[0]])

    activations = [a]
    for block in keys[1:]:
        x = resblock_enc(x, layer_dict[block], res) 
        activations.append(x)
    
    return x, [None] + activations[:-1]



def resnet_unet(num_classes, input_size, input_dim, att_indices=[], last_attention=False):

    #Downloading the ResNet
    resnet = ResNet50(weights="imagenet", include_top=False, input_shape=(input_size,input_size,input_dim))

    layer_dict = resnet_blocks = {
    f"block_{n}" : [layer for layer in resnet.layers if layer.name[:5] == f"conv{n}"] for n in range(1, 5)
    }

    #Freezing the layers of the ResNet
    resnet.trainable = False

    #Building the model
    inp = resnet.input
    x = preprocess_input(x=inp)
    x, a = resnet_encoder(inp, layer_dict, resnet)


    #Decoder
    filters = [512, 256, 128, 64]

    l = len(att_indices)
    assert l >= 0, "Attention indices should be 0 or greater"
    assert l <= len(filters) + 1, "Number of layers for attetention can not exceed 5"

    #assert len(att_indices[att_indices > 5]) == 0, "Attention indices must be from 1 to 5"

    if last_attention:
        x = DualAttention()(x)

    output = decoder_full(a, x, filters, num_classes, att_indices)

    model = tf.keras.Model(inp, output)

    return model

In [22]:
resnet_unet = resnet_unet(5, 512, 3)
resnet_unet.compile(optimizer="adam", loss=dice_loss, metrics="accuracy")

False
False
False
False


### Efficientnet U-Net

In [25]:
from tensorflow.keras.applications.efficientnet import EfficientNetB4, preprocess_input

def effnet_stem(input, layers):

    x = input

    for layer in layers:
        x = layer(x)
    
    return x
    
def effblock_enc(inp, layers, model):

    enc = tf.keras.Model(inputs=model.get_layer(layers[0].name).input, outputs=model.get_layer(layers[-1].name).output)
    x = enc(inp)
    return x

def effnet_encoder(inp, layer_dict, model):

    keys = list(layer_dict.keys())

    x = effnet_stem(inp, layer_dict[keys[0]])

    activations = [x]
    for block in keys[1:]:

        x = effblock_enc(x, layer_dict[block], model) 
        if block == "block_1" or block ==  "block_4" or block == "block_6":
            continue
        activations.append(x)
    
    return x, [None] + activations[:-1]

def effnet_unet(num_classes, input_size, input_dim, att_indices=[], last_attention=False):

    b4 = EfficientNetB4(weights="imagenet", include_top=False, input_shape=(input_size,input_size,input_dim))

    effnet_blocks_dict = {}
    stem_start = [b4.layers[0], b4.layers[1], b4.layers[2]]
    stem_start.extend([layer for layer in b4.layers if layer.name[:4] == "stem"])
    effnet_blocks_dict["stem"] = stem_start
    effnet_blocks_dict = {**effnet_blocks_dict, **{ 
        f"block_{n}" : [layer for layer in b4.layers if layer.name[:6] == f"block{n}"] for n in range(1, 8)
    }}

    #Freeszing the weights of the model
    b4.trainable = False

    #Encoder of the network
    inp = tf.keras.layers.Input((input_size, input_size, input_dim))
    x = preprocess_input(inp)
    #a not being reversed here because it's reversed in the decoder function
    x, a = effnet_encoder(x, effnet_blocks_dict, b4)

    #Decoder of the network
    filters = [160, 56, 32, 48, 64]

    l = len(att_indices)
    assert l >= 0, "Attention indices should be 0 or greater"
    assert l <= len(filters) + 1, "Number of layers for attetention can not exceed 5"

    #assert len(att_indices[att_indices > 5]) == 0, "Attention indices must be from 1 to 5"


    if last_attention:
        x = DualAttention()(x)
    
    output = decoder_full(a, x, filters, num_classes, att_indices)

    model  = tf.keras.Model(inp, output)

    return model

In [26]:
eff_unet = effnet_unet(5, 512, 3)
eff_unet.compile(optimizer="adam", loss=dice_loss, metrics="accuracy")

False
False
False
False
False


## Data paths, and dataset definitions

In [47]:
#The directory where the images and the segmentation maps are orignally stored - the zip files
og_img_dir = ""
og_map_dir = ""

#The directory where the images and segmentation maps are stored they need to be completed based on where the model is trained, this directory is where the train, test, and validation directories are stored
image_dir = ""
segmaps_dir = ""

#The directory for the training images and segmentation maps
train_img_dir = ""
train_map_dir = ""

#The directory for the test images and segmentation maps
test_img_dir = ""
test_map_dir = ""

#The directory for the validation images and segmentation maps
val_img_dir = ""
val_map_dir = ""


In [48]:
#File seperation into train, test, and validation sets
fs = SegmentationFileSeperator(image_dir, og_img_dir, segmaps_dir, og_map_dir)
fs.run(0.7, 15)

: 

In [31]:
train_ds = SegmentationDataSet(train_img_dir, train_map_dir, (512, 512), 5)
test_ds = SegmentationDataSet(test_img_dir, test_map_dir, (512, 512), 5)
valid_ds = SegmentationDataSet(val_img_dir, val_map_dir, (512, 512), 5)