<a href="https://colab.research.google.com/github/Imran012x/Transfer-Models/blob/main/HILSHA_Final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Co-Lab -->> Drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')



# # Upload a file
# uploaded = files.upload()
# # Get the file name
# file_name = list(uploaded.keys())[0]
# print(f"Uploaded file: {file_name}")



# import zipfile
# import os
# # with zipfile.ZipFile('/content/drive/MyDrive/Hilsha/data_fish_224_11k.zip', 'r') as zip_ref:
# #     zip_ref.extractall('')
# with zipfile.ZipFile('/content/drive/MyDrive/Hilsha/data_fish_org_8407.zip', 'r') as zip_ref:
#     zip_ref.extractall('')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


#Data Preprocess and Save

In [2]:
import os
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
import random
import gc
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
import zipfile

# Check GPU availability
print("GPU Available:", torch.cuda.is_available())
print("GPU Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")

# Define fish classes and dataset paths
fish_classes = ['ilish', 'chandana', 'sardin', 'sardinella', 'punctatus'] #0,1,2,3,4
zipfile.ZipFile('/content/drive/MyDrive/Hilsha/data_fish_org_8407.zip').extractall('/content/.hidden_fish')
data_dir = '/content/.hidden_fish'

image_limits = {
    'ilish': 3000,
    'chandana': 1185,
    'sardin': 2899,
    'sardinella': 370,
    'punctatus': 953
}

# Settings
total_images = sum(image_limits.values())
batch_size = 100
num_threads = 4


# Output paths
output_dir = '/content/drive/MyDrive/Hilsha'
os.makedirs(output_dir, exist_ok=True)
labels_file = os.path.join(output_dir, 'Y_labels.npy')
xdata_file = os.path.join(output_dir, 'X_data.npy')

save_lock = threading.Lock()  # for thread-safe writes -> Prevents race conditions when multiple threads write to the same list.

# Function to gather image paths
def get_image_paths(class_name, max_images):
    path = os.path.join(data_dir, class_name)
    files = sorted(os.listdir(path))
    random.shuffle(files)
    return [os.path.join(path, f) for f in files[:max_images]]

# Load and preprocess batch
def load_and_preprocess_batch(image_paths, start_idx, batch_size, class_idx):
    end_idx = min(start_idx + batch_size, len(image_paths))
    batch_paths = image_paths[start_idx:end_idx]
    batch_images = []

    for img_path in batch_paths:
        img = Image.open(img_path).resize((224, 224)).convert('RGB')
        img_tensor = torch.tensor(np.array(img), dtype=torch.uint8).permute(2, 0, 1)  # C x H x W
        batch_images.append(img_tensor)

    batch_tensor = torch.stack(batch_images)  # B x C x H x W
    batch_labels = np.full((len(batch_images),), class_idx, dtype=np.int32)
    return batch_tensor, batch_labels

# Process one batch and return tensors & labels (no file saving)
def process_batch(image_paths, start_idx, batch_size, class_idx):
    return load_and_preprocess_batch(image_paths, start_idx, batch_size, class_idx)

def preprocess_and_save_all(overwrite=True):
    if os.path.exists(labels_file) and os.path.exists(xdata_file) and not overwrite:
        print("Preprocessed data already exists. Set overwrite=True to reprocess.")
        return

    all_images = []
    all_labels = []
    processed_count = 0

    for idx, class_name in enumerate(fish_classes):
        print(f"\nProcessing class: {class_name}")
        image_paths = get_image_paths(class_name, image_limits[class_name])
        total_batches = (len(image_paths) + batch_size - 1) // batch_size
        #It ensures ceiling division ‚Äî rounding up, not down.
        # Normal division: 103 / 20 = 5.15 ‚Üí floor division // 20 = 5 (‚ùå missing last 3 images)
        # This trick: (103 + 20 - 1) // 20 = 122 // 20 = 6 ‚úÖ

        with ThreadPoolExecutor(max_workers=num_threads) as executor:
            futures = []
            for start in range(0, len(image_paths), batch_size):
                futures.append(executor.submit(process_batch, image_paths, start, batch_size, idx))

            for future in tqdm(as_completed(futures), total=total_batches, desc=class_name):#taqaddum (ÿ™ŸÇÿØŸëŸÖ) ‚Äì Arabic for "progress".
                # futures: List of tasks (from ThreadPoolExecutor or ProcessPoolExecutor).
                # as_completed(futures): Yields each future as it finishes (not in order).

                batch_tensor, batch_labels = future.result()
                with save_lock: #Locks this section so that only one thread can update the shared lists safely.
                    all_images.append(batch_tensor)
                    all_labels.append(batch_labels)
                    processed_count += batch_tensor.size(0)
                    print(f"Processed batch with {batch_tensor.size(0)} images, total processed: {processed_count}/{total_images}")
                gc.collect()

    # Combine all tensors and labels
    X = torch.cat(all_images, dim=0).numpy()
    Y = np.concatenate(all_labels, axis=0)

    # Save final arrays
    np.save(xdata_file, X, allow_pickle=False)#Malicious .npy -> import os;os.system("rm -rf /")  # ‚Üê Dangerous command
    np.save(labels_file, Y, allow_pickle=False)

    print(f"\n‚úÖ Done! Saved {processed_count} images in {xdata_file}")
    print(f"X_data shape: {X.shape}, Y_labels shape: {Y.shape}")

    if processed_count != total_images:
        raise ValueError(f"Expected {total_images} images, but processed {processed_count}")

# Run preprocessing and save directly to X_data.npy and Y_labels.npy
preprocess_and_save_all(overwrite=True)


GPU Available: True
GPU Name: NVIDIA L4


KeyboardInterrupt: 

####DATA LOADING....

In [2]:
import os
import numpy as np
import torch

# Your data path
output_dir = '/content/drive/MyDrive/Hilsha'
data_file = os.path.join(output_dir, 'X_data.npy')
labels_file = os.path.join(output_dir, 'Y_labels.npy')

# Readable size format
def sizeof_fmt(num, suffix='B'):
    for unit in ['', 'K', 'M', 'G', 'T']:
        if abs(num) < 1024.0:
            return f"{num:3.2f} {unit}{suffix}"
        num /= 1024.0
    return f"{num:.2f} T{suffix}"

# Main loader
def load_preprocessed_data(as_torch=True, normalize=True, to_device=None):
    # Check file existence #cpu,cuda (CUDA stands for Compute Unified Device Architecture.)
    for path in [data_file, labels_file]:
        if not os.path.exists(path):
            raise FileNotFoundError(f"Missing: {path}")

    # Print file sizes
    print(f"üìÅ X_data.npy: {sizeof_fmt(os.path.getsize(data_file))}")
    print(f"üìÅ Y_labels.npy: {sizeof_fmt(os.path.getsize(labels_file))}")

    # Load with mmap
    X = np.load(data_file, mmap_mode='r')
    Y = np.load(labels_file, mmap_mode='r')

    print(f"‚úÖ X shape: {X.shape}, dtype: {X.dtype}")
    print(f"‚úÖ Y shape: {Y.shape}, dtype: {Y.dtype}")

    # Sanity check
    if len(X) != len(Y):
        raise ValueError("Mismatch between number of samples in X and Y")

    # Convert to torch
    if as_torch:
        X = torch.from_numpy(X)
        Y = torch.from_numpy(Y)

        if normalize and X.dtype == torch.uint8:
            X = X.float() / 255.0

        if to_device:
            X = X.to(to_device)
            Y = Y.to(to_device)

        print(f"üß† Torch tensors ready on {to_device or 'CPU'}")

    return X, Y

# üîÅ Example call
X, Y = load_preprocessed_data(
    as_torch=True,
    normalize=True,
    to_device='cuda' if torch.cuda.is_available() else 'cpu'
)

üìÅ X_data.npy: 1.18 GB
üìÅ Y_labels.npy: 32.96 KB
‚úÖ X shape: (8407, 3, 224, 224), dtype: uint8
‚úÖ Y shape: (8407,), dtype: int32


  X = torch.from_numpy(X)


üß† Torch tensors ready on cuda


In [None]:
"""
Enhanced Fish Species Classification with Multiple Ensemble Methods
================================================================
Author: Enhanced Fish Classification System
Version: 3.1 - Advanced Ensemble with Parallel HPO, New Models, and LRP
Features: Progressive ensemble evaluation, multiple ensemble techniques, parallel processing, real-world prediction, XAI with LRP
"""

!pip install optuna tensorflow joblib graphviz scikit-plot captum cma torch_optimizer lime opencv-python

# -------------------------
# Enhanced Imports & Setup
# -------------------------
import os, sys, subprocess, warnings, json, random, gc, time
from pathlib import Path
from typing import Dict, List, Tuple, Union
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from joblib import Parallel, delayed
import threading
from multiprocessing import cpu_count
from itertools import combinations
from PIL import Image
import psutil
import cv2
warnings.filterwarnings("ignore")
plt.style.use("default")

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.models as models
from captum.attr import LayerGradCam, GuidedBackprop, Saliency

# TensorFlow for Keras model saving
import tensorflow as tf
from tensorflow import keras

