In [None]:
import os
import cv2 
import numpy as np
import matplotlib.pyplot as plt

BATCH_SIZE = 8
MODEL_SHAPE = (256, 256)

In [None]:
data_dir = ""
labeled_dir = os.path.join(data_dir, "labeled")
unlabeled_dir = os.path.join(data_dir, "unlabeled")

unlabeled_paths = []
for root, dirs, files in os.walk(unlabeled_dir):
    for name in files:
        unlabeled_paths.append(os.path.join(root, name))

labeled_paths = []
for root, dirs, files in os.walk(root, labeled_dir):
    for name in files:
        labeled_paths.append(os.path.join(root, name))

unlabeled_paths.sort()
labeled_paths.sort()

In [None]:

input_pairs = {}
count = 0
for unlabeled_path in unlabeled_paths:
    input_pairs[unlabeled_path] = labeled_paths[count]

In [None]:
def normalize(paths, batchsize=8, outshape=(256, 256)):
    counter = 0
    img_batch = []
    mask_batch = []
    for imgpath in paths.keys():
        counter += 1
        img = cv2.imread(imgpath)
        img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
        img = cv2.resize(img, outshape) / 255 # Normalize image
        
        mask = np.zeros(img.shape[:2], dtype=np.uint8)
        mask = cv2.resize(mask, outshape)
        tmpmask = cv2.imread(paths[imgpath], 0)
        tmpmask = cv2.resize(tmpmask, outshape)
        mask = cv2.bitwise_or(mask, tmpmask)
        mask = np.expand_dims(mask, -1) / 255 # Normalize image
            
        img_batch.append(img[np.newaxis, ...])
        mask_batch.append(mask[np.newaxis, ...])
        if counter % batchsize == 0:
            yield np.concatenate(img_batch, axis=0), np.concatenate(mask_batch, axis=0)
            counter = 0
            img_batch = []
            mask_batch = []

In [None]:
# Training and Validation Split
train_size = 80
training_paths = dict(list(input_pairs.items()))[:train_size]
validation_paths = dict(list(input_pairs.items()))[:train_size]

In [None]:
# Visualization
imgs, masks = next(normalize(input_pairs))
print("Batch size:", imgs.shape, masks.shape)
plt.figure(figsize=(8,8))
plt.tight_layout()
plt.subplot(2,2,1)
plt.axis('off')
plt.imshow(imgs[4])
plt.subplot(2,2,2)
plt.axis('off')
plt.imshow(masks[4].reshape(masks[4].shape[:2]), cmap='gray')
plt.subplot(2,2,3)
plt.axis('off')
plt.imshow(imgs[2])
plt.subplot(2,2,4)
plt.axis('off')
plt.imshow(masks[2].reshape(masks[2].shape[:2]), cmap='gray')
plt.show()

In [None]:
# TODO: DATA AUGMENTATION

In [None]:
from nn import get_model
import tensorflow.keras as tf

model = get_model()
model.compile(optimizer=tf.optimizers.Adam(learning_rate=1e-3),
    loss=tf.losses.BinaryCrossEntropy(),
    metrics=[tf.metrics.BinaryAccuracy(), tf.metrics.FalseNegatives()]
)