# Check for feasibility

In [1]:
import glob
import os
import tensorflow as tf
import cv2
from tensorflow.keras.layers import Conv2D, MaxPool2D, Flatten, Dense
from tensorflow.keras.models import Sequential

## Set hyperparameter

In [13]:
EPOCHS = 50

DATASET_PATH = '/Users/shim/dl-python-ImageDetection/dataset/2/'
DATASET_OK_PATTERN = '/Users/shim/dl-python-ImageDetection/dataset/2/OK/'
DATASET_FAIL_PATTERN = '/Users/shim/dl-python-ImageDetection/dataset/2/FAIL/'

RESULT_SAVE_PATH = '/Users/shim/dl-python-ImageDetection/dataset/results/'

## Simple Model Setting 

In [3]:
def Model():
    return Sequential([Conv2D(32, (3,3), activation='relu',input_shape=(256, 256, 1)),
                       MaxPool2D(),
                       Conv2D(64, (3,3), activation='relu'),
                       MaxPool2D(),
                       Conv2D(128, (3,3), activation='relu'),
                       MaxPool2D(),
                       Conv2D(256, (3,3), activation='relu'),
                       MaxPool2D(),
                       Flatten(),
                       Dense(1, activation='sigmoid')])

## Import dataset

In [5]:
def preprocess(file_name):
    img = tf.io.read_file(file_name)
    img = tf.image.decode_png(img, channels=1)
    return tf.image.convert_image_dtype(img, tf.float32)

In [23]:
def get_file_list(directory, extension):
    file_list = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith(extension):
                file_list.append(os.path.join(root, file))
    return file_list


In [14]:
ok_list = get_file_list(DATASET_OK_PATTERN, '.png')
ds_ok = tf.data.Dataset.list_files(ok_list)
ds_ok_label = tf.data.Dataset.from_tensor_slices([0] * len(ok_list))

ds_ok = ds_ok.map(preprocess)
ds_ok = tf.data.Dataset.zip((ds_ok, ds_ok_label))

fail_list = get_file_list(DATASET_FAIL_PATTERN, '.png')
ds_fail = tf.data.Dataset.list_files(fail_list)
ds_fail_label = tf.data.Dataset.from_tensor_slices([1] * len(fail_list))

ds_fail = ds_fail.map(preprocess)
ds_fail = tf.data.Dataset.zip((ds_fail, ds_fail_label))

ds = tf.data.Dataset.concatenate(ds_ok, ds_fail)

## Split Train, Validation set

In [15]:
ds_size = len(ok_list) + len(fail_list)
train_size = int(ds_size * 0.7)

ds = ds.shuffle(ds_size)
ds_train = ds.take(train_size).shuffle(1024, reshuffle_each_iteration=True).batch(32)
ds_validation = ds.skip(train_size).batch(32)

In [17]:
train_size

662

## Build and Train the Model

In [18]:
model = Model()
model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])

In [19]:
model.fit(ds_train, validation_data = ds_validation, epochs=EPOCHS)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


<keras.callbacks.History at 0x7fe4704b7fd0>

## Save results as images

In [21]:
def mkdir(path):
    if os.path.exists(path) is False:
        os.mkdir(path)

mkdir(RESULT_SAVE_PATH)
mkdir(RESULT_SAVE_PATH + '/TP')
mkdir(RESULT_SAVE_PATH + '/TN')
mkdir(RESULT_SAVE_PATH + '/FP')
mkdir(RESULT_SAVE_PATH + '/FN')

index = 0
for imgs, labels in ds_validation:
    preds = model.predict(imgs)
    for idx in range (imgs.shape[0]):
        gt = labels[idx].numpy()
        y = preds[idx]
        if gt == 1 and y > 0.5:
            path = RESULT_SAVE_PATH + '/TP'
        elif gt == 1 and y <= 0.5:
            path = RESULT_SAVE_PATH + '/FN'
        elif gt == 0 and y > 0.5:
            path = RESULT_SAVE_PATH + '/FP'
        else:
            path = RESULT_SAVE_PATH + '/TN'
        cv2.imwrite(path + '/%.f_%04d.png' % (y, index), imgs[idx].numpy() *255)
        index +=1

