<a href="https://colab.research.google.com/github/Imran012x/Transfer-Models/blob/main/Hilsha_Final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ‚úÖ Step 1: Colab-Setup


In [1]:
import os
from google.colab import drive

drive_path = '/content/drive'

if os.path.exists(drive_path) and os.path.ismount(drive_path):
    print("Google Drive is already connected ‚úÖ")
else:
    drive.mount(drive_path)
    print("Google Drive connection done ‚úÖ")


# # Upload a file
# uploaded = files.upload()
# # Get the file name
# file_name = list(uploaded.keys())[0]
# print(f"Uploaded file: {file_name}")


# import zipfile
# import os
# # with zipfile.ZipFile('/content/drive/MyDrive/Hilsha/data_fish_224_11k.zip', 'r') as zip_ref:
# #     zip_ref.extractall('')
# with zipfile.ZipFile('/content/drive/MyDrive/Hilsha/data_fish_org_8407.zip', 'r') as zip_ref:
#     zip_ref.extractall('')

Google Drive is already connected ‚úÖ


# ‚úÖ Step 2: Import & Config & Env Setup

In [2]:
#1. IMPORTS AND INITIAL SETUP
# ================================================================================================================================
# Purpose: Import all required libraries and set up warnings to suppress unnecessary messages.



!pip install pytorch-gradcam optuna captum -q  # Uncomment if running in a new environment


import sys
import numpy
import pandas
import seaborn as sns

print(f"python_version: {sys.version.split()[0]}")
print(f"numpy_version: {numpy.__version__}")
print(f"pandas_version: {pandas.__version__}")
print(f"seaborn_version: {sns.__version__}\n")




# ============================================================
# Standard Library
# ============================================================
import os
import sys
import gc
import time
import json
import zipfile
import logging
import random
import warnings
import traceback
import logging
import subprocess
import threading
import traceback
from pathlib import Path
from threading import Lock
import multiprocessing as mp
from itertools import combinations
from datetime import datetime, timedelta
from collections import Counter, defaultdict
from typing import Tuple, Dict, Any, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed

# ============================================================
# Data Handling & Utilities
# ============================================================
import numpy as np
import pandas as pd
from tqdm import tqdm

# ============================================================
# Visualization
# ============================================================
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import cv2

# ============================================================
# System & Resource Monitoring
# ============================================================
import psutil
import pynvml

# ============================================================
# Machine Learning
# ============================================================
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (
    confusion_matrix, classification_report, f1_score, accuracy_score,
    precision_score, recall_score, roc_curve, auc
)
from sklearn.utils.class_weight import compute_class_weight
from sklearn.preprocessing import label_binarize

# Imbalanced data handling
from imblearn.over_sampling import SMOTE

# ============================================================
# Deep Learning - PyTorch
# ============================================================
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, SubsetRandomSampler
import torchvision.models as models
import torchvision.transforms as transforms

# ============================================================
# Augmentation
# ============================================================
import albumentations as A
from albumentations.pytorch import ToTensorV2


# ============================================================
# Explainable AI (XAI)
# ============================================================

import torch.autograd as autograd
from captum.attr import LRP

# Optuna
import optuna
import optuna.logging



# ============================================================
# Hyperparameter Optimization
# ============================================================
try:
    OPTUNA_AVAILABLE = True
except ImportError:
    OPTUNA_AVAILABLE = False
    print("Warning: Optuna not available. Using default hyperparameters.")



#For DeprecationWarning / FutureWarning specifically:
warnings.simplefilter("ignore", category=DeprecationWarning)
warnings.simplefilter("ignore", category=FutureWarning)

# Hide all pip warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)


# ---
# 2. CONFIGURATION
# ================================================================================================================================
# Purpose: Define configuration settings and initialize the environment.

class Config:


    OUTPUT_DIR = '/content/drive/MyDrive/Hilsha'

    # Dataset parameters
    NUM_CLASSES = 5
    CLASS_NAMES = ['Ilish', 'Chandana', 'Sardin', 'Sardinella', 'Punctatus']
    INPUT_SIZE = 224

    # Training parameters
    BATCH_SIZE = 32 #Will Change Dynamically
    DATALOADER_NUM_WORKERS = 1 #Will Change Dynamically
    # Dynamically adjust batch size and workers
    EPOCHS = 40
    PIN_MEMORY = True
    USE_MIXED_PRECISION = True #True
    COMPILE_MODEL = True
    PATIENCE = 4
    LEARNING_RATE = 1e-5
    WEIGHT_DECAY = 1e-4

    # Hyperparameter tuning
    OPTUNA_TRIALS = 100
    OPTUNA_EPOCHS = 10

    # Models to train
    # MODELS = ['resnet50','efficientnet_b0','mobilenet_v3_large','vgg16', 'densenet121']
    MODELS = [
        'resnet50',
        'efficientnet_b0'
        # # 'mobilenet_v3_large',
        # 'vgg16',
        # 'densenet121',
        # 'inception_v3',
        # 'vit_b_16',
        # 'convnext_base',
        # 'regnet_y_32gf'
    ]


    # Ensemble methods
    ENSEMBLE_METHODS = ['simple_average', 'weighted_average', 'confidence_based', 'learnable_weighted']

    # Device
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    SEED = 42




def setup_environment():
    """Setup random seeds, directories, and dynamically adjust batch size and workers"""

    os.environ['PYTHONHASHSEED'] = str(Config.SEED)  # For hash seed reproducibility
    random.seed(Config.SEED)
    np.random.seed(Config.SEED)
    torch.manual_seed(Config.SEED)
    torch.cuda.manual_seed_all(Config.SEED)  # For multi-GPU if applicable
    #Guard for GPU determinism (optional, but helpful if you want exact reproducibility across runs):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    warnings.filterwarnings("ignore")


    torch.manual_seed(Config.SEED)
    np.random.seed(Config.SEED)

    directories = [
        Config.OUTPUT_DIR,
        f"{Config.OUTPUT_DIR}/models",
        f"{Config.OUTPUT_DIR}/visualizations",
        f"{Config.OUTPUT_DIR}/reports",
        f"{Config.OUTPUT_DIR}/xai_visualizations"
    ]

    for directory in directories:
        Path(directory).mkdir(parents=True, exist_ok=True)#With exist_ok=True:Python will not raise an error if already exists.Or else raise a FileExistsError
        #& parents=True ‚Üí creates all missing parent directories in the path.

    # Ensure all output directories exist
    os.makedirs(f"{Config.OUTPUT_DIR}/best_model", exist_ok=True)
    os.makedirs(f"{Config.OUTPUT_DIR}/model_results", exist_ok=True)
    os.makedirs(f"{Config.OUTPUT_DIR}/kfold_results", exist_ok=True)
    os.makedirs(f"{Config.OUTPUT_DIR}/visualizations", exist_ok=True)

    print(f"Using device: {Config.DEVICE}")
    print(f"PyTorch version: {torch.__version__}")
    print(f"Dynamic BATCH_SIZE: {Config.BATCH_SIZE}, DATALOADER_NUM_WORKERS: {Config.DATALOADER_NUM_WORKERS}")
    print("-" * 70)


def worker_init_fn(worker_id):
    seed = Config.SEED + worker_id
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)


python_version: 3.12.11
numpy_version: 1.26.4
pandas_version: 2.2.2
seaborn_version: 0.13.2



# ‚úÖ Step 3: Pre-processing & Save

In [3]:

# # Check GPU availability
# print("GPU Available:", torch.cuda.is_available())
# print("GPU Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")


# # Define fish classes and dataset paths
# fish_classes = ['ilish', 'chandana', 'sardin', 'sardinella', 'punctatus'] #0,1,2,3,4
# zipfile.ZipFile('/content/drive/MyDrive/Hilsha/data_fish_org_8407.zip').extractall('/content/.hidden_fish')
# data_dir = '/content/.hidden_fish'

# image_limits = {
#     'ilish': 3000,
#     'chandana': 1185,
#     'sardin': 2899,
#     'sardinella': 370,
#     'punctatus': 953
# }

# # Settings
# total_images = sum(image_limits.values())
# batch_size = 100
# num_threads = 4


# # Output paths
# output_dir = '/content/drive/MyDrive/Hilsha'
# os.makedirs(output_dir, exist_ok=True)
# labels_file = os.path.join(output_dir, 'Y_labels.npy')
# xdata_file = os.path.join(output_dir, 'X_data.npy')

# save_lock = threading.Lock()  # for thread-safe writes -> Prevents race conditions when multiple threads write to the same list.

# # Function to gather image paths
# def get_image_paths(class_name, max_images):
#     path = os.path.join(data_dir, class_name)
#     files = sorted(os.listdir(path))
#     random.shuffle(files)
#     return [os.path.join(path, f) for f in files[:max_images]]

# # Load and preprocess batch
# def load_and_preprocess_batch(image_paths, start_idx, batch_size, class_idx):
#     end_idx = min(start_idx + batch_size, len(image_paths))
#     batch_paths = image_paths[start_idx:end_idx]
#     batch_images = []

#     for img_path in batch_paths:
#         img = Image.open(img_path).resize((224, 224)).convert('RGB')
#         img_tensor = torch.tensor(np.array(img), dtype=torch.uint8).permute(2, 0, 1)  # C x H x W
#         batch_images.append(img_tensor)

#     batch_tensor = torch.stack(batch_images)  # B x C x H x W
#     batch_labels = np.full((len(batch_images),), class_idx, dtype=np.int32)
#     return batch_tensor, batch_labels

# # Process one batch and return tensors & labels (no file saving)
# def process_batch(image_paths, start_idx, batch_size, class_idx):
#     return load_and_preprocess_batch(image_paths, start_idx, batch_size, class_idx)

# def preprocess_and_save_all(overwrite=True):
#     if os.path.exists(labels_file) and os.path.exists(xdata_file) and not overwrite:
#         print("Preprocessed data already exists. Set overwrite=True to reprocess.")
#         return

#     all_images = []
#     all_labels = []
#     processed_count = 0

#     for idx, class_name in enumerate(fish_classes):
#         print(f"\nProcessing class: {class_name}")
#         image_paths = get_image_paths(class_name, image_limits[class_name])
#         total_batches = (len(image_paths) + batch_size - 1) // batch_size
#         #It ensures ceiling division ‚Äî rounding up, not down.
#         # Normal division: 103 / 20 = 5.15 ‚Üí floor division // 20 = 5 (‚ùå missing last 3 images)
#         # This trick: (103 + 20 - 1) // 20 = 122 // 20 = 6 ‚úÖ

#         with ThreadPoolExecutor(max_workers=num_threads) as executor:
#             futures = []
#             for start in range(0, len(image_paths), batch_size):
#                 futures.append(executor.submit(process_batch, image_paths, start, batch_size, idx))

#             for future in tqdm(as_completed(futures), total=total_batches, desc=class_name):#taqaddum (ÿ™ŸÇÿØŸëŸÖ) ‚Äì Arabic for "progress".
#                 # futures: List of tasks (from ThreadPoolExecutor or ProcessPoolExecutor).
#                 # as_completed(futures): Yields each future as it finishes (not in order).

#                 batch_tensor, batch_labels = future.result()
#                 with save_lock: #Locks this section so that only one thread can update the shared lists safely.
#                     all_images.append(batch_tensor)
#                     all_labels.append(batch_labels)
#                     processed_count += batch_tensor.size(0)
#                     print(f"Processed batch with {batch_tensor.size(0)} images, total processed: {processed_count}/{total_images}")
#                 gc.collect()

#     # Combine all tensors and labels
#     X = torch.cat(all_images, dim=0).numpy()
#     Y = np.concatenate(all_labels, axis=0)

#     # Save final arrays
#     np.save(xdata_file, X, allow_pickle=False)#Malicious .npy -> import os;os.system("rm -rf /")  # ‚Üê Dangerous command
#     np.save(labels_file, Y, allow_pickle=False)

#     print(f"\n‚úÖ Done! Saved {processed_count} images in {xdata_file}")
#     print(f"X_data shape: {X.shape}, Y_labels shape: {Y.shape}")

#     if processed_count != total_images:
#         raise ValueError(f"Expected {total_images} images, but processed {processed_count}")

# # Run preprocessing and save directly to X_data.npy and Y_labels.npy
# preprocess_and_save_all(overwrite=True)










class FishDataset(Dataset):
    def __init__(self, images, labels, transform=None):

        self.images = self._preprocess_images(images)
        self.labels = labels.astype(np.int64)
        self.transform = transform #Here means: Medium,Heavy or Any

    def _preprocess_images(self, images):
        """Preprocess images to ensure proper format and normalization"""
        if images.max() > 1.5: #üëâ The threshold 1.5 is just a safe cutoff to distinguish between the two cases.
            #Because some normalized images can have values slightly above 1.0 (e.g., after augmentations, rounding, or scaling bugs).
            images = images.astype(np.float32) / 255.0

        if len(images.shape) == 4 and images.shape[1] == 3: #If input is (batch, channels, height, width) ‚Üí convert to (batch, height, width, channels) (common for TensorFlow).
            images = np.transpose(images, (0, 2, 3, 1))
        return images.astype(np.float32)

    def __len__(self):
        """Return the total number of samples in the dataset"""
        return len(self.images)




    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if self.transform:  #Applies an Albumentations transform pipeline (it returns a dict, so you take ['image']).
            image = self.transform(image=image)['image'] #
        else:
            image = torch.from_numpy(image).permute(2, 0, 1)
        #With transform ‚Üí advanced augmentations.
        #Without transform ‚Üí just convert to PyTorch format.


        # Convert label to plain Python int to avoid CUDA tensor creation in workers.That wastes memory and slows down training.
        if isinstance(label, np.ndarray):
            label = int(label.item())
        elif hasattr(label, 'item'):
            label = int(label.item())
        else:
            label = int(label)


        return image, label  # Plain Python int, not torch.tensor
        # return image, torch.tensor(int(label), dtype=torch.long)




    # def __getitem__(self, idx):
    #     image = self.images[idx]  # H x W x C
    #     label = self.labels[idx]

    #     # Ensure image has 3 channels
    #     if image.ndim == 2:  # grayscale H x W
    #         image = np.stack([image]*3, axis=-1)
    #     elif image.shape[-1] == 4:  # RGBA
    #         image = image[:, :, :3]

    #     # Apply Albumentations transform if any
    #     if self.transform:
    #         image = self.transform(image=image)['image']  # may already be tensor

    #     # Convert to PyTorch tensor C x H x W if it's a numpy array
    #     if isinstance(image, np.ndarray):
    #         image = torch.from_numpy(image).permute(2, 0, 1).float()
    #     elif isinstance(image, torch.Tensor) and image.ndim == 3 and image.shape[0] != 3:
    #         # If transform returns H x W x C tensor, permute to C x H x W
    #         image = image.permute(2, 0, 1).float()
    #     # else assume it's already C x H x W

    #     # Convert label to tensor
    #     label = int(label) if not isinstance(label, torch.Tensor) else label.long()

    #     return image, torch.tensor(label, dtype=torch.long)




        # class MyClass:
        #     def greet(self):
        #         print("Hello!")
        # obj = MyClass()
        # print(hasattr(obj, 'greet'))   # True, because obj has a method greet
        # print(hasattr(obj, 'name'))    # False, no attribute called name
        # # Using hasattr with .item()
        # import torch
        # x = torch.tensor(5)  # scalar tensor
        # print(hasattr(x, 'item'))      # True
        # print(x.item())                # 5

        # return image, torch.tensor(label, dtype=torch.long)  # <-- ensure label is tensor