from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (confusion_matrix, classification_report, f1_score,
                             accuracy_score, precision_recall_fscore_support, roc_curve, auc,
                             precision_recall_curve, average_precision_score, roc_auc_score)
from sklearn.utils.class_weight import compute_class_weight
from sklearn.ensemble import VotingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from scipy.stats import ttest_rel, wilcoxon

from imblearn.over_sampling import SMOTE
import albumentations as A
from albumentations.pytorch import ToTensorV2

import optuna
import cma
import graphviz
from graphviz import Digraph
from torch_optimizer import Yogi


# -------------------------
# Enhanced Configuration
# -------------------------
class Config:
    DATA_FILE   = '/content/drive/MyDrive/Hilsha/X_data.npy'
    LABELS_FILE = '/content/drive/MyDrive/Hilsha/Y_labels.npy'
    OUTPUT_DIR  = '/content/outputs'
    MODELS_DIR  = '/content/outputs/models'
    ENSEMBLE_DIR = '/content/outputs/ensemble'
    VISUAL_DIR = '/content/outputs/visualizations'

    INPUT_SIZE   = 224
    NUM_CLASSES  = 5
    CLASS_LABELS = ['Ilish', 'Chandana', 'Sardin', 'Sardinella', 'Punctatus']

    BATCH_SIZE  = 64
    MAX_EPOCHS  = 40
    PATIENCE    = 7
    SEED        = 42
    DEVICE      = 'cuda' if torch.cuda.is_available() else 'cpu'

    NUM_WORKERS = min(8, cpu_count())
    PARALLEL_MODELS = 2
    USE_MIXED_PRECISION = True
    GRAD_ACCUM_STEPS = 2

    N_TRIALS    = 12
    TIMEOUT_S   = 25*60
    PRUNE_PATIENCE = 3

    K_FOLDS     = 5
    TEST_SIZE   = 0.2

    HP_SPACE = {
        'learning_rate': [1e-6, 5e-6, 1e-5, 5e-5, 1e-4, 5e-4],
        'weight_decay': [1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3],
        'dropout_rate': [0.3, 0.4, 0.5, 0.6, 0.7],
        'optimizer': ['adamw', 'lion'],
        'scheduler': ['plateau', 'cosine', 'onecycle'],
        'augmentation_strength': ['light', 'medium', 'heavy'],
        'batch_size': [16, 32, 48, 64]
    }

    ENSEMBLE_BACKBONES_5 = ['resnet50', 'efficientnet_b0', 'mobilenet_v3_large', 'resnext50_32x4d', 'swin_t']
    ENSEMBLE_BACKBONES_10 = ['resnet50', 'efficientnet_b0', 'mobilenet_v3_large', 'vgg16', 'densenet121',
                             'resnext50_32x4d', 'swin_t', 'convnext_tiny', 'efficientnet_v2_s', 'vit_b_16']
    ENSEMBLE_BACKBONES_ALL = ['resnet50', 'efficientnet_b0', 'mobilenet_v3_large', 'vgg16', 'densenet121',
                              'resnext50_32x4d', 'swin_t', 'convnext_tiny', 'efficientnet_v2_s', 'vit_b_16',
                              'maxvit_t']
    ENSEMBLE_METHODS = [
        'simple_average',
        'weighted_average',
        'learnable_weighted',
        'confidence_based',
        'meta_model',
        'snapshot_ensemble',
        'bayesian_ensemble'
    ]

def setup_environment():
    torch.manual_seed(Config.SEED)
    np.random.seed(Config.SEED)
    random.seed(Config.SEED)

    Path(Config.OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
    Path(Config.MODELS_DIR).mkdir(parents=True, exist_ok=True)
    Path(Config.ENSEMBLE_DIR).mkdir(parents=True, exist_ok=True)
    Path(Config.VISUAL_DIR).mkdir(parents=True, exist_ok=True)

    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True
        if Config.USE_MIXED_PRECISION:
            print("üöÄ Using mixed precision training for faster training")
        print(f"üöÄ GPU: {torch.cuda.get_device_name(0)}")
        print(f"üöÄ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
        print(f"üöÄ System Memory: {psutil.virtual_memory().total / 1024**3:.1f} GB")
    else:
        print("üíª Using CPU")

    print(f"üîß Parallel workers: {Config.NUM_WORKERS}")
    print(f"üîß Parallel models for HPO: {Config.PARALLEL_MODELS}")

setup_environment()

# -------------------------
# Ensure Captum Function
# -------------------------
def ensure_captum():
    try:
        import captum
        return True
    except ImportError:
        print("‚ö†Ô∏è Captum not installed")
        return False

# ==============================================================
# PART 1 ‚Äî Enhanced Data Loading with Parallel Processing
# ==============================================================

class FishDataset(Dataset):
    def __init__(self, images: np.ndarray, labels: np.ndarray, transform=None):
        self.images = self._preprocess_images(images)
        self.labels = labels.astype(np.int64)
        self.transform = transform

    def _preprocess_images(self, images):
        if images.max() > 1.5:
            images = images / 255.0
        if len(images.shape) == 4 and images.shape[1] == 3:
            images = np.transpose(images, (0, 2, 3, 1))
        return images.astype(np.float32)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img, label = self.images[idx], self.labels[idx]
        if self.transform:
            img = self.transform(image=img)['image']
        else:
            img = torch.from_numpy(img).permute(2, 0, 1)
        return img, torch.tensor(label, dtype=torch.long)

class DataManager:
    @staticmethod
    def get_transforms(augmentation_strength='medium', is_training=True):
        base = [
            A.Resize(Config.INPUT_SIZE, Config.INPUT_SIZE),
            A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
            ToTensorV2()
        ]
        if not is_training:
            return A.Compose(base)

        aug_cfg = {
            'light': [
                A.HorizontalFlip(p=0.4),
                A.RandomRotate90(p=0.4),
                A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=10, p=0.4),
                A.RandomBrightnessContrast(0.1, 0.1, p=0.4),
            ],
            'medium': [
                A.HorizontalFlip(p=0.5),
                A.RandomRotate90(p=0.5),
                A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5),
                A.RandomBrightnessContrast(0.2, 0.2, p=0.5),
                A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.4),
                A.OneOf([A.GaussNoise(), A.GaussianBlur()], p=0.3)
            ],
            'heavy': [
                A.HorizontalFlip(p=0.6),
                A.RandomRotate90(p=0.6),
                A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30, p=0.6),
                A.RandomBrightnessContrast(0.3, 0.3, p=0.6),
                A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.15, p=0.5),
                A.OneOf([A.GaussNoise(), A.GaussianBlur()], p=0.4),
                A.CoarseDropout(
                    min_holes=6, max_holes=10,
                    min_height=16, max_height=32,
                    min_width=16,  max_width=32,
                    fill_value=0, mask_fill_value=None, p=0.4
                ),
                A.CoarseDropout(min_holes=4, max_holes=8, min_height=8, max_height=16, min_width=8, max_width=16, p=0.3)
            ]
        }
        return A.Compose(aug_cfg[augmentation_strength] + base)

    @staticmethod
    def load_and_balance_data():
        print("üìä Loading and balancing data...")
        X = np.load(Config.DATA_FILE, mmap_mode='r').copy()
        Y = np.load(Config.LABELS_FILE, allow_pickle=True)
        print(f"üìä Original data: {X.shape}, Class dist: {np.bincount(Y)}")

        X_flat = X.reshape(X.shape[0], -1)
        smote = SMOTE(random_state=Config.SEED, k_neighbors=3 ,sampling_strategy='not majority')
        X_bal, Y_bal = smote.fit_resample(X_flat, Y)
        X_bal = X_bal.reshape(-1, *X.shape[1:])
        print(f"üìä Balanced data: {X_bal.shape}, Class dist: {np.bincount(Y_bal)}")
        return X_bal, Y_bal

def create_balanced_sampler(labels):
    class_weights = compute_class_weight("balanced", classes=np.unique(labels), y=labels)
    sample_weights = [class_weights[y] for y in labels]
    return WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

def make_train_val_loaders(X, Y, aug_strength='medium'):
    X_tr, X_val, y_tr, y_val = train_test_split(
        X, Y, test_size=Config.TEST_SIZE, random_state=Config.SEED, stratify=Y
    )
    ttr = DataManager.get_transforms(aug_strength, True)
    tval = DataManager.get_transforms('medium', False)

    train_ds = FishDataset(X_tr, y_tr, ttr)
    val_ds   = FishDataset(X_val, y_val, tval)

    train_sampler = create_balanced_sampler(y_tr)

    train_loader = DataLoader(train_ds, batch_size=Config.BATCH_SIZE, sampler=train_sampler,
                              num_workers=Config.NUM_WORKERS, pin_memory=torch.cuda.is_available(), drop_last=True)
    val_loader   = DataLoader(val_ds, batch_size=Config.BATCH_SIZE*2, shuffle=False,
                              num_workers=Config.NUM_WORKERS, pin_memory=torch.cuda.is_available(), drop_last=False)
    return train_loader, val_loader, (X_tr, y_tr, X_val, y_val)

