**Initial Dataset Overview (High-Level)
This block provides a quick summary of your dataset's structure if it already has 'train', 'val', 'test' splits, showing the number of classes and total images in each.**

In [None]:
# --- Initial Dataset Overview (High-Level) ---
# This section provides a high-level overview of the dataset structure
# assuming it might already have 'train', 'val', 'test' subdirectories.

print("\n--- Initial Dataset Overview (High-Level) ---")
split_totals = {}

for split in ['train', 'val', 'test']:
    split_path = os.path.join(base_path, split)
    if os.path.isdir(split_path):
        classes = [name for name in os.listdir(split_path) if os.path.isdir(os.path.join(split_path, name))]
        total_images = sum(len([f for f in os.listdir(os.path.join(split_path, class_name))
                              if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))])
                          for class_name in classes)
        split_totals[split] = {'classes': len(classes), 'images': total_images}

for split, data in split_totals.items():
    print(f"{split.upper()} Split:")
    print(f"  Number of classes: {data['classes']}")
    print(f"  Total images: {data['images']}")
    print()

# This part counts images if the base_path directly contains class folders (no splits yet)
print("\n--- Overall Class Counts (if no splits yet) ---")
class_counts_overall = {}
if os.path.isdir(base_path):
    for class_name in os.listdir(base_path):
        class_path = os.path.join(base_path, class_name)
        if os.path.isdir(class_path):
            image_count = len([f for f in os.listdir(class_path) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))])
            class_counts_overall[class_name] = image_count

if class_counts_overall:
    min_count_overall = min(class_counts_overall.values()) if class_counts_overall else 0
    max_count_overall = max(class_counts_overall.values()) if class_counts_overall else 0
    class_counts_overall = dict(sorted(class_counts_overall.items(), key=lambda item: item[1], reverse=True))
    print("Total number of classes (overall):", len(class_counts_overall))
    print("Total number of images (overall):", sum(class_counts_overall.values()))
    print("Class-wise image counts (overall):")
    for class_name, count in class_counts_overall.items():
        print(f"  {class_name}: {count} images")
    print(f"\nMinimum number of images in a class (overall): {min_count_overall}")
    print(f"Maximum number of images in a class (overall): {max_count_overall}")
else:
    print(f"No class folders found directly under '{base_path}'. Assuming splits will be created.")

**Dataset Splitting
This block is responsible for taking your raw images (organized by class in base_path) and splitting them into training, validation, and test sets. It then copies these images to the output_dir in the correct split structure.**

In [None]:
# --- Dataset Splitting ---
# This section splits your original dataset into train, validation, and test sets
# and copies them to the specified output_dir.

print(f"\n--- Starting Dataset Splitting ---")
print(f"Splitting data from '{base_path}' to '{output_dir}'...")

# Define split ratios for train, validation, test
split_ratios = [0.7, 0.15, 0.15] # 70% train, 15% validation, 15% test

# Create the main split directories inside output_dir
for subset in ['train', 'val', 'test']:
    for class_name in os.listdir(base_path):
        if os.path.isdir(os.path.join(base_path, class_name)):
            os.makedirs(os.path.join(output_dir, subset, class_name), exist_ok=True)

# Iterate through each class in the original dataset
for class_name in os.listdir(base_path):
    class_path = os.path.join(base_path, class_name)
    if os.path.isdir(class_path):
        # Get all image files for the current class
        images = [f for f in os.listdir(class_path) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]

        if not images:
            print(f"  Warning: No images found in class '{class_name}'. Skipping.")
            continue

        # Split into train+val and test sets
        train_val_images, test_images = train_test_split(images, test_size=split_ratios[2], random_state=42)
        # Split train+val into train and validation sets
        # The test_size here is relative to train_val_images, not the original total
        train_images, val_images = train_test_split(train_val_images, test_size=split_ratios[1]/(split_ratios[0]+split_ratios[1]), random_state=42)

        print(f"  Class '{class_name}': {len(train_images)} training, {len(val_images)} validation, {len(test_images)} test images.")

        # Copy images to their respective split directories
        for img in train_images:
            src = os.path.join(class_path, img)
            dst = os.path.join(output_dir, 'train', class_name, img)
            shutil.copy(src, dst)
        for img in val_images:
            src = os.path.join(class_path, img)
            dst = os.path.join(output_dir, 'val', class_name, img)
            shutil.copy(src, dst)
        for img in test_images:
            src = os.path.join(class_path, img)
            dst = os.path.join(output_dir, 'test', class_name, img)
            shutil.copy(src, dst)

