<a href="https://colab.research.google.com/github/galnov/hello-world/blob/dataset/Binary_Image_Classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this notebook we'll create a binary image classifier, train it and use it to classify a given image

In [None]:
import tensorflow as tf

# Check if GPU is available
physical_devices = tf.config.list_physical_devices('GPU')
if len(physical_devices) == 0:
    print("No GPU devices available. Training on CPU.")
else:
    print("GPU is available. Training on GPU.")

In [None]:
!pip install keras-facenet

Creating the model

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from keras.layers import Dense, Flatten
from keras.models import Model, load_model
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from keras.utils import load_img, img_to_array
from keras_facenet import FaceNet


# Load the FaceNet model
facenet = FaceNet()

# Freeze the layers in the FaceNet model
for layer in facenet.model.layers:
    layer.trainable = False

# Adding a binary classifier on top of the FaceNet model
x = Flatten()(facenet.model.output)
x = Dense(128, activation='relu')(x)
output = Dense(1, activation='sigmoid')(x)

# Create a new model
model = Model(inputs=facenet.model.input, outputs=output)

# Compile the model
model.compile(optimizer=Adam(learning_rate=0.001), loss='binary_crossentropy',
              metrics=['accuracy'])
#model.summary()

Fetching training data

In [None]:
!git clone https://github.com/galnov/hello-world
%cd hello-world
# In case a specific branch is checked out, uncomment the next line
#!git checkout <branch>

Prepare the training and validation data

In [None]:
# Set up data generators
train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.25)

train_generator = train_datagen.flow_from_directory(
    'data',
    target_size=(160, 160),
    batch_size=4,
    class_mode='binary',
    subset='training'
)

validation_generator = train_datagen.flow_from_directory(
    'data',
    target_size=(160, 160),
    batch_size=4,
    class_mode='binary',
    subset='validation'
)

# Get the mapping of class names to numerical labels
class_indices = train_generator.class_indices

# Print the mapping
print("Class Indices:", class_indices)

(Optional) View a train batch as a sanity check

In [None]:
imgs, labels = next(train_generator)

def plots(ims, figsize=(12,6), titles=None):
    if type(ims[0]) is np.ndarray:
        ims = np.array(ims)
        if (ims.shape[-1] != 3):
            ims = ims.transpose((0,2,3,1))
    f = plt.figure(figsize=figsize)
    cols = len(ims) if len(ims) % 2 == 0 else len(ims) + 1
    for i in range(len(ims)):
        sp = f.add_subplot(1, cols, i+1)
        sp.axis('Off')
        if titles is not None:
            sp.set_title(titles[i], fontsize=16)
        plt.imshow(ims[i])

plots(imgs, titles=labels)

Train the model

In [None]:
# Train the model
history = model.fit(
    train_generator,
    epochs=10,
    validation_data=validation_generator
)

# Plot training history
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

# Save the trained model
# model.save('image_classifier_model.h5')

Running inference for emotion prediction

In [None]:
# Load the model for future use
#model = load_model('image_classifier_model.h5')

# Load and preprocess the new image for prediction
img_path = '/content/hello-world/Gal_Novik.png'
img = load_img(img_path, target_size=(160, 160))
img_array = img_to_array(img) / 255.0
plt.imshow(img_array)
img_array = np.expand_dims(img_array, axis=0)

# Make predictions
predictions = model.predict(img_array)
print(predictions)

# Assuming binary classification, you can interpret the predictions
# If the value is closer to 0, it belongs to the first class; if closer to 1,
# it belongs to the second class
if predictions[0][0] < 0.5:
    print("Prediction: Happy")
else:
    print("Prediction: Sad")