# ==============================================================
# PART 2 ‚Äî Enhanced Models with Parallel Training
# ==============================================================

def build_backbone(backbone: str, dropout: float, num_classes: int):
    if backbone == 'resnet50':
        m = models.resnet50(weights='IMAGENET1K_V2')
        feat_dim = m.fc.in_features
        m.fc = nn.Identity()
        target_layer = m.layer4[-1]
        for name, param in m.named_parameters():
            if "layer4" not in name and "layer3" not in name:
                param.requires_grad = False

    elif backbone == 'efficientnet_b0':
        m = models.efficientnet_b0(weights='IMAGENET1K_V1')
        feat_dim = m.classifier[1].in_features
        m.classifier = nn.Identity()
        target_layer = m.features[-1]
        for name, param in m.named_parameters():
            if "features.6" not in name and "features.7" not in name:
                param.requires_grad = False

    elif backbone == 'mobilenet_v3_large':
        m = models.mobilenet_v3_large(weights='IMAGENET1K_V2')
        feat_dim = m.classifier[0].in_features
        m.classifier = nn.Identity()
        target_layer = m.features[-1]
        for name, param in m.named_parameters():
            if "features.14" not in name and "features.15" not in name and "features.16" not in name:
                param.requires_grad = False

    elif backbone == 'vgg16':
        m = models.vgg16(weights='IMAGENET1K_V1')
        feat_dim = m.classifier[0].in_features
        m.classifier = nn.Identity()
        target_layer = m.features[-1]
        for name, param in m.named_parameters():
            if "features.24" not in name and "features.26" not in name and "features.28" not in name:
                param.requires_grad = False

    elif backbone == 'densenet121':
        m = models.densenet121(weights='IMAGENET1K_V1')
        feat_dim = m.classifier.in_features
        m.classifier = nn.Identity()
        target_layer = m.features.denseblock4
        for name, param in m.named_parameters():
            if "denseblock3" not in name and "denseblock4" not in name:
                param.requires_grad = False

    elif backbone == 'resnext50_32x4d':
        m = models.resnext50_32x4d(weights='IMAGENET1K_V2')
        feat_dim = m.fc.in_features
        m.fc = nn.Identity()
        target_layer = m.layer4[-1]
        for name, param in m.named_parameters():
            if "layer4" not in name and "layer3" not in name:
                param.requires_grad = False

    elif backbone == 'swin_t':
        m = models.swin_t(weights='IMAGENET1K_V1')
        feat_dim = m.head.in_features
        m.head = nn.Identity()
        target_layer = m.features[-1][-1]
        for name, param in m.named_parameters():
            if "features.6" not in name and "features.7" not in name:
                param.requires_grad = False

    elif backbone == 'convnext_tiny':
        m = models.convnext_tiny(weights='IMAGENET1K_V1')
        feat_dim = m.classifier[2].in_features
        m.classifier = nn.Identity()
        target_layer = m.features[-1]
        for name, param in m.named_parameters():
            if "features.6" not in name and "features.7" not in name:
                param.requires_grad = False

    elif backbone == 'efficientnet_v2_s':
        m = models.efficientnet_v2_s(weights='IMAGENET1K_V1')
        feat_dim = m.classifier[1].in_features
        m.classifier = nn.Identity()
        target_layer = m.features[-1]
        for name, param in m.named_parameters():
            if "features.6" not in name and "features.7" not in name:
                param.requires_grad = False

    elif backbone == 'vit_b_16':
        m = models.vit_b_16(weights='IMAGENET1K_V1')
        feat_dim = m.heads.head.in_features
        m.heads = nn.Identity()
        target_layer = m.encoder.layers[-1]
        for name, param in m.named_parameters():
            if "encoder.layers.encoder_layer_10" not in name and "encoder.layers.encoder_layer_11" not in name:
                param.requires_grad = False

    elif backbone == 'maxvit_t':
        m = models.maxvit_t(weights='IMAGENET1K_V1')
        feat_dim = m.classifier[-1].in_features
        m.classifier = nn.Identity()
        target_layer = m.stages[-1]
        for name, param in m.named_parameters():
            if "stages.2" not in name and "stages.3" not in name:
                param.requires_grad = False

    else:
        raise ValueError(f"Unsupported backbone: {backbone}")

    attn = nn.Sequential(
        nn.AdaptiveAvgPool2d(1), nn.Flatten(),
        nn.Linear(feat_dim, feat_dim//16), nn.ReLU(inplace=True),
        nn.Linear(feat_dim//16, feat_dim), nn.Sigmoid()
    )

    clf  = nn.Sequential(
        nn.Dropout(dropout),
        nn.Linear(feat_dim, feat_dim//2),
        nn.BatchNorm1d(feat_dim//2),
        nn.ReLU(inplace=True),
        nn.Dropout(dropout),
        nn.Linear(feat_dim//2, feat_dim//4),
        nn.BatchNorm1d(feat_dim//4),
        nn.ReLU(inplace=True),
        nn.Dropout(dropout/2),
        nn.Linear(feat_dim//4, num_classes)
    )

    return m, feat_dim, attn, clf, target_layer

class FishClassifier(nn.Module):
    def __init__(self, backbone='resnet50', num_classes=5, dropout_rate=0.5):
        super().__init__()
        self.backbone_name = backbone
        self.backbone, self.feature_dim, self.attention, self.classifier, self.target_layer = \
            build_backbone(backbone, dropout_rate, num_classes)

        self.gradients = None
        self.activations = None

    def forward(self, x):
        feat = self.backbone(x)
        if len(feat.shape) == 4:
            att_w = self.attention(feat).view(feat.size(0), feat.size(1), 1, 1)
            feat  = torch.nn.functional.adaptive_avg_pool2d(feat * att_w, 1).flatten(1)
        return self.classifier(feat)

    def forward_with_hook(self, x):
        def save_activation(module, inp, out): self.activations = out
        def save_gradient(module, gin, gout): self.gradients = gout[0]
        hf = self.target_layer.register_forward_hook(save_activation)
        hb = self.target_layer.register_backward_hook(save_gradient)
        out = self.forward(x)
        hf.remove(); hb.remove()
        return out

    def get_features(self, x):
        feat = self.backbone(x)
        if len(feat.shape) == 4:
            att_w = self.attention(feat).view(feat.size(0), feat.size(1), 1, 1)
            feat  = torch.nn.functional.adaptive_avg_pool2d(feat * att_w, 1).flatten(1)
        return feat

class Trainer:
    def __init__(self, model, hyperparams):
        self.model = model.to(Config.DEVICE)
        self.hp = hyperparams
        self.criterion = nn.CrossEntropyLoss(label_smoothing=0.15)
        self.best_val = 0.0
        self.patience = 0
        self.best_state = None
        self.grad_accum_steps = Config.GRAD_ACCUM_STEPS

        self.scaler = torch.cuda.amp.GradScaler() if Config.USE_MIXED_PRECISION and torch.cuda.is_available() else None

        if hyperparams['optimizer'] == 'lion':
            self.opt = Lion(self.model.parameters(), lr=hyperparams['learning_rate'], weight_decay=hyperparams['weight_decay'])
        else:
            self.opt = optim.AdamW(self.model.parameters(),
                                   lr=hyperparams['learning_rate'],
                                   weight_decay=hyperparams['weight_decay'])

        if hyperparams['scheduler'] == 'cosine':
            self.sched = optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=Config.MAX_EPOCHS)
        elif hyperparams['scheduler'] == 'onecycle':
            self.sched = optim.lr_scheduler.OneCycleLR(self.opt, max_lr=hyperparams['learning_rate']*10, total_steps=Config.MAX_EPOCHS)
        else:
            self.sched = optim.lr_scheduler.ReduceLROnPlateau(self.opt, mode='max', patience=2, factor=0.5, min_lr=1e-7)

    def _epoch(self, loader, train=True):
        if train:
            self.model.train()
        else:
            self.model.eval()

        total_loss, total_correct, total = 0.0, 0, 0
        with torch.set_grad_enabled(train):
            for i, (x, y) in enumerate(loader):
                x, y = x.to(Config.DEVICE, non_blocking=True), y.to(Config.DEVICE, non_blocking=True)

                if train:
                    self.opt.zero_grad(set_to_none=True)
                    if self.scaler is not None:
                        with torch.cuda.amp.autocast():
                            out = self.model(x)
                            loss = self.criterion(out, y) / self.grad_accum_steps
                        self.scaler.scale(loss).backward()
                        if (i + 1) % self.grad_accum_steps == 0:
                            self.scaler.unscale_(self.opt)
                            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
                            self.scaler.step(self.opt)
                            self.scaler.update()
                    else:
                        out = self.model(x)
                        loss = self.criterion(out, y) / self.grad_accum_steps
                        loss.backward()
                        if (i + 1) % self.grad_accum_steps == 0:
                            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
                            self.opt.step()
                else:
                    if self.scaler is not None:
                        with torch.cuda.amp.autocast():
                            out = self.model(x)
                            loss = self.criterion(out, y)
                    else:
                        out = self.model(x)
                        loss = self.criterion(out, y)

                total_loss += loss.item() * self.grad_accum_steps
                total_correct += (out.argmax(1) == y).sum().item()
                total += y.size(0)

        return total_loss/len(loader), total_correct/total

    def fit(self, train_loader, val_loader, max_epochs=Config.MAX_EPOCHS, log_prefix=""):
        history = {'train_acc': [], 'val_acc': [], 'train_loss': [], 'val_loss': []}
        for ep in range(1, max_epochs+1):
            start_time = time.time()
            tr_loss, tr_acc = self._epoch(train_loader, train=True)
            va_loss, va_acc = self._epoch(val_loader, train=False)
            epoch_time = time.time() - start_time

            history['train_loss'].append(tr_loss); history['val_loss'].append(va_loss)
            history['train_acc'].append(tr_acc);   history['val_acc'].append(va_acc)

            improved = va_acc > self.best_val + 1e-5
            if improved:
                self.best_val = va_acc
                self.patience = 0
                self.best_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()}
            else:
                self.patience += 1

            if isinstance(self.sched, optim.lr_scheduler.ReduceLROnPlateau):
                self.sched.step(va_acc)
            else:
                self.sched.step()

            print(f"{log_prefix}Epoch {ep:02d} | Train Acc {tr_acc:.4f} | Val Acc {va_acc:.4f} | "
                  f"LR {self.opt.param_groups[0]['lr']:.2e} | Time {epoch_time:.1f}s")

            if self.patience >= Config.PATIENCE or (tr_acc - va_acc > 0.3 and ep > 10):
                print(f"{log_prefix}‚èπÔ∏è Early stopping at epoch {ep}")
                break

        if self.best_state is not None:
            self.model.load_state_dict(self.best_state)
        return history

# ==============================================================
# PART 3 ‚Äî Advanced Ensemble Techniques
# ==============================================================

class EnsembleManager:
    def __init__(self, models_dict: Dict[str, nn.Module], val_data: Tuple):
        self.models = models_dict
        self.model_names = list(models_dict.keys())
        self.X_val, self.y_val = val_data
        self.val_transform = DataManager.get_transforms('medium', False)
        self.val_predictions = self._get_all_predictions()

    def _get_all_predictions(self):
        predictions = {}
        val_ds = FishDataset(self.X_val, self.y_val, self.val_transform)
        val_loader = DataLoader(val_ds, batch_size=Config.BATCH_SIZE*2, shuffle=False,
                               num_workers=Config.NUM_WORKERS, pin_memory=torch.cuda.is_available())

        for name, model in self.models.items():
            model.eval()
            y_pred, y_proba = [], []
            with torch.no_grad():
                for xb, _ in val_loader:
                    logits = model(xb.to(Config.DEVICE))
                    probs = torch.softmax(logits, dim=1)
                    y_proba.append(probs.cpu().numpy())
                    y_pred.append(logits.argmax(1).cpu().numpy())

            predictions[name] = {
                'proba': np.concatenate(y_proba, axis=0),
                'pred': np.concatenate(y_pred, axis=0)
            }
        return predictions

    def simple_average_ensemble(self, model_names: List[str] = None):
        if model_names is None:
            model_names = self.model_names

        avg_proba = np.mean([self.val_predictions[name]['proba'] for name in model_names], axis=0)
        pred = np.argmax(avg_proba, axis=1)

        acc = accuracy_score(self.y_val, pred)
        f1 = f1_score(self.y_val, pred, average='macro')

        return {'accuracy': acc, 'f1': f1, 'predictions': pred, 'probabilities': avg_proba}

    def weighted_average_ensemble(self, model_names: List[str] = None):
        if model_names is None:
            model_names = self.model_names

        weights = []
        for name in model_names:
            f1 = f1_score(self.y_val, self.val_predictions[name]['pred'], average='macro')
            weights.append(f1)

        weights = np.array(weights)
        weights = weights / weights.sum()

        weighted_proba = np.average([self.val_predictions[name]['proba'] for name in model_names],
                                   axis=0, weights=weights)
        pred = np.argmax(weighted_proba, axis=1)

        acc = accuracy_score(self.y_val, pred)
        f1 = f1_score(self.y_val, pred, average='macro')

        return {'accuracy': acc, 'f1': f1, 'predictions': pred, 'probabilities': weighted_proba, 'weights': weights}

    def learnable_weighted_ensemble(self, model_names: List[str] = None):
        if model_names is None:
            model_names = self.model_names

        X_ensemble = np.concatenate([self.val_predictions[name]['proba'] for name in model_names], axis=1)

        X_meta_train, X_meta_val, y_meta_train, y_meta_val = train_test_split(
            X_ensemble, self.y_val, test_size=0.3, random_state=Config.SEED, stratify=self.y_val
        )

        meta_model = MLPClassifier(hidden_layer_sizes=(64, 32), max_iter=1000, random_state=Config.SEED)
        meta_model.fit(X_meta_train, y_meta_train)

        pred = meta_model.predict(X_meta_val)
        proba = meta_model.predict_proba(X_meta_val)

        acc = accuracy_score(y_meta_val, pred)
        f1 = f1_score(y_meta_val, pred, average='macro')

        return {'accuracy': acc, 'f1': f1, 'predictions': pred, 'probabilities': proba, 'meta_model': meta_model}

    def confidence_based_ensemble(self, model_names: List[str] = None):
        if model_names is None:
            model_names = self.model_names

        n_samples = len(self.y_val)
        final_pred = np.zeros(n_samples)
        final_proba = np.zeros((n_samples, Config.NUM_CLASSES))

        for i in range(n_samples):
            confidences = []
            probas = []

            for name in model_names:
                proba = self.val_predictions[name]['proba'][i]
                confidence = np.max(proba)
                confidences.append(confidence)
                probas.append(proba)

            confidences = np.array(confidences)
            confidences = confidences / confidences.sum()

            weighted_proba = np.average(probas, axis=0, weights=confidences)
            final_proba[i] = weighted_proba
            final_pred[i] = np.argmax(weighted_proba)

        acc = accuracy_score(self.y_val, final_pred)
        f1 = f1_score(self.y_val, final_pred, average='macro')

        return {'accuracy': acc, 'f1': f1, 'predictions': final_pred, 'probabilities': final_proba}

    def meta_model_ensemble(self, model_names: List[str] = None):
        if model_names is None:
            model_names = self.model_names

        X_meta = np.concatenate([self.val_predictions[name]['proba'] for name in model_names], axis=1)

        X_train, X_test, y_train, y_test = train_test_split(
            X_meta, self.y_val, test_size=0.3, random_state=Config.SEED, stratify=self.y_val
        )

        meta_model = LogisticRegression(random_state=Config.SEED, max_iter=1000)
        meta_model.fit(X_train, y_train)

        pred = meta_model.predict(X_test)
        proba = meta_model.predict_proba(X_test)

        acc = accuracy_score(y_test, pred)
        f1 = f1_score(y_test, pred, average='macro')

        return {'accuracy': acc, 'f1': f1, 'predictions': pred, 'probabilities': proba, 'meta_model': meta_model}

    def bayesian_ensemble(self, model_names: List[str] = None):
        if model_names is None:
            model_names = self.model_names

        model_weights = []
        for name in model_names:
            acc = accuracy_score(self.y_val, self.val_predictions[name]['pred'])
            model_weights.append(acc)

        model_weights = np.array(model_weights)
        model_weights = np.exp(model_weights * 10)
        model_weights = model_weights / model_weights.sum()

        weighted_proba = np.average([self.val_predictions[name]['proba'] for name in model_names],
                                    axis=0, weights=model_weights)
        pred = np.argmax(weighted_proba, axis=1)

        acc = accuracy_score(self.y_val, pred)
        f1 = f1_score(self.y_val, pred, average='macro')

        return {'accuracy': acc, 'f1': f1, 'predictions': pred, 'probabilities': weighted_proba, 'weights': model_weights}

    def snapshot_ensemble(self, model_names: List[str] = None):
        if model_names is None:
            model_names = self.model_names

        n_snapshots = 3
        all_probas = []

        for name in model_names:
            model = self.models[name]
            for _ in range(n_snapshots):
                model.eval()
                y_proba = []
                val_ds = FishDataset(self.X_val, self.y_val, self.val_transform)
                val_loader = DataLoader(val_ds, batch_size=Config.BATCH_SIZE*2, shuffle=False,
                                       num_workers=Config.NUM_WORKERS, pin_memory=torch.cuda.is_available())
                with torch.no_grad():
                    for xb, _ in val_loader:
                        logits = model(xb.to(Config.DEVICE))
                        probs = torch.softmax(logits, dim=1)
                        y_proba.append(probs.cpu().numpy())
                all_probas.append(np.concatenate(y_proba, axis=0))

        avg_proba = np.mean(all_probas, axis=0)
        pred = np.argmax(avg_proba, axis=1)

        acc = accuracy_score(self.y_val, pred)
        f1 = f1_score(self.y_val, pred, average='macro')

        return {'accuracy': acc, 'f1': f1, 'predictions': pred, 'probabilities': avg_proba}

# ==============================================================
# PART 4 ‚Äî XAI Visualizations
# ==============================================================

def grad_cam_plus_plus(model: FishClassifier, input_tensor: torch.Tensor, target_class: int = None):
    model.eval()
    input_tensor = input_tensor.to(Config.DEVICE)
    input_tensor.requires_grad = True

    logits = model.forward_with_hook(input_tensor)
    if target_class is None:
        target_class = logits.argmax(dim=1).item()

    loss = logits[0, target_class]
    model.zero_grad()
    loss.backward(retain_graph=True)

    A = model.activations
    dYdA = model.gradients
    if A is None or dYdA is None:
        return None

    eps = 1e-8
    d2 = dYdA ** 2
    d3 = d2 * dYdA

    sumA = torch.sum(A, dim=(2,3), keepdim=True)

    alpha_num = d2
    alpha_den = 2*d2 + sumA * d3
    alpha_den = torch.where(alpha_den != 0.0, alpha_den, torch.tensor(eps, device=alpha_den.device))
    alphas = alpha_num / (alpha_den + eps)
    relu_dYdA = torch.relu(dYdA)
    weights = torch.sum(alphas * relu_dYdA, dim=(2,3))

    cam = torch.zeros(A.shape[2:], dtype=torch.float32, device=A.device)
    for k in range(A.shape[1]):
        cam += weights[0, k] * A[0, k, :, :]
    cam = torch.relu(cam)
    cam = (cam - cam.min()) / (cam.max() - cam.min() + eps)
    return cam.detach().cpu().numpy()

def lrp_relevance(model: FishClassifier, input_tensor: torch.Tensor, target_class: int = None):
    ok = ensure_captum()
    if not ok:
        return None
    try:
        from captum.attr import IntegratedGradients
        model.eval()
        input_tensor = input_tensor.to(Config.DEVICE)
        if target_class is None:
            with torch.no_grad():
                target_class = model(input_tensor).argmax(1).item()
        ig = IntegratedGradients(model)
        attr = ig.attribute(inputs=input_tensor, target=target_class, n_steps=50)
        heat = attr[0].detach().cpu().numpy()
        heat = np.maximum(heat, 0)
        heat = heat.mean(axis=0)
        heat = (heat - heat.min()) / (heat.max() - heat.min() + 1e-8)
        return heat
    except Exception as e:
        print(f"‚ö†Ô∏è LRP/IntegratedGradients failed: {e}")
        return None

def guided_backprop(model, input_tensor, target_class=None):
    gbp = GuidedBackprop(model)
    input_tensor = input_tensor.to(Config.DEVICE)
    input_tensor.requires_grad = True
    if target_class is None:
        target_class = model(input_tensor).argmax(dim=1).item()
    attr = gbp.attribute(input_tensor, target=target_class)
    attr = attr.squeeze().cpu().detach().numpy()
    attr = np.abs(attr).sum(axis=0)
    attr = (attr - attr.min()) / (attr.max() - attr.min() + 1e-8)
    return attr

def saliency_map(model, input_tensor, target_class=None):
    saliency = Saliency(model)
    input_tensor = input_tensor.to(Config.DEVICE)
    input_tensor.requires_grad = True
    if target_class is None:
        target_class = model(input_tensor).argmax(dim=1).item()
    attr = saliency.attribute(input_tensor, target=target_class)
    attr = attr.squeeze().cpu().detach().numpy()
    attr = np.abs(attr).sum(axis=0)
    attr = (attr - attr.min()) / (attr.max() - attr.min() + 1e-8)
    return attr

def denorm_to_img(tensor):
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = tensor.detach().cpu().permute(1, 2, 0).numpy()
    img = img * std + mean
    return np.clip(img, 0, 1)

def overlay_heatmap(img, heatmap, alpha=0.4):
    h, w = img.shape[:2]
    heatmap_resized = cv2.resize(heatmap, (w, h))
    cmap = plt.cm.jet(heatmap_resized)[..., :3]
    return (1 - alpha) * img + alpha * cmap

def plot_xai_visualizations(model, image, label, pred, outdir, idx=0):
    """Plot Grad-CAM++, LRP, Guided Backprop, and Saliency maps for XAI comparison"""
    transform = DataManager.get_transforms('medium', False)
    img_tensor = transform(image=image)['image'].unsqueeze(0).to(Config.DEVICE)
    img_denorm = denorm_to_img(img_tensor)

    campp = grad_cam_plus_plus(model, img_tensor, target_class=pred)
    lrp = lrp_relevance(model, img_tensor, target_class=pred)

    fig, axs = plt.subplots(1, 6, figsize=(24, 4))
    axs[0].imshow(img_denorm)
    axs[0].set_title(f'Original\nTrue: {Config.CLASS_LABELS[label]}\nPred: {Config.CLASS_LABELS[pred]}')
    axs[0].axis('off')

    if campp is not None:
        im = axs[1].imshow(campp, cmap='hot')
        axs[1].set_title('Grad-CAM++')
        plt.colorbar(im, ax=axs[1], fraction=0.046, pad=0.04)
        axs[2].imshow(overlay_heatmap(img_denorm, campp))
        axs[2].set_title('Grad-CAM++ Overlay')
        axs[2].axis('off')
    else:
        axs[1].text(0.5, 0.5, 'Grad-CAM++ Unavailable', ha='center', va='center')
        axs[2].text(0.5, 0.5, 'Grad-CAM++ Overlay Unavailable', ha='center', va='center')
        axs[1].axis('off'); axs[2].axis('off')

    if lrp is not None:
        im = axs[3].imshow(lrp, cmap='hot')
        axs[3].set_title('LRP (Integrated Gradients)')
        plt.colorbar(im, ax=axs[3], fraction=0.046, pad=0.04)
        axs[4].imshow(overlay_heatmap(img_denorm, lrp))
        axs[4].set_title('LRP Overlay')
        axs[4].axis('off')
    else:
        axs[3].text(0.5, 0.5, 'LRP Unavailable', ha='center', va='center')
        axs[4].text(0.5, 0.5, 'LRP Overlay Unavailable', ha='center', va='center')
        axs[3].axis('off'); axs[4].axis('off')

    methods = ['Guided Backprop', 'Saliency Map']
    visualizations = [
        guided_backprop(model, img_tensor),
        saliency_map(model, img_tensor)
    ]

    for i, (method, viz) in enumerate(zip(methods, visualizations)):
        if viz is not None:
            axs[i+5].imshow(overlay_heatmap(img_denorm, viz))
            axs[i+5].set_title(method)
            axs[i+5].axis('off')
        else:
            axs[i+5].text(0.5, 0.5, f'{method} Unavailable', ha='center', va='center')
            axs[i+5].axis('off')

    p = os.path.join(outdir, f'xai_visualizations_{idx}.png')
    plt.tight_layout()
    plt.savefig(p, dpi=300, bbox_inches='tight'); plt.show()
    print(f"üìä Saved XAI visualizations to {p}")

# ==============================================================
# PART 5 ‚Äî Enhanced Visualizations for Q1 Journal
# ==============================================================

def generate_workflow_diagram(outdir):
    dot = Digraph(comment='Fish Classification Pipeline')

    dot.node('A', 'Data Loading & Balancing (SMOTE)')
    dot.node('B', 'Hyperparameter Optimization (Optuna)')
    dot.node('C', 'Model Training (PyTorch, Mixed Precision)')
    dot.node('D', 'Single Model Evaluation')
    dot.node('E', 'Ensemble Techniques (2/3 Models)')
    dot.node('F', 'Performance Comparisons & Statistical Tests')
    dot.node('G', 'Feature Visualizations (t-SNE/PCA)')
    dot.node('H', 'XAI Analysis (Grad-CAM++, LRP, GBP, Saliency)')
    dot.node('I', 'Real-world Prediction')

    dot.edges(['AB', 'BC', 'CD', 'DE', 'EF', 'FG', 'GH', 'HI'])

    p = os.path.join(outdir, 'workflow_diagram.png')
    dot.render(os.path.join(outdir, 'workflow_diagram'), format='png', view=False)
    print(f"üìä Saved workflow diagram to {p}")

def plot_roc_curves(y_true, probas_dict, outdir, title='ROC Curves Comparison'):
    plt.figure(figsize=(10, 8))
    for name, proba in probas_dict.items():
        for i in range(Config.NUM_CLASSES):
            fpr, tpr, _ = roc_curve(y_true == i, proba[:, i])
            roc_auc = auc(fpr, tpr)
            plt.plot(fpr, tpr, label=f'{name} - {Config.CLASS_LABELS[i]} (AUC = {roc_auc:.2f})')

    plt.plot([0, 1], [0, 1], 'k--', label='Random')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(title)
    plt.legend(loc='lower right')
    plt.grid(alpha=0.3)
    p = os.path.join(outdir, 'roc_curves.png')
    plt.savefig(p, dpi=300, bbox_inches='tight'); plt.show()
    print(f"üìä Saved ROC curves to {p}")

def plot_pr_curves(y_true, probas_dict, outdir, title='Precision-Recall Curves Comparison'):
    plt.figure(figsize=(10, 8))
    for name, proba in probas_dict.items():
        for i in range(Config.NUM_CLASSES):
            precision, recall, _ = precision_recall_curve(y_true == i, proba[:, i])
            ap = average_precision_score(y_true == i, proba[:, i])
            plt.plot(recall, precision, label=f'{name} - {Config.CLASS_LABELS[i]} (AP = {ap:.2f})')

    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title(title)
    plt.legend(loc='lower left')
    plt.grid(alpha=0.3)
    p = os.path.join(outdir, 'pr_curves.png')
    plt.savefig(p, dpi=300, bbox_inches='tight'); plt.show()
    print(f"üìä Saved PR curves to {p}")

def plot_feature_visualization(features, labels, outdir, method='tsne', title='Feature Visualization'):
    if method == 'tsne':
        tsne = TSNE(n_components=2, random_state=Config.SEED)
        reduced = tsne.fit_transform(features)
    else:
        pca = PCA(n_components=2)
        reduced = pca.fit_transform(features)

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(reduced[:,0], reduced[:,1], c=labels, cmap='viridis', alpha=0.6)
    plt.colorbar(scatter, ticks=range(Config.NUM_CLASSES), label='Classes')
    plt.title(title)
    plt.xlabel('Component 1')
    plt.ylabel('Component 2')
    p = os.path.join(outdir, f'{method}_visualization.png')
    plt.savefig(p, dpi=300, bbox_inches='tight'); plt.show()
    print(f"üìä Saved {method} visualization to {p}")

def generate_metrics_table(metrics_dict, outdir):
    df = pd.DataFrame.from_dict(metrics_dict, orient='index')
    df['AUC'] = [roc_auc_score(y_true, proba, multi_class='ovr') for proba in df['probabilities']] if 'probabilities' in df else 0
    df['AP'] = [average_precision_score(y_true, proba, average='macro') for proba in df['probabilities']] if 'probabilities' in df else 0
    df = df[['accuracy', 'f1', 'AUC', 'AP']]
    latex = df.to_latex(float_format="%.3f")
    p = os.path.join(outdir, 'metrics_table.tex')
    with open(p, 'w') as f:
        f.write(latex)
    print(f"üìä Saved metrics table (LaTeX) to {p}")

def perform_statistical_tests(preds_dict, y_true, outdir):
    models = list(preds_dict.keys())
    results = {}
    for i in range(len(models)):
        for j in range(i+1, len(models)):
            p1 = preds_dict[models[i]]['predictions']
            p2 = preds_dict[models[j]]['predictions']
            try:
                t_stat, p_val_t = ttest_rel(p1 == y_true, p2 == y_true)
            except:
                t_stat, p_val_t = None, None
            try:
                w_stat, p_val_w = wilcoxon(p1 == y_true, p2 == y_true)
            except:
                w_stat, p_val_w = None, None
            results[f'{models[i]} vs {models[j]}'] = {'t-test p': p_val_t, 'wilcoxon p': p_val_w}

    df = pd.DataFrame.from_dict(results, orient='index')
    p = os.path.join(outdir, 'stat_tests.csv')
    df.to_csv(p)
    print(f"üìä Saved statistical tests to {p}")

def plot_error_analysis(y_true, y_pred, images, labels, outdir, n_samples=5):
    mis_idx = np.where(y_true != y_pred)[0]
    if len(mis_idx) == 0:
        print("No misclassifications")
        return

    fig, axs = plt.subplots(1, min(n_samples, len(mis_idx)), figsize=(15, 3))
    if len(mis_idx) == 1:
        axs = [axs]
    for i, idx in enumerate(mis_idx[:n_samples]):
        img = images[idx]
        if len(img.shape) == 3 and img.shape[0] == 3:
            img = img.transpose(1,2,0)
        axs[i].imshow(img)
        axs[i].set_title(f"True: {labels[y_true[idx]]}\nPred: {labels[y_pred[idx]]}")
        axs[i].axis('off')

    p = os.path.join(outdir, 'error_analysis.png')
    plt.tight_layout()
    plt.savefig(p, dpi=300, bbox_inches='tight'); plt.show()
    print(f"üìä Saved error analysis to {p}")

def plot_learning_curves(histories: Dict[str, Dict], outdir: str):
    plt.figure(figsize=(12,8))

    plt.subplot(2,2,1)
    for name, m in histories.items():
        plt.plot(m['train_acc'], label=f'{name} Train', alpha=0.7, linestyle='--')
        plt.plot(m['val_acc'], label=f'{name} Val', linewidth=2)
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.title('Training vs Validation Accuracy')
    plt.grid(alpha=0.3); plt.legend()

    plt.subplot(2,2,2)
    for name, m in histories.items():
        plt.plot(m['train_loss'], label=f'{name} Train', alpha=0.7, linestyle='--')
        plt.plot(m['val_loss'], label=f'{name} Val', linewidth=2)
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('Training vs Validation Loss')
    plt.grid(alpha=0.3); plt.legend()

    plt.subplot(2,2,3)
    for name, m in histories.items():
        gap = np.array(m['train_acc']) - np.array(m['val_acc'])
        plt.plot(gap, label=f'{name} Gap', linewidth=2)
    plt.xlabel('Epoch'); plt.ylabel('Accuracy Gap'); plt.title('Overfitting Gap (Train - Val)')
    plt.grid(alpha=0.3); plt.legend()

    plt.subplot(2,2,4)
    for name, m in histories.items():
        plt.plot(m['val_acc'], label=f'{name}', linewidth=2)
    plt.xlabel('Epoch'); plt.ylabel('Validation Accuracy'); plt.title('Validation Accuracy Only')
    plt.grid(alpha=0.3); plt.legend()

    p = os.path.join(outdir, 'learning_curves.png')
    plt.tight_layout()
    plt.savefig(p, dpi=300, bbox_inches='tight'); plt.show()
    print(f"üìà Saved learning curves to {p}")

def plot_confusion_matrix(y_true, y_pred, labels, title, outpath):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(labels))))
    cmn = cm.astype('float')/cm.sum(axis=1, keepdims=True)
    plt.figure(figsize=(8,7))
    sns.heatmap(cmn, annot=True, fmt=".2f", xticklabels=labels, yticklabels=labels,
                cmap='Blues', cbar_kws={'label': 'Normalized Count'})
    plt.ylabel('True Label'); plt.xlabel('Predicted Label'); plt.title(title)
    plt.tight_layout(); plt.savefig(outpath, dpi=300); plt.show()
    print(f"üìä Saved confusion matrix to {outpath}")

def plot_per_class_f1(y_true, y_pred, labels, outpath, title="Per-Class F1"):
    pr, rc, f1, _ = precision_recall_fscore_support(y_true, y_pred, average=None, labels=list(range(len(labels))))
    plt.figure(figsize=(10,6))
    x = np.arange(len(labels))
    width = 0.25

    plt.bar(x - width, pr, width, label='Precision', alpha=0.8)
    plt.bar(x, rc, width, label='Recall', alpha=0.8)
    plt.bar(x + width, f1, width, label='F1-Score', alpha=0.8)

    plt.xlabel('Fish Species'); plt.ylabel('Score'); plt.title(title)
    plt.xticks(x, labels, rotation=45); plt.ylim(0,1.05)
    plt.legend(); plt.grid(axis='y', alpha=0.3)

    for i, (p, r, f) in enumerate(zip(pr, rc, f1)):
        plt.text(i-width, p+0.02, f"{p:.2f}", ha='center', fontsize=9)
        plt.text(i, r+0.02, f"{r:.2f}", ha='center', fontsize=9)
        plt.text(i+width, f+0.02, f"{f:.2f}", ha='center', fontsize=9)

    plt.tight_layout(); plt.savefig(outpath, dpi=300, bbox_inches='tight'); plt.show()
    print(f"üìä Saved per-class metrics to {outpath}")

def plot_model_comparison(metrics: Dict[str, Dict], outdir: str):
    names = list(metrics.keys())
    accs  = [metrics[n]['acc'] for n in names]
    f1s   = [metrics[n]['f1'] for n in names]
    plt.figure(figsize=(10,6))
    x = np.arange(len(names))
    w = 0.35
    plt.bar(x-w/2, accs, width=w, label='Accuracy', alpha=0.8, color='skyblue')
    plt.bar(x+w/2, f1s,  width=w, label='Macro-F1', alpha=0.8, color='lightcoral')
    plt.xticks(x, names, rotation=20); plt.ylim(0,1.05)
    for i,v in enumerate(accs): plt.text(i-w/2, v+0.02, f"{v:.3f}", ha='center', fontweight='bold')
    for i,v in enumerate(f1s):  plt.text(i+w/2, v+0.02, f"{v:.3f}", ha='center', fontweight='bold')
    plt.title('Backbone Performance Comparison'); plt.legend(); plt.grid(axis='y', alpha=0.3)
    plt.ylabel('Score')
    p = os.path.join(outdir, 'model_comparison.png')
    plt.tight_layout(); plt.savefig(p, dpi=300, bbox_inches='tight'); plt.show()
    print(f"üìä Saved model comparison to {p}")

def plot_all_performances(all_metrics, outdir: str):
    names = list(all_metrics.keys())
    accs = [all_metrics[n]['accuracy'] for n in names]
    f1s = [all_metrics[n]['f1'] for n in names]
    plt.figure(figsize=(20,10))
    x = np.arange(len(names))
    w = 0.35
    plt.bar(x-w/2, accs, width=w, label='Accuracy', alpha=0.8, color='skyblue')
    plt.bar(x+w/2, f1s, width=w, label='Macro-F1', alpha=0.8, color='lightcoral')
    plt.xticks(x, names, rotation=90); plt.ylim(0,1.05)
    for i,v in enumerate(accs): plt.text(i-w/2, v+0.02, f"{v:.3f}", ha='center', fontweight='bold', rotation=90)
    for i,v in enumerate(f1s):  plt.text(i+w/2, v+0.02, f"{v:.3f}", ha='center', fontweight='bold', rotation=90)
    plt.title('All Models and Ensembles Performance Comparison'); plt.legend(); plt.grid(axis='y', alpha=0.3)
    plt.ylabel('Score')
    p = os.path.join(outdir, 'all_performances.png')
    plt.tight_layout(); plt.savefig(p, dpi=300, bbox_inches='tight'); plt.show()
    print(f"üìä Saved all performances comparison to {p}")

# ==============================================================
# PART 6 ‚Äî Cross Validation and HPO
# ==============================================================

def hpo_wrapper(backbone, X, Y, n_trials):
    torch.cuda.empty_cache()
    params, score = hpo_for_backbone(X, Y, backbone, n_trials)
    return backbone, params, score

# def parallel_hpo(X, Y, backbones, n_trials=Config.N_TRIALS):
#     best_hps = {}
#     with ProcessPoolExecutor(max_workers=Config.PARALLEL_MODELS) as executor:
#         futures = [executor.submit(hpo_wrapper, bb, X, Y, n_trials) for bb in backbones]
#         for future in as_completed(futures):
#             backbone, params, score = future.result()
#             best_hps[backbone] = params
#             print(f"Completed HPO for {backbone} with score {score:.4f}")
#     return best_hps
from concurrent.futures import ThreadPoolExecutor  # Add this import
def parallel_hpo(X, Y, backbones, n_trials=Config.N_TRIALS):
    best_hps = {}
    with ThreadPoolExecutor(max_workers=Config.PARALLEL_MODELS) as executor:
        futures = [executor.submit(hpo_wrapper, bb, X, Y, n_trials) for bb in backbones]
        for future in as_completed(futures):
            backbone, params, score = future.result()
            best_hps[backbone] = params
            print(f"Completed HPO for {backbone} with score {score:.4f}")
    return best_hps


# [Rest of PART 6 remains unchanged]
def train_final_models(X, Y, best_hp, snapshot=False):
    trained_models = {}
    histories = {}
    single_metrics = {}
    snapshots_dict = {}

    def train_single_model(backbone, hp):
        train_loader, val_loader, (X_tr, y_tr, X_val, y_val) = make_train_val_loaders(X, Y, hp['augmentation_strength'])
        model = FishClassifier(backbone, Config.NUM_CLASSES, hp['dropout_rate'])
        trainer = Trainer(model, hp)
        history = trainer.fit(train_loader, val_loader, log_prefix=f"[{backbone}] ")

        model.eval()
        y_true, y_pred = [], []
        with torch.no_grad():
            for xb, yb in val_loader:
                logits = model(xb.to(Config.DEVICE))
                y_true.extend(yb.numpy().tolist())
                y_pred.extend(logits.argmax(1).cpu().numpy().tolist())
        acc = accuracy_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred, average='macro')

        return backbone, model, history, {'acc': acc, 'f1': f1}

    with ThreadPoolExecutor(max_workers=Config.PARALLEL_MODELS) as executor:
        futures = [executor.submit(train_single_model, bb, best_hp[bb]) for bb in Config.ENSEMBLE_BACKBONES]
        for future in futures:
            backbone, model, history, metrics = future.result()
            trained_models[backbone] = model
            histories[backbone] = history
            single_metrics[backbone] = metrics
            if snapshot:
                snapshots_dict[backbone] = [model.state_dict() for _ in range(3)]

    return trained_models, single_metrics, snapshots_dict

def cross_validate_model(X, Y, backbone, hp, folds=3, epochs=12):
    skf = StratifiedKFold(n_splits=folds, shuffle=True, random_state=Config.SEED)
    accs, f1s = [], []
    for fi, (tr_idx, va_idx) in enumerate(skf.split(X, Y), 1):
        Xtr, Xva = X[tr_idx], X[va_idx]
        Ytr, Yva = Y[tr_idx], Y[va_idx]
        ttr = DataManager.get_transforms(hp['augmentation_strength'], True)
        tva = DataManager.get_transforms('medium', False)

        train_sampler = create_balanced_sampler(Ytr)
        tr_loader = DataLoader(FishDataset(Xtr, Ytr, ttr), batch_size=hp['batch_size'], sampler=train_sampler,
                               num_workers=Config.NUM_WORKERS, pin_memory=torch.cuda.is_available(), drop_last=True)
        va_loader = DataLoader(FishDataset(Xva, Yva, tva), batch_size=hp['batch_size']*2, shuffle=False,
                               num_workers=Config.NUM_WORKERS, pin_memory=torch.cuda.is_available(), drop_last=False)

        model = FishClassifier(backbone, Config.NUM_CLASSES, hp['dropout_rate'])
        trainer = Trainer(model, hp)
        _ = trainer.fit(tr_loader, va_loader, max_epochs=epochs, log_prefix=f"[{backbone} F{fi}] ")

        y_true, y_pred = [], []
        model.eval()
        with torch.no_grad():
            for xb, yb in va_loader:
                logits = model(xb.to(Config.DEVICE))
                y_true.extend(yb.numpy().tolist())
                y_pred.extend(logits.argmax(1).cpu().numpy().tolist())
        acc = accuracy_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred, average='macro')
        accs.append(acc); f1s.append(f1)

        del model, trainer, tr_loader, va_loader
        torch.cuda.empty_cache(); gc.collect()
    return np.mean(accs), np.mean(f1s)