class DataManager:
    @staticmethod  #In Python, @staticmethod is used to define a method that belongs to a class but doesn‚Äôt access self or cls.

    # class DataManager:
    # staticmethod
    # def greet(name):
    #     return f"Hello, {name}!"
    # # Call without creating an instance
    # print(DataManager.greet("Imran"))  # Output: Hello, Imran!
    # # Call with an instance
    # dm = DataManager()
    # print(dm.greet("Imran"))           # Output: Hello, Imran!

    # class MyClass:
    #     count = 0

    #     staticmethod
    #     def greet(name):
    #         return f"Hello, {name}!"

    #     classmethod
    #     def increment_count(cls):
    #         cls.count += 1
    #         return cls.count

    # # Static method
    # print(MyClass.greet("Imran"))      # Hello, Imran!
    # # Class method
    # print(MyClass.increment_count())   # 1
    # print(MyClass.increment_count())   # 2
    #Static method ‚Üí independent of class/instance.
    #Class method ‚Üí works with the class itself (cls), can modify class variables.


    def get_transforms(is_training=True, augmentation_strength='medium'):
        """Get data transforms with configurable augmentation strength"""
        if is_training:
            if augmentation_strength == 'light':
                return A.Compose([
                    A.Resize(Config.INPUT_SIZE, Config.INPUT_SIZE),
                    A.HorizontalFlip(p=0.3),
                    A.RandomRotate90(p=0.3),
                    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
                    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                    ToTensorV2()
                ])
            elif augmentation_strength == 'heavy':
                return A.Compose([
                    A.Resize(Config.INPUT_SIZE, Config.INPUT_SIZE),
                    A.HorizontalFlip(p=0.7),
                    A.VerticalFlip(p=0.5),
                    A.RandomRotate90(p=0.7),
                    # A.ShiftScaleRotate(shift_limit=0.3, scale_limit=0.3, rotate_limit=45, p=0.8),
                    # A.RandomBrightnessContrast(brightness_limit=0.4, contrast_limit=0.4, p=0.8),
                    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5),
                    A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15, p=0.5),
                    # A.HueSaturationValue(hue_shift_limit=30, sat_shift_limit=40, val_shift_limit=30, p=0.7),
                    A.GaussianBlur(blur_limit=(3, 9), p=0.5),
                    A.GaussNoise(var_limit=(10.0, 80.0), p=0.4),
                    A.CoarseDropout(max_holes=12, max_height=25, max_width=25, p=0.5),
                    A.ElasticTransform(p=0.3),
                    A.GridDistortion(p=0.3),
                    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

                    A.RandomFog(fog_coef_lower=0.1, fog_coef_upper=0.3, p=0.3),
                    A.RandomRain(blur_value=3, p=0.2),
                    A.ColorJitter(hue=0.1, p=0.5),

                    ToTensorV2()
                ])
            else:  # medium
                return A.Compose([
                    A.Resize(Config.INPUT_SIZE, Config.INPUT_SIZE),
                    A.HorizontalFlip(p=0.5),
                    A.VerticalFlip(p=0.3),
                    A.RandomRotate90(p=0.5),
                    # A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30, p=0.7),
                    # A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.7),
                    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=10, p=0.1),
                    A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15, p=0.5),
                    # A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.6),
                    A.GaussianBlur(blur_limit=(3, 7), p=0.4),
                    A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
                    A.CoarseDropout(max_holes=8, max_height=20, max_width=20, p=0.4),
                    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                    ToTensorV2()
                ])
        else:
            return A.Compose([
                A.Resize(Config.INPUT_SIZE, Config.INPUT_SIZE),
                A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2()
            ])


    @staticmethod
    def load_and_balance_data():
        """Load data and apply SMOTE"""
        print("Loading and preprocessing data...")

        # # Check GPU availability
        # print("GPU Available:", torch.cuda.is_available())
        # print("GPU Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")


        # # Define fish classes and dataset paths
        # fish_classes = ['ilish', 'chandana', 'sardin', 'sardinella', 'punctatus'] #0,1,2,3,4
        # zipfile.ZipFile('/content/drive/MyDrive/Hilsha/data_fish_org_8407.zip').extractall('/content/.hidden_fish')
        # data_dir = '/content/.hidden_fish'

        # image_limits = {
        #     'ilish': 3000,
        #     'chandana': 1185,
        #     'sardin': 2899,
        #     'sardinella': 370,
        #     'punctatus': 953
        # }

        # # Settings
        # total_images = sum(image_limits.values())
        # batch_size = 100
        # num_threads = 4


        # # Output paths
        # output_dir = '/content/drive/MyDrive/Hilsha'
        # os.makedirs(output_dir, exist_ok=True)
        # labels_file = os.path.join(output_dir, 'Y_labels.npy')
        # xdata_file = os.path.join(output_dir, 'X_data.npy')

        # save_lock = threading.Lock()  # for thread-safe writes -> Prevents race conditions when multiple threads write to the same list.

        # # Function to gather image paths
        # def get_image_paths(class_name, max_images):
        #     path = os.path.join(data_dir, class_name)
        #     files = sorted(os.listdir(path))
        #     random.shuffle(files)
        #     return [os.path.join(path, f) for f in files[:max_images]]

        # # Load and preprocess batch
        # def load_and_preprocess_batch(image_paths, start_idx, batch_size, class_idx):
        #     end_idx = min(start_idx + batch_size, len(image_paths))
        #     batch_paths = image_paths[start_idx:end_idx]
        #     batch_images = []

        #     for img_path in batch_paths:
        #         img = Image.open(img_path).resize((224, 224)).convert('RGB')
        #         img_tensor = torch.tensor(np.array(img), dtype=torch.uint8).permute(2, 0, 1)  # C x H x W
        #         batch_images.append(img_tensor)

        #     batch_tensor = torch.stack(batch_images)  # B x C x H x W
        #     batch_labels = np.full((len(batch_images),), class_idx, dtype=np.int32)
        #     return batch_tensor, batch_labels

        # # Process one batch and return tensors & labels (no file saving)
        # def process_batch(image_paths, start_idx, batch_size, class_idx):
        #     return load_and_preprocess_batch(image_paths, start_idx, batch_size, class_idx)

        # def preprocess_and_save_all(overwrite=True):
        #     if os.path.exists(labels_file) and os.path.exists(xdata_file) and not overwrite:
        #         print("Preprocessed data already exists. Set overwrite=True to reprocess.")
        #         return

        #     all_images = []
        #     all_labels = []
        #     processed_count = 0

        #     for idx, class_name in enumerate(fish_classes):
        #         print(f"\nProcessing class: {class_name}")
        #         image_paths = get_image_paths(class_name, image_limits[class_name])
        #         total_batches = (len(image_paths) + batch_size - 1) // batch_size
        #         #It ensures ceiling division ‚Äî rounding up, not down.
        #         # Normal division: 103 / 20 = 5.15 ‚Üí floor division // 20 = 5 (‚ùå missing last 3 images)
        #         # This trick: (103 + 20 - 1) // 20 = 122 // 20 = 6 ‚úÖ

        #         with ThreadPoolExecutor(max_workers=num_threads) as executor:
        #             futures = []
        #             for start in range(0, len(image_paths), batch_size):
        #                 futures.append(executor.submit(process_batch, image_paths, start, batch_size, idx))

        #             for future in tqdm(as_completed(futures), total=total_batches, desc=class_name):#taqaddum (ÿ™ŸÇÿØŸëŸÖ) ‚Äì Arabic for "progress".
        #                 # futures: List of tasks (from ThreadPoolExecutor or ProcessPoolExecutor).
        #                 # as_completed(futures): Yields each future as it finishes (not in order).

        #                 batch_tensor, batch_labels = future.result()
        #                 with save_lock: #Locks this section so that only one thread can update the shared lists safely.
        #                     all_images.append(batch_tensor)
        #                     all_labels.append(batch_labels)
        #                     processed_count += batch_tensor.size(0)
        #                     print(f"Processed batch with {batch_tensor.size(0)} images, total processed: {processed_count}/{total_images}")
        #                 gc.collect()

        #     # Combine all tensors and labels
        #     X = torch.cat(all_images, dim=0).numpy()
        #     Y = np.concatenate(all_labels, axis=0)

        #     # Save final arrays
        #     np.save(xdata_file, X, allow_pickle=False)#Malicious .npy -> import os;os.system("rm -rf /")  # ‚Üê Dangerous command
        #     np.save(labels_file, Y, allow_pickle=False)

        #     print(f"\n‚úÖ Done! Saved {processed_count} images in {xdata_file}")
        #     print(f"X_data shape: {X.shape}, Y_labels shape: {Y.shape}")

        #     if processed_count != total_images:
        #         raise ValueError(f"Expected {total_images} images, but processed {processed_count}")

        # # Run preprocessing and save directly to X_data.npy and Y_labels.npy
        # preprocess_and_save_all(overwrite=True)







        # X = np.load(Config.DATA_FILE)
        # Y = np.load(Config.LABELS_FILE)







        # # Your data path
        # output_dir = '/content/drive/MyDrive/Hilsha'
        # data_file = os.path.join(output_dir, 'X_data.npy')
        # labels_file = os.path.join(output_dir, 'Y_labels.npy')

        # # Readable size format
        # def sizeof_fmt(num, suffix='B'):
        #     for unit in ['', 'K', 'M', 'G', 'T']:
        #         if abs(num) < 1024.0:
        #             return f"{num:3.2f} {unit}{suffix}"
        #         num /= 1024.0
        #     return f"{num:.2f} T{suffix}"

        # # Main loader
        # def load_preprocessed_data(as_torch=True, normalize=True, to_device=None):
        #     # Check file existence #cpu,cuda (CUDA stands for Compute Unified Device Architecture.)
        #     for path in [data_file, labels_file]:
        #         if not os.path.exists(path):
        #             raise FileNotFoundError(f"Missing: {path}")

        #     # Print file sizes
        #     print(f"üìÅ X_data.npy: {sizeof_fmt(os.path.getsize(data_file))}")
        #     print(f"üìÅ Y_labels.npy: {sizeof_fmt(os.path.getsize(labels_file))}")

        #     # Load with mmap
        #     X = np.load(data_file, mmap_mode='r')
        #     Y = np.load(labels_file, mmap_mode='r')

        #     print(f"‚úÖ X shape: {X.shape}, dtype: {X.dtype}")
        #     print(f"‚úÖ Y shape: {Y.shape}, dtype: {Y.dtype}")

        #     # Sanity check
        #     if len(X) != len(Y):
        #         raise ValueError("Mismatch between number of samples in X and Y")

        #     # Convert to torch
        #     if as_torch:
        #         X = torch.from_numpy(X)
        #         Y = torch.from_numpy(Y)

        #         if normalize and X.dtype == torch.uint8:
        #             X = X.float() / 255.0

        #         if to_device:
        #             X = X.to(to_device)
        #             Y = Y.to(to_device)

        #         print(f"üß† Torch tensors ready on {to_device or 'CPU'}")

        #     return X, Y

        # # üîÅ Example call
        # X, Y = load_preprocessed_data(
        #     as_torch=True,
        #     normalize=True,
        #     to_device='cuda' if torch.cuda.is_available() else 'cpu'
        # )

        # print(f"\nOriginal data shape: {X.shape}")
        # # print(f"Original class distribution: {np.bincount(Y)}")
        # class_dist = np.bincount(Y.cpu().numpy()) if torch.is_tensor(Y) else np.bincount(Y)
        # print(f"Original class distribution: {class_dist}")






        # print("Applying SMOTE for class balancing...")
        # X_flat = X.reshape(X.shape[0], -1)
        # smote = SMOTE(random_state=Config.SEED, k_neighbors=min(5, np.bincount(Y).min()-1))
        # X_balanced_flat, Y_balanced = smote.fit_resample(X_flat, Y)
        # X_balanced = X_balanced_flat.reshape(-1, *X.shape[1:])
        # print(f"Balanced data shape: {X_balanced.shape}")
        # print(f"Balanced class distribution: {np.bincount(Y_balanced)}")
        # return X_balanced, Y_balanced



        # Remove SMOTE completely and use WeightedRandomSampler only
        # Using WeightedRandomSampler instead of SMOTE
        # Compute weights and create sampler during DataLoader, not here
        # return X, Y
        # # Example data
        # X = torch.randn(100, 3, 32, 32)  # 100 images
        # Y = torch.randint(0, 5, (100,))  # 5 classes, imbalanced
        # # Compute class weights
        # class_counts = torch.bincount(Y)
        # class_weights = 1.0 / class_counts.float()
        # sample_weights = class_weights[Y]  # assign weight to each sample
        # # Create sampler
        # sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
        # # Create DataLoader
        # dataset = TensorDataset(X, Y)
        # loader = DataLoader(dataset, batch_size=16, sampler=sampler)



        print("Applying SMOTE for class balancing...")
        # Apply SMOTE with reduced k_neighbors and combine with WeightedRandomSampler
        X_flat = X.reshape(X.shape[0], -1)
        smote = SMOTE(random_state=Config.SEED, k_neighbors=3, sampling_strategy='not majority')
        # smote = SMOTE(random_state=Config.SEED, k_neighbors=2, sampling_strategy= 'auto')
        X_balanced_flat, Y_balanced = smote.fit_resample(X_flat, Y)
        X_balanced = X_balanced_flat.reshape(-1, *X.shape[1:])
        # Ensures WeightedRandomSampler is still used in DataLoader
        print(f"Balanced data shape: {X_balanced.shape}")
        print(f"Balanced class distribution: {np.bincount(Y_balanced)}")
        return X_balanced, Y_balanced
        # Benefit: Using a smaller k_neighbors=3 reduces the risk of generating unnatural
        # image artifacts, while sampling_strategy='not majority' balances classes more conservatively.
        # Retaining WeightedRandomSampler in the DataLoader further ensures balanced sampling during
        # training, maintaining smoothness and preventing accuracy drops by avoiding over-reliance
        # on SMOTE-generated samples.



        # print("Applying SMOTE for class balancing...")

        # X_flat = X.cpu().numpy().reshape(X.shape[0], -1) if torch.is_tensor(X) else X.reshape(X.shape[0], -1)
        # Y_np = Y.cpu().numpy() if torch.is_tensor(Y) else Y

        # smote = SMOTE(random_state=Config.SEED, k_neighbors=3, sampling_strategy='not majority')
        # X_balanced_flat, Y_balanced = smote.fit_resample(X_flat, Y_np)
        # X_balanced = X_balanced_flat.reshape(-1, *X.shape[1:])

        # print(f"Balanced data shape: {X_balanced.shape}")
        # print(f"Balanced class distribution: {np.bincount(Y_balanced)}")
        # return X_balanced, Y_balanced






    @staticmethod
    def create_data_loaders(X, Y, test_size=0.2, batch_size=None, augmentation_strength='medium'):


        X_temp, X_test, y_temp, y_test = train_test_split(X, Y, test_size=test_size, random_state=Config.SEED, stratify=Y)
        X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=0.25, random_state=Config.SEED, stratify=y_temp)

        print(f"Train: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}")
        print(f"Using optimized batch size: {batch_size}")



        train_dataset = FishDataset(X_train, y_train,DataManager.get_transforms(True, augmentation_strength))
        val_dataset = FishDataset(X_val, y_val, DataManager.get_transforms(False))
        test_dataset = FishDataset(X_test, y_test, DataManager.get_transforms(False))


        class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
        #compute_class_weight('balanced', ...) gives higher weight to minority classes.
        sample_weights = [class_weights[y] for y in y_train]
        sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)
        # Samples with higher weights are more likely to be picked in each batch.
        # replacement=True allows oversampling of minority classes. ‚úÖ


        # Conditionally set prefetch_factor based on num_workers
        prefetch_factor = 2 if Config.DATALOADER_NUM_WORKERS > 0 else None
        pin_memory=Config.PIN_MEMORY if 'cuda' in Config.DEVICE else False
        num_workers = Config.DATALOADER_NUM_WORKERS if torch.cuda.is_available() else 0
        use_prefetch = num_workers > 0


        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            # sampler=sampler, #Imbalanced dataset ‚Üí use sampler.Balanced dataset ‚Üí use shuffle=True.
            shuffle=True,
            num_workers=num_workers,
            pin_memory=torch.cuda.is_available(),
            prefetch_factor=2 if use_prefetch else None,  # Only use prefetch_factor when num_workers > 0
            # persistent_workers=Config.DATALOADER_NUM_WORKERS > 0,
            # persistent_workers=False,
            worker_init_fn=worker_init_fn  # Add this
        )
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=torch.cuda.is_available(),
            prefetch_factor=2 if use_prefetch else None,  # Only use prefetch_factor when num_workers > 0
            # persistent_workers=Config.DATALOADER_NUM_WORKERS > 0,
            # persistent_workers=False,
            worker_init_fn=worker_init_fn  # Add this
        )
        test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=torch.cuda.is_available(),
            prefetch_factor=2 if use_prefetch else None,  # Only use prefetch_factor when num_workers > 0
            # persistent_workers=Config.DATALOADER_NUM_WORKERS > 0,
            # persistent_workers=False, #False is slow but exact reproductivity ensures & workers reset each epoch).
            worker_init_fn=worker_init_fn  # Add this
        )

        return train_loader, val_loader, test_loader, (X_val, y_val), (X_test, y_test)



# ‚úÖ Step 4: Loading

In [4]:
DATA_FILE = '/content/drive/MyDrive/Hilsha/X_data.npy'
LABEL_FILE = '/content/drive/MyDrive/Hilsha/Y_labels.npy'


X = np.load(DATA_FILE)
Y = np.load(LABEL_FILE)

# ‚úÖ Step 5:Data Visualization [From Processed Image]


In [5]:
# import numpy as np
# import matplotlib.pyplot as plt
# import seaborn as sns
# import pandas as pd
# from collections import Counter
# import os
# from sklearn.decomposition import PCA
# from sklearn.manifold import TSNE
# import warnings
# warnings.filterwarnings('ignore')


# # Scientific plotting setup
# plt.style.use('seaborn-v0_8')
# sns.set_palette("husl")
# plt.rcParams['figure.dpi'] = 300
# plt.rcParams['savefig.dpi'] = 300
# plt.rcParams['font.size'] = 12
# plt.rcParams['axes.titlesize'] = 14
# plt.rcParams['axes.labelsize'] = 12
# plt.rcParams['xtick.labelsize'] = 10
# plt.rcParams['ytick.labelsize'] = 10
# plt.rcParams['legend.fontsize'] = 10
# plt.rcParams['figure.titlesize'] = 16

# class FishDatasetNumpyAnalyzer:
#     """Comprehensive analysis suite for fish species dataset from NumPy arrays"""

#     def __init__(self, X_data, Y_labels, output_dir='./fish_classification_results'):
#         self.X_data = X_data
#         self.Y_labels = Y_labels
#         self.output_dir = output_dir
#         self.create_output_dirs()

#         # Dataset metadata
#         self.n_samples = X_data.shape[0]
#         self.image_shape = X_data.shape[1:]
#         self.unique_labels = np.unique(Y_labels)
#         self.n_classes = len(self.unique_labels)

#         # Determine image format (channels first vs channels last)
#         self.channels_first = self._detect_channels_first()

#         # Create label mapping if labels are numeric
#         if np.issubdtype(Y_labels.dtype, np.number):
#             self.label_names = [f"Species_{i}" for i in self.unique_labels]
#             self.label_to_name = dict(zip(self.unique_labels, self.label_names))
#         else:
#             self.label_names = self.unique_labels.tolist()
#             self.label_to_name = dict(zip(self.unique_labels, self.label_names))

#         print(f"Dataset loaded: {self.n_samples} samples, {self.n_classes} classes")
#         print(f"Image shape: {self.image_shape}")
#         print(f"Data type: {X_data.dtype}")
#         print(f"Channels first format: {self.channels_first}")

#     def _detect_channels_first(self):
#         """Detect if images are in channels-first format"""
#         if len(self.image_shape) == 3:
#             # If first dimension is small (1-4), likely channels first
#             # If last dimension is small (1-4), likely channels last
#             if self.image_shape[0] <= 4 and self.image_shape[0] < min(self.image_shape[1], self.image_shape[2]):
#                 return True
#             elif self.image_shape[2] <= 4 and self.image_shape[2] < min(self.image_shape[0], self.image_shape[1]):
#                 return False
#             else:
#                 # Default assumption based on common formats
#                 return self.image_shape[0] <= 4
#         return False

#     def _prepare_image_for_display(self, img):
#         """Convert image to proper format for matplotlib display"""
#         if len(img.shape) == 3:
#             if self.channels_first:
#                 # Convert from (C, H, W) to (H, W, C)
#                 img = np.transpose(img, (1, 2, 0))

#             # Handle different channel counts
#             if img.shape[2] == 1:  # Grayscale with channel dimension
#                 img = img.squeeze(axis=2)
#                 return img, 'gray'
#             elif img.shape[2] == 3:  # RGB
#                 return img, None
#             elif img.shape[2] == 4:  # RGBA
#                 return img[:, :, :3], None  # Drop alpha channel
#             else:
#                 # Multi-channel, use first channel as grayscale
#                 return img[:, :, 0], 'gray'
#         else:  # 2D grayscale
#             return img, 'gray'

#     def create_output_dirs(self):
#         """Create organized output directory structure"""
#         dirs = [
#             self.output_dir,
#             f"{self.output_dir}/figures",
#             f"{self.output_dir}/statistics",
#             f"{self.output_dir}/sample_images",
#             f"{self.output_dir}/reports"
#         ]
#         for dir_path in dirs:
#             os.makedirs(dir_path, exist_ok=True)

#     def analyze_data_properties(self):
#         """Analyze basic properties of the loaded data"""
#         print("Analyzing data properties...")

#         properties = {
#             'dataset_size': self.n_samples,
#             'n_classes': self.n_classes,
#             'image_shape': self.image_shape,
#             'channels_first': self.channels_first,
#             'data_type': str(self.X_data.dtype),
#             'data_range': {
#                 'min': float(self.X_data.min()),
#                 'max': float(self.X_data.max()),
#                 'mean': float(self.X_data.mean()),
#                 'std': float(self.X_data.std())
#             },
#             'class_distribution': dict(Counter(self.Y_labels)),
#             'memory_usage_mb': self.X_data.nbytes / (1024 * 1024)
#         }

#         # Per-class statistics
#         class_stats = {}
#         for label in self.unique_labels:
#             mask = self.Y_labels == label
#             class_data = self.X_data[mask]
#             class_stats[self.label_to_name[label]] = {
#                 'count': int(np.sum(mask)),
#                 'mean_intensity': float(class_data.mean()),
#                 'std_intensity': float(class_data.std()),
#                 'min_intensity': float(class_data.min()),
#                 'max_intensity': float(class_data.max())
#             }

#         properties['class_statistics'] = class_stats
#         self.data_properties = properties

#         return properties

#     def plot_class_distribution(self, figsize=(15, 8)):
#         """Visualize class distribution"""
#         class_counts = Counter(self.Y_labels)
#         class_names = [self.label_to_name[label] for label in class_counts.keys()]
#         counts = list(class_counts.values())

#         fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)

#         # Bar plot
#         bars = ax1.bar(range(len(class_names)), counts, color='skyblue', alpha=0.7)
#         ax1.set_title('Class Distribution', fontweight='bold')
#         ax1.set_xlabel('Species')
#         ax1.set_ylabel('Number of Samples')
#         ax1.set_xticks(range(len(class_names)))
#         ax1.set_xticklabels(class_names, rotation=45, ha='right')

#         # Add value labels on bars
#         for bar, count in zip(bars, counts):
#             height = bar.get_height()
#             ax1.text(bar.get_x() + bar.get_width()/2., height,
#                     f'{count}', ha='center', va='bottom')

#         # Pie chart
#         ax2.pie(counts, labels=class_names, autopct='%1.1f%%', startangle=90)
#         ax2.set_title('Class Distribution (%)', fontweight='bold')

#         plt.tight_layout()
#         plt.savefig(f'{self.output_dir}/figures/class_distribution.png', bbox_inches='tight')
#         plt.show()

#     def plot_sample_images(self, samples_per_class=5, figsize=(20, 12)):
#         """Display sample images from each class"""
#         n_classes = len(self.unique_labels)

#         fig, axes = plt.subplots(n_classes, samples_per_class, figsize=figsize)
#         if n_classes == 1:
#             axes = axes.reshape(1, -1)
#         elif samples_per_class == 1:
#             axes = axes.reshape(-1, 1)

#         fig.suptitle('Sample Images by Class', fontsize=16, fontweight='bold')

#         for i, label in enumerate(self.unique_labels):
#             # Get indices for this class
#             class_indices = np.where(self.Y_labels == label)[0]

#             # Sample random images from this class
#             if len(class_indices) >= samples_per_class:
#                 sample_indices = np.random.choice(class_indices, samples_per_class, replace=False)
#             else:
#                 sample_indices = class_indices

#             for j in range(samples_per_class):
#                 if j < len(sample_indices):
#                     img = self.X_data[sample_indices[j]].copy()

#                     # Prepare image for display
#                     display_img, cmap = self._prepare_image_for_display(img)

#                     # Normalize if needed
#                     if display_img.max() > 1:
#                         display_img = display_img.astype(float) / 255.0

#                     axes[i, j].imshow(display_img, cmap=cmap)
#                     axes[i, j].axis('off')

#                     if j == 0:  # Label the first column with class names
#                         axes[i, j].set_ylabel(self.label_to_name[label],
#                                             rotation=90, fontsize=12, va='center')
#                 else:
#                     axes[i, j].axis('off')

#         plt.tight_layout()
#         plt.savefig(f'{self.output_dir}/figures/sample_images.png', bbox_inches='tight')
#         plt.show()

#     def plot_pixel_intensity_analysis(self, figsize=(20, 12)):
#         """Analyze pixel intensity distributions"""
#         fig, axes = plt.subplots(2, 3, figsize=figsize)
#         fig.suptitle('Pixel Intensity Analysis', fontsize=16, fontweight='bold')

#         # Overall intensity distribution
#         axes[0, 0].hist(self.X_data.flatten(), bins=100, alpha=0.7, color='blue', density=True)
#         axes[0, 0].set_title('Overall Pixel Intensity Distribution')
#         axes[0, 0].set_xlabel('Pixel Intensity')
#         axes[0, 0].set_ylabel('Density')

#         # Mean intensity per image
#         mean_intensities = np.mean(self.X_data.reshape(self.n_samples, -1), axis=1)
#         axes[0, 1].hist(mean_intensities, bins=50, alpha=0.7, color='green', density=True)
#         axes[0, 1].set_title('Mean Intensity per Image')
#         axes[0, 1].set_xlabel('Mean Intensity')
#         axes[0, 1].set_ylabel('Density')

#         # Standard deviation per image
#         std_intensities = np.std(self.X_data.reshape(self.n_samples, -1), axis=1)
#         axes[0, 2].hist(std_intensities, bins=50, alpha=0.7, color='red', density=True)
#         axes[0, 2].set_title('Intensity Standard Deviation per Image')
#         axes[0, 2].set_xlabel('Std Intensity')
#         axes[0, 2].set_ylabel('Density')

#         # Class-wise intensity comparison
#         class_intensities = []
#         class_labels = []
#         for label in self.unique_labels:
#             mask = self.Y_labels == label
#             class_data = self.X_data[mask]
#             class_mean_intensities = np.mean(class_data.reshape(np.sum(mask), -1), axis=1)
#             class_intensities.extend(class_mean_intensities)
#             class_labels.extend([self.label_to_name[label]] * len(class_mean_intensities))

#         intensity_df = pd.DataFrame({
#             'intensity': class_intensities,
#             'class': class_labels
#         })

#         # Create boxplot data
#         box_data = [intensity_df[intensity_df['class'] == name]['intensity'].values
#                    for name in self.label_names]

#         axes[1, 0].boxplot(box_data, labels=self.label_names)
#         axes[1, 0].set_title('Mean Intensity by Class')
#         axes[1, 0].set_ylabel('Mean Intensity')
#         axes[1, 0].tick_params(axis='x', rotation=45)

#         # Average image intensity heatmap by class
#         avg_images = np.zeros((len(self.unique_labels), *self.image_shape))
#         for i, label in enumerate(self.unique_labels):
#             mask = self.Y_labels == label
#             avg_images[i] = np.mean(self.X_data[mask], axis=0)

#         # Calculate average intensity across spatial dimensions
#         if len(self.image_shape) == 3:
#             if self.channels_first:
#                 # Average across height and width for each channel
#                 avg_intensities = np.mean(avg_images, axis=(2, 3))  # Shape: (n_classes, n_channels)
#             else:
#                 # Average across height and width for each channel
#                 avg_intensities = np.mean(avg_images, axis=(1, 2))  # Shape: (n_classes, n_channels)
#         else:
#             # Grayscale images - average across spatial dimensions
#             avg_intensities = np.mean(avg_images, axis=(1, 2))  # Shape: (n_classes,)
#             avg_intensities = avg_intensities.reshape(-1, 1)  # Make it 2D for heatmap

#         im = axes[1, 1].imshow(avg_intensities, cmap='viridis', aspect='auto')
#         axes[1, 1].set_title('Average Intensity by Class')
#         axes[1, 1].set_ylabel('Class Index')
#         axes[1, 1].set_yticks(range(len(self.unique_labels)))
#         axes[1, 1].set_yticklabels([self.label_to_name[label] for label in self.unique_labels])

#         if len(self.image_shape) == 3:
#             if self.channels_first:
#                 n_channels = self.image_shape[0]
#             else:
#                 n_channels = self.image_shape[2]

#             if n_channels > 1:
#                 axes[1, 1].set_xlabel('Channel')
#                 axes[1, 1].set_xticks(range(n_channels))
#                 if n_channels == 3:
#                     axes[1, 1].set_xticklabels(['R', 'G', 'B'])
#                 else:
#                     axes[1, 1].set_xticklabels([f'Ch{i}' for i in range(n_channels)])
#         else:
#             axes[1, 1].set_xlabel('Intensity')

#         plt.colorbar(im, ax=axes[1, 1])

#         # Channel analysis for color images
#         if len(self.image_shape) == 3:
#             if self.channels_first:
#                 n_channels = self.image_shape[0]
#                 channel_means = np.mean(self.X_data, axis=(0, 2, 3))  # Average over batch, height, width
#             else:
#                 n_channels = self.image_shape[2]
#                 channel_means = np.mean(self.X_data, axis=(0, 1, 2))  # Average over batch, height, width