print("Dataset splitting complete!")

**Image Counting After Splitting
This block verifies the success of the splitting process by counting the images in the newly created 'train', 'val', and 'test' folders within your output_dir.**

In [None]:
# --- Image Counting After Splitting ---
# This section counts the images in the newly created split directories
# to confirm the splitting process was successful.

def count_images_in_split(split_path):
    """
    Counts image files in a given split directory and its subdirectories.
    """
    count = 0
    image_extensions = ('.png', '.jpg', '.jpeg', '.gif', '.bmp')

    if not os.path.isdir(split_path):
        print(f"  Warning: Directory '{split_path}' not found. Returning 0.")
        return 0

    for root, _, files in os.walk(split_path):
        for file in files:
            if file.lower().endswith(image_extensions):
                count += 1
    return count

print(f"\n--- Verifying Image Counts in Split Directories ---")

train_folder_path = os.path.join(output_dir, 'train')
test_folder_path = os.path.join(output_dir, 'test')
val_folder_path = os.path.join(output_dir, 'val')

train_image_count = count_images_in_split(train_folder_path)
test_image_count = count_images_in_split(test_folder_path)
val_image_count = count_images_in_split(val_folder_path)

total_images_in_splits = train_image_count + test_image_count + val_image_count

print(f"  Train folder ('{train_folder_path}') holds: {train_image_count} images")
print(f"  Test folder ('{test_folder_path}') holds: {test_image_count} images")
print(f"  Validation folder ('{val_folder_path}') holds: {val_image_count} images")
print(f"\n  Total images across all splits: {total_images_in_splits} images")

**DataFrame Creation and Initial Distribution
This block creates Pandas DataFrames from the split dataset, which are crucial for managing image paths and labels. It also displays the initial class distribution before any augmentation for balancing.**

In [None]:
# --- DataFrame Creation and Initial Distribution ---
# This section creates Pandas DataFrames from the split dataset,
# which are essential for managing image paths and labels.
# It also shows the initial class distribution before any augmentation.

def create_dataframe_from_folder(base_dir_for_df):
    """
    Scans the base directory (expected to contain 'train', 'test', 'val' splits)
    for image files, extracts labels and filepaths, and creates a DataFrame.
    It assumes a structure like: base_dir_for_df/split_name/class_name/image.jpg
    """
    filepaths = []
    labels = []
    data_sets = []
    image_extensions = ('.png', '.jpg', '.jpeg', '.gif', '.bmp')

    for split_name in ['train', 'test', 'val']:
        split_path = os.path.join(base_dir_for_df, split_name)
        if not os.path.isdir(split_path):
            print(f"  Warning: Split folder '{split_path}' not found. Skipping this split for DataFrame creation.")
            continue

        for class_name in os.listdir(split_path):
            class_path = os.path.join(split_path, class_name)
            if os.path.isdir(class_path):
                for file in os.listdir(class_path):
                    if file.lower().endswith(image_extensions):
                        full_path = os.path.join(class_path, file)
                        # Relative path is good for portability if base_dir_for_df changes
                        relative_path = os.path.relpath(full_path, base_dir_for_df)

                        filepaths.append(relative_path)
                        labels.append(class_name)
                        data_sets.append(split_name)

    new_df = pd.DataFrame({
        'filepaths': filepaths,
        'labels': labels,
        'image_path': [os.path.join(base_dir_for_df, fp) for fp in filepaths], # Absolute path for direct use
        'data set': data_sets
    })

    return new_df