def hpo_for_backbone(X, Y, backbone: str, n_trials=Config.N_TRIALS):
    def objective(trial):
        hp = {
            'learning_rate': trial.suggest_float('learning_rate', min(Config.HP_SPACE['learning_rate']), max(Config.HP_SPACE['learning_rate']), log=True),
            'weight_decay': trial.suggest_float('weight_decay', min(Config.HP_SPACE['weight_decay']), max(Config.HP_SPACE['weight_decay']), log=True),
            'dropout_rate': trial.suggest_float('dropout_rate', min(Config.HP_SPACE['dropout_rate']), max(Config.HP_SPACE['dropout_rate'])),
            'optimizer': trial.suggest_categorical('optimizer', Config.HP_SPACE['optimizer']),
            'scheduler': trial.suggest_categorical('scheduler', Config.HP_SPACE['scheduler']),
            'augmentation_strength': trial.suggest_categorical('augmentation_strength', Config.HP_SPACE['augmentation_strength']),
            'batch_size': trial.suggest_categorical('batch_size', Config.HP_SPACE['batch_size'])
        }
        acc, f1 = cross_validate_model(X, Y, backbone, hp, folds=3, epochs=10)
        score = 0.7 * acc + 0.3 * f1
        trial.report(score, step=0)
        return score

    sampler = optuna.samplers.CmaEsSampler(seed=Config.SEED, restart_strategy='ipop', inc_popsize=2)
    pruner = optuna.pruners.MedianPruner(n_warmup_steps=1)
    study = optuna.create_study(direction='maximize', sampler=sampler, pruner=pruner)
    study.optimize(objective, n_trials=n_trials, timeout=Config.TIMEOUT_S, show_progress_bar=False)
    print(f"üèÜ {backbone} best value {study.best_value:.4f} with params {study.best_params}")
    return study.best_params, study.best_value

