In [None]:
#Import libraries
import os
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.optimizers import Adam

In [None]:
#Base training directory path
base_dir = "[INSERT PATH]"

In [None]:
#Load data in batches
def load_and_prepare_data(path, image_size=(64, 64), batch_size=10000):
    batches = ImageDataGenerator().flow_from_directory(
        directory=path, 
        target_size=image_size, 
        batch_size=batch_size
    )
    data, labels = next(batches)
    x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=0.3, random_state=42)
    return x_train, x_test, y_train, y_test

In [None]:
#Define ResNet50 model
def build_model(image_size, num_classes):
    resnet = ResNet50(weights='imagenet', include_top=False, input_shape=(image_size[0], image_size[1], 3))
    model = Sequential([
        resnet,
        Flatten(),
        Dense(256, activation='relu'),
        Dense(num_classes, activation='softmax')
    ])
    for layer in resnet.layers:
        layer.trainable = False
    model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy'])
    return model

In [None]:
#Train CNN and log results
def train_model(model, x_train, y_train, x_test, y_test):
    history = model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=20)
    return history

In [None]:
#Load the datasets
x_train, x_test, y_train, y_test = load_and_prepare_data(base_dir)

In [None]:

#Build the model
model = build_model(image_size=(64, 64), num_classes=len(np.unique(np.argmax(y_train, axis=1))))


In [None]:

#Train the model
history = train_model(model, x_train, y_train, x_test, y_test)

#Print accuracy
accuracy = history.history['val_accuracy'][-1]
print(f"Validation Accuracy: {accuracy:.4f}")

In [None]:
#Plot training history
import matplotlib.pyplot as plt

# Function to plot training history
def plot_history(history):
    # Create a figure for the plots
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot training and validation accuracy
    axes[0].plot(history.history['accuracy'], label='Train Accuracy')
    axes[0].plot(history.history['val_accuracy'], label='Validation Accuracy')
    axes[0].set_title('Model Accuracy')
    axes[0].set_xlabel('Epochs')
    axes[0].set_ylabel('Accuracy')
    axes[0].legend()
    
    # Plot training and validation loss
    axes[1].plot(history.history['loss'], label='Train Loss')
    axes[1].plot(history.history['val_loss'], label='Validation Loss')
    axes[1].set_title('Model Loss')
    axes[1].set_xlabel('Epochs')
    axes[1].set_ylabel('Loss')
    axes[1].legend()

    # Display the plot
    plt.tight_layout()
    plt.show()

# Example of how you would call this function with your training history
# Assuming you have a `history` object from training your model
plot_history(history)
