In [None]:
# ----- IMPORT OF USEFUL LIBRARIES

import os
import sys
import numpy as np
import tensorflow as tf
import random
import math
import warnings
import pandas as pd
import cv2
import matplotlib.pyplot as plt
from py_img_seg_eval.eval_segm import mean_IU

from tqdm import tqdm
from itertools import chain
import skimage
from skimage import io
from skimage.io import imread, imshow, imread_collection, concatenate_images
from skimage.transform import resize
from skimage.morphology import label

warnings.filterwarnings('ignore', category=UserWarning, module='skimage')
seed = 42
random.seed = seed
np.random.seed = seed

In [None]:
# Setting parameters
IMG_WIDTH = 128
IMG_HEIGHT = 128
IMG_CHANNELS = 3
TRAIN_PATH = '/Users/Gregoire/Desktop/ETHZ/Semester project/stage1_train/'
TEST_PATH = '/Users/Gregoire/Desktop/ETHZ/Semester project/stage1_test/'

warnings.filterwarnings('ignore', category=UserWarning, module='skimage')
seed = 42
random.seed = seed
np.random.seed = seed

# Get train and test IDs
    
train_ids = next(os.walk(TRAIN_PATH))[1]
test_ids = next(os.walk(TEST_PATH))[1]

In [None]:
# Get and resize train images and masks

import h5py
from pathlib import Path

my_file = Path("data.h5")
if my_file.is_file():
    # file exists, just load it
    h5f = h5py.File('data.h5','r')
    images = h5f['images'][:]
    labels = h5f['labels'][:]
    h5f.close()
else:
    #file doesnt exit, import data and create .h5 file
    images = np.zeros((len(train_ids), IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.uint8)
    labels = np.zeros((len(train_ids), IMG_HEIGHT, IMG_WIDTH, 1), dtype=np.bool)
    print('Getting and resizing train images and masks ... ')
    sys.stdout.flush()
    for n, id_ in tqdm(enumerate(train_ids), total=len(train_ids)):
        path = TRAIN_PATH + id_
        img = imread(path + '/images/' + id_ + '.png')[:,:,:IMG_CHANNELS]
        img = resize(img, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
        images[n] = img
        mask = np.zeros((IMG_HEIGHT, IMG_WIDTH, 1), dtype=np.bool)
        for mask_file in next(os.walk(path + '/masks/'))[2]:
            mask_ = imread(path + '/masks/' + mask_file)
            mask_ = np.expand_dims(resize(mask_, (IMG_HEIGHT, IMG_WIDTH), mode='constant', 
                                          preserve_range=True), axis=-1)
            mask = np.maximum(mask, mask_)
        labels[n] = mask

    X_train = images
    Y_train = labels

    h5f = h5py.File('data.h5', 'w')
    h5f.create_dataset('images', data=images)
    h5f.create_dataset('labels', data=labels)
    h5f.close()

# Get and resize test images
# X_test = np.zeros((len(test_ids), IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.uint8)
# sizes_test = []
# print('Getting and resizing test images ... ')
# sys.stdout.flush()
# for n, id_ in tqdm(enumerate(test_ids), total=len(test_ids)):
#     path = TEST_PATH + id_
#     img = imread(path + '/images/' + id_ + '.png')[:,:,:IMG_CHANNELS]
#     sizes_test.append([img.shape[0], img.shape[1]])
#     img = resize(img, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
#     X_test[n] = img

print('Done!')

In [None]:
#Split into training and validation set
from sklearn.model_selection import train_test_split

X_train = images
Y_train = labels

val_split = 0.1
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=val_split, random_state=seed)

In [None]:
# Displaying a training (and validation) sample together with its mask

n_train=X_train.shape[0]
n_val=X_val.shape[0]

n=np.random.randint(n_train)
mask = Y_train[n].astype(np.uint8)#convert to an unsigned byte
mask*=255
mask=np.squeeze(mask,axis=2)

n2=np.random.randint(n_val)
mask2 = Y_val[n2].astype(np.uint8)  #convert to an unsigned byte
mask2*=255
mask2=np.squeeze(mask2,axis=2)

fig, axes = plt.subplots(ncols=4, figsize=(10, 8))
ax = axes.ravel()

ax[0] = plt.subplot(1, 4, 1)
ax[1] = plt.subplot(1, 4, 2)
ax[2] = plt.subplot(1, 4, 3)
ax[3] = plt.subplot(1, 4, 4)

ax[0].imshow(X_train[n])
ax[0].set_title('Original')

ax[1].imshow(mask)
ax[1].set_title('Mask')

ax[2].imshow(X_val[n2])
ax[2].set_title('Val_original')

ax[3].imshow(mask2)
ax[3].set_title('Val_mask')

plt.show()