In [None]:
import tensorflow as tf
from tensorflow.keras.layers import (
    Input, Conv2D, Concatenate, GlobalAveragePooling2D, Dense, BatchNormalization
)
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Define the REEM Stack
def REEMStack(input_tensor, num_reem_layers=4):
    x = input_tensor
    for _ in range(num_reem_layers):
        x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
        x = Conv2D(128, (5, 5), activation='relu', padding='same')(x)
        x = HyAttentionBlock(x)
    return x

def HyAttentionBlock(input_tensor):
    y = Conv2D(128, (1, 1), activation='relu', padding='same')(input_tensor)
    x = Conv2D(64, (3, 3), activation='relu', padding='same')(input_tensor)
    x = Conv2D(128, (5, 5), activation='relu', padding='same')(x)
    x = tf.keras.layers.Add()([x, y])
    return x

def HyFiNet(input_shape, num_classes):
    inputs = Input(shape=input_shape)
    
    # Apply four stacks of REEM
    x = REEMStack(inputs)
    x = REEMStack(x)
    x = REEMStack(x)
    x = REEMStack(x)
    
    # Additional layers
    x = Conv2D(128, (1, 1), activation='relu', padding='same')(x)    
    x = GlobalAveragePooling2D()(x)
    x = Dense(256, activation='relu')(x)
    x = BatchNormalization()(x)
    
    # Output Layer
    outputs = Dense(num_classes, activation='softmax')(x)
    
    model = Model(inputs=inputs, outputs=outputs)
    return model

# Create the HyFiNet model
input_shape = (224, 224, 3)  # Adjust input shape as needed
num_classes = 10  # Number of classes in your dataset
model = HyFiNet(input_shape, num_classes)

# Display the model summary
model.summary()

# Define paths to your training and validation data
train_data_dir = '/content/drive/MyDrive/Color'
valid_data_dir = '/content/drive/MyDrive/Color'

# Define data generators with data augmentation
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

valid_datagen = ImageDataGenerator(rescale=1./255)

# Set batch size and image dimensions
batch_size = 5
img_height = 224
img_width = 224

# Create data generators
train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical'
)

valid_generator = valid_datagen.flow_from_directory(
    valid_data_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical'
)

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Train the model
epochs = 10  # Adjust the number of epochs as needed
history = model.fit(
    train_generator,
    epochs=epochs,
    validation_data=valid_generator,
    verbose=1
)

# Save the trained model
model.save('custom_hyfinet_model.h5')