#             if n_channels > 1:
#                 colors = ['red', 'green', 'blue'] if n_channels == 3 else ['gray'] * n_channels
#                 axes[1, 2].bar(range(n_channels), channel_means, color=colors)
#                 axes[1, 2].set_title('Mean Intensity by Channel')
#                 axes[1, 2].set_xlabel('Channel')
#                 axes[1, 2].set_ylabel('Mean Intensity')
#                 if n_channels == 3:
#                     axes[1, 2].set_xticks(range(3))
#                     axes[1, 2].set_xticklabels(['Red', 'Green', 'Blue'])
#                 else:
#                     axes[1, 2].set_xticks(range(n_channels))
#                     axes[1, 2].set_xticklabels([f'Ch{i}' for i in range(n_channels)])
#             else:
#                 axes[1, 2].text(0.5, 0.5, 'Single Channel\n(Grayscale)',
#                                ha='center', va='center', transform=axes[1, 2].transAxes)
#                 axes[1, 2].set_title('Channel Information')
#         else:
#             axes[1, 2].text(0.5, 0.5, 'Single Channel\n(Grayscale)',
#                            ha='center', va='center', transform=axes[1, 2].transAxes)
#             axes[1, 2].set_title('Channel Information')

#         plt.tight_layout()
#         plt.savefig(f'{self.output_dir}/figures/pixel_intensity_analysis.png', bbox_inches='tight')
#         plt.show()

#     def perform_dimensionality_reduction(self, n_components=2, sample_size=5000):
#         """Perform PCA and t-SNE for visualization"""
#         print("Performing dimensionality reduction...")

#         # Sample data if too large
#         if self.n_samples > sample_size:
#             indices = np.random.choice(self.n_samples, sample_size, replace=False)
#             X_sample = self.X_data[indices]
#             Y_sample = self.Y_labels[indices]
#         else:
#             X_sample = self.X_data
#             Y_sample = self.Y_labels

#         # Flatten images
#         X_flat = X_sample.reshape(len(X_sample), -1)

#         # Normalize data
#         X_normalized = (X_flat - X_flat.mean()) / (X_flat.std() + 1e-8)

#         # PCA
#         print("Computing PCA...")
#         pca = PCA(n_components=n_components)
#         X_pca = pca.fit_transform(X_normalized)

#         # t-SNE
#         print("Computing t-SNE...")
#         tsne = TSNE(n_components=n_components, random_state=42, perplexity=30)
#         X_tsne = tsne.fit_transform(X_normalized)

#         # Plot results
#         fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

#         # PCA plot
#         scatter1 = ax1.scatter(X_pca[:, 0], X_pca[:, 1],
#                               c=Y_sample, cmap='tab10', alpha=0.6)
#         ax1.set_title(f'PCA Visualization\nExplained Variance: {pca.explained_variance_ratio_.sum():.3f}')
#         ax1.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.3f})')
#         ax1.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.3f})')

#         # Add colorbar for PCA
#         cbar1 = plt.colorbar(scatter1, ax=ax1)
#         cbar1.set_label('Class')

#         # t-SNE plot
#         scatter2 = ax2.scatter(X_tsne[:, 0], X_tsne[:, 1],
#                               c=Y_sample, cmap='tab10', alpha=0.6)
#         ax2.set_title('t-SNE Visualization')
#         ax2.set_xlabel('t-SNE 1')
#         ax2.set_ylabel('t-SNE 2')

#         # Add colorbar for t-SNE
#         cbar2 = plt.colorbar(scatter2, ax=ax2)
#         cbar2.set_label('Class')

#         plt.tight_layout()
#         plt.savefig(f'{self.output_dir}/figures/dimensionality_reduction.png', bbox_inches='tight')
#         plt.show()

#         return pca, X_pca, X_tsne

#     def create_data_summary_report(self):
#         """Create comprehensive data summary"""
#         if not hasattr(self, 'data_properties'):
#             self.analyze_data_properties()

#         summary = f"""
# FISH SPECIES DATASET - DATA ANALYSIS REPORT
# {'='*60}

# Dataset Overview:
# - Total samples: {self.data_properties['dataset_size']:,}
# - Number of classes: {self.data_properties['n_classes']}
# - Image dimensions: {self.data_properties['image_shape']}
# - Channels first format: {self.data_properties['channels_first']}
# - Data type: {self.data_properties['data_type']}
# - Memory usage: {self.data_properties['memory_usage_mb']:.1f} MB

# Data Range:
# - Minimum value: {self.data_properties['data_range']['min']:.3f}
# - Maximum value: {self.data_properties['data_range']['max']:.3f}
# - Mean value: {self.data_properties['data_range']['mean']:.3f}
# - Standard deviation: {self.data_properties['data_range']['std']:.3f}

# Class Distribution:
# """

#         for class_name, stats in self.data_properties['class_statistics'].items():
#             summary += f"\n{class_name}:\n"
#             summary += f"  - Sample count: {stats['count']:,}\n"
#             summary += f"  - Mean intensity: {stats['mean_intensity']:.3f}\n"
#             summary += f"  - Std intensity: {stats['std_intensity']:.3f}\n"
#             summary += f"  - Intensity range: {stats['min_intensity']:.3f} - {stats['max_intensity']:.3f}\n"

#         # Save report
#         with open(f'{self.output_dir}/reports/data_summary.txt', 'w') as f:
#             f.write(summary)

#         print(summary)
#         return summary

#     def run_complete_analysis(self, sample_size_dr=5000):
#         """Run complete analysis pipeline"""
#         print("Starting comprehensive NumPy dataset analysis...")
#         print("=" * 60)

#         # Step 1: Analyze data properties
#         self.analyze_data_properties()

#         # Step 2: Create visualizations
#         print("\nGenerating visualizations...")
#         self.plot_class_distribution()
#         self.plot_sample_images()
#         self.plot_pixel_intensity_analysis()

#         # Step 3: Dimensionality reduction
#         pca, X_pca, X_tsne = self.perform_dimensionality_reduction(sample_size=sample_size_dr)

#         # Step 4: Generate report
#         print("\nGenerating comprehensive report...")
#         self.create_data_summary_report()

#         print(f"\nAnalysis complete! Results saved to: {self.output_dir}")
#         print("\nGenerated files:")
#         print("- Figures: class_distribution.png, sample_images.png, pixel_intensity_analysis.png, dimensionality_reduction.png")
#         print("- Reports: data_summary.txt")

#         return pca, X_pca, X_tsne


# # Usage example:
# analyzer = FishDatasetNumpyAnalyzer(X_Loaded, Y_Loaded, output_dir='./fish_classification_results')
# pca_model, X_pca_result, X_tsne_result = analyzer.run_complete_analysis(sample_size_dr=3000)
# print("\nDataset analysis completed successfully!")
# print("All visualizations and reports have been generated.")

# ‚úÖ Step 6: Model Architecture

In [6]:
# class ModelFactory:
#     @staticmethod
#     def create_model(model_name, params=None, num_classes=Config.NUM_CLASSES):
#         """Create model with GPU optimizations"""
#         if params is None:
#             params = {}
#         dropout_rate = params.get('dropout', 0.5)
#         hidden_dim_multiplier = params.get('hidden_dim_multiplier', 0.5)

#         # Create base model
#         model = ModelFactory._create_base_model(model_name, params, num_classes, dropout_rate, hidden_dim_multiplier)

#         # Apply GPU optimizations
#         return ModelFactory._optimize_for_gpu(model)

#     @staticmethod
#     def _optimize_for_gpu(model):
#         """Apply GPU-specific optimizations"""
#         if torch.cuda.is_available():
#             # Convert to channels_last memory format for better GPU utilization
#             model = model.to(memory_format=torch.channels_last)

#             # Enable torch.compile for PyTorch 2.0+ (significant speedup)
#             if hasattr(torch, 'compile') and torch.cuda.get_device_capability()[0] >= 7:
#                 try:
#                     model = torch.compile(model, mode='max-autotune', dynamic=True)
#                 except Exception:
#                     pass  # Fallback if compilation fails

#             # Replace standard activations with more GPU-efficient ones
#             ModelFactory._replace_activations(model)

#         return model

#     @staticmethod
#     def _replace_activations(model):
#         """Replace ReLU with more GPU-efficient activations"""
#         for name, module in model.named_children():
#             if isinstance(module, nn.ReLU):
#                 # Replace with SiLU for better GPU utilization
#                 setattr(model, name, nn.SiLU(inplace=True))
#             elif len(list(module.children())) > 0:
#                 ModelFactory._replace_activations(module)

#     @staticmethod
#     def _create_gpu_optimized_classifier(in_features, hidden_dim, num_classes, dropout_rate):
#         """Create GPU-optimized classifier with grouped operations"""
#         return nn.Sequential(
#             # Group operations for better GPU memory access
#             nn.Dropout(dropout_rate),
#             nn.Linear(in_features, hidden_dim, bias=False),  # Remove bias (BatchNorm handles it)
#             nn.BatchNorm1d(hidden_dim),
#             nn.SiLU(inplace=True),  # More GPU-efficient than ReLU

#             # Second layer with residual-like structure
#             nn.Dropout(dropout_rate * 0.5),
#             nn.Linear(hidden_dim, hidden_dim // 2, bias=False),
#             nn.BatchNorm1d(hidden_dim // 2),
#             nn.SiLU(inplace=True),

#             # Final classification layer
#             nn.Linear(hidden_dim // 2, num_classes)
#         )

#     @staticmethod
#     def _create_base_model(model_name, params, num_classes, dropout_rate, hidden_dim_multiplier):
#         """Create the base model architecture"""

#         if model_name == 'resnet50':
#             model = models.resnet50(weights='IMAGENET1K_V2')
#             # Unfreeze more layers for better GPU utilization
#             for name, param in model.named_parameters():
#                 param.requires_grad = False
#                 if any(layer in name for layer in ["layer2", "layer3", "layer4", "fc"]):
#                     param.requires_grad = True

#             num_features = model.fc.in_features
#             hidden_dim = int(num_features * hidden_dim_multiplier)
#             model.fc = ModelFactory._create_gpu_optimized_classifier(
#                 num_features, hidden_dim, num_classes, dropout_rate
#             )

#         elif model_name == 'efficientnet_b0':
#             model = models.efficientnet_b0(weights='IMAGENET1K_V1')
#             # Unfreeze more layers for better GPU utilization
#             for name, param in model.named_parameters():
#                 param.requires_grad = False
#                 if any(layer in name for layer in ["_blocks.12", "_blocks.13", "_blocks.14", "_blocks.15", "_blocks.16", "classifier"]):
#                     param.requires_grad = True

#             num_features = model.classifier[1].in_features
#             hidden_dim = int(num_features * hidden_dim_multiplier)
#             model.classifier = ModelFactory._create_gpu_optimized_classifier(
#                 num_features, hidden_dim, num_classes, dropout_rate
#             )

#         elif model_name == 'mobilenet_v3_large':
#             model = models.mobilenet_v3_large(weights='IMAGENET1K_V2')
#             # Unfreeze more layers for better GPU utilization
#             for name, param in model.named_parameters():
#                 param.requires_grad = False
#                 if any(layer in name for layer in ["features.10", "features.11", "features.12", "features.13", "classifier"]):
#                     param.requires_grad = True

#             num_features = 960
#             hidden_dim = int(num_features * hidden_dim_multiplier)
#             model.classifier = ModelFactory._create_gpu_optimized_classifier(
#                 num_features, hidden_dim, num_classes, dropout_rate
#             )

#         elif model_name == 'vgg16':
#             model = models.vgg16(weights='IMAGENET1K_V1')
#             # Unfreeze more layers for better GPU utilization
#             for name, param in model.named_parameters():
#                 param.requires_grad = False
#                 if any(layer in name for layer in ["features.24", "features.26", "features.28", "classifier"]):
#                     param.requires_grad = True

#             hidden_dim = int(4096 * hidden_dim_multiplier)
#             # Simplified classifier for better GPU utilization
#             model.classifier = nn.Sequential(
#                 nn.Dropout(dropout_rate),
#                 nn.Linear(512 * 7 * 7, hidden_dim, bias=False),
#                 nn.BatchNorm1d(hidden_dim),
#                 nn.SiLU(inplace=True),
#                 nn.Dropout(dropout_rate * 0.5),
#                 nn.Linear(hidden_dim, hidden_dim // 2, bias=False),
#                 nn.BatchNorm1d(hidden_dim // 2),
#                 nn.SiLU(inplace=True),
#                 nn.Linear(hidden_dim // 2, num_classes)
#             )

#         elif model_name == 'densenet121':
#             model = models.densenet121(weights='IMAGENET1K_V1')
#             # Unfreeze more layers for better GPU utilization
#             for name, param in model.named_parameters():
#                 param.requires_grad = False
#                 if any(layer in name for layer in ["denseblock3", "denseblock4", "classifier"]):
#                     param.requires_grad = True

#             num_features = model.classifier.in_features
#             hidden_dim = int(num_features * hidden_dim_multiplier)
#             model.classifier = ModelFactory._create_gpu_optimized_classifier(
#                 num_features, hidden_dim, num_classes, dropout_rate
#             )

#         elif model_name == 'inception_v3':
#             model = models.inception_v3(weights='IMAGENET1K_V1')
#             # Unfreeze more layers for better GPU utilization
#             for name, param in model.named_parameters():
#                 param.requires_grad = False
#                 if any(layer in name for layer in ["Mixed_6", "Mixed_7a", "Mixed_7b", "Mixed_7c", "fc", "AuxLogits"]):
#                     param.requires_grad = True

#             num_features = model.fc.in_features
#             hidden_dim = int(num_features * hidden_dim_multiplier)

#             # Main classifier
#             model.fc = ModelFactory._create_gpu_optimized_classifier(
#                 num_features, hidden_dim, num_classes, dropout_rate
#             )

#             # Auxiliary classifier (if exists)
#             if hasattr(model, 'AuxLogits') and model.AuxLogits is not None:
#                 aux_features = model.AuxLogits.fc.in_features
#                 aux_hidden = int(aux_features * hidden_dim_multiplier)
#                 model.AuxLogits.fc = ModelFactory._create_gpu_optimized_classifier(
#                     aux_features, aux_hidden, num_classes, dropout_rate
#                 )

#         elif model_name == 'vit_b_16':
#             model = models.vit_b_16(weights='IMAGENET1K_V1')
#             # Unfreeze more layers for better GPU utilization
#             for name, param in model.named_parameters():
#                 param.requires_grad = False
#                 if any(layer in name for layer in ["encoder.layers.8", "encoder.layers.9", "encoder.layers.10", "encoder.layers.11", "heads"]):
#                     param.requires_grad = True

#             num_features = model.heads.head.in_features
#             hidden_dim = int(num_features * hidden_dim_multiplier)
#             model.heads.head = ModelFactory._create_gpu_optimized_classifier(
#                 num_features, hidden_dim, num_classes, dropout_rate
#             )

#         elif model_name == 'convnext_base':
#             model = models.convnext_base(weights='IMAGENET1K_V1')
#             # Unfreeze more layers for better GPU utilization
#             for name, param in model.named_parameters():
#                 param.requires_grad = False
#                 if any(layer in name for layer in ["features.6", "features.7", "classifier"]):
#                     param.requires_grad = True

#             num_features = model.classifier[2].in_features
#             hidden_dim = int(num_features * hidden_dim_multiplier)
#             model.classifier = nn.Sequential(
#                 model.classifier[0],  # Keep the LayerNorm
#                 model.classifier[1],  # Keep the Flatten
#                 nn.Dropout(dropout_rate),
#                 nn.Linear(num_features, hidden_dim, bias=False),
#                 nn.BatchNorm1d(hidden_dim),
#                 nn.SiLU(inplace=True),
#                 nn.Dropout(dropout_rate * 0.5),
#                 nn.Linear(hidden_dim, hidden_dim // 2, bias=False),
#                 nn.BatchNorm1d(hidden_dim // 2),
#                 nn.SiLU(inplace=True),
#                 nn.Linear(hidden_dim // 2, num_classes)
#             )

#         elif model_name == 'regnet_y_32gf':
#             model = models.regnet_y_32gf(weights='IMAGENET1K_V2')
#             # Unfreeze more layers for better GPU utilization
#             for name, param in model.named_parameters():
#                 param.requires_grad = False
#                 if any(layer in name for layer in ["trunk_output.block3", "trunk_output.block4", "fc"]):
#                     param.requires_grad = True

#             num_features = model.fc.in_features
#             hidden_dim = int(num_features * hidden_dim_multiplier)
#             model.fc = ModelFactory._create_gpu_optimized_classifier(
#                 num_features, hidden_dim, num_classes, dropout_rate
#             )

#         elif model_name == 'cnn':
#             class GPUOptimizedCNN(nn.Module):
#                 def __init__(self, num_classes=5, dropout_rate=0.3, hidden_dim_multiplier=0.3):
#                     super(GPUOptimizedCNN, self).__init__()

#                     # GPU-optimized feature extractor with depthwise separable convolutions
#                     self.features = nn.Sequential(
#                         # Block 1 - Initial feature extraction
#                         nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False),
#                         nn.BatchNorm2d(32),
#                         nn.SiLU(inplace=True),

#                         # Block 2 - Depthwise separable conv
#                         nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, groups=32, bias=False),  # Depthwise
#                         nn.Conv2d(32, 64, kernel_size=1, bias=False),  # Pointwise
#                         nn.BatchNorm2d(64),
#                         nn.SiLU(inplace=True),
#                         nn.MaxPool2d(2, 2),

#                         # Block 3 - More efficient computation
#                         nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, groups=64, bias=False),  # Depthwise
#                         nn.Conv2d(64, 128, kernel_size=1, bias=False),  # Pointwise
#                         nn.BatchNorm2d(128),
#                         nn.SiLU(inplace=True),
#                         nn.MaxPool2d(2, 2),

#                         # Block 4 - Final feature extraction
#                         nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, groups=128, bias=False),  # Depthwise
#                         nn.Conv2d(128, 256, kernel_size=1, bias=False),  # Pointwise
#                         nn.BatchNorm2d(256),
#                         nn.SiLU(inplace=True),

#                         # Global average pooling for efficiency
#                         nn.AdaptiveAvgPool2d((1, 1))
#                     )

#                     # Efficient classifier
#                     hidden_dim = max(128, int(256 * hidden_dim_multiplier))
#                     self.classifier = nn.Sequential(
#                         nn.Flatten(),
#                         nn.Dropout(dropout_rate),
#                         nn.Linear(256, hidden_dim, bias=False),
#                         nn.BatchNorm1d(hidden_dim),
#                         nn.SiLU(inplace=True),
#                         nn.Dropout(dropout_rate * 0.5),
#                         nn.Linear(hidden_dim, num_classes)
#                     )

#                     # Initialize weights properly
#                     self._initialize_weights()

#                 def _initialize_weights(self):
#                     for m in self.modules():
#                         if isinstance(m, nn.Conv2d):
#                             nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
#                         elif isinstance(m, nn.Linear):
#                             nn.init.xavier_uniform_(m.weight, gain=1.0)
#                             if m.bias is not None:
#                                 nn.init.zeros_(m.bias)
#                         elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
#                             if m.weight is not None:
#                                 nn.init.ones_(m.weight)
#                             if m.bias is not None:
#                                 nn.init.zeros_(m.bias)

#                 def forward(self, x):
#                     # Ensure channels_last format for GPU optimization
#                     if x.device.type == 'cuda':
#                         x = x.to(memory_format=torch.channels_last)

#                     # Feature extraction
#                     x = self.features(x)

#                     # Classification
#                     x = self.classifier(x)

#                     return x

#             model = GPUOptimizedCNN(num_classes=num_classes, dropout_rate=dropout_rate,
#                                   hidden_dim_multiplier=hidden_dim_multiplier)

#         else:
#             raise ValueError(f"Unsupported model: {model_name}")

#         return model