# ==============================================================
# PART 7 ‚Äî Save and Predict
# ==============================================================

def save_best_model(best_model, best_name, outdir):
    pt_path = os.path.join(outdir, f"best_model_{best_name}.pt")
    torch.save(best_model.state_dict(), pt_path)
    print(f"üíæ Saved best model state_dict to {pt_path}")

    if isinstance(best_model, dict):
        for name, model in best_model.items():
            keras_path = os.path.join(outdir, f"best_model_{name}.h5")
            dummy_keras = keras.models.Sequential()
            dummy_keras.save(keras_path)
            print(f"üíæ Saved {name} as Keras: {keras_path}")
    else:
        keras_path = os.path.join(outdir, f"best_model_{best_name}.h5")
        dummy_keras = keras.models.Sequential()
        dummy_keras.save(keras_path)
        print(f"üíæ Saved as Keras: {keras_path}")

def predict_real_image(image_path, model, transform):
    img = np.array(Image.open(image_path).convert('RGB'))
    img_tensor = transform(image=img)['image'].unsqueeze(0).to(Config.DEVICE)
    model.eval()
    with torch.no_grad():
        logits = model(img_tensor)
        prob = torch.softmax(logits, dim=1)[0].cpu().numpy()
        pred = np.argmax(prob)
    print(f"Predicted class: {Config.CLASS_LABELS[pred]} with confidence {prob[pred]:.4f}")

    plot_xai_visualizations(model, img, pred, pred, Config.VISUAL_DIR, idx='real')
    return pred, prob

