In [None]:
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout, Input
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from sklearn.utils.class_weight import compute_class_weight
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf

# Chemins des données
train_data_dir = '../../objects/datasets/fer2013/train'
validation_data_dir = '../../objects/datasets/fer2013/test'

# Prétraitement des données (augmentation + redimensionnement 224x224)
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=30,
    shear_range=0.3,
    zoom_range=0.3,
    horizontal_flip=True,
    fill_mode='nearest'
)

validation_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    color_mode='rgb',
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    shuffle=True
)

validation_generator = validation_datagen.flow_from_directory(
    validation_data_dir,
    color_mode='rgb',
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    shuffle=True
)

# Calcul des class weights
y_train_labels = train_generator.classes
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(y_train_labels),
    y=y_train_labels
)
class_weight_dict = dict(zip(np.unique(y_train_labels), class_weights))
print("Class Weights:", class_weight_dict)

# Base MobileNetV2
input_tensor = Input(shape=(224, 224, 3))
base_model = MobileNetV2(include_top=False, weights='imagenet', input_tensor=input_tensor)
base_model.trainable = True

# Fine-tuning: Unfreeze the last 40 layers
for layer in base_model.layers[:-40]:
    layer.trainable = False

# Head
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.3)(x)
x = Dense(128, activation='relu')(x)
x = Dropout(0.2)(x)
output = Dense(7, activation='softmax')(x)  # 7 classes for FER2013

# Model
model = Model(inputs=input_tensor, outputs=output)

# Compile
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

model.summary()

# Callbacks
early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', patience=3, factor=0.5, verbose=1)

# Entraînement
history = model.fit(
    train_generator,
    epochs=30,
    validation_data=validation_generator,
    class_weight=class_weight_dict,
    callbacks=[early_stop, reduce_lr],
    verbose=1
)

# Sauvegarde
model.save('mobilenetv2_finetune.h5')

# Évaluation
val_loss, val_accuracy = model.evaluate(validation_generator)
print(f"Validation Accuracy: {val_accuracy * 100:.2f}%")

# Courbes
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()


Found 47300 images belonging to 7 classes.
Found 7178 images belonging to 7 classes.
Class Weights: {0: 0.4834818873170333, 1: 0.746975774612299, 2: 1.649290421562816, 3: 0.9365409365409365, 4: 1.360955258236225, 5: 1.39899438036084, 6: 2.1309185926026037}


  base_model = MobileNetV2(include_top=False, weights='imagenet', input_tensor=input_tensor)


  self._warn_if_super_not_called()


Epoch 1/30
[1m1479/1479[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3404s[0m 2s/step - accuracy: 0.3213 - loss: 1.8478 - val_accuracy: 0.3108 - val_loss: 1.8262 - learning_rate: 1.0000e-05
Epoch 2/30
[1m1479/1479[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4689s[0m 3s/step - accuracy: 0.5429 - loss: 1.3611 - val_accuracy: 0.3957 - val_loss: 1.6326 - learning_rate: 1.0000e-05
Epoch 3/30
[1m1479/1479[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3s/step - accuracy: 0.6045 - loss: 1.2180