class ModelFactory:
    @staticmethod
    # def create_model(model_name, num_classes=Config.NUM_CLASSES, dropout_rate=0.5, hidden_dim_multiplier=0.5):
    def create_model(model_name, params=None, num_classes=Config.NUM_CLASSES, dropout_rate=0.5, hidden_dim_multiplier=0.5):
        #Create model with configurable architecture
        if params is None:
            params = {}
        dropout_rate = params.get('dropout', 0.5)
        hidden_dim_multiplier = params.get('hidden_dim_multiplier', 0.5)

        if model_name == 'resnet50':
            model = models.resnet50(weights='IMAGENET1K_V2')
            # Partial unfreeze for better accuracy: unfreeze layer4 and fc
            for name, param in model.named_parameters():
                param.requires_grad = False
                # if "layer4" in name or "fc" in name:
                if "layer3" in name or "layer4" in name or "fc" in name:
                    param.requires_grad = True

            num_features = model.fc.in_features
            hidden_dim = int(num_features * hidden_dim_multiplier)
            model.fc = nn.Sequential(
                nn.Dropout(dropout_rate),
                nn.Linear(num_features, hidden_dim),
                nn.ReLU(inplace=True),
                nn.BatchNorm1d(hidden_dim),
                nn.Dropout(dropout_rate / 2),
                nn.Linear(hidden_dim, num_classes)
            )

        elif model_name == 'efficientnet_b0':
            model = models.efficientnet_b0(weights='IMAGENET1K_V1')
            # Partial unfreeze: last blocks
            for name, param in model.named_parameters():
                param.requires_grad = False
                if "_blocks.15" in name or "_blocks.16" in name or "classifier" in name:
                    param.requires_grad = True
            num_features = model.classifier[1].in_features
            hidden_dim = int(num_features * hidden_dim_multiplier)
            model.classifier = nn.Sequential(
                nn.Dropout(dropout_rate),
                nn.Linear(num_features, hidden_dim),
                nn.ReLU(inplace=True),
                nn.BatchNorm1d(hidden_dim),
                nn.Dropout(dropout_rate / 2),
                nn.Linear(hidden_dim, num_classes)
            )

        elif model_name == 'mobilenet_v3_large':
            model = models.mobilenet_v3_large(weights='IMAGENET1K_V2')
            # Partial unfreeze: last features
            for name, param in model.named_parameters():
                param.requires_grad = False
                if "features.12" in name or "features.13" in name or "classifier" in name:
                    param.requires_grad = True
            num_features = 960
            hidden_dim = int(num_features * hidden_dim_multiplier)
            model.classifier = nn.Sequential(
                nn.Dropout(dropout_rate),
                nn.Linear(num_features, hidden_dim),
                nn.ReLU(inplace=True),
                nn.BatchNorm1d(hidden_dim),
                nn.Dropout(dropout_rate / 2),
                nn.Linear(hidden_dim, num_classes)
            )

        elif model_name == 'vgg16':
            model = models.vgg16(weights='IMAGENET1K_V1')
            # Partial unfreeze: classifier and last features
            for name, param in model.named_parameters():
                param.requires_grad = False
                if "classifier" in name or "features.28" in name:
                    param.requires_grad = True
            hidden_dim = int(4096 * hidden_dim_multiplier)
            model.classifier = nn.Sequential(
                nn.Linear(512 * 7 * 7, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, hidden_dim),
                nn.ReLU(True),
                nn.BatchNorm1d(hidden_dim),
                nn.Dropout(dropout_rate),
                nn.Linear(hidden_dim, num_classes)
            )

        elif model_name == 'densenet121':
            model = models.densenet121(weights='IMAGENET1K_V1')
            # Partial unfreeze: denseblock4 and classifier
            for name, param in model.named_parameters():
                param.requires_grad = False
                if "denseblock4" in name or "classifier" in name:
                    param.requires_grad = True
            num_features = model.classifier.in_features
            hidden_dim = int(num_features * hidden_dim_multiplier)
            model.classifier = nn.Sequential(
                nn.Dropout(dropout_rate),
                nn.Linear(num_features, hidden_dim),
                nn.ReLU(inplace=True),
                nn.BatchNorm1d(hidden_dim),
                nn.Dropout(dropout_rate / 2),
                nn.Linear(hidden_dim, num_classes)
            )

        elif model_name == 'inception_v3':
            model = models.inception_v3(weights='IMAGENET1K_V1')
            # Partial unfreeze: Mixed_7a, Mixed_7b, Mixed_7c and classifiers
            for name, param in model.named_parameters():
                param.requires_grad = False
                if any(layer in name for layer in ["Mixed_7a", "Mixed_7b", "Mixed_7c", "fc", "AuxLogits"]):
                    param.requires_grad = True

            num_features = model.fc.in_features
            hidden_dim = int(num_features * hidden_dim_multiplier)

            # Main classifier
            model.fc = nn.Sequential(
                nn.Dropout(dropout_rate),
                nn.Linear(num_features, hidden_dim),
                nn.ReLU(inplace=True),
                nn.BatchNorm1d(hidden_dim),
                nn.Dropout(dropout_rate / 2),
                nn.Linear(hidden_dim, num_classes)
            )

            # Auxiliary classifier (if exists)
            if hasattr(model, 'AuxLogits') and model.AuxLogits is not None:
                aux_features = model.AuxLogits.fc.in_features
                aux_hidden = int(aux_features * hidden_dim_multiplier)
                model.AuxLogits.fc = nn.Sequential(
                    nn.Dropout(dropout_rate),
                    nn.Linear(aux_features, aux_hidden),
                    nn.ReLU(inplace=True),
                    nn.BatchNorm1d(aux_hidden),
                    nn.Dropout(dropout_rate / 2),
                    nn.Linear(aux_hidden, num_classes)
                )

        elif model_name == 'vit_b_16':
            model = models.vit_b_16(weights='IMAGENET1K_V1')
            # Partial unfreeze: last encoder layers and head
            for name, param in model.named_parameters():
                param.requires_grad = False
                if any(layer in name for layer in ["encoder.layers.10", "encoder.layers.11", "heads"]):
                    param.requires_grad = True

            num_features = model.heads.head.in_features
            hidden_dim = int(num_features * hidden_dim_multiplier)
            model.heads.head = nn.Sequential(
                nn.Dropout(dropout_rate),
                nn.Linear(num_features, hidden_dim),
                nn.ReLU(inplace=True),
                nn.BatchNorm1d(hidden_dim),
                nn.Dropout(dropout_rate / 2),
                nn.Linear(hidden_dim, num_classes)
            )

        elif model_name == 'convnext_base':
            model = models.convnext_base(weights='IMAGENET1K_V1')
            # Partial unfreeze: last stages and classifier
            for name, param in model.named_parameters():
                param.requires_grad = False
                if any(layer in name for layer in ["features.7", "classifier"]):
                    param.requires_grad = True

            num_features = model.classifier[2].in_features
            hidden_dim = int(num_features * hidden_dim_multiplier)
            model.classifier = nn.Sequential(
                model.classifier[0],  # Keep the LayerNorm
                model.classifier[1],  # Keep the Flatten
                nn.Dropout(dropout_rate),
                nn.Linear(num_features, hidden_dim),
                nn.ReLU(inplace=True),
                nn.BatchNorm1d(hidden_dim),
                nn.Dropout(dropout_rate / 2),
                nn.Linear(hidden_dim, num_classes)
            )

        elif model_name == 'regnet_y_32gf':
            model = models.regnet_y_32gf(weights='IMAGENET1K_V2')
            # Partial unfreeze: last trunk stage and fc
            for name, param in model.named_parameters():
                param.requires_grad = False
                if "trunk_output" in name or "fc" in name:
                    param.requires_grad = True

            num_features = model.fc.in_features
            hidden_dim = int(num_features * hidden_dim_multiplier)
            model.fc = nn.Sequential(
                nn.Dropout(dropout_rate),
                nn.Linear(num_features, hidden_dim),
                nn.ReLU(inplace=True),
                nn.BatchNorm1d(hidden_dim),
                nn.Dropout(dropout_rate / 2),
                nn.Linear(hidden_dim, num_classes)
            )

        elif model_name == 'cnn':
            # class SimpleCNN(nn.Module):

            class SimpleCNN(nn.Module):
                def __init__(self, num_classes=5, dropout_rate=0.3, hidden_dim_multiplier=0.3):
                    super(SimpleCNN, self).__init__()

                    # More conservative feature extractor to prevent overfitting
                    self.features = nn.Sequential(
                        # Block 1 - Start small
                        nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
                        nn.BatchNorm2d(32),
                        nn.ReLU(inplace=True),
                        nn.Dropout2d(0.1),  # Spatial dropout in conv layers
                        nn.MaxPool2d(2, 2),  # 224 -> 112

                        # Block 2
                        nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
                        nn.BatchNorm2d(64),
                        nn.ReLU(inplace=True),
                        nn.Dropout2d(0.15),
                        nn.MaxPool2d(2, 2),  # 112 -> 56

                        # Block 3
                        nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
                        nn.BatchNorm2d(128),
                        nn.ReLU(inplace=True),
                        nn.Dropout2d(0.2),
                        nn.MaxPool2d(2, 2),  # 56 -> 28

                        # Block 4 - Add one more conv before pooling
                        nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
                        nn.BatchNorm2d(128),
                        nn.ReLU(inplace=True),
                        nn.Dropout2d(0.25),
                        nn.MaxPool2d(2, 2),  # 28 -> 14

                        # Block 5 - Final feature extraction
                        nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
                        nn.BatchNorm2d(256),
                        nn.ReLU(inplace=True),
                        nn.Dropout2d(0.3),
                        nn.AdaptiveAvgPool2d((7, 7))  # Fixed spatial size
                    )

                    # Calculate features after adaptive pooling
                    conv_output_size = 256 * 7 * 7  # 12544

                    # Much smaller hidden dimension to prevent overfitting
                    hidden_dim = int(conv_output_size * hidden_dim_multiplier)
                    hidden_dim = max(64, min(hidden_dim, 512))  # Smaller range

                    # Simple but effective classifier
                    self.classifier = nn.Sequential(
                        nn.Dropout(dropout_rate),
                        nn.Linear(conv_output_size, hidden_dim),
                        nn.ReLU(inplace=True),
                        nn.BatchNorm1d(hidden_dim),
                        nn.Dropout(dropout_rate * 0.5),
                        nn.Linear(hidden_dim, num_classes)
                    )

                    # Initialize weights properly
                    self._initialize_weights()

                def _initialize_weights(self):
                    for m in self.modules():
                        if isinstance(m, nn.Conv2d):
                            # Use smaller initialization for better gradient flow
                            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                            if m.bias is not None:
                                nn.init.zeros_(m.bias)
                        elif isinstance(m, nn.Linear):
                            # Smaller initialization for linear layers
                            nn.init.xavier_uniform_(m.weight, gain=0.5)
                            if m.bias is not None:
                                nn.init.zeros_(m.bias)
                        elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                            if m.weight is not None:
                                nn.init.ones_(m.weight)
                            if m.bias is not None:
                                nn.init.zeros_(m.bias)

                def forward(self, x):
                    # Feature extraction
                    x = self.features(x)

                    # Flatten
                    x = torch.flatten(x, 1)

                    # Classification with gradient clipping
                    x = self.classifier(x)

                    # Clip outputs to prevent extreme values
                    x = torch.clamp(x, min=-10, max=10)

                    return x

            # # Example usage
            # model = SimpleCNN(num_classes=num_classes, dropout_rate=dropout_rate, hidden_dim_multiplier=hidden_dim_multiplier)
            # model = model.to(Config.DEVICE)  # Move to device right after creation
            model = SimpleCNN(num_classes=num_classes, dropout_rate=dropout_rate, hidden_dim_multiplier=hidden_dim_multiplier)

        else:
            raise ValueError(f"Unsupported model: {model_name}")

        return model

# ‚úÖ Step 7: Ensamble Model Architecture

In [7]:
# ---
# 6. ENSEMBLE METHODS
# ================================================================================================================================
# Purpose: Implement ensemble methods (simple, weighted, confidence-based, learnable).

class EnsembleManager:
    def __init__(self, models_dict, val_data):
        self.models = models_dict
        self.X_val, self.y_val = val_data
        self.model_predictions = self._get_predictions()
        self.histories = {}

    def _get_predictions(self):
        print("Getting model predictions for ensemble...")
        predictions = {}

        val_dataset = FishDataset(self.X_val, self.y_val, DataManager.get_transforms(False))
        val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False)

        for name, model in self.models.items():
            model.eval()
            all_preds = []
            all_probs = []
            all_losses = []
            all_labels = []
            total = 0
            correct = 0
            criterion = nn.CrossEntropyLoss()

            with torch.no_grad():
                for images, labels in val_loader:
                    images, labels = images.to(Config.DEVICE), labels.to(Config.DEVICE)
                    outputs = model(images)
                    probabilities = torch.softmax(outputs, dim=1)
                    loss = criterion(outputs, labels).item()

                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

                    all_preds.extend(predicted.cpu().numpy())
                    all_probs.extend(probabilities.cpu().numpy())
                    all_losses.append(loss)
                    all_labels.extend(labels.cpu().numpy())

            accuracy = correct / total
            f1 = f1_score(self.y_val, all_preds, average='macro')
            avg_loss = np.mean(all_losses)

            predictions[name] = {
                'predictions': np.array(all_preds),
                'probabilities': np.array(all_probs),
                'loss': avg_loss,
                'accuracy': accuracy,
                'f1': f1,
                'true_labels': np.array(all_labels)
            }

            print(f"  {name}: F1 = {f1:.4f}, Acc = {accuracy:.4f}, Loss = {avg_loss:.4f}")

        return predictions

    def simple_average_ensemble(self, model_combo):
        selected_probs = [self.model_predictions[name]['probabilities'] for name in model_combo]
        avg_probs = np.mean(selected_probs, axis=0)
        predictions = np.argmax(avg_probs, axis=1)

        accuracy = accuracy_score(self.y_val, predictions)
        f1 = f1_score(self.y_val, predictions, average='macro')
        loss = np.mean([self.model_predictions[name]['loss'] for name in model_combo])

        avg_probs = np.mean(selected_probs, axis=0) if selected_probs else np.zeros((len(self.y_val), Config.NUM_CLASSES))

        return {
            'accuracy': accuracy,
            'f1': f1,
            'loss': loss,
            'predictions': predictions,
            'models': model_combo,
            # 'probabilities': avg_probs,
            'probabilities': avg_probs if avg_probs.ndim == 2 else np.zeros((0, Config.NUM_CLASSES)),
            'true_labels': self.y_val
        }

    def weighted_average_ensemble(self, model_combo):
        weights = []
        selected_probs = []

        for name in model_combo:
            f1 = self.model_predictions[name]['f1']
            weights.append(f1)
            selected_probs.append(self.model_predictions[name]['probabilities'])

        weights = np.array(weights) / np.sum(weights)
        weighted_probs = np.average(selected_probs, axis=0, weights=weights)
        predictions = np.argmax(weighted_probs, axis=1)

        accuracy = accuracy_score(self.y_val, predictions)
        f1 = f1_score(self.y_val, predictions, average='macro')
        loss = np.average([self.model_predictions[name]['loss'] for name in model_combo], weights=weights)

        return {
            'accuracy': accuracy,
            'f1': f1,
            'loss': loss,
            'predictions': predictions,
            'weights': weights,
            'models': model_combo,
            'probabilities': weighted_probs,
            'true_labels': self.y_val
        }

    def confidence_based_ensemble(self, model_combo):
        final_predictions = []
        all_probs = []

        for i in range(len(self.y_val)):
            confidences = []
            probs = []

            for name in model_combo:
                prob = self.model_predictions[name]['probabilities'][i]
                confidence = np.max(prob)
                confidences.append(confidence)
                probs.append(prob)

            confidences = np.array(confidences)
            weights = confidences / np.sum(confidences) if np.sum(confidences) > 0 else np.ones(len(confidences)) / len(confidences)

            final_prob = np.average(probs, axis=0, weights=weights)
            final_predictions.append(np.argmax(final_prob))
            all_probs.append(final_prob)

        predictions = np.array(final_predictions)
        accuracy = accuracy_score(self.y_val, predictions)
        f1 = f1_score(self.y_val, predictions, average='macro')
        loss = np.mean([self.model_predictions[name]['loss'] for name in model_combo])

        return {
            'accuracy': accuracy,
            'f1': f1,
            'loss': loss,
            'predictions': predictions,
            'models': model_combo,
            'probabilities': np.array(all_probs),
            'true_labels': self.y_val
        }

    def learnable_weighted_ensemble(self, model_combo, epochs=30):
        print(f"Training learnable weighted ensemble with {len(model_combo)} models...")

        selected_probs = []
        for name in model_combo:
            selected_probs.append(self.model_predictions[name]['probabilities'])

        ensemble_input = np.stack(selected_probs, axis=1)

        X_ensemble = torch.FloatTensor(ensemble_input).to(Config.DEVICE)
        y_ensemble = torch.LongTensor(self.y_val).to(Config.DEVICE)

        ensemble_model = LearnableWeightedEnsemble(
            num_models=len(model_combo),
            num_classes=Config.NUM_CLASSES
        ).to(Config.DEVICE)

        optimizer = optim.AdamW(ensemble_model.parameters(), lr=1e-3, weight_decay=1e-4)
        criterion = nn.CrossEntropyLoss()
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

        history = {'train_loss': [], 'train_acc': [], 'val_f1': []}
        best_loss = float('inf')

        for epoch in range(epochs):
            ensemble_model.train()
            optimizer.zero_grad()
            predictions, weights = ensemble_model(X_ensemble)
            loss = criterion(predictions, y_ensemble)
            loss.backward()
            optimizer.step()
            scheduler.step()

            accuracy = accuracy_score(y_ensemble.cpu().numpy(), predictions.argmax(dim=1).cpu().numpy())
            f1 = f1_score(y_ensemble.cpu().numpy(), predictions.argmax(dim=1).cpu().numpy(), average='macro')

            history['train_loss'].append(loss.item())
            history['train_acc'].append(accuracy)
            history['val_f1'].append(f1)

            print(f"Ensemble Epoch {epoch+1}/{epochs}: Loss = {loss.item():.4f}, "
                  f"Acc = {accuracy:.4f}, F1 = {f1:.4f}")

            if loss.item() < best_loss:
                best_loss = loss.item()
                torch.save(ensemble_model.state_dict(), f"{Config.OUTPUT_DIR}/models/learnable_ensemble_{'+'.join(model_combo)}.pt")

        ensemble_model.load_state_dict(torch.load(f"{Config.OUTPUT_DIR}/models/learnable_ensemble_{'+'.join(model_combo)}.pt"))
        ensemble_model.eval()
        with torch.no_grad():
            final_predictions, learned_weights = ensemble_model(X_ensemble)
            predictions = final_predictions.argmax(dim=1).cpu().numpy()
            probabilities = torch.softmax(final_predictions, dim=1).cpu().numpy()
            avg_weights = learned_weights.mean(dim=0).cpu().numpy()

        accuracy = accuracy_score(self.y_val, predictions)
        f1 = f1_score(self.y_val, predictions, average='macro')
        loss = np.mean([self.model_predictions[name]['loss'] for name in model_combo])

        self.histories[f"learnable_weighted_{'+'.join(model_combo)}"] = history

        return {
            'accuracy': accuracy,
            'f1': f1,
            'loss': loss,
            'predictions': predictions,
            'models': model_combo,
            'learned_weights': avg_weights,
            'probabilities': probabilities,
            'true_labels': self.y_val
        }

    def test_ensemble_combinations(self):
        print("Testing ensemble combinations...")

        model_names = list(self.models.keys())
        all_results = {}
        best_result = None
        best_score = 0

        for size in range(2, min(len(model_names) + 1, 5)):
            print(f"Testing {size}-model combinations...")

            for combo in list(combinations(model_names, size))[:5]:
                combo_name = f"combo_{size}_{'+'.join(combo)}"

                for method_name in Config.ENSEMBLE_METHODS:
                    full_name = f"{combo_name}_{method_name}"

                    try:
                        if method_name == 'simple_average':
                            result = self.simple_average_ensemble(combo)
                        elif method_name == 'weighted_average':
                            result = self.weighted_average_ensemble(combo)
                        elif method_name == 'confidence_based':
                            result = self.confidence_based_ensemble(combo)
                        elif method_name == 'learnable_weighted':
                            result = self.learnable_weighted_ensemble(combo)

                        # Verify result contains required keys
                        required_keys = ['accuracy', 'f1', 'loss', 'predictions', 'models', 'probabilities', 'true_labels']
                        if not all(key in result for key in required_keys):
                            missing = [key for key in required_keys if key not in result]
                            print(f"  {full_name}: Missing keys {missing}")
                            continue
                        # Ensure probabilities is 2D
                        if 'probabilities' in result and (result['probabilities'].ndim != 2 or result['probabilities'].shape[1] != Config.NUM_CLASSES):
                            result['probabilities'] = np.zeros((len(result['true_labels']), Config.NUM_CLASSES))

                        all_results[full_name] = result
                        print(f"  {full_name}: F1 = {result['f1']:.4f}, Acc = {result['accuracy']:.4f}, "
                              f"Loss = {result['loss']:.4f}, True Labels Shape = {result['true_labels'].shape}")

                        if result['f1'] > best_score:
                            best_score = result['f1']
                            best_result = (full_name, result)

                    except Exception as e:
                        print(f"  {full_name}: FAILED - {str(e)}")

        if best_result:
            print(f"\n‚úì Best ensemble: {best_result[0]} (F1: {best_result[1]['f1']:.4f})")
        else:
            print("\nNo valid ensemble results generated.")

        return all_results, best_result





# LEARNABLE WEIGHTED ENSEMBLE MODEL
# ===============================================================================================================================
# Purpose: Define a neural network for learning optimal ensemble weights.

class LearnableWeightedEnsemble(nn.Module):
    """Ensemble model with per-class adaptive weights and attention"""
    def __init__(self, num_models, num_classes, hidden_dim=128, num_heads=4):
        super(LearnableWeightedEnsemble, self).__init__()
        self.num_models = num_models
        self.num_classes = num_classes

        # Attention mechanism to learn relations between model predictions
        self.attention = nn.MultiheadAttention(embed_dim=num_classes, num_heads=num_heads, batch_first=True)

        # Weight network outputs per-class weights for each model
        self.weight_network = nn.Sequential(
            nn.Linear(num_classes, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, num_classes),
            nn.Sigmoid()  # Per-class weight scaling
        )

        # Prediction head: combines weighted predictions + raw predictions
        self.prediction_head = nn.Sequential(
            nn.Linear(num_classes * (num_models + 1), hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, num_classes * 2),
            nn.ReLU(),
            nn.Linear(num_classes * 2, num_classes)
        )

    def forward(self, model_predictions):
        """
        model_predictions: (batch, num_models, num_classes)
        Returns:
            final_predictions: logits for classification
            weights: learned per-class weights for each model
        """
        batch_size = model_predictions.size(0)

        # --- Step 1: Attention over model predictions --- #The model looks at how predictions of different models relate to each other.
        attn_output, _ = self.attention(model_predictions, model_predictions, model_predictions)
        # shape: (batch, num_models, num_classes)


        # --- Step 2: Per-class weights for each model ---
        #Learns a weight for each model for each class.
        #softmax ensures weights across models sum to 1 for each class.
        #Basically: ‚ÄúFor class 0, I trust model 2 more; for class 1, I trust model 0 more.‚Äù
        weights = self.weight_network(attn_output)  # (batch, num_models, num_classes)
        weights = F.softmax(weights, dim=1)  # normalize over models


        # --- Step 3: Weighted average across models ---
        #Combines the models‚Äô predictions using the learned weights ‚Üí smarter than a plain average.
        weighted_avg = torch.sum(model_predictions * weights, dim=1)  # (batch, num_classes)


        # --- Step 4: Residual connection with raw predictions ---
        #Combines the weighted average and all raw predictions.Gives the network more info to refine the final prediction.
        flat_preds = model_predictions.view(batch_size, -1)  # (batch, num_models * num_classes)
        final_input = torch.cat([weighted_avg, flat_preds], dim=1)  # (batch, num_classes + num_models*num_classes)


        # --- Step 5: Final refined prediction ---
        #A small feed-forward network refines the predictions.Output: (batch_size, num_classes) ‚Üí logits for each class.
        final_predictions = self.prediction_head(final_input)  # (batch, num_classes)

        return final_predictions, weights
        #It learns which model is best for each class, combines their predictions smartly using attention, and produces a refined final prediction.

    # def entropy_regularization(self, weights):
    #     """Encourage diverse weight usage (optional loss term)."""
    #     # weights: (batch, num_models, num_classes)
    #     entropy = -torch.sum(weights * torch.log(weights + 1e-8), dim=1)  # (batch, num_classes)
    #     return torch.mean(entropy)


# ‚úÖ Step 8: üìä Optuna Trials [Hyper-parameter Tuning]




In [8]:

# Optimized for maximum GPU utilization and enhanced user experience
from threading import Lock
from termcolor import colored, cprint

import warnings
from optuna.exceptions import ExperimentalWarning
# Suppress only ExperimentalWarning
warnings.filterwarnings("ignore", category=ExperimentalWarning)
import optuna
optuna.logging.set_verbosity(optuna.logging.WARNING)

# # Configure logging
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# logger = logging.getLogger(__name__)

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, WeightedRandomSampler
import numpy as np
import os
import gc
import psutil
import time
import json
import traceback
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import f1_score
from typing import Dict, Any, Tuple


def worker_init_fn(worker_id):
    #Initialize worker with different random seed
    np.random.seed(torch.initial_seed() % 2**32 + worker_id)

