#### Import libraries

In [1]:
import re
import glob
import os

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

from skimage.transform import resize

from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, UpSampling2D

In [2]:
def rle_to_binary(rle, shape): 
    # create array of zeros representing a blank mask
    mask = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    
    # split rle string into list
    s = rle.split()
    
    # convert the starts and lengths to numpy arrays of integers
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0::2], s[1::2])]
    
    # subtract 1 from the starts since RLE encoding is 1-index and python uses 0-indexing
    starts -= 1
    
    # calculate the end positions of each continuous sequence by adding the lengths to their respective starts (ends = starts + lengths)
    for start, end in zip(starts, ends):
        mask[start:end] = 1
      
    # reshape the 1D array to the desired 2D output shape
    return mask.reshape(shape, order='F')

In [3]:
def create_multiclass_mask(rle_list, original_shape, target_shape):
    masks = []
    
    for rle in rle_list:
        # decode he rle mask into a binary mask
        mask = rle_to_binary(rle, original_shape)
        
        # reshape the mask to so that it matches the target
        resized = resize(mask, target_shape, preserve_range=True, mode='reflect', anti_aliasing=True)
        masks.append(resized_mask)
        
    return np.stack(masks, axis=-1)

In [5]:
def custom_datagen(df, dir, batch_size, target_size):
    # create the image data generator, for now only normalizing as a preprocessing step
    datagen = ImageDataGenerator(rescale=1/255)
    
    # create a list of unique image ids
    ids = np.array(list(df.groups.keys())) ####
    
    while True:
        sample = df.sample(batch_size)
            
        images = []
        masks = []
        
        for _, row in sample.iterrows():
            # get the dataframe rows for the current image
            img_rows = df.get_group(row['id'])
            
            # extract the rle masks from the rows
            rle_list = img_rows['segmentation'].tolist()
            
            # search the dataset directory for an image matching the image id
            pattern = os.path.join(dir, f"{row['id']}*.png")
            file = glob.glob(pattern)[0] 
                        
            # load the image and correct its size 
            original_shape = Image.open(file).size[::-1]
            image = load_img(file, target_size=target_size, color_mode='grayscale')
            image_array = img_to_array(image)
            
            # load and decode the multiclass mask and resize it accordingly
            mask = create_multiclass_mask(rle_list, original_shape, target_size)
            
            images.append(image_array)
            masks.append(mask)
            
        x = np.array(images)
        y = np.array(masks)
        
        # yield statement allows us to return some data to the caller and if
        # called again resume execution from within the while loop
        yield datagen.flow(x, y, batch_size=batch_size)

In [9]:
def make_unet(input_size, num_classes=3):
    input = Input(input_size)
    
    # contracting
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(input)
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    
    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    
    # bottleneck
    convB = Conv2D(256, (3, 3), activation='relu', padding='same')(pool2)
    convB = Conv2D(256, (3, 3), activation='relu', padding='same')(convB)
    
    # expanding
    up3 = concatenate([UpSampling2D(size=(2, 2))(convB), conv2], axis=-1)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(up3)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
    
    up4 = concatenate([UpSampling2D(size=(2, 2))(conv3), conv1], axis=-1)
    conv4 = Conv2D(64, (3, 3), activation='relu', padding='same')(up4)
    conv4 = Conv2D(64, (3, 3), activation='relu', padding='same')(up4)
    
    output = Conv2D(num_classes, (1, 1), activation='softmax')(conv4)
    
    return Model(inputs=input, outputs=output)

In [10]:
# read the csv file into a dataframe and convert it into a groupby object to keep our images grouped by id
df = pd.read_csv('./Dataset/train.csv')
grouped_df = df.groupby('id')

# dataset directory
dir = './Dataset/train'

batch_size = 32
target_size = 256

# create the training data generator
train_generator = custom_datagen(grouped_df, dir, batch_size, (target_size, target_size))

In [11]:
# create the U-net neural network and compile it
model = make_unet(input_size=(target_size, target_size, 1), num_classes=3)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acciracy'])

In [12]:
model.fit(train_generator, batch_size=batch_size, epochs=10)

IndexError: list index out of range

#### Extract id and image size info from each image filename using regular expressions

In [None]:
slice_names = []

for filename in generator.filenames:

    # substitute characters so that the slice name matches the csv file
    new_fn = re.sub("case([0-9]+)\\\\", "", filename)
    new_fn = re.sub("\\\\", "_", new_fn)
    new_fn = re.sub("_[0-9]{3}_[0-9]{3}_1.50_1.50.png", "", new_fn)

    slice_names.append(new_fn)

#### Display some images from the dataset

In [None]:
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(20, 20))

images = next(generator)

itr = 0
for i in range(2):
    for j in range(2):
        ax[i, j].imshow(images[itr])
        ax[i, j].title.set_text(slice_names[itr])
        itr += 1
        
plt.show()

# Issues
### If we reshape an image in the datagen then the mask probably is no longer accurate