In [None]:
import os, cv2, tensorflow as tf, numpy as np
from matplotlib import pyplot as plt
import imghdr

In [None]:
data_dir = 'data'
image_exts = ['jpeg','jpg','png','bmp']

In [None]:
for image_class in os.listdir(data_dir):
    if(image_class.endswith('.gitkeep')):
        continue
    for image in os.listdir(os.path.join(data_dir,image_class)):
        image_path = os.path.join(data_dir,image_class,image)
        try:
            img = cv2.imread(image_path)
            tip = imghdr.what(image_path)
            if tip not in image_exts:
                os.remove(image_path)
                
        except Exception as e:
            print('Issue with image {}'.format(image_path))
            os.remove(image_path)


In [None]:
data = tf.keras.utils.image_dataset_from_directory(data_dir)

In [None]:
data = data.map(lambda x,y: (x/255,y))

In [None]:
train_size = int(len(data)*.7)
val_size = int(len(data)*.2) + 1
test_size = int(len(data)*.1) + 1

In [None]:
train = data.take(train_size)
val = data.skip(train_size).take(val_size)
test = data.skip(train_size+val_size).take(test_size)

In [None]:
model = tf.keras.models.Sequential()

model.add(tf.keras.layers.Conv2D(16, (3,3), 1, activation='relu', input_shape=(256,256,3)))
model.add(tf.keras.layers.MaxPooling2D())

model.add(tf.keras.layers.Conv2D(32, (3,3), 1, activation='relu'))
model.add(tf.keras.layers.MaxPooling2D())

model.add(tf.keras.layers.Conv2D(16, (3,3), 1, activation='relu'))
model.add(tf.keras.layers.MaxPooling2D())

model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(256,activation='relu'))
model.add(tf.keras.layers.Dense(1,activation='sigmoid'))

model.compile('adam', loss='binary_crossentropy', metrics=['accuracy'])

In [None]:
hist = model.fit(train,epochs=20,validation_data=val)

In [None]:
model.save('sfw_nswf_identifier.h5')

In [None]:
fig = plt.figure()
plt.plot(hist.history['loss'], color='teal',label='loss')
plt.plot(hist.history['val_loss'], color='orange',label='val_loss')
fig.suptitle('Loss',fontsize=20)
plt.legend(loc='upper left')
plt.show()

In [None]:
fig = plt.figure()
plt.plot(hist.history['accuracy'], color='teal',label='accuracy')
plt.plot(hist.history['val_accuracy'], color='orange',label='val_accuracy')
fig.suptitle('Accuracy',fontsize=20)
plt.legend(loc='upper left')
plt.show()