class Optuna_DataManager:

    @staticmethod
    def create_data_loaders(X, Y, train_batch_size=64, val_batch_size=128,
                                    test_size=0.2, augmentation_strength='medium',
                                    num_workers=8, pin_memory=True, persistent_workers=True):

        # Split data strategically
        X_temp, X_test, y_temp, y_test = train_test_split(
            X, Y, test_size=test_size, random_state=42, stratify=Y
        )
        X_train, X_val, y_train, y_val = train_test_split(
            X_temp, y_temp, test_size=0.25, random_state=42, stratify=y_temp
        )

        # OPTIMIZATION: Delete intermediate variables immediately to save CPU RAM
        del X_temp, y_temp
        gc.collect()

        cprint(f"üìä Data Distribution:", 'cyan', attrs=['bold'])
        print(f"   Train: {len(X_train):,} samples")
        print(f"   Val:   {len(X_val):,} samples")
        print(f"   Test:  {len(X_test):,} samples")
        print(f"   Batch: Train={train_batch_size}, Val={val_batch_size}")

        # Create datasets with transforms (assuming these classes exist)
        # You need to define these or import them
        try:
            # from your_data_module import FishDataset, DataManager  # Replace with actual imports
            train_dataset = FishDataset(X_train, y_train, DataManager.get_transforms(True, augmentation_strength))
            val_dataset = FishDataset(X_val, y_val, DataManager.get_transforms(False))
            test_dataset = FishDataset(X_test, y_test, DataManager.get_transforms(False))
        except ImportError:
            raise ImportError("FishDataset and DataManager classes not found. Please ensure they are imported.")
        except Exception as e:
            raise Exception(f"Error creating datasets: {e}")

        # OPTIMIZATION: More efficient class weight calculation to save memory
        unique_classes, class_counts = np.unique(y_train, return_counts=True)
        class_weights = len(y_train) / (len(unique_classes) * class_counts)
        class_weight_dict = dict(zip(unique_classes, class_weights))
        sample_weights = [class_weight_dict[y] for y in y_train]
        sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)

        # Clean up intermediate weight calculations
        del class_weights, class_weight_dict, sample_weights

        # OPTIMIZATION: Reduce CPU workers to save RAM, prioritize GPU feeding
        if num_workers is None:
            if torch.cuda.is_available():
                # More conservative worker count to save CPU RAM
                num_workers = min(os.cpu_count() // 2, 8)  # Reduced from 16
            else:
                num_workers = 2  # Minimal for CPU-only
        else:
            # Cap the provided num_workers to save CPU RAM
            num_workers = min(num_workers, os.cpu_count() // 2, 8)

        # OPTIMIZATION: Reduce prefetch factor to save CPU memory
        prefetch_factor = 2 if torch.cuda.is_available() else 1  # Reduced from 4

        # Create GPU-optimized data loaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=train_batch_size,
            sampler=sampler,
            num_workers=num_workers,
            pin_memory=pin_memory and torch.cuda.is_available(),
            prefetch_factor=prefetch_factor,
            persistent_workers=persistent_workers and num_workers > 0,
            worker_init_fn=worker_init_fn,
            drop_last=True
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=val_batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=pin_memory and torch.cuda.is_available(),
            prefetch_factor=prefetch_factor,
            persistent_workers=persistent_workers and num_workers > 0,
            worker_init_fn=worker_init_fn
        )

        test_loader = DataLoader(
            test_dataset,
            batch_size=val_batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=pin_memory and torch.cuda.is_available(),
            prefetch_factor=prefetch_factor,
            persistent_workers=persistent_workers and num_workers > 0,
            worker_init_fn=worker_init_fn
        )

        return train_loader, val_loader, test_loader, (X_val, y_val), (X_test, y_test)


def setup_maximum_gpu_utilization() -> Tuple[int, float, float]:
    """Setup GPU optimizations with proper error handling"""
    print("\n" + "="*40)
    cprint("üöÄ SETTING UP MAXIMUM GPU UTILIZATION", 'red', attrs=['bold'])
    print("="*40)

    gpu_memory_gb = 0.0
    if torch.cuda.is_available():
        # Aggressive GPU optimizations
        torch.cuda.empty_cache()
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.enabled = True

        # Maximum performance settings
        if hasattr(torch.backends.cuda.matmul, 'allow_tf32'):
            torch.backends.cuda.matmul.allow_tf32 = True
        if hasattr(torch.backends.cudnn, 'allow_tf32'):
            torch.backends.cudnn.allow_tf32 = True

        # Fixed: Proper Flash Attention setup
        if hasattr(torch.backends.cuda, 'enable_flash_sdp'):
            torch.backends.cuda.enable_flash_sdp(True)

        # OPTIMIZATION: Use more aggressive GPU memory (increased to 95%)
        torch.cuda.set_per_process_memory_fraction(0.95)

        # Get GPU specifications
        for i in range(torch.cuda.device_count()):
            gpu_props = torch.cuda.get_device_properties(i)
            gpu_memory_gb = max(gpu_memory_gb, gpu_props.total_memory / 1e9)

            cprint(f"üéÆ GPU {i}: {gpu_props.name}", 'green', attrs=['bold'])
            print(f"   Memory: {gpu_memory_gb:.1f}GB")
            print(f"   Compute: {gpu_props.major}.{gpu_props.minor}")
            print(f"   Cores: {gpu_props.multi_processor_count}")

        # Set multi-GPU if available
        if torch.cuda.device_count() > 1:
            cprint(f"üî• Using {torch.cuda.device_count()} GPUs!", 'red', attrs=['bold'])
    else:
        cprint("‚ö†Ô∏è  No GPU available - using CPU only", 'yellow', attrs=['bold'])

    # OPTIMIZATION: Conservative CPU optimizations to save RAM
    cpu_count = os.cpu_count()
    optimal_threads = min(cpu_count // 2, 8)  # More conservative to save CPU RAM

    torch.set_num_threads(optimal_threads)
    os.environ['OMP_NUM_THREADS'] = str(optimal_threads)
    os.environ['MKL_NUM_THREADS'] = str(optimal_threads)
    os.environ['NUMEXPR_NUM_THREADS'] = str(optimal_threads)

    # Memory information
    memory_info = psutil.virtual_memory()
    available_ram = memory_info.available / (1024**3)
    total_ram = memory_info.total / (1024**3)

    cprint(f"üíª CPU: {cpu_count} cores (using {optimal_threads})", 'blue', attrs=['bold'])
    cprint(f"üß† RAM: {total_ram:.1f}GB total, {available_ram:.1f}GB available", 'blue', attrs=['bold'])

    return optimal_threads, available_ram, gpu_memory_gb


def get_maximum_batch_sizes(model_name: str, available_ram_gb: float, gpu_memory_gb: float) -> Tuple[int, int]:
    """Calculate maximum batch sizes for full GPU utilization - ENHANCED for better GPU usage"""

    # OPTIMIZATION: More aggressive base batch sizes for better GPU utilization
    base_batch_sizes = {
        'resnet50': {'train': 96, 'val': 192},           # Increased from 64/128
        'efficientnet_b0': {'train': 128, 'val': 256},   # Increased from 96/192
        'mobilenet_v3_large': {'train': 160, 'val': 320}, # Increased from 128/256
        'vgg16': {'train': 48, 'val': 96},               # Increased from 32/64
        'densenet121': {'train': 64, 'val': 128},        # Increased from 48/96
        'inception_v3': {'train': 56, 'val': 112},       # Increased from 40/80
        'vit_b_16': {'train': 48, 'val': 96},            # Increased from 32/64
        'convnext_base': {'train': 48, 'val': 96},       # Increased from 36/72
        'regnet_y_32gf': {'train': 32, 'val': 64}       # Increased from 24/48
    }

    # OPTIMIZATION: More aggressive GPU memory scaling
    if gpu_memory_gb >= 24:  # High-end GPU (RTX 4090, A100)
        gpu_multiplier = 2.5  # Increased from 1.8
    elif gpu_memory_gb >= 16:  # Mid-range GPU (RTX 4080, 3090)
        gpu_multiplier = 2.0  # Increased from 1.5
    elif gpu_memory_gb >= 12:  # RTX 4070Ti, 3080Ti
        gpu_multiplier = 1.7  # New tier
    elif gpu_memory_gb >= 8:   # Entry-level GPU (RTX 3070, 4060Ti)
        gpu_multiplier = 1.3  # Increased from 1.2
    else:
        gpu_multiplier = 1.0

    # RAM scaling - less conservative since we're prioritizing GPU
    ram_multiplier = min(1.5, available_ram_gb / 16)  # Reduced impact
    total_multiplier = gpu_multiplier * 0.8 + ram_multiplier * 0.2  # 80% GPU focus, 20% RAM

    model_key = model_name.lower()
    if model_key not in base_batch_sizes:
        model_key = 'resnet50'

    base_train = base_batch_sizes[model_key]['train']
    base_val = base_batch_sizes[model_key]['val']

    train_batch = int(base_train * total_multiplier)
    val_batch = int(base_val * total_multiplier)

    # OPTIMIZATION: Higher minimum viable sizes for better GPU utilization
    train_batch = max(32, train_batch)  # Increased from 16
    val_batch = max(64, val_batch)      # Increased from 32

    return train_batch, val_batch


class HyperparameterOptimizer:
    """Enhanced hyperparameter optimizer with proper error handling"""

    def __init__(self, model_name: str, train_loader, val_loader, n_trials: int = 100,
                  train_batch_size: int = 64, val_batch_size: int = 128, X = None , Y = None ):
        self.model_name = model_name
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.n_trials = n_trials
        self.train_batch_size = train_batch_size  # ADD THIS
        self.val_batch_size = val_batch_size      # ADD THIS
        self.X = X  # ADD THIS
        self.Y = Y  # ADD THIS

        # Set Google Drive path
        self.drive_path = '/content/drive/MyDrive/Hilsha'
        os.makedirs(self.drive_path, exist_ok=True)

        # Use all available GPUs
        if torch.cuda.is_available():
            self.device = torch.device('cuda:0')
            self.use_multi_gpu = torch.cuda.device_count() > 1
        else:
            self.device = torch.device('cpu')
            self.use_multi_gpu = False

        self.lock = Lock()
        self.best_accuracy = 0.0
        self.current_trial = 0

        # Track best trial information
        self.best_trial_info = {
            'trial_number': 0,
            'accuracy': 0.0,
            'train_loss': 0.0,
            'val_loss': 0.0,
            'train_acc': 0.0,
            'val_acc': 0.0,
            'train_f1': 0.0,
            'val_f1': 0.0,
            'hyperparameters': {}
        }

    def suggest_hyperparameters(self, trial) -> Dict[str, Any]:
        return {
            # More conservative learning rate range for better convergence
            'lr': trial.suggest_float('lr', 5e-6, 5e-3, log=True),

            # Wider weight decay range for better regularization
            'weight_decay': trial.suggest_float('weight_decay', 1e-7, 5e-2, log=True),

            # Add RMSprop which works well for many vision models
            'optimizer': trial.suggest_categorical('optimizer', ['adamw', 'adam', 'sgd', 'rmsprop']),

            # Add more scheduler options including warmup restart
            'scheduler': trial.suggest_categorical('scheduler', ['cosine', 'cosine_warm', 'step', 'plateau']),

            # Reduce label smoothing max for better accuracy
            'label_smoothing': trial.suggest_float('label_smoothing', 0.0, 0.15),

            # More conservative gradient clipping
            'gradient_clip': trial.suggest_float('gradient_clip', 0.5, 1.5),

            # Extended warmup range
            'warmup_epochs': trial.suggest_int('warmup_epochs', 0, 5),

            # Model-specific dropout based on architecture
            'dropout': self._get_model_specific_dropout(trial),

            # Flexible batch size multipliers
            # Get batch multipliers first
            'train_batch_multiplier': trial.suggest_categorical('train_batch_multiplier', [0.5, 0.75, 1.0, 1.25, 1.5, 2.0]),
            'val_batch_multiplier': trial.suggest_categorical('val_batch_multiplier', [0.5, 0.75, 1.0, 1.25, 1.5, 2.0]),

            # Calculate and round to nearest power of 2 for GPU efficiency
            'train_batch_size': max(8, 2 ** round(__import__('math').log2(max(8, int(self.train_batch_size * trial.params['train_batch_multiplier']))))),
            'val_batch_size': max(16, 2 ** round(__import__('math').log2(max(16, int(self.val_batch_size * trial.params['val_batch_multiplier'])))))
        }

    def _get_model_specific_dropout(self, trial):
        """Get model-specific dropout ranges"""
        if 'vgg' in self.model_name.lower():
            # VGG needs higher dropout
            return trial.suggest_float('dropout', 0.3, 0.7)
        elif 'mobilenet' in self.model_name.lower():
            # MobileNet is already regularized
            return trial.suggest_float('dropout', 0.1, 0.4)
        elif 'efficientnet' in self.model_name.lower():
            # EfficientNet has built-in regularization
            return trial.suggest_float('dropout', 0.1, 0.4)
        elif 'inception' in self.model_name.lower():
            # Inception needs moderate dropout
            return trial.suggest_float('dropout', 0.2, 0.5)
        else:
            # ResNet, DenseNet - standard range
            return trial.suggest_float('dropout', 0.1, 0.5)

    def create_model_with_params(self, params: Dict[str, Any]):
        """Create model with parameters - you need to implement this"""
        # This is a placeholder - implement your model creation logic
        try:
            # from your_model_module import ModelFactory  # Replace with actual import
            model = ModelFactory.create_model(self.model_name, params)
            if self.use_multi_gpu:
                model = nn.DataParallel(model)
            return model.to(self.device)
        except ImportError:
            raise ImportError("ModelFactory not found. Please ensure it is imported.")
        except Exception as e:
            raise Exception(f"Error creating model: {e}")

    def display_hyperparameters(self, trial_num: int, params: Dict[str, Any]):
        """Display hyperparameters in a formatted way"""
        print("\n" + "üîß" * 40)
        cprint(f"üìã TRIAL {trial_num} HYPERPARAMETERS - {self.model_name.upper()}", 'cyan', attrs=['bold'])
        print("üîß" * 40)

        # Display batch configuration with multipliers
        cprint("  üéØ BATCH CONFIGURATION:", 'yellow', attrs=['bold'])
        print(f"    üîπ {'base_train_batch':<20}: {self.train_batch_size}")
        print(f"    üîπ {'train_multiplier':<20}: {params.get('train_batch_multiplier', 1.0)}")
        print(f"    üîπ {'final_train_batch':<20}: {params['train_batch_size']}")
        print(f"    üîπ {'base_val_batch':<20}: {self.val_batch_size}")
        print(f"    üîπ {'val_multiplier':<20}: {params.get('val_batch_multiplier', 1.0)}")
        print(f"    üîπ {'final_val_batch':<20}: {params['val_batch_size']}")

        # Display other hyperparameters
        cprint("  üéØ HYPERPARAMETERS:", 'yellow', attrs=['bold'])
        skip_keys = ['train_batch_size', 'val_batch_size', 'train_batch_multiplier', 'val_batch_multiplier']
        for key, value in params.items():
            if key not in skip_keys:
                if isinstance(value, float):
                    print(f"    üîπ {key:<20}: {value:.8f}")
                else:
                    print(f"    üîπ {key:<20}: {value}")
        print("üîß" * 40)

    def display_best_trial_status(self):
        """Display current best trial information"""
        # print("\n" + "üèÜ" * 40)
        cprint(f"üëë CURRENT BEST TRIAL STATUS - {self.model_name.upper()}", 'red', attrs=['bold'])
        print("üèÜ" * 40)

        if self.best_trial_info['trial_number'] > 0:
            cprint(f"  ü•á Best Trial : #{self.best_trial_info['trial_number']}", 'yellow', attrs=['bold'])
            cprint(f"  üéØ Best Accuracy: {self.best_trial_info['accuracy']:.4f}%", 'green', attrs=['bold'])

            # Display metrics
            print(f"  üìä METRICS:")
            print(f"    üî∏ Train Loss:     {self.best_trial_info['train_loss']:.6f}")
            print(f"    üî∏ Val Loss:       {self.best_trial_info['val_loss']:.6f}")
            print(f"    üî∏ Train Accuracy: {self.best_trial_info['train_acc']:.4f}%")
            print(f"    üî∏ Val Accuracy:   {self.best_trial_info['val_acc']:.4f}%")
            print(f"    üî∏ Train F1:       {self.best_trial_info['train_f1']:.4f}%")
            print(f"    üî∏ Val F1:         {self.best_trial_info['val_f1']:.4f}%")
        else:
            cprint("  üîÑ No trials completed yet", 'yellow')

        # print("üèÜ" * 80)

    def create_optimizer_and_scheduler(self, model, params: Dict[str, Any], steps_per_epoch: int):
        """Create optimizer and scheduler with GPU optimizations"""

        # Create optimizer
        if params['optimizer'] == 'adamw':
            optimizer = torch.optim.AdamW(
                model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'],
                betas=(0.9, 0.999), eps=1e-8
            )
        elif params['optimizer'] == 'adam':
            optimizer = torch.optim.Adam(
                model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'],
                betas=(0.9, 0.999), eps=1e-8
            )
        elif params['optimizer'] == 'sgd':
            optimizer = torch.optim.SGD(
                model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'],
                momentum=0.9, nesterov=True
            )
        else:  # rmsprop
            optimizer = torch.optim.RMSprop(
                model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'],
                momentum=0.9, alpha=0.99
            )

        # Create scheduler
        if params['scheduler'] == 'cosine':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=8)
        elif params['scheduler'] == 'cosine_warm':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=3, T_mult=2)
        elif params['scheduler'] == 'step':
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.7)
        else:  # plateau
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode='max', patience=2, factor=0.5
            )

        return optimizer, scheduler

    def train_and_validate(self, model, params: Dict[str, Any], train_loader, val_loader, epochs: int=None, trial=None) -> Tuple[float, Dict]:
        #Enhanced training with comprehensive metrics and GPU utilization
        if epochs is None:
            epochs = Config.OPTUNA_EPOCHS

        steps_per_epoch = len(train_loader)
        optimizer, scheduler = self.create_optimizer_and_scheduler(model, params, steps_per_epoch)
        criterion = nn.CrossEntropyLoss(label_smoothing=params.get('label_smoothing', 0.0))
        scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None

        # OPTIMIZATION: Enable aggressive GPU optimizations
        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True
            torch.backends.cudnn.deterministic = False
            # Reset peak memory stats for accurate monitoring
            torch.cuda.reset_peak_memory_stats()

        # OPTIMIZATION: Use larger gradient accumulation for better GPU utilization
        gradient_accumulation_steps = max(1, 128 // params.get('train_batch_size', train_loader.batch_size))

        best_val_acc = 0.0
        metrics_history = []
        patience = 0  # ‡¶≤‡ßÅ‡¶™‡ßá‡¶∞ ‡¶¨‡¶æ‡¶á‡¶∞‡ßá initialize
        epoch_best_f1 = 0.0

        for epoch in range(epochs):
            # OPTIMIZATION: Aggressive memory cleanup each epoch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            # Training phase
            model.train()
            train_loss = 0.0
            train_correct = 0
            train_total = 0
            train_preds = []
            train_targets = []

            accumulated_loss = 0

            train_pbar = tqdm(
                train_loader,
                desc=f"  üèÉ Epoch {epoch+1:2d} Train",
                leave=False,
                ncols=100
            )

            for batch_idx, (data, targets) in enumerate(train_pbar):
                data, targets = data.to(self.device, non_blocking=True), targets.to(self.device, non_blocking=True)

                if scaler and torch.cuda.is_available():
                    # OPTIMIZATION: More aggressive mixed precision usage
                    with torch.cuda.amp.autocast(enabled=True):
                        outputs = model(data)
                        loss = criterion(outputs, targets) / gradient_accumulation_steps

                    scaler.scale(loss).backward()
                    accumulated_loss += loss.item()

                    # OPTIMIZATION: Better gradient accumulation for larger effective batch size
                    if (batch_idx + 1) % gradient_accumulation_steps == 0:
                        scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(model.parameters(), params.get('gradient_clip', 1.0))
                        scaler.step(optimizer)
                        scaler.update()
                        optimizer.zero_grad()
                else:
                    outputs = model(data)
                    loss = criterion(outputs, targets) / gradient_accumulation_steps
                    loss.backward()

                    if (batch_idx + 1) % gradient_accumulation_steps == 0:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), params.get('gradient_clip', 1.0))
                        optimizer.step()
                        optimizer.zero_grad()

                # Metrics calculation
                train_loss += loss.item() * gradient_accumulation_steps * data.size(0)
                _, predicted = torch.max(outputs, 1)
                train_total += targets.size(0)
                train_correct += (predicted == targets).sum().item()

                train_preds.extend(predicted.cpu().numpy())
                train_targets.extend(targets.cpu().numpy())

                # Update progress bar
                current_acc = 100 * train_correct / train_total
                train_pbar.set_postfix({
                    'Loss': f'{loss.item() * gradient_accumulation_steps:.4f}',
                    'Acc': f'{current_acc:.2f}%'
                })

            # Validation phase
            model.eval()
            val_loss = 0.0
            val_correct = 0
            val_total = 0
            val_preds = []
            val_targets = []

            with torch.no_grad():
                for data, targets in val_loader:
                    data, targets = data.to(self.device, non_blocking=True), targets.to(self.device, non_blocking=True)

                    if scaler and torch.cuda.is_available():
                        with torch.cuda.amp.autocast():
                            outputs = model(data)
                            loss = criterion(outputs, targets)
                    else:
                        outputs = model(data)
                        loss = criterion(outputs, targets)

                    val_loss += loss.item() * data.size(0)
                    _, predicted = torch.max(outputs, 1)
                    val_total += targets.size(0)
                    val_correct += (predicted == targets).sum().item()

                    val_preds.extend(predicted.cpu().numpy())
                    val_targets.extend(targets.cpu().numpy())

            # Calculate metrics
            train_accuracy = 100 * train_correct / train_total
            val_accuracy = 100 * val_correct / val_total
            avg_train_loss = train_loss / train_total
            avg_val_loss = val_loss / val_total

            train_f1 = f1_score(train_targets, train_preds, average='weighted') * 100
            val_f1 = f1_score(val_targets, val_preds, average='weighted') * 100

            val_wrong = val_total - val_correct

            # Print epoch results
            print(f"    üìä TL:{avg_train_loss:.4f} VL:{avg_val_loss:.4f} | " +
                  f"TA:{train_accuracy:.2f}% VA:{val_accuracy:.2f}% | " +
                  f"TF1:{train_f1:.2f}% VF1:{val_f1:.2f}% | " +
                  f"WP:{val_wrong}")

            # Store metrics
            epoch_metrics = {
                'epoch': epoch + 1,
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
                'train_acc': train_accuracy,
                'val_acc': val_accuracy,
                'train_f1': train_f1,
                'val_f1': val_f1,
                'wrong_predictions': val_wrong
            }
            metrics_history.append(epoch_metrics)

            # Update best accuracy
            if val_accuracy > best_val_acc:
                best_val_acc = val_accuracy

            # Update scheduler
            if params['scheduler'] == 'plateau':
                scheduler.step(val_accuracy)
            else:
                scheduler.step()

            # Early stopping for optimization speed
            if epoch+1 >= 5 and val_accuracy < 50.0:
                print(f"‚ö†Ô∏è Early stopping cause epoch {epoch+1} but still not satisfactory accuracy obtain.")
                break

            # Early stopping for not improvement
            if val_f1 > epoch_best_f1 * 1.001:  # improvement condition (0.1% increment)
                epoch_best_f1 = val_f1
                patience = 0  # reset patience, ‡¶ï‡¶æ‡¶∞‡¶£ improvement ‡¶π‡¶Ø‡¶º‡ßá‡¶õ‡ßá
            else:
                patience += 1  # no improvement, patience ‡¶¨‡¶æ‡¶°‡¶º‡¶æ‡¶ì
            if patience > Config.PATIENCE:
                print(f"‚ö†Ô∏è Early stopping: No improvement for {Config.PATIENCE} consecutive epochs")
                break

            #Early stopping for trial level pruning
            # Report intermediate value for pruning
            if trial is not None:
                trial.report(val_accuracy, epoch)
                # Check if trial should be pruned
                if trial.should_prune():
                    print(f"    ‚ö†Ô∏è Early stopping at epoch {epoch+1}: Low accuracy probability detected")
                    print(f"    üîÑ Pruning trial - proceeding to next hyperparameter combination")
                    raise optuna.exceptions.TrialPruned()

        return best_val_acc, {'history': metrics_history, 'best_epoch_metrics': max(metrics_history, key=lambda x: x['val_acc'])}

    def objective(self, trial) -> float:
      """Enhanced objective function with detailed progress tracking"""
      self.current_trial += 1

      print("\n" + "‚ñà" * 100)
      cprint(f"üî• TRIAL {self.current_trial:3d}/{self.n_trials} STARTING - {self.model_name.upper()}", 'red', attrs=['bold'])
      print("‚ñà" * 100)

      try:
          with self.lock:
              if torch.cuda.is_available():
                  torch.cuda.empty_cache()
                  for i in range(torch.cuda.device_count()):
                      memory_used = torch.cuda.memory_allocated(i) / 1e9
                      memory_total = torch.cuda.get_device_properties(i).total_memory / 1e9
                      print(f"  üéÆ GPU {i}: {memory_used:.1f}/{memory_total:.1f}GB ({memory_used/memory_total*100:.1f}%)")
              gc.collect()

          # Get hyperparameters
          params = self.suggest_hyperparameters(trial)
          self.display_hyperparameters(self.current_trial, params)

          # Recreate data loaders with new batch sizes if they differ significantly
          current_train_batch = params['train_batch_size']
          current_val_batch = params['val_batch_size']

          # Only recreate loaders if batch size changed significantly
          if (abs(current_train_batch - self.train_loader.batch_size) > 8 or
              abs(current_val_batch - self.val_loader.batch_size) > 16):

              train_loader, val_loader, _, _, _ = Optuna_DataManager.create_data_loaders(
                  self.X, self.Y,
                  train_batch_size=current_train_batch,
                  val_batch_size=current_val_batch,
                  num_workers=4,  # OPTIMIZATION: Reduced workers to save CPU RAM
                  pin_memory=True,
                  persistent_workers=True
              )
              using_new_loaders = True  # ADD THIS FLAG
          else:
              train_loader = self.train_loader
              val_loader = self.val_loader
              using_new_loaders = False  # ADD THIS FLAG

          # Create and train model
          model = self.create_model_with_params(params)
          best_acc, detailed_metrics = self.train_and_validate(
              model, params, train_loader, val_loader, trial=trial  # PASS THE LOADERS
          )

          # SUCCESS HANDLING - Move this BEFORE the except blocks
          best_epoch_metrics = detailed_metrics['best_epoch_metrics']

          # Trial completion summary
          print("  " + "‚îÄ" * 80)
          cprint(f"  ‚úÖ TRIAL {self.current_trial} COMPLETED", 'green', attrs=['bold'])
          cprint(f"  üéØ Highest Validation Accuracy for this Trial: {best_acc:.4f}%", 'yellow', attrs=['bold'])

          # Update best trial info if this is better
          if best_acc > self.best_accuracy:
              self.best_accuracy = best_acc
              self.best_trial_info = {
                  'trial_number': self.current_trial,
                  'accuracy': best_acc,
                  'train_loss': best_epoch_metrics['train_loss'],
                  'val_loss': best_epoch_metrics['val_loss'],
                  'train_acc': best_epoch_metrics['train_acc'],
                  'val_acc': best_epoch_metrics['val_acc'],
                  'train_f1': best_epoch_metrics['train_f1'],
                  'val_f1': best_epoch_metrics['val_f1'],
                  'hyperparameters': params.copy()
              }
              cprint(f"  üèÜ NEW BEST ACCURACY: {best_acc:.4f}%", 'red', attrs=['bold'])
              # Save immediately to Google Drive only
              self.save_best_params_immediately()

          self.display_best_trial_status()

          # OPTIMIZATION: Cleanup - ADD LOADER CLEANUP
          del model
          if using_new_loaders:  # Clean up new loaders if created
              del train_loader, val_loader
          if torch.cuda.is_available():
              torch.cuda.empty_cache()
          gc.collect()

          return best_acc

      except optuna.exceptions.TrialPruned:
          cprint(f"  ‚úÇÔ∏è TRIAL {self.current_trial} PRUNED: Low accuracy probability detected", 'yellow', attrs=['bold'])
          cprint(f"  üîÑ Skipping to next hyperparameter combination for efficiency", 'cyan')
          # Cleanup
          if 'model' in locals():
              del model
          if 'using_new_loaders' in locals() and using_new_loaders:
              if 'train_loader' in locals():
                  del train_loader
              if 'val_loader' in locals():
                  del val_loader
          if torch.cuda.is_available():
              torch.cuda.empty_cache()
          gc.collect()
          raise  # Re-raise the TrialPruned exception

      except Exception as e:
          cprint(f"  ‚ùå TRIAL {self.current_trial} FAILED: {e}", 'red', attrs=['bold'])
          cprint(f"  üìã Error Details: {traceback.format_exc()}", 'yellow')
          # Cleanup
          if 'model' in locals():
              del model
          if 'using_new_loaders' in locals() and using_new_loaders:
              if 'train_loader' in locals():
                  del train_loader
              if 'val_loader' in locals():
                  del val_loader
          if torch.cuda.is_available():
              torch.cuda.empty_cache()
          gc.collect()
          return 0.0

    def optimize(self) -> Dict[str, Any]:
        """Run optimization with enhanced progress tracking"""

        # print("\n" + "üöÄ" *20)
        cprint(f"STARTING HYPERPARAMETER OPTIMIZATION FOR {self.model_name.upper()}", 'red', attrs=['bold'])
        print("üöÄ" * 20)

        # Create study
        study = optuna.create_study(
            direction='maximize',
            sampler=optuna.samplers.TPESampler(
                # OPTIMIZATION: Reduced startup trials to save time and CPU
                n_startup_trials = max(10, self.n_trials // 8),  # Reduced from n_trials // 6
                # Example: If you set n_startup_trials=10, the first 10 trials will be random, then trial 11 onwards will use TPE-guided sampling.
                n_ei_candidates=24,  # Reduced from 32 to save CPU
                constant_liar=True,
                multivariate=True
            ),
            # Sampler's n_startup_trials ‚Üí when TPE optimization begins.
            # Pruner's n_startup_trials ‚Üí how many full trials to finish before pruning starts.
            # Pruner's n_warmup_steps ‚Üí how many epochs per trial to protect before pruning checks.
            pruner=optuna.pruners.MedianPruner(
                # OPTIMIZATION: More aggressive pruning to save resources
                n_startup_trials = max(6, self.n_trials //12),  # More aggressive
                n_warmup_steps=2,    # Reduced from 3
                interval_steps=1
            )
        )

        # Run optimization with early stopping check
        cprint(f"üéØ Target: {self.n_trials} trials", 'cyan', attrs=['bold'])
        for trial_num in range(self.n_trials):
            try:
                study.optimize(self.objective, timeout=None, n_jobs=1, n_trials=1)

                # Check for early stopping after each trial
                if self.best_accuracy >= 99.5:
                    cprint(f"\nüéØ TARGET ACCURACY ACHIEVED!", 'red', attrs=['bold'])
                    cprint(f"üèÜ Best Accuracy: {self.best_accuracy:.4f}% >= 99.5%", 'green', attrs=['bold'])
                    cprint(f"‚ö° Stopping optimization early after {self.current_trial} trials", 'yellow', attrs=['bold'])
                    cprint(f"üöÄ Moving to next model for maximum efficiency!", 'cyan', attrs=['bold'])
                    break

            except KeyboardInterrupt:
                cprint(f"\n‚ö†Ô∏è Optimization interrupted by user", 'yellow', attrs=['bold'])
                break
            except Exception as e:
                cprint(f"‚ö†Ô∏è Trial failed: {e}", 'red')
                continue

        # Final results
        print("\n" + "üèÅ" * 40)
        if self.best_accuracy >= 99.5:
            cprint(f"üéØ OPTIMIZATION COMPLETED - TARGET ACHIEVED!", 'green', attrs=['bold'])
            cprint(f"‚ö° Completed in {self.current_trial} trials (saved {self.n_trials - self.current_trial} trials)", 'yellow', attrs=['bold'])
        else:
            cprint(f"OPTIMIZATION COMPLETED FOR {self.model_name.upper()}", 'green', attrs=['bold'])

        cprint(f"üèÜ OPTIMIZATION BEST ACCURACY: {self.best_accuracy:.4f}%", 'red', attrs=['bold'])
        print("üèÅ" * 40)

        return study.best_params if study.best_params else {}

    def save_best_params_immediately(self) -> None:
        """Save best parameters immediately to Google Drive only"""
        if self.best_trial_info['trial_number'] == 0:
            return

        # Google Drive save only
        drive_file = f"{self.drive_path}/{self.model_name}_best_params_trial_{self.best_trial_info['trial_number']}.json"

        results_with_meta = {
            'model_name': self.model_name,
            'trial_number': self.best_trial_info['trial_number'],
            'accuracy': self.best_trial_info['accuracy'],
            'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"),
            'metrics': {
                'train_loss': self.best_trial_info['train_loss'],
                'val_loss': self.best_trial_info['val_loss'],
                'train_acc': self.best_trial_info['train_acc'],
                'val_acc': self.best_trial_info['val_acc'],
                'train_f1': self.best_trial_info['train_f1'],
                'val_f1': self.best_trial_info['val_f1']
            },
            'hyperparameters': self.best_trial_info['hyperparameters']
        }

        # Save to Google Drive only
        with open(drive_file, 'w') as f:
            json.dump(results_with_meta, f, indent=4, sort_keys=True)

        cprint(f"  ‚òÅÔ∏è Best params saved to Drive: {drive_file}", 'green')

def optimize_single_model(model_name: str, config: Dict[str, Any]) -> Dict[str, Any]:
    """Optimize single model with maximum GPU utilization"""

    print("\n" + "‚ö°" * 50)
    cprint(f"OPTIMIZING {model_name.upper()}", 'red', attrs=['bold'])
    print("‚ö°" * 50)

    try:
        # Setup environment
        optimal_threads, available_ram, gpu_memory_gb = setup_maximum_gpu_utilization()

        # Get maximum batch sizes
        train_batch_size, val_batch_size = get_maximum_batch_sizes(
            model_name, available_ram, gpu_memory_gb
        )

        cprint(f"üéØ Maximum Batch Sizes - Train: {train_batch_size}, Val: {val_batch_size}", 'green', attrs=['bold'])

        # Create data loaders with reduced workers to save CPU RAM
        train_loader, val_loader, test_loader, val_data, test_data = Optuna_DataManager.create_data_loaders(
            config['X'], config['Y'],
            train_batch_size=train_batch_size,
            val_batch_size=val_batch_size,
            num_workers=optimal_threads//2,  # OPTIMIZATION: Reduced workers
            pin_memory=True,
            persistent_workers=True
        )

        # Run optimization
        optimizer = HyperparameterOptimizer(
            model_name, train_loader, val_loader,
            n_trials=Config.OPTUNA_TRIALS,
            train_batch_size=train_batch_size,
            val_batch_size=val_batch_size,
            X=config['X'],
            Y=config['Y']
        )

        best_params = optimizer.optimize()

        # OPTIMIZATION: Cleanup
        del optimizer, train_loader, val_loader, test_loader
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()

        return best_params

    except Exception as e:
        cprint(f"‚ùå OPTIMIZATION FAILED FOR {model_name}: {e}", 'red', attrs=['bold'])
        cprint(f"üìã Error: {traceback.format_exc()}", 'yellow')
        return {}

def parallel_hyperparameter_optimization(model_configs: Dict[str, Any], max_workers: int = 1) -> Dict[str, Any]:
    # Run optimization with sequential processing for maximum GPU utilization

    results = {}

    # print("\n" + "üé™" * 40)
    cprint("STARTING PARALLEL HYPERPARAMETER OPTIMIZATION", 'red', attrs=['bold'])
    # print("üé™" * 50)

    # Sequential processing for maximum GPU utilization per model
    for i, (model_name, config) in enumerate(model_configs.items(), 1):
        cprint(f"\nüìç MODEL {i}/{len(model_configs)}: {model_name.upper()}", 'cyan', attrs=['bold'])

        try:
            best_params = optimize_single_model(model_name, config)
            results[model_name] = best_params

            if best_params:
                cprint(f"‚úÖ {model_name.upper()} OPTIMIZATION COMPLETED!", 'green', attrs=['bold'])
            else:
                cprint(f"‚ùå {model_name.upper()} OPTIMIZATION FAILED!", 'red', attrs=['bold'])

        except Exception as e:
            cprint(f"‚ùå {model_name.upper()} CRASHED: {e}", 'red', attrs=['bold'])
            results[model_name] = {}

        # OPTIMIZATION: Aggressive cleanup between models
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
        gc.collect()

    return results

def save_optimization_results(results: Dict[str, Any]) -> None:
    """Save optimization results to Google Drive only"""

    # Set Google Drive path
    drive_path = '/content/drive/MyDrive/Hilsha/hyper-parameters'
    os.makedirs(f"{drive_path}/hyperparameters", exist_ok=True)

    print("\n" + "üíæ" * 50)
    cprint("SAVING OPTIMIZATION RESULTS TO GOOGLE DRIVE", 'cyan', attrs=['bold'])
    print("üíæ" * 50)

    # Save individual model results
    successful_models = 0
    for model_name, best_params in results.items():
        if best_params:
            # Google Drive save only
            drive_file = f"{drive_path}/hyperparameters/{model_name}_best_params.json"

            # Enhanced metadata
            results_with_meta = {
                'model_name': model_name,
                'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"),
                'gpu_optimized': True,
                'hyperparameters': best_params,
                'optimization_config': {
                    'framework': 'optuna',
                    'sampler': 'TPE_Multivariate',
                    'pruner': 'Median',
                    'trials': 40,
                    'gpu_acceleration': torch.cuda.is_available(),
                    'multi_gpu': torch.cuda.device_count() > 1 if torch.cuda.is_available() else False
                }
            }

            with open(drive_file, 'w') as f:
                json.dump(results_with_meta, f, indent=4, sort_keys=True)

            cprint(f"‚úÖ {model_name.upper()} parameters saved to Google Drive!", 'green')

            # Display best parameters
            print(f"  üìã {model_name.upper()} BEST PARAMETERS:")
            for key, value in best_params.items():
                if isinstance(value, float):
                    print(f"    üîπ {key:<20}: {value:.6f}")
                else:
                    print(f"    üîπ {key:<20}: {value}")
            print()

            successful_models += 1

    # Save master results file to Google Drive
    master_file = f"{drive_path}/hyperparameters/all_best_params.json"

    # GPU information
    gpu_info = {}
    if torch.cuda.is_available():
        gpu_info = {
            'gpu_count': torch.cuda.device_count(),
            'gpu_names': [torch.cuda.get_device_properties(i).name for i in range(torch.cuda.device_count())],
            'total_gpu_memory_gb': sum(torch.cuda.get_device_properties(i).total_memory / 1e9 for i in range(torch.cuda.device_count()))
        }

    master_results = {
        'optimization_summary': {
            'total_models': len(results),
            'successful_optimizations': successful_models,
            'failed_optimizations': len(results) - successful_models,
            'success_rate_percent': (successful_models / len(results)) * 100 if results else 0,
            'gpu_accelerated': torch.cuda.is_available(),
            'system_info': {
                'cpu_cores': os.cpu_count(),
                'ram_gb': psutil.virtual_memory().total / (1024**3),
                **gpu_info
            }
        },
        'results': results
    }

    with open(master_file, 'w') as f:
        json.dump(master_results, f, indent=4, sort_keys=True)

    cprint(f"üíæ Master results saved to Google Drive: {master_file}", 'cyan', attrs=['bold'])

    # Final summary
    print("\n" + "üìä" * 50)
    cprint("OPTIMIZATION SUMMARY", 'yellow', attrs=['bold'])
    print("üìä" * 50)
    print(f"  üéØ Total Models: {len(results)}")
    print(f"  ‚úÖ Successful: {successful_models}")
    print(f"  ‚ùå Failed: {len(results) - successful_models}")
    print(f"  üìà Success Rate: {(successful_models / len(results)) * 100:.1f}%")
    if torch.cuda.is_available():
        print(f"  üéÆ GPU Acceleration: Enabled ({torch.cuda.device_count()} GPUs)")
    else:
        print(f"  üíª GPU Acceleration: Disabled (CPU only)")

def display_startup_banner():
    #Display an impressive startup banner
    banner = """
‚îå‚îÄ THE FISH OPTIMIZER ‚îÄ‚îê
‚îÇ   üêü Optimizing üêü   ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
    """

    # print("\n" + "="*120)
    cprint(banner, 'red', attrs=['bold'])
    # print("="*120)
    # cprint("üöÄ MAXIMUM GPU-ACCELERATED HYPERPARAMETER OPTIMIZATION üöÄ", 'yellow', attrs=['bold'])
    # cprint("üî• DESIGNED FOR MAXIMUM PERFORMANCE AND USER EXPERIENCE üî•", 'cyan', attrs=['bold'])
    # print("="*120)

def main():
    """Enhanced main function with spectacular UI and maximum GPU utilization"""

    # Display startup banner
    display_startup_banner()

    # Environment setup with detailed reporting
    print("\nüîß SYSTEM INITIALIZATION")
    print("‚îÄ" * 50)

    optimal_threads, available_ram, gpu_memory_gb = setup_maximum_gpu_utilization()

    # Data loading with progress
    print("\nüìä DATA LOADING AND PREPROCESSING")
    print("‚îÄ" * 50)

    try:
        # You need to implement or import these classes
        # from your_data_module import DataManager  # Replace with actual import

        cprint("üîÑ Loading and balancing dataset...", 'cyan', attrs=['bold'])
        X, Y = DataManager.load_and_balance_data()

        cprint(f"‚úÖ Dataset loaded successfully!", 'green', attrs=['bold'])
        print(f"   üìà Total samples: {len(X):,}")
        print(f"   üè∑Ô∏è  Total labels: {len(Y):,}")
        print(f"   üìä Classes: {len(np.unique(Y))}")

        if len(X) != len(Y) or len(X) == 0:
            raise ValueError("Invalid dataset: inconsistent or empty data")

    except ImportError:
        cprint("‚ùå DataManager not found. Please ensure it is imported.", 'red', attrs=['bold'])
        return
    except Exception as e:
        cprint(f"‚ùå Data loading failed: {e}", 'red', attrs=['bold'])
        return

    # Prepare model configurations
    print("\nü§ñ MODEL CONFIGURATION")
    print("‚îÄ" * 50)

    # Default models if Config.MODELS is not available
    try:
        # from your_config_module import Config  # Replace with actual import
        models = Config.MODELS
    except ImportError:
        cprint("‚ö†Ô∏è  Config not found. Using default models.", 'yellow', attrs=['bold'])
        models = ['resnet50', 'efficientnet_b0', 'mobilenet_v3_large']

    model_configs = {}
    for i, model_name in enumerate(models, 1):
        model_configs[model_name] = {'X': X, 'Y': Y}
        print(f"  {i:2d}. {model_name}")

    cprint(f"üéØ Configured {len(models)} models for optimization", 'green', attrs=['bold'])

    # Run optimization
    print("\nüöÄ STARTING HYPERPARAMETER OPTIMIZATION")
    print("‚îÄ" * 50)

    start_time = time.time()

    all_best_params = parallel_hyperparameter_optimization(
        model_configs,
        max_workers=1
    )

    end_time = time.time()
    total_time = end_time - start_time

    # Save results to Google Drive only
    save_optimization_results(all_best_params)

    # Final summary
    print("\n" + "üéâ" * 45)
    cprint("üèÜ HYPERPARAMETER OPTIMIZATION COMPLETED! üèÜ", 'red', attrs=['bold'])
    print("üéâ" * 45)

    successful = sum(1 for params in all_best_params.values() if params)
    total = len(all_best_params)

    print(f"‚è±Ô∏è  Total Time: {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.0f}s")
    print(f"üìä Models Processed: {total}")
    print(f"‚úÖ Successful Optimizations: {successful}")
    print(f"‚ùå Failed Optimizations: {total - successful}")
    print(f"üìà Success Rate: {100*successful/total:.1f}%")
    print(f"üíæ Results Location: /content/drive/MyDrive/Hilsha/hyperparameters/")

    if torch.cuda.is_available():
        print(f"üéÆ GPU Utilization: Maximum")
        print(f"üî• Multi-GPU: {'Yes' if torch.cuda.device_count() > 1 else 'No'}")

    print("\n" + "üéâ" * 45)
    cprint("üöÄ READY FOR TRAINING WITH OPTIMIZED HYPERPARAMETERS! üöÄ", 'green', attrs=['bold'])

if __name__ == "__main__":
    main()

# ‚úÖ Step 9: Training - Pipeline

In [None]:

# ============================================================================
# Resource Management
# ============================================================================
class ResourceManager:
    """Smart resource management for optimal GPU/CPU utilization"""

    def __init__(self):
        self.gpu_memory_gb = 20
        self.cpu_memory_gb = 50
        self.max_gpu_usage = 0.85
        self.max_cpu_usage = 0.90

    def get_memory_stats(self):
        """Get current memory usage statistics"""
        stats = {'cpu_percent': psutil.virtual_memory().percent}

        if torch.cuda.is_available():
            stats['gpu_allocated_gb'] = torch.cuda.memory_allocated() / (1024**3)
            stats['gpu_reserved_gb'] = torch.cuda.memory_reserved() / (1024**3)
            stats['gpu_percent'] = (stats['gpu_reserved_gb'] / self.gpu_memory_gb) * 100
        else:
            stats.update({'gpu_allocated_gb': 0, 'gpu_reserved_gb': 0, 'gpu_percent': 0})

        return stats

    def should_cleanup_aggressive(self):
        """Check if aggressive cleanup is needed"""
        stats = self.get_memory_stats()
        return (stats['gpu_percent'] > 90 or stats['cpu_percent'] > 90)

    def aggressive_cleanup(self):
        """Perform comprehensive memory cleanup"""
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
        gc.collect()
        time.sleep(0.1)

    def optimize_batch_size(self, base_size, model_complexity=1.0):
        """Calculate optimal batch size based on current memory state"""
        stats = self.get_memory_stats()
        memory_factor = max(0.4, 1.0 - (stats['gpu_percent'] / 100))
        optimal_size = int(base_size * memory_factor / model_complexity)
        return max(32, min(256, optimal_size))

# ============================================================================
# Training Progress Tracker (Console Only)
# ============================================================================
class TrainingProgressTracker:
    """Track training progress without plotting dependencies"""

    def __init__(self, model_name, total_epochs, batches_per_epoch):
        self.model_name = model_name
        self.total_epochs = total_epochs
        self.batches_per_epoch = batches_per_epoch
        self.current_epoch = 0
        self.start_time = time.time()

    def start_epoch(self, epoch):
        """Start tracking an epoch"""
        self.current_epoch = epoch
        self.epoch_start_time = time.time()

    def update_batch(self, batch_idx, loss, acc, is_training=True, total_batches=None):
        """Update batch progress - simplified for console only"""
        if batch_idx % 50 == 0 and batch_idx > 0:
            phase = "Train" if is_training else "Val"
            elapsed = time.time() - self.epoch_start_time
            tqdm.write(f"  [{phase}] Batch {batch_idx:4d} - Loss: {loss:.4f}, Acc: {acc:.4f}, Time: {elapsed:.1f}s")

    def finish_epoch(self, train_loss, train_acc, val_loss, val_acc, val_f1, is_best=False, lr=None):
        """Finish epoch tracking"""
        epoch_time = time.time() - self.epoch_start_time
        total_time = time.time() - self.start_time

        status = "üåü NEW BEST!" if is_best else ""

        tqdm.write(f"\nEpoch {self.current_epoch + 1}/{self.total_epochs} Complete {status}")
        tqdm.write(f"  Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
        tqdm.write(f"  Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}")
        if lr:
            tqdm.write(f"  LR: {lr:.6f}")
        tqdm.write(f"  Epoch Time: {epoch_time:.1f}s, Total: {total_time:.1f}s")


# ============================================================================
# Model Evaluator (Training-focused)
# ============================================================================
class ModelEvaluator:
    """Model evaluation for training purposes (no plotting)"""

    def evaluate_model(self, model, data_loader, model_name):
        """Evaluate model and return metrics for saving"""
        model.eval()

        all_predictions = []
        all_labels = []
        all_probabilities = []
        misclassified_samples = []
        total_samples = 0
        correct_predictions = 0

        print(f"\nEvaluating {model_name}...")

        with torch.no_grad():
            for batch_idx, (images, labels) in enumerate(data_loader):
                images = images.to(Config.DEVICE, memory_format=torch.channels_last)
                labels = labels.to(Config.DEVICE)

                outputs = model(images)
                probabilities = F.softmax(outputs, dim=1)
                _, predicted = torch.max(outputs, 1)

                # Store results
                all_predictions.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_probabilities.extend(probabilities.cpu().numpy())

                # Track misclassified samples
                mask = predicted != labels
                if mask.any():
                    misclassified_indices = torch.where(mask)[0]
                    for idx in misclassified_indices:
                        misclassified_samples.append({
                            'batch_idx': batch_idx,
                            'sample_idx': idx.item(),
                            'true_label': labels[idx].item(),
                            'predicted_label': predicted[idx].item(),
                            'confidence': probabilities[idx].max().item(),
                            'image_tensor': images[idx].cpu()  # Store for later visualization
                        })

                # Update counters
                batch_size = labels.size(0)
                total_samples += batch_size
                correct_predictions += (predicted == labels).sum().item()

        # Calculate metrics
        accuracy = accuracy_score(all_labels, all_predictions)
        f1_macro = f1_score(all_labels, all_predictions, average='macro', zero_division=0)
        f1_weighted = f1_score(all_labels, all_predictions, average='weighted', zero_division=0)
        conf_matrix = confusion_matrix(all_labels, all_predictions)

        # Classification report
        class_report = classification_report(
            all_labels, all_predictions,
            target_names=Config.CLASS_NAMES,
            output_dict=True,
            zero_division=0
        )

        print(f"Evaluation Results for {model_name}:")
        print(f"  Accuracy: {accuracy:.4f}")
        print(f"  F1-Macro: {f1_macro:.4f}")
        print(f"  F1-Weighted: {f1_weighted:.4f}")
        print(f"  Misclassified: {len(misclassified_samples)}/{total_samples}")

        return {
            'accuracy': accuracy,
            'f1_macro': f1_macro,
            'f1_weighted': f1_weighted,
            'confusion_matrix': conf_matrix.tolist(),  # Convert to list for JSON serialization
            'classification_report': class_report,
            'predictions': all_predictions,
            'true_labels': all_labels,
            'probabilities': np.array(all_probabilities).tolist(),  # Convert for JSON
            'misclassified_count': len(misclassified_samples),
            'total_samples': total_samples,
            'misclassified_details': misclassified_samples[:50]  # Limit to first 50 for storage
        }

# ============================================================================
# Enhanced Model Trainer
# ============================================================================
class EnhancedModelTrainer:
    def __init__(self, model, model_name, hyperparameters):
        self.model = model.to(Config.DEVICE)
        self.model_name = model_name
        self.hyperparameters = hyperparameters
        self.best_val_acc = 0.0
        self.best_val_f1 = 0.0
        self.patience_counter = 0

        # Resource management
        self.resource_manager = ResourceManager()
        self.memory_check_interval = 15

        # Setup training components
        self._setup_training_components()

        # Initialize history for saving
        self.history = {
            'train_loss': [],
            'train_acc': [],
            'val_loss': [],
            'val_acc': [],
            'val_f1': [],
            'learning_rates': [],
            'epoch_times': [],
            'memory_usage': []
        }

    def _setup_training_components(self):
        """Setup optimizer, criterion, and scheduler"""
        # Filter hyperparameters
        allowed_keys = ['lr', 'weight_decay', 'dropout', 'hidden_dim_multiplier',
                       'augmentation_strength', 'batch_size', 'optimizer_type',
                       'scheduler_type', 'label_smoothing']
        self.hyperparameters = {k: v for k, v in self.hyperparameters.items() if k in allowed_keys}

        # Optimizer setup
        lr = self.hyperparameters.get('lr', Config.LEARNING_RATE)
        weight_decay = self.hyperparameters.get('weight_decay', Config.WEIGHT_DECAY)
        optimizer_type = self.hyperparameters.get('optimizer_type', 'adamw')

        if optimizer_type == 'adamw':
            self.optimizer = optim.AdamW(
                self.model.parameters(), lr=lr, weight_decay=weight_decay,
                fused=torch.cuda.is_available()
            )
        elif optimizer_type == 'adam':
            self.optimizer = optim.Adam(
                self.model.parameters(), lr=lr, weight_decay=weight_decay,
                fused=torch.cuda.is_available()
            )
        else:
            self.optimizer = optim.SGD(
                self.model.parameters(), lr=lr, weight_decay=weight_decay,
                momentum=0.9, nesterov=True
            )

        # Criterion
        label_smoothing = self.hyperparameters.get('label_smoothing', 0.1)
        self.criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)

        # Scheduler
        scheduler_type = self.hyperparameters.get('scheduler_type', 'cosine')
        if scheduler_type == 'cosine':
            self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
                self.optimizer, T_max=Config.EPOCHS, eta_min=1e-6
            )
        elif scheduler_type == 'plateau':
            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer, mode='min', factor=0.5, patience=5
            )
        else:
            self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.95)

        # Mixed precision scaler
        self.scaler = torch.cuda.amp.GradScaler(enabled=Config.USE_MIXED_PRECISION)

    def train_epoch(self, train_loader, progress_tracker):
        """Enhanced training epoch with smart memory management"""
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0
        batch_count = len(train_loader)

        tqdm.write(f"Training: {len(train_loader.dataset):,} samples, "
                   f"{batch_count:,} batches, batch_size: {train_loader.batch_size}")

        try:
            for batch_idx, (images, labels) in enumerate(train_loader):
                try:
                    # Smart memory management
                    if batch_idx % self.memory_check_interval == 0:
                        if self.resource_manager.should_cleanup_aggressive():
                            self.resource_manager.aggressive_cleanup()

                    # Move data to device
                    images = images.to(Config.DEVICE, non_blocking=True, memory_format=torch.channels_last)
                    labels = labels.to(Config.DEVICE, non_blocking=True)

                    # Forward pass
                    self.optimizer.zero_grad(set_to_none=True)
                    with torch.cuda.amp.autocast(enabled=Config.USE_MIXED_PRECISION):
                        outputs = self.model(images)
                        loss = self.criterion(outputs, labels)

                    # Backward pass
                    self.scaler.scale(loss).backward()
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                    self.scaler.step(self.optimizer)
                    self.scaler.update()

                    # Calculate metrics
                    _, predicted = torch.max(outputs, 1)
                    batch_acc = (predicted == labels).float().mean().item()
                    batch_loss = loss.item()

                    # Update totals
                    total_loss += batch_loss * images.size(0)
                    total += images.size(0)
                    correct += (predicted == labels).sum().item()

                    # Update progress
                    progress_tracker.update_batch(batch_idx, batch_loss, batch_acc, is_training=True, total_batches=batch_count)

                    # Memory cleanup
                    del outputs, loss, predicted, images, labels

                except Exception as e:
                    tqdm.write(f"Error in batch {batch_idx}: {str(e)}")
                    self.resource_manager.aggressive_cleanup()
                    continue

            # Final cleanup
            self.resource_manager.aggressive_cleanup()

            return total_loss / max(1, total), correct / max(1, total)

        except Exception as e:
            tqdm.write(f"Training epoch failed: {str(e)}")
            self.resource_manager.aggressive_cleanup()
            return float('inf'), 0.0

    def validate_epoch(self, val_loader, progress_tracker):
        """Enhanced validation epoch with memory optimization"""
        self.model.eval()
        total_loss = 0
        total_samples = 0
        all_predictions = []
        all_labels = []
        batch_count = len(val_loader)

        tqdm.write(f"Validation: {len(val_loader.dataset):,} samples, "
                   f"{batch_count:,} batches, batch_size: {val_loader.batch_size}")

        try:
            with torch.no_grad():
                for batch_idx, (images, labels) in enumerate(val_loader):
                    try:
                        images = images.to(Config.DEVICE, non_blocking=True, memory_format=torch.channels_last)
                        labels = labels.to(Config.DEVICE, non_blocking=True)

                        with torch.cuda.amp.autocast(enabled=Config.USE_MIXED_PRECISION):
                            outputs = self.model(images)
                            loss = self.criterion(outputs, labels)

                        _, predicted = torch.max(outputs, 1)
                        batch_acc = (predicted == labels).float().mean().item()
                        batch_loss = loss.item()

                        # Store results
                        total_loss += batch_loss * images.size(0)
                        total_samples += images.size(0)
                        all_predictions.extend(predicted.cpu().numpy())
                        all_labels.extend(labels.cpu().numpy())

                        # Update progress
                        progress_tracker.update_batch(batch_idx, batch_loss, batch_acc,
                                                    is_training=False, total_batches=batch_count)

                        # Memory cleanup
                        del outputs, loss, predicted, images, labels

                    except Exception as e:
                        tqdm.write(f"Error in validation batch {batch_idx}: {str(e)}")
                        continue

            # Calculate final metrics
            val_acc = accuracy_score(all_labels, all_predictions)
            val_f1 = f1_score(all_labels, all_predictions, average='macro', zero_division=0)

            return total_loss / max(1, total_samples), val_acc, val_f1

        except Exception as e:
            tqdm.write(f"Validation epoch failed: {str(e)}")
            self.resource_manager.aggressive_cleanup()
            return float('inf'), 0.0, 0.0

    def train_main_model(self, train_loader, val_loader, test_loader=None):
        """Main model training with comprehensive data saving"""
        if not train_loader or len(train_loader.dataset) == 0:
            tqdm.write(f"Skipping {self.model_name}: No training data")
            return False

        if not val_loader or len(val_loader.dataset) == 0:
            tqdm.write(f"Skipping {self.model_name}: No validation data")
            return False

        tqdm.write(f"\nTraining {self.model_name}")
        tqdm.write(f"Training samples: {len(train_loader.dataset):,}")
        tqdm.write(f"Validation samples: {len(val_loader.dataset):,}")
        tqdm.write(f"Total epochs: {Config.EPOCHS}")
        tqdm.write(f"Batch size: {train_loader.batch_size}")

        # Setup model for training
        self.model = self.model.to(Config.DEVICE, memory_format=torch.channels_last)

        # Progress tracker
        progress_tracker = TrainingProgressTracker(self.model_name, Config.EPOCHS, len(train_loader))

        # Training loop
        training_start_time = time.time()

        for epoch in range(Config.EPOCHS):
            epoch_start_time = time.time()
            tqdm.write(f"\nEpoch {epoch + 1}/{Config.EPOCHS}")

            progress_tracker.start_epoch(epoch)

            # Training phase
            train_loss, train_acc = self.train_epoch(train_loader, progress_tracker)

            # Validation phase
            val_loss, val_acc, val_f1 = self.validate_epoch(val_loader, progress_tracker)

            # Learning rate scheduling
            if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                self.scheduler.step(val_loss)
            else:
                self.scheduler.step()

            # Track metrics
            is_best = val_f1 > self.best_val_f1 * 1.001
            current_lr = self.optimizer.param_groups[0]['lr']
            epoch_time = time.time() - epoch_start_time
            memory_stats = self.resource_manager.get_memory_stats()

            # Update progress tracker
            progress_tracker.finish_epoch(train_loss, train_acc, val_loss, val_acc, val_f1, is_best=is_best, lr=current_lr)

            # Store history
            self.history['train_loss'].append(float(train_loss))
            self.history['train_acc'].append(float(train_acc))
            self.history['val_loss'].append(float(val_loss))
            self.history['val_acc'].append(float(val_acc))
            self.history['val_f1'].append(float(val_f1))
            self.history['learning_rates'].append(float(current_lr))
            self.history['epoch_times'].append(float(epoch_time))
            self.history['memory_usage'].append(memory_stats)

            # Save best model
            if is_best:
                self.best_val_f1 = val_f1
                self.best_val_acc = val_acc
                self.patience_counter = 0
                self._save_best_model(epoch + 1, val_f1, val_acc)
            else:
                self.patience_counter += 1

            # Early stopping check
            if self.patience_counter >= Config.PATIENCE:
                total_time = time.time() - training_start_time
                tqdm.write(f"Early stopping at epoch {epoch + 1}")
                tqdm.write(f"Total training time: {total_time:.1f}s")
                break

        # Final evaluation and save results
        eval_loader = test_loader if test_loader else val_loader
        self._final_evaluation_and_save(eval_loader, training_start_time)

        # Cleanup
        self.resource_manager.aggressive_cleanup()
        return True

    def _save_best_model(self, epoch, val_f1, val_acc):
        """Save best model checkpoint"""
        checkpoint = {
            'model_state_dict': self.model.state_dict(),
            'model_name': self.model_name,
            'hyperparameters': self.hyperparameters,
            'epoch': epoch,
            'best_val_f1': val_f1,
            'best_val_acc': val_acc,
            'optimizer_state': self.optimizer.state_dict(),
            'scheduler_state': self.scheduler.state_dict() if self.scheduler else None,
            'num_classes': Config.NUM_CLASSES,
            'class_names': Config.CLASS_NAMES,
            'save_format_version': '1.0'
        }

        # Save paths
        best_model_dir = f"{Config.OUTPUT_DIR}/best_model"
        os.makedirs(best_model_dir, exist_ok=True)

        save_path = f"{best_model_dir}/{self.model_name}_best.pt"

        try:
            torch.save(checkpoint, save_path)
            tqdm.write(f"‚úÖ Best model saved: F1={val_f1:.4f}, Acc={val_acc:.4f}")
        except Exception as e:
            tqdm.write(f"‚ùå Error saving model {self.model_name}: {e}")

    def _final_evaluation_and_save(self, eval_loader, training_start_time):
        """Final evaluation and comprehensive data saving"""
        # Load best model for evaluation
        best_model_path = f"{Config.OUTPUT_DIR}/best_model/{self.model_name}_best.pt"
        if os.path.exists(best_model_path):
            try:
                checkpoint = torch.load(best_model_path, map_location=Config.DEVICE, weights_only=False)
                if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
                    self.model.load_state_dict(checkpoint['model_state_dict'])
                    tqdm.write(f"‚úÖ Loaded best model for evaluation")
                else:
                    self.model.load_state_dict(checkpoint)
            except Exception as e:
                tqdm.write(f"‚ö†Ô∏è Could not load best model: {e}")

        # Evaluate model
        evaluator = ModelEvaluator()
        evaluation_results = evaluator.evaluate_model(self.model, eval_loader, self.model_name)

        # Create comprehensive training data package
        training_data = {
            'model_info': {
                'model_name': self.model_name,
                'num_classes': Config.NUM_CLASSES,
                'class_names': Config.CLASS_NAMES,
                'total_parameters': sum(p.numel() for p in self.model.parameters()),
                'trainable_parameters': sum(p.numel() for p in self.model.parameters() if p.requires_grad)
            },
            'hyperparameters': self.hyperparameters,
            'training_history': self.history,
            'final_results': evaluation_results,
            'training_metadata': {
                'total_training_time': time.time() - training_start_time,
                'best_epoch': len(self.history['val_f1']) - self.patience_counter if self.patience_counter < Config.PATIENCE else len(self.history['val_f1']),
                'best_val_f1': self.best_val_f1,
                'best_val_acc': self.best_val_acc,
                'early_stopped': self.patience_counter >= Config.PATIENCE,
                'final_epoch': len(self.history['train_loss']),
                'device_used': str(Config.DEVICE),
                'mixed_precision': Config.USE_MIXED_PRECISION
            },
            'save_timestamp': time.time(),
            'config': {
                'batch_size': Config.BATCH_SIZE,
                'epochs': Config.EPOCHS,
                'patience': Config.PATIENCE,
                'learning_rate': Config.LEARNING_RATE,
                'weight_decay': Config.WEIGHT_DECAY
            }
        }

        # Save training data
        results_dir = f"{Config.OUTPUT_DIR}/training_results"
        os.makedirs(results_dir, exist_ok=True)

        # Save as JSON (for easy reading by visualization script)
        json_path = f"{results_dir}/{self.model_name}_training_data.json"
        try:
            with open(json_path, 'w') as f:
                json.dump(training_data, f, indent=2, default=str)
            tqdm.write(f"‚úÖ Training data saved to {json_path}")
        except Exception as e:
            tqdm.write(f"‚ùå Error saving training data: {e}")

        # Also save as pickle for complex objects
        import pickle
        pickle_path = f"{results_dir}/{self.model_name}_training_data.pkl"
        try:
            with open(pickle_path, 'wb') as f:
                pickle.dump(training_data, f)
            tqdm.write(f"‚úÖ Training data saved to {pickle_path}")
        except Exception as e:
            tqdm.write(f"‚ùå Error saving pickle data: {e}")

        # Training summary
        total_training_time = time.time() - training_start_time
        tqdm.write(f"\nTraining Summary for {self.model_name}:")
        tqdm.write(f"  Final Accuracy: {evaluation_results.get('accuracy', 0.0):.4f}")
        tqdm.write(f"  Final F1 Score: {evaluation_results.get('f1_macro', 0.0):.4f}")
        tqdm.write(f"  Best Validation F1: {self.best_val_f1:.4f}")
        tqdm.write(f"  Total Training Time: {total_training_time:.1f}s")
        tqdm.write(f"  Final Epoch: {len(self.history['train_loss'])}/{Config.EPOCHS}")

    def train_kfold(self, train_loader, val_loader, test_loader, n_folds=3):
        """K-fold cross-validation with data saving"""
        if n_folds <= 0:
            tqdm.write(f"Skipping k-fold for {self.model_name}: n_folds <= 0")
            return False

        from torch.utils.data import ConcatDataset
        combined_dataset = ConcatDataset([train_loader.dataset, val_loader.dataset])
        total_samples = len(combined_dataset)
        min_samples_per_fold = 500

        # Adjust folds based on data availability
        if total_samples < n_folds * min_samples_per_fold:
            n_folds = max(1, total_samples // min_samples_per_fold)

        if n_folds < 2:
            tqdm.write(f"Skipping k-fold: need at least {min_samples_per_fold*2} samples")
            return False

        tqdm.write(f"\nK-fold Cross-Validation for {self.model_name} ({n_folds} folds)")

        # Calculate fold indices
        samples_per_fold = total_samples // n_folds
        fold_results = []

        # Create model complexity map for batch size optimization
        model_complexity_map = {
            'efficientnet': 1.5, 'resnet': 1.0, 'vgg': 0.8,
            'mobilenet': 0.6, 'densenet': 1.3, 'convnext': 1.4
        }
        model_complexity = model_complexity_map.get(self.model_name.split('_')[0].lower(), 1.0)
        base_batch_size = self.hyperparameters.get('batch_size', Config.BATCH_SIZE)
        fold_batch_size = self.resource_manager.optimize_batch_size(base_batch_size, model_complexity)

        total_kfold_start = time.time()

        for fold in range(n_folds):
            fold_start_time = time.time()
            tqdm.write(f"\nTraining Fold {fold + 1}/{n_folds}")

            try:
                # Create fold indices
                val_start = fold * samples_per_fold
                val_end = min(val_start + samples_per_fold, total_samples)
                val_idx = list(range(val_start, val_end))
                train_idx = list(range(0, val_start)) + list(range(val_end, total_samples))

                # Create fold data loaders
                train_subsampler = SubsetRandomSampler(train_idx)
                val_subsampler = SubsetRandomSampler(val_idx)

                train_loader_fold = DataLoader(
                    combined_dataset, batch_size=fold_batch_size,
                    sampler=train_subsampler, num_workers=min(8, mp.cpu_count() // 2),
                    pin_memory=torch.cuda.is_available(), prefetch_factor=2
                )

                val_loader_fold = DataLoader(
                    combined_dataset, batch_size=fold_batch_size,
                    sampler=val_subsampler, num_workers=min(8, mp.cpu_count() // 2),
                    pin_memory=torch.cuda.is_available(), prefetch_factor=2
                )

                # Create fold model
                fold_model = ModelFactory.create_model(
                    self.model_name, num_classes=Config.NUM_CLASSES,
                    dropout_rate=self.hyperparameters.get('dropout', 0.5),
                    hidden_dim_multiplier=self.hyperparameters.get('hidden_dim_multiplier', 0.5)
                ).to(Config.DEVICE, memory_format=torch.channels_last)

                # Create fold trainer
                fold_trainer = EnhancedModelTrainer(
                    fold_model, f"{self.model_name}_fold_{fold + 1}", self.hyperparameters
                )

                # Train fold
                fold_success = fold_trainer.train_main_model(train_loader_fold, val_loader_fold)
                fold_time = time.time() - fold_start_time

                if fold_success:
                    # Evaluate fold on test data
                    evaluator = ModelEvaluator()
                    eval_loader = test_loader if test_loader else val_loader_fold
                    fold_results_data = evaluator.evaluate_model(fold_model, eval_loader, f"{self.model_name}_fold_{fold + 1}")

                    # Store fold results
                    fold_result = {
                        'fold_number': fold + 1,
                        'training_time': fold_time,
                        'training_history': fold_trainer.history,
                        'evaluation_results': fold_results_data,
                        'hyperparameters': self.hyperparameters,
                        'fold_indices': {'train': train_idx, 'val': val_idx},
                        'model_info': {
                            'model_name': f"{self.model_name}_fold_{fold + 1}",
                            'total_parameters': sum(p.numel() for p in fold_model.parameters()),
                            'trainable_parameters': sum(p.numel() for p in fold_model.parameters() if p.requires_grad)
                        }
                    }

                    fold_results.append(fold_result)

                    # Save fold model
                    fold_model_dir = f"{Config.OUTPUT_DIR}/kfold_models"
                    os.makedirs(fold_model_dir, exist_ok=True)
                    torch.save(fold_model.state_dict(), f"{fold_model_dir}/{self.model_name}_fold_{fold + 1}.pt")

                    tqdm.write(f"Fold {fold + 1} completed - Acc: {fold_results_data['accuracy']:.4f}, "
                              f"F1: {fold_results_data['f1_macro']:.4f}, Time: {fold_time:.1f}s")
                else:
                    tqdm.write(f"Fold {fold + 1} training failed")

                # Cleanup
                del fold_trainer, fold_model, train_loader_fold, val_loader_fold
                self.resource_manager.aggressive_cleanup()

            except Exception as e:
                tqdm.write(f"Error in fold {fold + 1}: {str(e)}")
                continue

        # Save comprehensive k-fold results
        if fold_results:
            total_kfold_time = time.time() - total_kfold_start

            # Calculate k-fold statistics
            fold_accuracies = [fr['evaluation_results']['accuracy'] for fr in fold_results]
            fold_f1_scores = [fr['evaluation_results']['f1_macro'] for fr in fold_results]

            kfold_summary = {
                'model_name': self.model_name,
                'n_folds': n_folds,
                'successful_folds': len(fold_results),
                'total_kfold_time': total_kfold_time,
                'fold_results': fold_results,
                'summary_statistics': {
                    'mean_accuracy': np.mean(fold_accuracies),
                    'std_accuracy': np.std(fold_accuracies),
                    'mean_f1_macro': np.mean(fold_f1_scores),
                    'std_f1_macro': np.std(fold_f1_scores),
                    'best_fold': {
                        'fold_number': fold_results[np.argmax(fold_f1_scores)]['fold_number'],
                        'accuracy': max(fold_accuracies),
                        'f1_macro': max(fold_f1_scores)
                    },
                    'worst_fold': {
                        'fold_number': fold_results[np.argmin(fold_f1_scores)]['fold_number'],
                        'accuracy': min(fold_accuracies),
                        'f1_macro': min(fold_f1_scores)
                    }
                },
                'hyperparameters': self.hyperparameters,
                'save_timestamp': time.time()
            }

            # Save k-fold results
            kfold_dir = f"{Config.OUTPUT_DIR}/kfold_results"
            os.makedirs(kfold_dir, exist_ok=True)

            # Save as JSON
            json_path = f"{kfold_dir}/{self.model_name}_kfold_results.json"
            with open(json_path, 'w') as f:
                json.dump(kfold_summary, f, indent=2, default=str)

            # Save as pickle
            import pickle
            pickle_path = f"{kfold_dir}/{self.model_name}_kfold_results.pkl"
            with open(pickle_path, 'wb') as f:
                pickle.dump(kfold_summary, f)

            tqdm.write(f"\nK-fold Summary for {self.model_name}:")
            tqdm.write(f"  Successful folds: {len(fold_results)}/{n_folds}")
            tqdm.write(f"  Mean Accuracy: {np.mean(fold_accuracies):.4f} ¬± {np.std(fold_accuracies):.4f}")
            tqdm.write(f"  Mean F1-Score: {np.mean(fold_f1_scores):.4f} ¬± {np.std(fold_f1_scores):.4f}")
            tqdm.write(f"  Total time: {total_kfold_time:.1f}s")
            tqdm.write(f"  Results saved to {json_path}")

        return len(fold_results) > 0

    def cleanup_trainer(self):
        """Complete cleanup of trainer resources"""
        try:
            if hasattr(self, 'model'):
                del self.model
            if hasattr(self, 'optimizer'):
                del self.optimizer
            if hasattr(self, 'scheduler'):
                del self.scheduler
            if hasattr(self, 'criterion'):
                del self.criterion
            if hasattr(self, 'scaler'):
                del self.scaler

            self.history.clear()
            self.resource_manager.aggressive_cleanup()

        except Exception as e:
            tqdm.write(f"Cleanup error: {e}")

# ============================================================================
# Environment Setup
# ============================================================================
def setup_environment():
    """Setup training environment for optimal performance"""
    # Set optimal thread count for CPU utilization
    torch.set_num_threads(min(16, os.cpu_count()))
    os.environ['OMP_NUM_THREADS'] = str(min(16, os.cpu_count()))

    # GPU optimizations
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        torch.backends.cuda.enable_flash_sdp(True)

        gpu_props = torch.cuda.get_device_properties(0)
        print(f"GPU: {gpu_props.name}, Memory: {gpu_props.total_memory / 1024**3:.1f}GB")

    print(f"CPU Cores: {os.cpu_count()}, Using threads: {torch.get_num_threads()}")

# ============================================================================
# Main Training Function
# ============================================================================
def main():
    """Main training function"""
    print("\nStarting Fish Species Model Training...")
    print("="*70)

    # Environment setup
    setup_environment()

    # Create output directories
    directories = [
        f"{Config.OUTPUT_DIR}/models",
        f"{Config.OUTPUT_DIR}/best_model",
        f"{Config.OUTPUT_DIR}/training_results",
        f"{Config.OUTPUT_DIR}/kfold_results",
        f"{Config.OUTPUT_DIR}/kfold_models"
    ]

    for directory in directories:
        os.makedirs(directory, exist_ok=True)

    # Initialize resource manager
    resource_manager = ResourceManager()

    # Load and balance data (REPLACE WITH YOUR IMPLEMENTATION)
    try:
        print("\nLoading and balancing data...")
        X, Y = DataManager.load_and_balance_data()
        print(f"Total samples after balancing: {len(X):,}, Labels: {len(Y):,}")

        # Validate data consistency
        if len(X) != len(Y):
            raise ValueError(f"Inconsistent data: X has {len(X)} samples, Y has {len(Y)} labels")
        if len(X) == 0:
            raise ValueError("No data available after loading and balancing")

    except NotImplementedError:
        print("ERROR: Please implement DataManager.load_and_balance_data() with your actual data loading logic")
        return
    except Exception as e:
        print(f"ERROR loading data: {e}")
        return

    # Load hyperparameters
    hyperparams_file = f"{Config.OUTPUT_DIR}/hyperparameters/all_best_params.json"
    if os.path.exists(hyperparams_file):
        with open(hyperparams_file, 'r') as f:
            all_best_params = json.load(f)
        print(f"Loaded best parameters for {len(all_best_params)} models")
    else:
        print("No hyperparameters found, using default parameters")
        all_best_params = {}

    # Process each model individually
    for model_name in Config.MODELS:
        print(f"\n{'='*70}")
        print(f"TRAINING MODEL: {model_name}")
        print(f"{'='*70}")

        try:
            # Get parameters for this model
            if model_name in all_best_params:
                best_params = all_best_params[model_name]
                print(f"Using optimized parameters for {model_name}")
            else:
                # Default parameters
                best_params = {
                    "dropout": 0.10289132195027265,
                    "label_smoothing": 0.03714841610239749,
                    "lr": 0.004155652374869997,
                    "optimizer_type": "adamw",
                    "scheduler_type": "cosine",
                    "batch_size": 64,
                    "weight_decay": 2.156989662164921e-06
                }
                print(f"Using default parameters for {model_name}")

            # Display parameters
            print(f"\n{model_name.upper()} TRAINING PARAMETERS:")
            for key, value in best_params.items():
                if key in ['lr', 'weight_decay', 'dropout', 'hidden_dim_multiplier', 'label_smoothing']:
                    print(f"  {key}: {value:.6f}")
                else:
                    print(f"  {key}: {value}")

            # Create data loaders
            print(f"\nCreating data loaders for {model_name}...")
            try:
                train_loader, val_loader, test_loader, val_data, test_data = DataManager.create_data_loaders(
                    X, Y, test_size=0.2,
                    batch_size=best_params.get('batch_size', Config.BATCH_SIZE),
                    augmentation_strength=best_params.get('augmentation_strength', 'medium')
                )

                print(f"Data loaders created successfully")
                print(f"Train: {len(train_loader.dataset)}, Val: {len(val_loader.dataset)}, Test: {len(test_loader.dataset)}")

            except NotImplementedError:
                print("ERROR: Please implement DataManager.create_data_loaders() with your actual data loader creation")
                continue
            except Exception as e:
                print(f"ERROR creating data loaders: {e}")
                continue

            # Validate data loaders
            if not train_loader or len(train_loader.dataset) == 0:
                print(f"Skipping {model_name}: No training data available")
                continue
            if not val_loader or len(val_loader.dataset) == 0:
                print(f"Skipping {model_name}: No validation data available")
                continue

            # Create model
            print(f"\nCreating model: {model_name}")
            try:
                model = ModelFactory.create_model(
                    model_name, num_classes=Config.NUM_CLASSES,
                    dropout_rate=best_params.get('dropout', 0.5),
                    hidden_dim_multiplier=best_params.get('hidden_dim_multiplier', 0.5)
                ).to(Config.DEVICE, memory_format=torch.channels_last)

                total_params = sum(p.numel() for p in model.parameters())
                trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
                print(f"Model created: {total_params:,} total params, {trainable_params:,} trainable")

            except NotImplementedError:
                print("ERROR: Please implement ModelFactory.create_model() with your actual model creation")
                continue
            except Exception as e:
                print(f"ERROR creating model: {e}")
                continue

            # Train main model
            print(f"\nStarting main model training for {model_name}")
            trainer = EnhancedModelTrainer(model, model_name, best_params)
            training_success = trainer.train_main_model(train_loader, val_loader, test_loader)

            if not training_success:
                print(f"Main training failed for {model_name}, skipping...")
                continue

            print(f"Main model training completed for {model_name}")

            # K-fold cross-validation
            print(f"\nStarting K-fold cross-validation for {model_name}")
            total_samples = len(train_loader.dataset)
            min_samples_per_fold = 500
            max_folds = total_samples // min_samples_per_fold
            n_folds = min(3, max_folds) if max_folds > 1 else 0

            if n_folds > 1:
                kfold_success = trainer.train_kfold(train_loader, val_loader, test_loader, n_folds=n_folds)
                if kfold_success:
                    print(f"K-fold validation completed for {model_name}")
                else:
                    print(f"K-fold validation failed for {model_name}")
            else:
                print(f"Skipping k-fold validation for {model_name}: insufficient data")

            # Save model for ensemble
            model_ensemble_path = f"{Config.OUTPUT_DIR}/models/{model_name}_for_ensemble.pt"
            torch.save(model.state_dict(), model_ensemble_path)
            print(f"Model saved for ensemble: {model_name}")

            # Cleanup
            trainer.cleanup_trainer()
            del trainer, model, train_loader, val_loader, test_loader
            resource_manager.aggressive_cleanup()

            print(f"‚úÖ {model_name} TRAINING COMPLETED!")

        except Exception as e:
            print(f"Error processing {model_name}: {e}")
            import traceback
            traceback.print_exc()

            # Emergency cleanup
            try:
                if 'trainer' in locals():
                    trainer.cleanup_trainer()
                    del trainer
                if 'model' in locals():
                    del model
                resource_manager.aggressive_cleanup()
            except:
                pass
            continue

    # Final cleanup and summary
    print("\nFinal cleanup and summary...")
    try:
        if 'X' in locals():
            del X
        if 'Y' in locals():
            del Y
        if 'all_best_params' in locals():
            del all_best_params

        resource_manager.aggressive_cleanup()
        final_stats = resource_manager.get_memory_stats()
        print(f"Final GPU memory: {final_stats['gpu_allocated_gb']:.2f}GB ({final_stats['gpu_percent']:.1f}%)")
        print(f"Final CPU usage: {final_stats['cpu_percent']:.1f}%")

    except Exception as e:
        print(f"Error in final cleanup: {e}")

    print("\n" + "="*70)
    print("MODEL TRAINING COMPLETED!")
    print("="*70)
    print("\nGenerated Files:")
    print(f"- Model checkpoints: {Config.OUTPUT_DIR}/models/")
    print(f"- Best models: {Config.OUTPUT_DIR}/best_model/")
    print(f"- Training results: {Config.OUTPUT_DIR}/training_results/")
    print(f"- K-fold results: {Config.OUTPUT_DIR}/kfold_results/")
    print(f"- K-fold models: {Config.OUTPUT_DIR}/kfold_models/")

if __name__ == "__main__":
    main()


Starting Fish Species Model Training...
GPU: NVIDIA L4, Memory: 22.2GB
CPU Cores: 12, Using threads: 12

Loading and balancing data...
Loading and preprocessing data...
Applying SMOTE for class balancing...
Balanced data shape: (15000, 3, 224, 224)
Balanced class distribution: [3000 3000 3000 3000 3000]
Total samples after balancing: 15,000, Labels: 15,000
No hyperparameters found, using default parameters

TRAINING MODEL: resnet50
Using default parameters for resnet50

RESNET50 TRAINING PARAMETERS:
  dropout: 0.102891
  label_smoothing: 0.037148
  lr: 0.004156
  optimizer_type: adamw
  scheduler_type: cosine
  batch_size: 64
  weight_decay: 0.000002

Creating data loaders for resnet50...
Train: 9000, Val: 3000, Test: 3000
Using optimized batch size: 64
Data loaders created successfully
Train: 9000, Val: 3000, Test: 3000

Creating model: resnet50
Model created: 25,613,381 total params, 24,168,453 trainable

Starting main model training for resnet50

Training resnet50
Training samples:

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20.5M/20.5M [00:00<00:00, 144MB/s]


Model created: 4,831,873 total params, 824,325 trainable

Starting main model training for efficientnet_b0

Training efficientnet_b0
Training samples: 9,000
Validation samples: 3,000
Total epochs: 40
Batch size: 64

Epoch 1/40
Training: 9,000 samples, 141 batches, batch_size: 64
  [Train] Batch   50 - Loss: 1.3198, Acc: 0.5625, Time: 39.4s
  [Train] Batch  100 - Loss: 1.3966, Acc: 0.5469, Time: 51.4s
Validation: 3,000 samples, 47 batches, batch_size: 64

Epoch 1/40 Complete üåü NEW BEST!
  Train - Loss: 1.1945, Acc: 0.5756
  Val   - Loss: 0.5008, Acc: 0.8757, F1: 0.8737
  LR: 0.004149
  Epoch Time: 120.1s, Total: 120.1s
‚úÖ Best model saved: F1=0.8737, Acc=0.8757

Epoch 2/40
Training: 9,000 samples, 141 batches, batch_size: 64
  [Train] Batch   50 - Loss: 1.1365, Acc: 0.5469, Time: 13.3s
  [Train] Batch  100 - Loss: 1.1099, Acc: 0.6094, Time: 26.1s
Validation: 3,000 samples, 47 batches, batch_size: 64

Epoch 2/40 Complete 
  Train - Loss: 1.0558, Acc: 0.6268
  Val   - Loss: 0.4990, Ac

# ‚úÖ Step 10: Evaluation and Plot's

# üåç Real-World Data Test üîéüìä


In [None]:
# Add this import at the top
# from your_model_module import ModelFactory  # Replace with your actual import

# Simple Image Predictor for Google Colab
import torch
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import requests
from io import BytesIO
from google.colab import files
import os

# Setup - Replace these with your values
CLASS_NAMES = ['Class1', 'Class2', 'Class3', 'Class4', 'Class5']  # Replace with your 5 classes
MODEL_PATH = "/content/output/best_model/your_model_name_best.pt"  # Replace with your model path

# Image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load model
def load_model():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Load checkpoint
    checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=False)

    # Get model info from checkpoint
    model_name = checkpoint['model_name']
    hyperparameters = checkpoint.get('hyperparameters', {})

    # Create model using your ModelFactory
    model = ModelFactory.create_model(
        model_name,
        params=hyperparameters,
        num_classes=len(CLASS_NAMES)
    )

    # Load the trained weights
    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)

    print(f"Loaded model: {model_name}")
    return model, device

# Predict function
def predict_image(image_path_or_pil):
    model, device = load_model()
    if model is None:
        return

    # Load and preprocess image
    if isinstance(image_path_or_pil, str):
        image = Image.open(image_path_or_pil)
    else:
        image = image_path_or_pil

    if image.mode != 'RGB':
        image = image.convert('RGB')

    input_tensor = transform(image).unsqueeze(0).to(device)

    # Make prediction
    model.eval()
    with torch.no_grad():
        outputs = model(input_tensor)
        probabilities = torch.softmax(outputs, dim=1)
        confidence, predicted_idx = torch.max(probabilities, 1)

    predicted_class = CLASS_NAMES[predicted_idx.item()]
    confidence_score = confidence.item()

    # Show results
    plt.figure(figsize=(10, 4))

    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title('Input Image')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    probs = probabilities[0].cpu().numpy()
    colors = ['green' if i == predicted_idx else 'skyblue' for i in range(len(CLASS_NAMES))]
    plt.bar(CLASS_NAMES, probs, color=colors)
    plt.title('Predictions')
    plt.xticks(rotation=45)

    plt.tight_layout()
    plt.show()

    print(f"üéØ Predicted: {predicted_class}")
    print(f"üìä Confidence: {confidence_score:.2%}")

    return predicted_class, confidence_score

# Predict from URL
def predict_from_url(image_url):
    try:
        response = requests.get(image_url)
        image = Image.open(BytesIO(response.content))
        return predict_image(image)
    except Exception as e:
        print(f"Error: {e}")

# Predict from uploaded file
def predict_uploaded():
    print("Upload an image file:")
    uploaded = files.upload()

    if uploaded:
        filename = list(uploaded.keys())[0]
        result = predict_image(filename)
        os.remove(filename)  # Clean up
        return result
    else:
        print("No file uploaded")

# Example usage:
print("üöÄ Simple Image Predictor Ready!")
print("\nHow to use:")
print("1. First, replace MODEL_PATH and CLASS_NAMES above")
print("2. Fix the load_model() function with your actual model")
print("3. Then run:")
print("   predict_uploaded()  # To upload image")
print("   predict_from_url('http://example.com/image.jpg')  # To predict from URL")


# import torch
# import torch.nn as nn
# from torchvision import transforms, models
# from PIL import Image
# import requests
# from io import BytesIO

# # ---------------------------
# # CONFIG
# # ---------------------------
# MODEL_PATH = "best_model.pth"   # ‡¶§‡ßã‡¶Æ‡¶æ‡¶∞ trained model file
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# # ‡¶§‡ßã‡¶Æ‡¶æ‡¶∞ dataset ‡¶è‡¶∞ class ‡¶≤‡¶ø‡¶∏‡ßç‡¶ü (‡¶®‡¶ø‡¶ú‡ßá‡¶∞ dataset ‡¶Ö‡¶®‡ßÅ‡¶Ø‡¶æ‡ßü‡ßÄ ‡¶¨‡¶¶‡¶≤‡¶æ‡¶¨‡ßá)
# CLASS_NAMES = ["ilish", "chandana", "sardin", "sardinella", "punctatus"]

# # ---------------------------
# # Load Model
# # ---------------------------
# def load_model():
#     model = models.resnet50(weights=None)   # ‡¶§‡ßÅ‡¶Æ‡¶ø ‡¶Ø‡ßá‡¶ü‡¶æ ‡¶¨‡ßç‡¶Ø‡¶¨‡¶π‡¶æ‡¶∞ ‡¶ï‡¶∞‡ßá‡¶õ‡ßã ‡¶∏‡ßá‡¶ü‡¶æ ‡¶¨‡¶∏‡¶æ‡¶ì
#     num_features = model.fc.in_features
#     model.fc = nn.Linear(num_features, len(CLASS_NAMES))

#     checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
#     model.load_state_dict(checkpoint["model_state_dict"])  # ‡¶§‡ßã‡¶Æ‡¶æ‡¶∞ save format ‡¶Ö‡¶®‡ßÅ‡¶Ø‡¶æ‡ßü‡ßÄ adjust ‡¶ï‡¶∞‡ßã
#     model.to(DEVICE)
#     model.eval()
#     return model

# # ---------------------------
# # Preprocess
# # ---------------------------
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),    # training ‡¶∏‡¶Æ‡ßü ‡¶Ø‡¶æ ‡¶¶‡¶ø‡ßü‡ßá‡¶õ‡ßã, ‡¶∏‡ßá‡¶ü‡¶æ ‡¶Æ‡ßá‡¶≤‡¶æ‡¶§‡ßá ‡¶π‡¶¨‡ßá
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406],
#                          [0.229, 0.224, 0.225])
# ])