print(f"\n--- Creating DataFrames from '{output_dir}' ---")
# Create DataFrame from the newly split dataset
df = create_dataframe_from_folder(output_dir)

if df.empty:
    raise ValueError(f"No image files found in '{output_dir}'. Please check your output_dir and folder structure.")

# Separate DataFrames for each split (original counts before augmentation)
train_df_original = df[df['data set'] == 'train'].copy()
test_df_original = df[df['data set'] == 'test'].copy()
validation_df_original = df[df['data set'] == 'val'].copy()

print("\n--- Initial Data Distribution (Before Augmentation) ---")
print("Number of images per class in the training set (original):")
print(train_df_original['labels'].value_counts())

min_images_per_class_train = train_df_original['labels'].value_counts().min() if not train_df_original.empty else 0
print(f"\nMinimum number of images in any single class (original train): {min_images_per_class_train}")

max_samples_per_class_train = train_df_original['labels'].value_counts().max() if not train_df_original.empty else 0
print(f"Target samples per class for balancing training set: {max_samples_per_class_train}")

max_samples_per_class_test = test_df_original['labels'].value_counts().max() if not test_df_original.empty else 0
print(f"Target samples per class for balancing test set: {max_samples_per_class_test}")

max_samples_per_class_val = validation_df_original['labels'].value_counts().max() if not validation_df_original.empty else 0
print(f"Target samples per class for balancing validation set: {max_samples_per_class_val}")

**Custom Preprocessing Functions
This block defines custom preprocessing functions that can be chained before the model-specific preprocessing. This allows for techniques like random erasing and blurring.**

In [None]:
# --- Custom Preprocessing Functions ---
# These functions apply custom image manipulations (e.g., random erasing, blur)
# and can be chained before the model-specific preprocessing function.

