<a href="https://colab.research.google.com/github/mostakimjihad/EnFER-ViT/blob/main/EnFER_ViT(RAF_DB%2C_AfffectNet).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import kagglehub

# FER_2013_DATASET_PATH = kagglehub.dataset_download("msambare/fer2013")
RAF_DB_DATASET_PATH = kagglehub.dataset_download("shuvoalok/raf-db-dataset")
print(RAF_DB_DATASET_PATH)

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_path = RAF_DB_DATASET_PATH + '/DATASET//train'
test_path = RAF_DB_DATASET_PATH + '/DATASET//test'

train_datagen = ImageDataGenerator(
    rescale=1/255.0,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    width_shift_range=0.1,
    height_shift_range=0.1,
    brightness_range=[0.8, 1.2],
    rotation_range=15
)

train_generator = train_datagen.flow_from_directory(
    train_path,
    target_size=(100, 100),
    color_mode='rgb',
    batch_size=16,
    class_mode='categorical',
    shuffle=True
)

test_datagen = ImageDataGenerator(rescale=1/255.0)
test_generator = test_datagen.flow_from_directory(
    test_path,
    target_size=(100, 100),
    color_mode='rgb',
    batch_size=16,
    class_mode='categorical',
    shuffle=False
)

In [None]:
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.regularizers import l2
import tensorflow as tf
from tensorflow.keras.applications import EfficientNetB4
from tensorflow.keras.layers import Input, Dense, Flatten, Dropout, BatchNormalization, Reshape, GlobalAveragePooling1D
from tensorflow.keras.models import Model
from tensorflow.keras.layers import LayerNormalization, MultiHeadAttention, Add


efficientnet = EfficientNetB4(weights='imagenet', include_top=False, input_shape=(100, 100, 3))

for layer in efficientnet.layers[-20:]:
    layer.trainable = True

def transformer_encoder(inputs, num_heads=4, key_dim=32, ff_dim=128, dropout=0.1):
    attn_output = MultiHeadAttention(num_heads=num_heads, key_dim=key_dim)(inputs, inputs)
    attn_output = Dropout(dropout)(attn_output)
    out1 = Add()([inputs, attn_output])
    out1 = LayerNormalization(epsilon=1e-6)(out1)

    ffn = Dense(ff_dim, activation='relu')(out1)
    ffn = Dense(inputs.shape[-1])(ffn)
    ffn_output = Dropout(dropout)(ffn)
    out2 = Add()([out1, ffn_output])
    return LayerNormalization(epsilon=1e-6)(out2)


input_tensor = Input(shape=(100, 100, 3))
efficientnet_features = Flatten()(efficientnet(input_tensor))

combined_features = Dense(512, activation='relu', kernel_regularizer=l2(1e-4))(efficientnet_features)
combined_features = Dropout(0.2)(BatchNormalization()(combined_features))

sequence_length = 16
embedding_dim = combined_features.shape[-1] // sequence_length
reshaped_features = Reshape((sequence_length, embedding_dim))(combined_features)


transformer_output = transformer_encoder(reshaped_features, num_heads=4, key_dim=embedding_dim, ff_dim=128, dropout=0.1)
pooled_output = GlobalAveragePooling1D()(transformer_output)


output_tensor = Dense(7, activation='softmax')(pooled_output)


model = Model(inputs=input_tensor, outputs=output_tensor)

model.compile(
    optimizer='adam',
    loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
    metrics=['accuracy']
)

model.summary()


In [None]:
# Training with Callbacks
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
lr_scheduler = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3)

history = model.fit(
    train_generator,
    epochs=30,
    validation_data=test_generator,
    batch_size=16,
    callbacks=[early_stopping, lr_scheduler]
)

In [None]:
import matplotlib.pyplot as plt


accuracy = history.history['accuracy']
val_accuracy = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(accuracy) + 1)


plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(epochs, accuracy, label='Training Accuracy', marker='o')
plt.plot(epochs, val_accuracy, label='Validation Accuracy', marker='o')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.savefig('QoEOne.pdf')
files.download("QoEOne.pdf")


plt.subplot(1, 2, 2)
plt.plot(epochs, loss, label='Training Loss', marker='o', color='red')
plt.plot(epochs, val_loss, label='Validation Loss', marker='o', color='orange')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Show the plots
plt.tight_layout()
plt.show()
