In [24]:
# !pip install keras_resnet

In [27]:
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import ResNet50

# Data augmentation
train_datagen = ImageDataGenerator(
    rescale=1.0/255,
    horizontal_flip=True,
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1
)

val_datagen = ImageDataGenerator(rescale=1.0/255)

train_loader = train_datagen.flow_from_directory('data/train', target_size=(224, 224), batch_size=32, class_mode='categorical')
val_loader = val_datagen.flow_from_directory('data/test', target_size=(224, 224), batch_size=32, class_mode='categorical')

# Load ResNet50 model without the top layer
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# Add a global spatial average pooling layer
x = layers.GlobalAveragePooling2D()(base_model.output)

# Add a fully connected layer
x = layers.Dense(512, activation='relu')(x)

# Add the output layer with 2 classes
predictions = layers.Dense(2, activation='softmax')(x)

# Define the model
model = models.Model(inputs=base_model.input, outputs=predictions)

# Optionally freeze the layers of the base model to only train the new FC layers
for layer in base_model.layers:
    layer.trainable = False

# Compile the model with SGD optimizer
optimizer = optimizers.SGD(learning_rate=0.001, momentum=0.9)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

# Early stopping and model checkpoint
early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
checkpoint = ModelCheckpoint('best_model.keras', monitor='val_loss', save_best_only=True)

# Train the model
model.fit(
    train_loader,
    epochs=50,
    validation_data=val_loader,
    callbacks=[early_stopping, checkpoint]
)


Found 320 images belonging to 2 classes.
Found 82 images belonging to 2 classes.
Epoch 1/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 1s/step - accuracy: 0.6053 - loss: 0.6776 - val_accuracy: 0.7805 - val_loss: 0.6496
Epoch 2/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 1s/step - accuracy: 0.6007 - loss: 0.7027 - val_accuracy: 0.5854 - val_loss: 0.6488
Epoch 3/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 1s/step - accuracy: 0.4827 - loss: 0.7292 - val_accuracy: 0.5854 - val_loss: 0.6585
Epoch 4/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 1s/step - accuracy: 0.5669 - loss: 0.6556 - val_accuracy: 0.6585 - val_loss: 0.5882
Epoch 5/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 1s/step - accuracy: 0.6125 - loss: 0.6270 - val_accuracy: 0.5854 - val_loss: 0.6745
Epoch 6/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 1s/step - accuracy: 0.5657 - loss: 0.7032 -

<keras.src.callbacks.history.History at 0x3536a2e50>