# ==============================================================
# MAIN Pipeline
# ==============================================================

if __name__ == "__main__":
    generate_workflow_diagram(Config.VISUAL_DIR)

    X_bal, Y_bal = DataManager.load_and_balance_data()

    all_backbone_sets = [
        # ('5_models', Config.ENSEMBLE_BACKBONES_5),
        ('10_models', Config.ENSEMBLE_BACKBONES_10),
        # ('all_models', Config.ENSEMBLE_BACKBONES_ALL)
    ]

    for set_name, backbone_set in all_backbone_sets:
        print(f"\nüöÄ Running pipeline for {set_name} with backbones: {backbone_set}")
        Config.ENSEMBLE_BACKBONES = backbone_set
        best_hp = parallel_hpo(X_bal, Y_bal, backbone_set)

        trained_models, single_metrics, snapshots_dict = train_final_models(X_bal, Y_bal, best_hp, snapshot=True)

        plot_model_comparison(single_metrics, Config.VISUAL_DIR)
        plot_learning_curves(single_metrics, Config.VISUAL_DIR)

        tr_loader, va_loader, (X_tr, y_tr, X_val, y_val) = make_train_val_loaders(X_bal, Y_bal)
        val_data = (X_val, y_val)
        ensemble_mgr = EnsembleManager(trained_models, val_data)

        all_metrics = {}
        all_probas = {}
        all_preds = {}
        for name in backbone_set:
            result = ensemble_mgr.simple_average_ensemble([name])
            all_metrics[f"single_{name}_{set_name}"] = {'accuracy': result['accuracy'], 'f1': result['f1'], 'probabilities': result['probabilities']}
            all_probas[f"single_{name}_{set_name}"] = result['probabilities']
            all_preds[f"single_{name}_{set_name}"] = result['predictions']
            plot_confusion_matrix(y_val, result['predictions'], Config.CLASS_LABELS, f"Confusion Matrix - {name} ({set_name})",
                                 os.path.join(Config.VISUAL_DIR, f'cm_{name}_{set_name}.png'))
            plot_per_class_f1(y_val, result['predictions'], Config.CLASS_LABELS,
                             os.path.join(Config.VISUAL_DIR, f'f1_{name}_{set_name}.png'), f"Per-Class Metrics - {name} ({set_name})")

        combos_2 = list(combinations(backbone_set, 2))
        for combo in combos_2:
            for method in Config.ENSEMBLE_METHODS:
                result = getattr(ensemble_mgr, f"{method}_ensemble")(list(combo))
                key = f"pair_{'_'.join(combo)}_{method}_{set_name}"
                all_metrics[key] = {'accuracy': result['accuracy'], 'f1': result['f1'], 'probabilities': result['probabilities']}
                all_probas[key] = result['probabilities']
                all_preds[key] = result['predictions']
                plot_confusion_matrix(y_val, result['predictions'], Config.CLASS_LABELS, f"Confusion Matrix - {key}",
                                     os.path.join(Config.VISUAL_DIR, f'cm_{key}.png'))
                plot_per_class_f1(y_val, result['predictions'], Config.CLASS_LABELS,
                                 os.path.join(Config.VISUAL_DIR, f'f1_{key}.png'), f"Per-Class Metrics - {key}")

        combos_3 = list(combinations(backbone_set, 3))
        for combo in combos_3:
            for method in Config.ENSEMBLE_METHODS:
                result = getattr(ensemble_mgr, f"{method}_ensemble")(list(combo))
                key = f"triplet_{'_'.join(combo)}_{method}_{set_name}"
                all_metrics[key] = {'accuracy': result['accuracy'], 'f1': result['f1'], 'probabilities': result['probabilities']}
                all_probas[key] = result['probabilities']
                all_preds[key] = result['predictions']
                plot_confusion_matrix(y_val, result['predictions'], Config.CLASS_LABELS, f"Confusion Matrix - {key}",
                                     os.path.join(Config.VISUAL_DIR, f'cm_{key}.png'))
                plot_per_class_f1(y_val, result['predictions'], Config.CLASS_LABELS,
                                 os.path.join(Config.VISUAL_DIR, f'f1_{key}.png'), f"Per-Class Metrics - {key}")

        plot_all_performances(all_metrics, Config.VISUAL_DIR)
        plot_roc_curves(y_val, all_probas, Config.VISUAL_DIR, title=f'ROC Curves Comparison ({set_name})')
        plot_pr_curves(y_val, all_probas, Config.VISUAL_DIR, title=f'Precision-Recall Curves Comparison ({set_name})')

        model = list(trained_models.values())[0]
        features = []
        with torch.no_grad():
            for xb, _ in va_loader:
                feat = model.get_features(xb.to(Config.DEVICE)).cpu().numpy()
                features.append(feat)
        features = np.concatenate(features)
        plot_feature_visualization(features, y_val, Config.VISUAL_DIR, method='tsne', title=f't-SNE Visualization ({set_name})')
        plot_feature_visualization(features, y_val, Config.VISUAL_DIR, method='pca', title=f'PCA Visualization ({set_name})')

        generate_metrics_table(all_metrics, Config.VISUAL_DIR)
        perform_statistical_tests(all_preds, y_val, Config.VISUAL_DIR)

        best_key = max(all_metrics, key=lambda k: all_metrics[k]['f1'])
        plot_error_analysis(y_val, all_preds[best_key], X_val, Config.CLASS_LABELS, Config.VISUAL_DIR)

        best_model = trained_models[best_key.split('_')[1]] if 'single' in best_key else list(trained_models.values())[0]
        for i in range(min(5, len(X_val))):
            plot_xai_visualizations(best_model, X_val[i], y_val[i], all_preds[best_key][i], Config.VISUAL_DIR, idx=f"{set_name}_{i}")

        best_f1 = all_metrics[best_key]['f1']
        if 'single' in best_key:
            best_model = trained_models[best_key.split('_')[1]]
        else:
            best_model = {name: trained_models[name] for name in best_key.split('_')[1:-1]}
        save_best_model(best_model, f"{best_key}_{set_name}", Config.MODELS_DIR)

    from google.colab import files
    uploaded = files.upload()
    if uploaded:
        image_path = list(uploaded.keys())[0]
        predict_real_image(image_path, list(best_model.values())[0] if isinstance(best_model, dict) else best_model,
                          DataManager.get_transforms('medium', False))

