In [None]:
import glob

files = glob.glob('img/*.jpg')

In [None]:
import numpy as np
from keras.preprocessing.image import load_img, img_to_array

imgs = np.array([img_to_array(load_img(f)) for f in files])

imgs.shape

In [None]:
import os

th = 128

labels = np.array([img_to_array(load_img(f"{os.path.splitext(f)[0]}.png", color_mode = 'grayscale')) for f in files])

labels[labels < th] = 0
labels[labels >= th] = 1

labels.shape


In [None]:
b_len = len(labels[labels == 0])
w_len = len(labels[labels == 1])

print(f"0 = {b_len}, 1 = {w_len}")

b_len / w_len

In [None]:
from keras.models import Model
from keras.layers import Input, Dropout, GRU, Reshape, Bidirectional, Permute, concatenate, Dense
from keras.layers.convolutional import Conv2D
from keras.layers.normalization import BatchNormalization

input = Input(shape = imgs.shape[1:])

x = input

x = BatchNormalization()(x)

x = Conv2D(1, 3, padding='same', activation = 'relu')(x)

x = Reshape(imgs.shape[1:-1])(x)

x1 = x

x1 = Bidirectional(GRU(128, return_sequences = True, activation = 'relu'))(x1)
x1 = Bidirectional(GRU(128, return_sequences = True, activation = 'relu'))(x1)

x1 = BatchNormalization()(x1)
x1 = Dropout(0.3)(x1)

x1 = Bidirectional(GRU(int(imgs.shape[2] / 2), return_sequences = True, activation = 'relu'))(x1)
x1 = Reshape(imgs.shape[1:-1] + (1,))(x1)

x2 = Permute((2, 1))(x)

x2 = Bidirectional(GRU(128, return_sequences = True, activation = 'relu'))(x2)
x2 = Bidirectional(GRU(128, return_sequences = True, activation = 'relu'))(x2)

x2 = BatchNormalization()(x2)
x2 = Dropout(0.3)(x2)

x2 = Bidirectional(GRU(int(imgs.shape[1] / 2), return_sequences = True, activation = 'relu'))(x2)

x2 = Permute((2, 1))(x2)
x2 = Reshape(imgs.shape[1:-1] + (1,))(x2)

x = concatenate([x1, x2])

output = Dense(1, activation = 'sigmoid')(x)

model = Model(inputs = input, outputs = output)

model.compile(loss = 'binary_crossentropy', optimizer = 'adam', metrics = ['acc'])

model.summary()

In [None]:
wg = [1, b_len / w_len]

hist = model.fit(imgs, labels, initial_epoch = 0, epochs = 40, batch_size = 10, class_weight = wg)

hist

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = (8, 4)

plt.subplot(1, 2, 1)
plt.plot(hist.history['loss'])

plt.subplot(1, 2, 2)
plt.plot(hist.history['acc'])

In [None]:
model.save('model/r1-1_0.h5')

In [None]:
def predict(index, s = 6.0):
    plt.rcParams['figure.figsize'] = (s, s)

    sh = imgs.shape[1:-1]
    
    pred = model.predict(np.array([imgs[index]]))[0]
    pred *= 255

    plt.subplot(1, 3, 1)
    plt.imshow(imgs[index].astype(int))

    plt.subplot(1, 3, 2)
    plt.imshow(labels[index].reshape(sh), cmap = 'gray')

    plt.subplot(1, 3, 3)
    plt.imshow(pred.reshape(sh).astype(int), cmap = 'gray')
    

In [None]:
def predict_eval(file, s = 4.0):
    plt.rcParams['figure.figsize'] = (s, s)

    img = img_to_array(load_img(file))
    
    pred = model.predict(np.array([img]))[0]
    pred *= 255

    plt.subplot(1, 2, 1)
    plt.imshow(img.astype(int))

    plt.subplot(1, 2, 2)
    plt.imshow(pred.reshape(pred.shape[:-1]).astype(int), cmap = 'gray')

In [None]:
predict(0)

In [None]:
predict(1)

In [None]:
predict(2)

In [None]:
predict_eval('img_eval/t01.jpg')

In [None]:
predict_eval('img_eval/t02.jpg')