# def load_image(img_path=None, img_url=None):
#     if img_path:
#         image = Image.open(img_path).convert("RGB")
#     elif img_url:
#         response = requests.get(img_url)
#         image = Image.open(BytesIO(response.content)).convert("RGB")
#     else:
#         raise ValueError("Provide either img_path or img_url")
#     return image

# # ---------------------------
# # Prediction
# # ---------------------------
# def predict_image(model, image):
#     img_tensor = transform(image).unsqueeze(0).to(DEVICE)
#     with torch.no_grad():
#         outputs = model(img_tensor)
#         probs = torch.softmax(outputs, dim=1)
#         conf, pred = torch.max(probs, 1)
#     return CLASS_NAMES[pred.item()], conf.item()

# # ---------------------------
# # Main
# # ---------------------------
# if __name__ == "__main__":
#     model = load_model()

#     # Option 1: ‡¶≤‡ßã‡¶ï‡¶æ‡¶≤ ‡¶´‡¶æ‡¶á‡¶≤ ‡¶•‡ßá‡¶ï‡ßá
#     img_path = "test_fish.jpg"
#     image = load_image(img_path=img_path)
#     label, confidence = predict_image(model, image)
#     print(f"Prediction: {label} ({confidence:.2f})")

#     # Option 2: URL ‡¶•‡ßá‡¶ï‡ßá
#     img_url = "https://example.com/sample_fish.jpg"
#     image = load_image(img_url=img_url)
#     label, confidence = predict_image(model, image)
#     print(f"Prediction: {label} ({confidence:.2f})")


#End