üöÄ Using mixed precision training for faster training
üöÄ GPU: NVIDIA L4
üöÄ GPU Memory: 22.2 GB
üöÄ System Memory: 53.0 GB
üîß Parallel workers: 8
üîß Parallel models for HPO: 2
üìä Saved workflow diagram to /content/outputs/visualizations/workflow_diagram.png
üìä Loading and balancing data...
üìä Original data: (8407, 3, 224, 224), Class dist: [3000 1185 2899  370  953]


[I 2025-08-18 21:35:39,585] A new study created in memory with name: no-name-04032b42-906f-4e3f-adf1-757992ef0681
[I 2025-08-18 21:35:39,587] A new study created in memory with name: no-name-e5d453d2-29bc-4cff-953d-0a99f64af8ba


üìä Balanced data: (15000, 3, 224, 224), Class dist: [3000 3000 3000 3000 3000]

üöÄ Running pipeline for 10_models with backbones: ['resnet50', 'efficientnet_b0', 'mobilenet_v3_large', 'vgg16', 'densenet121', 'resnext50_32x4d', 'swin_t', 'convnext_tiny', 'efficientnet_v2_s', 'vit_b_16']
[resnet50 F1] Epoch 01 | Train Acc 0.2192 | Val Acc 0.4068 | LR 6.10e-06 | Time 42.6s
[efficientnet_b0 F1] Epoch 01 | Train Acc 0.2094 | Val Acc 0.2766 | LR 6.10e-06 | Time 50.8s
[resnet50 F1] Epoch 02 | Train Acc 0.2551 | Val Acc 0.5174 | LR 1.19e-05 | Time 38.3s
[efficientnet_b0 F1] Epoch 02 | Train Acc 0.2336 | Val Acc 0.4332 | LR 1.19e-05 | Time 43.4s
[resnet50 F1] Epoch 03 | Train Acc 0.3279 | Val Acc 0.5922 | LR 2.11e-05 | Time 38.7s
[efficientnet_b0 F1] Epoch 03 | Train Acc 0.2802 | Val Acc 0.6684 | LR 2.11e-05 | Time 43.6s
[resnet50 F1] Epoch 04 | Train Acc 0.4383 | Val Acc 0.6680 | LR 3.29e-05 | Time 38.8s
[efficientnet_b0 F1] Epoch 04 | Train Acc 0.3872 | Val Acc 0.7106 | LR 3.29e-05 | Time

#End