In [None]:
import tensorflow as tf
import os
import matplotlib.pyplot as plt

# Define Data Paths 
base_dir = r"C:\Users\janak\OneDrive\Desktop\Projects\Ambulence_Project-Charindu\Ambulence_Project-Charindu\archive" 


train_dir = os.path.join(base_dir, 'Train')
val_dir = os.path.join(base_dir, 'Val')

#  Set Up Model Parameters
IMG_SIZE = (224, 224) # MobileNetV2 was trained on 224x224 images
BATCH_SIZE = 32       # How many images to process at once

# Load the Eye Data
# automatically find 'Open_Eyes' and 'Closed_Eyes'

print("Loading training data...")
train_eye_dataset = tf.keras.utils.image_dataset_from_directory(
    train_dir,
    label_mode='binary',
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    seed=123 
)

print("Loading validation data...")
val_eye_dataset = tf.keras.utils.image_dataset_from_directory(
    val_dir,
    label_mode='binary',
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    seed=123
)

# Verify the Classes
class_names = train_eye_dataset.class_names
print("\nSuccessfully loaded classes for Eye-State Detector:")
print(class_names) 

#  Optimize Data Loading 
AUTOTUNE = tf.data.AUTOTUNE

train_eye_dataset = train_eye_dataset.prefetch(buffer_size=AUTOTUNE)
val_eye_dataset = val_eye_dataset.prefetch(buffer_size=AUTOTUNE)

print("\n Data loading setup is complete.")

Loading training data...


TypeError: image_dataset_from_directory() got an unexpected keyword argument 'classes'

In [None]:
from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNetV2

# Create a Data Augmentation Layer 
data_augmentation = models.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.1),
    ]
)

#  Load the Pre-trained Base Model (MobileNetV2)

base_model = MobileNetV2(
    input_shape=IMG_SIZE + (3,), 
    include_top=False, 
    weights='imagenet'
)

#  Freeze the Base Model 

base_model.trainable = False

# Add Own Classification "Head"



inputs = tf.keras.Input(shape=IMG_SIZE + (3,))

# Apply data augmentation
x = data_augmentation(inputs)

# Pre-process the inputs (scales pixel values from [0, 255] to [-1, 1])
# This is how MobileNetV2 expects its data.
x = tf.keras.applications.mobilenet_v2.preprocess_input(x)

# Pass the data to the frozen base_model
x = base_model(x, training=False) # 'training=False' is important here

# flatten the output from the base model
x = layers.GlobalAveragePooling2D()(x)

# Add a "dropout" layer to prevent overfitting
x = layers.Dropout(0.2)(x)

# Add final output layer.
outputs = layers.Dense(1, activation='sigmoid')(x)

# Create and Compile the Final Model

eye_model = models.Model(inputs, outputs)

# Now we compile the model
eye_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), 
    loss='binary_crossentropy', 
    metrics=['accuracy'] 
)

#  Print a Summary

eye_model.summary()

print("\n Eye-State Detector model is built, compiled, and ready for training.")

In [None]:
# Set Training Parameters 

EPOCHS = 10

#  Train the Model 

print(f"Starting training for {EPOCHS} epochs...")

history = eye_model.fit(
    train_eye_dataset,
    epochs=EPOCHS,
    validation_data=val_eye_dataset
)

print("\n Training complete!")

#  Plot the Results (Accuracy and Loss)


acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(EPOCHS)

plt.figure(figsize=(12, 5))

# Plot Training & Validation Accuracy
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')

# Plot Training & Validation Loss
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')

# Show the plots
plt.show()