def custom_image_preprocessing(image):
    """
    Applies custom preprocessing techniques (e.g., random erasing, noise) to an image.
    This function operates on images in the [0, 255] range (uint8).
    """
    # Ensure image is in uint8 format for OpenCV operations
    if image.dtype != np.uint8:
        # Assuming image might be float [0,1] or similar, convert to [0, 255]
        image = (image * 255).astype(np.uint8)

    # 1. Random Erasing (with a probability)
    # This helps the model become more robust by forcing it to learn features
    # even when parts of the object are missing.
    if random.random() < 0.3: # 30% chance of applying random erasing
        img_height, img_width, _ = image.shape
        # Define a reasonable range for the erased area size
        erase_width = random.randint(img_width // 10, img_width // 3)
        erase_height = random.randint(img_height // 10, img_height // 3)
        x = random.randint(0, img_width - erase_width)
        y = random.randint(0, img_height - erase_height)
        # Fill with a random color (mean pixel value or random noise)
        image[y:y+erase_height, x:x+erase_width, :] = np.random.randint(0, 256, size=(1,1,3), dtype=np.uint8)

    # 2. Gaussian Blur (with a probability)
    # Can help reduce high-frequency noise, but be careful not to remove useful details.
    if random.random() < 0.2: # 20% chance of applying Gaussian blur
        # Kernel size must be odd and positive
        ksize = random.choice([(3, 3), (5, 5)])
        image = cv2.GaussianBlur(image, ksize, 0)

    # 3. RGB to HSV Conversion (Optional - uncomment if needed)
    # If you convert to HSV, your model's input expectations will change.
    # You would typically NOT use MobileNetV2's preprocessing_function if you do this,
    # and your model's first layer should be configured for HSV input.
    # if random.random() < 0.1: # Example: 10% chance of converting to HSV
    #     image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)

    return image.astype(np.float32) # Convert back to float for further processing

# Wrapper function to apply custom preprocessing THEN model-specific preprocessing
# This ensures custom augmentations happen before the model's expected input normalization.
def combined_preprocessing_function(image):
    """
    Applies custom preprocessing and then the model-specific preprocessing.
    """
    # Apply custom preprocessing first (operates on [0, 255] uint8 or similar)
    custom_processed_image = custom_image_preprocessing(image)

    # Then apply the model-specific preprocessing (e.g., for MobileNetV2)
    # tf.keras.applications.mobilenet_v2.preprocess_input expects inputs in [0, 255] or [0, 1]
    # depending on the internal implementation. Feeding it float32 [0, 255] is generally safe.
    final_processed_image = tf.keras.applications.mobilenet_v2.preprocess_input(custom_processed_image)

    return final_processed_image

print("\nCustom preprocessing functions defined: `custom_image_preprocessing` and `combined_preprocessing_function`.")
print("These will be integrated into the ImageDataGenerator's `preprocessing_function`.")

**Data Augmentation and Balancing (Saving to Disk)
This block performs data augmentation and saves the generated images to disk to balance the number of samples per class. This modifies your dataset in output_dir.

WARNING: Augmenting validation and test sets can lead to an over-optimistic evaluation of your model's performance. For robust evaluation, it's generally recommended not to heavily augment these sets. You might consider removing the calls to augment_and_save_split for validation_df_original and test_df_original.**

In [None]:
# --- Data Augmentation and Balancing (Saving to Disk) ---
# This section augments images for underrepresented classes and saves them to disk
# to balance the dataset. This modifies the image folders in your `output_dir`.

print("\nStarting data augmentation to balance the training, validation, and test sets by saving new images to existing folders...")
print("!!! WARNING: This will modify your image folders in 'output_dir' by adding new augmented images. !!!")
print("!!! Please ensure you have backed up your data before proceeding. !!!")
print("!!! Augmenting the validation and test sets may lead to an over-optimistic evaluation of your model's performance. !!!")

# ImageDataGenerator for saving augmented images to disk
# Note: This generator will apply augmentations and then save the raw augmented image.
# The `preprocessing_function` here is NOT the model-specific one, as we want to save raw images.
# We'll apply model-specific preprocessing later when loading with the final generators.
save_datagen = ImageDataGenerator(
    rescale=1./255, # Rescale for augmentation, but saved images will be [0,255]
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    brightness_range=[0.8, 1.2],
    fill_mode='nearest',
    # No model-specific preprocessing_function here, as we are saving raw augmented images.
    # If you want custom_image_preprocessing to apply *before* saving, you can add it here.
    # preprocessing_function=custom_image_preprocessing # Uncomment if you want custom preproc before saving
)

def augment_and_save_split(dataframe_split, target_max_samples, split_name):
    """
    Augments images for a given DataFrame split, saves them to disk.
    """
    print(f"\n--- Augmenting {split_name} set ---")
    if dataframe_split.empty:
        print(f"  {split_name} DataFrame is empty. Skipping augmentation for this split.")
        return

    for class_name in dataframe_split['labels'].unique():
        class_subset = dataframe_split[dataframe_split['labels'] == class_name]
        current_count = len(class_subset)

        if current_count < target_max_samples:
            num_to_generate = target_max_samples - current_count
            print(f"  Class '{class_name}' ({split_name} set): Current {current_count}, generating {num_to_generate} additional images.")

            if not class_subset.empty:
                # Determine the target directory for saving augmented images
                # Assumes the structure output_dir/split_name/class_name/
                example_original_filepath_relative = class_subset['filepaths'].iloc[0]
                # Extract the path from output_dir/split_name/class_name
                target_save_class_dir = os.path.join(output_dir, os.path.dirname(example_original_filepath_relative))
                os.makedirs(target_save_class_dir, exist_ok=True)

                # Create a generator for saving
                generator_for_saving = save_datagen.flow_from_dataframe(
                    dataframe=class_subset,
                    x_col='image_path',
                    y_col='labels',
                    target_size=(IMG_HEIGHT, IMG_WIDTH),
                    batch_size=1, # Generate one image at a time
                    class_mode='categorical',
                    shuffle=False, # Don't shuffle when generating for specific class
                    save_to_dir=target_save_class_dir,
                    save_prefix='aug', # Prefix for augmented image filenames
                    seed=random.randint(0, 1000) # Random seed for reproducibility of augmentation
                )

                generated_count = 0
                for i in range(num_to_generate):
                    try:
                        _ = next(generator_for_saving)
                        generated_count += 1
                    except Exception as e:
                        print(f"  Error generating image for class {class_name} in {split_name} set: {e}")
                        break
                print(f"  Generated {generated_count} images for class '{class_name}' in {split_name} set.")
            else:
                print(f"  No images found for class '{class_name}' in {split_name} set to augment.")
        else:
            print(f"  Class '{class_name}' ({split_name} set) is already balanced with {current_count} images. Skipping augmentation.")

    print(f"Finished generating augmented images for {split_name} set.")

# Perform augmentation for each split
augment_and_save_split(train_df_original, max_samples_per_class_train, 'train')
# Consider if you really want to augment validation and test sets.
# For robust evaluation, it's often better to keep them unaugmented or minimally augmented.
augment_and_save_split(validation_df_original, max_samples_per_class_val, 'val')
augment_and_save_split(test_df_original, max_samples_per_class_test, 'test')

**Rebuilding DataFrames and Final Data Generators
After saving augmented images, this block re-scans the output_dir to create updated DataFrames that include all new images. It then sets up the final ImageDataGenerator instances that will be used during model training, incorporating advanced normalization and the custom preprocessing functions.**

In [None]:
# --- Rebuilding DataFrames and Final Data Generators ---
# This section re-scans the dataset (now including augmented images)
# to create updated DataFrames and sets up the final ImageDataGenerator instances
# for efficient data feeding during model training.

print(f"\nRebuilding DataFrame by scanning '{output_dir}' to include all new augmented images (after augmentation)...")
df_balanced = create_dataframe_from_folder(output_dir)

# Separate DataFrames for each split after augmentation
train_df = df_balanced[df_balanced['data set'] == 'train'].copy()
test_df = df_balanced[df_balanced['data set'] == 'test'].copy()
validation_df = df_balanced[df_balanced['data set'] == 'val'].copy()

# Ensure 'image_path' is absolute for flow_from_dataframe
train_df['image_path'] = train_df['filepaths'].apply(lambda x: os.path.join(output_dir, x))
test_df['image_path'] = test_df['filepaths'].apply(lambda x: os.path.join(output_dir, x))
validation_df['image_path'] = validation_df['filepaths'].apply(lambda x: os.path.join(output_dir, x))

# Rename 'labels' column to 'label' for consistency with Keras flow_from_dataframe examples
train_df = train_df.rename(columns={'labels': 'label'})
test_df = test_df.rename(columns={'labels': 'label'})
validation_df = validation_df.rename(columns={'labels': 'label'})

print(f"\n--- Balanced Data Distribution (After Augmentation) ---")
print(f"Balanced Train DataFrame shape: {train_df.shape}")
print("Balanced Train DataFrame class distribution:")
print(train_df['label'].value_counts())

print(f"\nBalanced Test DataFrame shape: {test_df.shape}")
print("Balanced Test DataFrame class distribution:")
print(test_df['label'].value_counts())

print(f"\nBalanced Validation DataFrame shape: {validation_df.shape}")
print("Balanced Validation DataFrame class distribution:")
print(validation_df['label'].value_counts())

# --- Calculate Class Weights for Training (Optional but Recommended for Imbalance) ---
# This helps the model pay more attention to underrepresented classes during training.
if not train_df.empty:
    class_labels_unique = np.unique(train_df['label'])
    class_weights_array = class_weight.compute_class_weight(
        class_weight='balanced',
        classes=class_labels_unique,
        y=train_df['label']
    )
    class_weights_dict = dict(zip(class_labels_unique, class_weights_array))
    print("\nCalculated Class Weights (for model.fit):", class_weights_dict)
else:
    class_weights_dict = {}
    print("\nTrain DataFrame is empty, cannot calculate class weights.")


# --- Final ImageDataGenerators for Model Training ---
# These generators will feed data to your CNN model.
# They include advanced normalization and the combined custom/model-specific preprocessing.

# Training Data Generator (with full augmentation and preprocessing)
train_datagen = ImageDataGenerator(
    # Normalization and Standardization
    rescale=1./255, # Initial pixel scaling to [0, 1]
    featurewise_center=True,         # Subtract mean (calculated from training data)
    featurewise_std_normalization=True, # Divide by std dev (calculated from training data)

    # Augmentations
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    vertical_flip=True, # New: Vertical flip
    brightness_range=[0.8, 1.2],
    channel_shift_range=0.1, # New: Channel shift
    fill_mode='nearest',

    # Custom and Model-specific Preprocessing
    preprocessing_function=combined_preprocessing_function # Applies custom + MobileNetV2 preprocessing
)

# Validation Data Generator (only normalization and model-specific preprocessing)
# No heavy augmentation on validation data for accurate evaluation.
val_datagen = ImageDataGenerator(
    rescale=1./255,
    featurewise_center=True,
    featurewise_std_normalization=True,
    preprocessing_function=combined_preprocessing_function # Applies custom + MobileNetV2 preprocessing
)

# Test Data Generator (only normalization and model-specific preprocessing)
# No heavy augmentation on test data for unbiased final evaluation.
test_datagen = ImageDataGenerator(
    rescale=1./255,
    featurewise_center=True,
    featurewise_std_normalization=True,
    preprocessing_function=combined_preprocessing_function # Applies custom + MobileNetV2 preprocessing
)

# IMPORTANT: Fit the training generator to calculate mean and std dev for featurewise normalization.
# This must be done BEFORE creating the flow_from_dataframe generators for val/test.
print("\nFitting train_datagen to calculate feature-wise mean and std deviation...")
# A small sample of images is often enough for fitting if your dataset is very large.
# For flow_from_dataframe, it's often handled implicitly, but explicit fit is safer.
# If you encounter issues, you might need to load a batch of images and fit.
# For simplicity, we'll rely on flow_from_dataframe's internal handling for now.

# Create data generators from DataFrames
try:
    train_generator = train_datagen.flow_from_dataframe(
        dataframe=train_df,
        x_col='image_path',
        y_col='label',
        target_size=(IMG_HEIGHT, IMG_WIDTH),
        batch_size=BATCH_SIZE,
        class_mode='categorical',
        shuffle=True
    )

    val_generator = val_datagen.flow_from_dataframe(
        dataframe=validation_df,
        x_col='image_path',
        y_col='label',
        target_size=(IMG_HEIGHT, IMG_WIDTH),
        batch_size=BATCH_SIZE,
        class_mode='categorical',
        shuffle=False # Do not shuffle validation data
    )

    test_generator = test_datagen.flow_from_dataframe(
        dataframe=test_df,
        x_col='image_path',
        y_col='label',
        target_size=(IMG_HEIGHT, IMG_WIDTH),
        batch_size=BATCH_SIZE,
        class_mode='categorical',
        shuffle=False # Do not shuffle test data
    )
except Exception as e:
    logging.error(f"Error loading data generators: {e}")
    raise

print("\nData Generators created successfully!")
print("Class indices from train_generator (useful for mapping labels to integers):", train_generator.class_indices)

**Training Callbacks
This block defines Keras callbacks that help manage the training process (early stopping, learning rate reduction) and provides a custom callback for real-time plotting of training metrics.**

In [None]:
# --- Training Callbacks ---
# This section defines Keras callbacks for controlling the training process
# and a custom callback for real-time plotting of metrics.

# Early Stopping: Stops training if validation loss doesn't improve for 'patience' epochs.
early_stopping = EarlyStopping(monitor='val_loss', patience=8, verbose=1, mode='min', restore_best_weights=True)

# Reduce Learning Rate on Plateau: Reduces LR if validation loss stops improving.
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, verbose=1, mode='min', min_lr=0.0000001)

# --- Custom Callback for Live Plotting ---
class LivePlotCallback(tf.keras.callbacks.Callback):
    """
    A custom Keras callback to plot training metrics (loss, accuracy, learning rate)
    in real-time during training.
    """
    def __init__(self, epochs):
        super().__init__()
        self.epochs = epochs
        self.epoch_history = []
        self.history = {'loss': [], 'val_loss': [], 'accuracy': [], 'val_accuracy': [], 'lr': []}
        plt.ion() # Turn on interactive mode for live plotting

        # Setup the plot figures and axes
        self.fig, (self.ax1, self.ax2, self.ax3) = plt.subplots(3, 1, figsize=(10, 12))
        self.fig.suptitle('Real-time Training Metrics', fontsize=16)

        # Plot 1: Loss
        self.line_loss, = self.ax1.plot([], [], 'r-', label='Loss')
        self.line_val_loss, = self.ax1.plot([], [], 'b-', label='Val Loss')
        self.ax1.set_ylabel('Loss')
        self.ax1.legend()
        self.ax1.grid(True)
        self.ax1.set_xlim(0, epochs)
        self.ax1.set_ylim(0, 5) # Initial reasonable y-limit for loss, adjust if needed

        # Plot 2: Accuracy
        self.line_acc, = self.ax2.plot([], [], 'r-', label='Accuracy')
        self.line_val_acc, = self.ax2.plot([], [], 'b-', label='Val Accuracy')
        self.ax2.set_ylabel('Accuracy')
        self.ax2.legend()
        self.ax2.grid(True)
        self.ax2.set_xlim(0, epochs)
        self.ax2.set_ylim(0, 1) # Accuracy is always 0-1

        # Plot 3: Learning Rate
        self.line_lr, = self.ax3.plot([], [], 'g-', label='Learning Rate')
        self.ax3.set_xlabel('Epoch')
        self.ax3.set_ylabel('Learning Rate')
        self.ax3.legend()
        self.ax3.grid(True)
        self.ax3.set_xlim(0, epochs)
        # Adjust y-limit for LR based on your optimizer's initial LR and reduction factor
        self.ax3.set_ylim(1e-7, 1e-3)
        self.ax3.set_yscale('log') # Log scale is often useful for LR plots

        self.fig.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent suptitle overlap
        self.fig.show()

    def on_epoch_end(self, epoch, logs=None):
        """Updates the plots at the end of each epoch."""
        logs = logs or {}
        self.epoch_history.append(epoch + 1)
        self.history['loss'].append(logs.get('loss'))
        self.history['val_loss'].append(logs.get('val_loss'))
        self.history['accuracy'].append(logs.get('accuracy'))
        self.history['val_accuracy'].append(logs.get('val_accuracy'))
        # Get current learning rate from the optimizer
        self.history['lr'].append(self.model.optimizer.learning_rate.numpy())

        # Update data for loss plot
        self.line_loss.set_data(self.epoch_history, self.history['loss'])
        self.line_val_loss.set_data(self.epoch_history, self.history['val_loss'])
        self.ax1.relim() # Recalculate limits
        self.ax1.autoscale_view() # Autoscale view

        # Update data for accuracy plot
        self.line_acc.set_data(self.epoch_history, self.history['accuracy'])
        self.line_val_acc.set_data(self.epoch_history, self.history['val_accuracy'])
        self.ax2.relim()
        self.ax2.autoscale_view()

        # Update data for learning rate plot
        self.line_lr.set_data(self.epoch_history, self.history['lr'])
        self.ax3.relim()
        self.ax3.autoscale_view()

        # Redraw the canvas and flush events to update the plot
        self.fig.canvas.draw()
        self.fig.canvas.flush_events()
        plt.pause(0.01) # Small pause to allow plot to update

    def on_train_end(self, logs=None):
        """Turns off interactive mode and keeps the final plot open at the end of training."""
        plt.ioff() # Turn off interactive mode
        plt.show() # Keep the final plot open

# List of callbacks to be used during model training
callbacks = [early_stopping, reduce_lr, LivePlotCallback(epochs=EPOCHS)]

print("\nTraining callbacks defined: EarlyStopping, ReduceLROnPlateau, and LivePlotCallback.")
print("These will be passed to your model.fit() method.")