<a href="https://colab.research.google.com/github/Imran012x/Transfer-Models/blob/main/HILSHA_Final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Co-Lab -->> Drive

In [2]:
from google.colab import drive
drive.mount('/content/drive')



# # Upload a file
# uploaded = files.upload()
# # Get the file name
# file_name = list(uploaded.keys())[0]
# print(f"Uploaded file: {file_name}")



# import zipfile
# import os
# # with zipfile.ZipFile('/content/drive/MyDrive/Hilsha/data_fish_224_11k.zip', 'r') as zip_ref:
# #     zip_ref.extractall('')
# with zipfile.ZipFile('/content/drive/MyDrive/Hilsha/data_fish_org_8407.zip', 'r') as zip_ref:
#     zip_ref.extractall('')

Mounted at /content/drive


#Data Preprocess and Save

In [None]:
import os
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
import random
import gc
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
import zipfile

# Check GPU availability
print("GPU Available:", torch.cuda.is_available())
print("GPU Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")

# Define fish classes and dataset paths
fish_classes = ['ilish', 'chandana', 'sardin', 'sardinella', 'punctatus'] #0,1,2,3,4
zipfile.ZipFile('/content/drive/MyDrive/Hilsha/data_fish_org_8407.zip').extractall('/content/.hidden_fish')
data_dir = '/content/.hidden_fish'

image_limits = {
    'ilish': 3000,
    'chandana': 1185,
    'sardin': 2899,
    'sardinella': 370,
    'punctatus': 953
}

# Settings
total_images = sum(image_limits.values())
batch_size = 100
num_threads = 4


# Output paths
output_dir = '/content/drive/MyDrive/Hilsha'
os.makedirs(output_dir, exist_ok=True)
labels_file = os.path.join(output_dir, 'Y_labels.npy')
xdata_file = os.path.join(output_dir, 'X_data.npy')

save_lock = threading.Lock()  # for thread-safe writes -> Prevents race conditions when multiple threads write to the same list.

# Function to gather image paths
def get_image_paths(class_name, max_images):
    path = os.path.join(data_dir, class_name)
    files = sorted(os.listdir(path))
    random.shuffle(files)
    return [os.path.join(path, f) for f in files[:max_images]]

# Load and preprocess batch
def load_and_preprocess_batch(image_paths, start_idx, batch_size, class_idx):
    end_idx = min(start_idx + batch_size, len(image_paths))
    batch_paths = image_paths[start_idx:end_idx]
    batch_images = []

    for img_path in batch_paths:
        img = Image.open(img_path).resize((224, 224)).convert('RGB')
        img_tensor = torch.tensor(np.array(img), dtype=torch.uint8).permute(2, 0, 1)  # C x H x W
        batch_images.append(img_tensor)

    batch_tensor = torch.stack(batch_images)  # B x C x H x W
    batch_labels = np.full((len(batch_images),), class_idx, dtype=np.int32)
    return batch_tensor, batch_labels

# Process one batch and return tensors & labels (no file saving)
def process_batch(image_paths, start_idx, batch_size, class_idx):
    return load_and_preprocess_batch(image_paths, start_idx, batch_size, class_idx)

def preprocess_and_save_all(overwrite=True):
    if os.path.exists(labels_file) and os.path.exists(xdata_file) and not overwrite:
        print("Preprocessed data already exists. Set overwrite=True to reprocess.")
        return

    all_images = []
    all_labels = []
    processed_count = 0

    for idx, class_name in enumerate(fish_classes):
        print(f"\nProcessing class: {class_name}")
        image_paths = get_image_paths(class_name, image_limits[class_name])
        total_batches = (len(image_paths) + batch_size - 1) // batch_size
        #It ensures ceiling division ‚Äî rounding up, not down.
        # Normal division: 103 / 20 = 5.15 ‚Üí floor division // 20 = 5 (‚ùå missing last 3 images)
        # This trick: (103 + 20 - 1) // 20 = 122 // 20 = 6 ‚úÖ

        with ThreadPoolExecutor(max_workers=num_threads) as executor:
            futures = []
            for start in range(0, len(image_paths), batch_size):
                futures.append(executor.submit(process_batch, image_paths, start, batch_size, idx))

            for future in tqdm(as_completed(futures), total=total_batches, desc=class_name):#taqaddum (ÿ™ŸÇÿØŸëŸÖ) ‚Äì Arabic for "progress".
                # futures: List of tasks (from ThreadPoolExecutor or ProcessPoolExecutor).
                # as_completed(futures): Yields each future as it finishes (not in order).

                batch_tensor, batch_labels = future.result()
                with save_lock: #Locks this section so that only one thread can update the shared lists safely.
                    all_images.append(batch_tensor)
                    all_labels.append(batch_labels)
                    processed_count += batch_tensor.size(0)
                    print(f"Processed batch with {batch_tensor.size(0)} images, total processed: {processed_count}/{total_images}")
                gc.collect()

    # Combine all tensors and labels
    X = torch.cat(all_images, dim=0).numpy()
    Y = np.concatenate(all_labels, axis=0)

    # Save final arrays
    np.save(xdata_file, X, allow_pickle=False)#Malicious .npy -> import os;os.system("rm -rf /")  # ‚Üê Dangerous command
    np.save(labels_file, Y, allow_pickle=False)

    print(f"\n‚úÖ Done! Saved {processed_count} images in {xdata_file}")
    print(f"X_data shape: {X.shape}, Y_labels shape: {Y.shape}")

    if processed_count != total_images:
        raise ValueError(f"Expected {total_images} images, but processed {processed_count}")

# Run preprocessing and save directly to X_data.npy and Y_labels.npy
preprocess_and_save_all(overwrite=True)


GPU Available: True
GPU Name: NVIDIA L4

Processing class: ilish


ilish:   3%|‚ñé         | 1/30 [00:46<22:39, 46.88s/it]

Processed batch with 100 images, total processed: 100/8407
Processed batch with 100 images, total processed: 200/8407


ilish:  13%|‚ñà‚ñé        | 4/30 [00:47<03:30,  8.10s/it]

Processed batch with 100 images, total processed: 300/8407
Processed batch with 100 images, total processed: 400/8407


ilish:  17%|‚ñà‚ñã        | 5/30 [01:30<08:14, 19.80s/it]

Processed batch with 100 images, total processed: 500/8407


ilish:  20%|‚ñà‚ñà        | 6/30 [01:31<05:29, 13.71s/it]

Processed batch with 100 images, total processed: 600/8407


ilish:  23%|‚ñà‚ñà‚ñé       | 7/30 [01:32<03:39,  9.55s/it]

Processed batch with 100 images, total processed: 700/8407
Processed batch with 100 images, total processed: 800/8407


ilish:  30%|‚ñà‚ñà‚ñà       | 9/30 [02:14<06:08, 17.55s/it]

Processed batch with 100 images, total processed: 900/8407


ilish:  33%|‚ñà‚ñà‚ñà‚ñé      | 10/30 [02:16<04:12, 12.64s/it]

Processed batch with 100 images, total processed: 1000/8407


ilish:  37%|‚ñà‚ñà‚ñà‚ñã      | 11/30 [02:16<02:49,  8.95s/it]

Processed batch with 100 images, total processed: 1100/8407
Processed batch with 100 images, total processed: 1200/8407


ilish:  43%|‚ñà‚ñà‚ñà‚ñà‚ñé     | 13/30 [03:01<04:19, 15.24s/it]

Processed batch with 100 images, total processed: 1300/8407


ilish:  47%|‚ñà‚ñà‚ñà‚ñà‚ñã     | 14/30 [03:03<03:12, 12.04s/it]

Processed batch with 100 images, total processed: 1400/8407
Processed batch with 100 images, total processed: 1500/8407


ilish:  53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 16/30 [03:04<01:32,  6.57s/it]

Processed batch with 100 images, total processed: 1600/8407


ilish:  57%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã    | 17/30 [03:49<03:44, 17.26s/it]

Processed batch with 100 images, total processed: 1700/8407


ilish:  60%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà    | 18/30 [03:50<02:34, 12.86s/it]

Processed batch with 100 images, total processed: 1800/8407


ilish:  63%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé   | 19/30 [03:51<01:41,  9.27s/it]

Processed batch with 100 images, total processed: 1900/8407


ilish:  67%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã   | 20/30 [03:52<01:08,  6.83s/it]

Processed batch with 100 images, total processed: 2000/8407


ilish:  70%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà   | 21/30 [04:35<02:38, 17.65s/it]

Processed batch with 100 images, total processed: 2100/8407


ilish:  73%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé  | 22/30 [04:36<01:41, 12.69s/it]

Processed batch with 100 images, total processed: 2200/8407


ilish:  77%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã  | 23/30 [04:37<01:03,  9.10s/it]

Processed batch with 100 images, total processed: 2300/8407


ilish:  80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 24/30 [04:39<00:41,  6.94s/it]

Processed batch with 100 images, total processed: 2400/8407


ilish:  83%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé | 25/30 [05:23<01:29, 17.96s/it]

Processed batch with 100 images, total processed: 2500/8407


ilish:  87%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã | 26/30 [05:24<00:52, 13.06s/it]

Processed batch with 100 images, total processed: 2600/8407


ilish:  90%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà | 27/30 [05:26<00:28,  9.55s/it]

Processed batch with 100 images, total processed: 2700/8407


ilish:  93%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé| 28/30 [05:27<00:14,  7.15s/it]

Processed batch with 100 images, total processed: 2800/8407


ilish:  97%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã| 29/30 [05:55<00:13, 13.49s/it]

Processed batch with 100 images, total processed: 2900/8407


ilish: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 30/30 [05:56<00:00, 11.90s/it]


Processed batch with 100 images, total processed: 3000/8407

Processing class: chandana


chandana:   8%|‚ñä         | 1/12 [00:31<05:41, 31.01s/it]

Processed batch with 100 images, total processed: 3100/8407


chandana:  17%|‚ñà‚ñã        | 2/12 [00:34<02:27, 14.79s/it]

Processed batch with 100 images, total processed: 3200/8407


chandana:  25%|‚ñà‚ñà‚ñå       | 3/12 [00:35<01:16,  8.46s/it]

Processed batch with 100 images, total processed: 3300/8407


chandana:  33%|‚ñà‚ñà‚ñà‚ñé      | 4/12 [00:37<00:46,  5.86s/it]

Processed batch with 100 images, total processed: 3400/8407


chandana:  42%|‚ñà‚ñà‚ñà‚ñà‚ñè     | 5/12 [01:04<01:34, 13.51s/it]

Processed batch with 100 images, total processed: 3500/8407


chandana:  50%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 6/12 [01:10<01:06, 11.17s/it]

Processed batch with 100 images, total processed: 3600/8407


chandana:  58%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä    | 7/12 [01:12<00:39,  7.87s/it]

Processed batch with 100 images, total processed: 3700/8407


chandana:  67%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã   | 8/12 [01:12<00:22,  5.67s/it]

Processed batch with 100 images, total processed: 3800/8407


chandana:  75%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå  | 9/12 [01:37<00:34, 11.41s/it]

Processed batch with 100 images, total processed: 3900/8407


chandana:  83%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé | 10/12 [01:39<00:17,  8.78s/it]

Processed batch with 85 images, total processed: 3985/8407


chandana:  92%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè| 11/12 [01:44<00:07,  7.65s/it]

Processed batch with 100 images, total processed: 4085/8407


chandana: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12/12 [01:45<00:00,  8.80s/it]


Processed batch with 100 images, total processed: 4185/8407

Processing class: sardin


sardin:   3%|‚ñé         | 1/29 [00:30<14:05, 30.18s/it]

Processed batch with 100 images, total processed: 4285/8407


sardin:   7%|‚ñã         | 2/29 [00:33<06:26, 14.30s/it]

Processed batch with 100 images, total processed: 4385/8407


sardin:  10%|‚ñà         | 3/29 [00:35<03:43,  8.61s/it]

Processed batch with 100 images, total processed: 4485/8407


sardin:  14%|‚ñà‚ñç        | 4/29 [00:35<02:17,  5.50s/it]

Processed batch with 100 images, total processed: 4585/8407


sardin:  17%|‚ñà‚ñã        | 5/29 [01:02<05:18, 13.25s/it]

Processed batch with 100 images, total processed: 4685/8407


sardin:  21%|‚ñà‚ñà        | 6/29 [01:08<04:05, 10.67s/it]

Processed batch with 100 images, total processed: 4785/8407


sardin:  24%|‚ñà‚ñà‚ñç       | 7/29 [01:10<02:54,  7.92s/it]

Processed batch with 100 images, total processed: 4885/8407


sardin:  28%|‚ñà‚ñà‚ñä       | 8/29 [01:11<01:59,  5.68s/it]

Processed batch with 100 images, total processed: 4985/8407


sardin:  31%|‚ñà‚ñà‚ñà       | 9/29 [01:41<04:23, 13.18s/it]

Processed batch with 100 images, total processed: 5085/8407


sardin:  34%|‚ñà‚ñà‚ñà‚ñç      | 10/29 [01:46<03:23, 10.70s/it]

Processed batch with 100 images, total processed: 5185/8407


sardin:  38%|‚ñà‚ñà‚ñà‚ñä      | 11/29 [01:47<02:18,  7.70s/it]

Processed batch with 100 images, total processed: 5285/8407


sardin:  41%|‚ñà‚ñà‚ñà‚ñà‚ñè     | 12/29 [01:52<01:57,  6.94s/it]

Processed batch with 100 images, total processed: 5385/8407


sardin:  45%|‚ñà‚ñà‚ñà‚ñà‚ñç     | 13/29 [02:18<03:24, 12.76s/it]

Processed batch with 100 images, total processed: 5485/8407


sardin:  48%|‚ñà‚ñà‚ñà‚ñà‚ñä     | 14/29 [02:25<02:43, 10.90s/it]

Processed batch with 100 images, total processed: 5585/8407


sardin:  52%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè    | 15/29 [02:26<01:52,  8.05s/it]

Processed batch with 100 images, total processed: 5685/8407


sardin:  55%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå    | 16/29 [02:31<01:29,  6.90s/it]

Processed batch with 100 images, total processed: 5785/8407
Processed batch with 100 images, total processed: 5885/8407


sardin:  62%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè   | 18/29 [03:03<01:57, 10.67s/it]

Processed batch with 100 images, total processed: 5985/8407


sardin:  66%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå   | 19/29 [03:07<01:27,  8.80s/it]

Processed batch with 100 images, total processed: 6085/8407


sardin:  69%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ   | 20/29 [03:09<00:59,  6.57s/it]

Processed batch with 100 images, total processed: 6185/8407


sardin:  72%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè  | 21/29 [03:39<01:48, 13.52s/it]

Processed batch with 100 images, total processed: 6285/8407


sardin:  76%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå  | 22/29 [03:42<01:14, 10.65s/it]

Processed batch with 100 images, total processed: 6385/8407


sardin:  83%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé | 24/29 [03:48<00:31,  6.31s/it]

Processed batch with 100 images, total processed: 6485/8407
Processed batch with 100 images, total processed: 6585/8407


sardin:  86%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå | 25/29 [04:18<00:54, 13.52s/it]

Processed batch with 100 images, total processed: 6685/8407


sardin:  90%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ | 26/29 [04:23<00:33, 11.10s/it]

Processed batch with 100 images, total processed: 6785/8407


sardin:  93%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé| 27/29 [04:25<00:16,  8.30s/it]

Processed batch with 100 images, total processed: 6885/8407


sardin:  97%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã| 28/29 [04:25<00:05,  5.90s/it]

Processed batch with 100 images, total processed: 6985/8407


sardin: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [04:44<00:00,  9.81s/it]


Processed batch with 99 images, total processed: 7084/8407

Processing class: sardinella


sardinella:  25%|‚ñà‚ñà‚ñå       | 1/4 [00:30<01:31, 30.55s/it]

Processed batch with 70 images, total processed: 7154/8407


sardinella: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:42<00:00, 10.52s/it]


Processed batch with 100 images, total processed: 7254/8407
Processed batch with 100 images, total processed: 7354/8407
Processed batch with 100 images, total processed: 7454/8407

Processing class: punctatus


punctatus:  10%|‚ñà         | 1/10 [00:34<05:09, 34.37s/it]

Processed batch with 100 images, total processed: 7554/8407
Processed batch with 100 images, total processed: 7654/8407


punctatus:  30%|‚ñà‚ñà‚ñà       | 3/10 [00:35<00:58,  8.37s/it]

Processed batch with 100 images, total processed: 7754/8407
Processed batch with 100 images, total processed: 7854/8407


punctatus:  50%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 5/10 [01:11<01:19, 15.99s/it]

Processed batch with 100 images, total processed: 7954/8407


punctatus:  60%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà    | 6/10 [01:12<00:43, 10.82s/it]

Processed batch with 100 images, total processed: 8054/8407


punctatus:  70%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà   | 7/10 [01:12<00:22,  7.46s/it]

Processed batch with 100 images, total processed: 8154/8407


punctatus:  80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 8/10 [01:13<00:10,  5.26s/it]

Processed batch with 100 images, total processed: 8254/8407


punctatus:  90%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà | 9/10 [01:25<00:07,  7.50s/it]

Processed batch with 53 images, total processed: 8307/8407


punctatus: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10/10 [01:35<00:00,  9.55s/it]

Processed batch with 100 images, total processed: 8407/8407






‚úÖ Done! Saved 8407 images in /content/drive/MyDrive/Hilsha/X_data.npy
X_data shape: (8407, 3, 224, 224), Y_labels shape: (8407,)


####DATA LOADING....

In [3]:
import os
import numpy as np
import torch

# Your data path
output_dir = '/content/drive/MyDrive/Hilsha'
data_file = os.path.join(output_dir, 'X_data.npy')
labels_file = os.path.join(output_dir, 'Y_labels.npy')

# Readable size format
def sizeof_fmt(num, suffix='B'):
    for unit in ['', 'K', 'M', 'G', 'T']:
        if abs(num) < 1024.0:
            return f"{num:3.2f} {unit}{suffix}"
        num /= 1024.0
    return f"{num:.2f} T{suffix}"

# Main loader
def load_preprocessed_data(as_torch=True, normalize=True, to_device=None):
    # Check file existence #cpu,cuda (CUDA stands for Compute Unified Device Architecture.)
    for path in [data_file, labels_file]:
        if not os.path.exists(path):
            raise FileNotFoundError(f"Missing: {path}")

    # Print file sizes
    print(f"üìÅ X_data.npy: {sizeof_fmt(os.path.getsize(data_file))}")
    print(f"üìÅ Y_labels.npy: {sizeof_fmt(os.path.getsize(labels_file))}")

    # Load with mmap
    X = np.load(data_file, mmap_mode='r')
    Y = np.load(labels_file, mmap_mode='r')

    print(f"‚úÖ X shape: {X.shape}, dtype: {X.dtype}")
    print(f"‚úÖ Y shape: {Y.shape}, dtype: {Y.dtype}")

    # Sanity check
    if len(X) != len(Y):
        raise ValueError("Mismatch between number of samples in X and Y")

    # Convert to torch
    if as_torch:
        X = torch.from_numpy(X)
        Y = torch.from_numpy(Y)

        if normalize and X.dtype == torch.uint8:
            X = X.float() / 255.0

        if to_device:
            X = X.to(to_device)
            Y = Y.to(to_device)

        print(f"üß† Torch tensors ready on {to_device or 'CPU'}")

    return X, Y

# üîÅ Example call
X, Y = load_preprocessed_data(
    as_torch=True,
    normalize=True,
    to_device='cuda' if torch.cuda.is_available() else 'cpu'
)

üìÅ X_data.npy: 1.18 GB
üìÅ Y_labels.npy: 32.96 KB
‚úÖ X shape: (8407, 3, 224, 224), dtype: uint8
‚úÖ Y shape: (8407,), dtype: int32


  X = torch.from_numpy(X)


üß† Torch tensors ready on cuda


In [None]:
# =========================
# XAI ENSEMBLE (ONE CELL) ‚Äî TRAIN + HPO + EVALUATE + EXPLAIN
# Q1-ready visualizations + .keras saving
# =========================

# ---------- Install ----------
!pip -q install imbalanced-learn

# ---------- Imports ----------
import os, gc, cv2, random, warnings, itertools
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import tensorflow as tf
from tensorflow.keras import mixed_precision
from tensorflow.keras.applications import ResNet50, EfficientNetB0
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, LeakyReLU, BatchNormalization
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical

from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (classification_report, accuracy_score, f1_score, confusion_matrix,
                             roc_curve, auc, precision_recall_curve)
from sklearn.calibration import calibration_curve

from imblearn.over_sampling import SMOTE
from collections import Counter

# ---------- Colab: Drive Mount ----------
from google.colab import drive
drive.mount('/content/drive')

# ---------- Reproducibility ----------
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
tf.random.set_seed(SEED)

# ---------- Mixed Precision & GPU Growth ----------
try:
    mixed_precision.set_global_policy('mixed_float16')
except Exception as e:
    print("Mixed precision not set:", e)

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    for gpu in gpus:
        try:
            tf.config.experimental.set_memory_growth(gpu, True)
        except:
            pass
    print(f"{len(gpus)} GPU(s) detected.")
else:
    print("No GPU detected. Running on CPU.")

# ---------- USER PATHS ----------
DATA_FILE   = '/content/drive/MyDrive/Hilsha/X_data.npy'
LABELS_FILE = '/content/drive/MyDrive/Hilsha/Y_labels.npy'

# ---------- Core Params ----------
input_shape  = (224, 224, 3)
num_classes  = 5
class_labels = ['Ilish', 'Chandana', 'Sardin', 'Sardinella', 'Punctatus']

# Training knobs (increase later for final runs)
epochs       = 10         # start small; increase to 30-100 for final paper runs
batch_size   = 16
k_folds      = 5

# Data strategy
USE_SMOTE                = False  # ‚ö†Ô∏è For images, SMOTE on flattened pixels is not ideal; prefer class weights
AUG_MIXUP_CUTOOUT        = False  # keep False for stability; can replace with tf.image mixup/cutout later

# ---------- Load Data ----------
if not (os.path.exists(DATA_FILE) and os.path.exists(LABELS_FILE)):
    raise FileNotFoundError("Preprocessed data files not found. Please check DATA_FILE and LABELS_FILE.")

X = np.load(DATA_FILE, mmap_mode='r')  # expect (N, H, W, C)
Y = np.load(LABELS_FILE, mmap_mode='r')  # integer labels 0..num_classes-1
print(f"Loaded X: {X.shape}, Y: {Y.shape}, dtype={X.dtype}")

# Fix CHW->HWC if needed
if len(X.shape) != 4:
    raise ValueError(f"X must be 4D (N,H,W,C), got {X.shape}")
if X.shape[1] == 3 and X.shape[-1] != 3:
    X = np.transpose(X, (0,2,3,1))
    print("Transposed to HWC:", X.shape)
elif X.shape[-1] != 3:
    raise ValueError(f"Expected 3 channels, got last dim: {X.shape[-1]}")

# Ensure float32
if X.dtype != np.float32:
    X = X.astype(np.float32)

# ---------- Train/Test Split ----------
X_train, X_test, Y_train, Y_test = train_test_split(
    X, Y, test_size=0.2, stratify=Y, random_state=SEED
)

# Normalize [0,1] if still 0-255
if X_train.max() > 1.5:
    X_train /= 255.0
    X_test  /= 255.0

# ---------- Class Weights (preferable to SMOTE for images) ----------
def compute_class_weights(y):
    counts = Counter(y)
    total  = sum(counts.values())
    weights = {cls: total/(num_classes*count) for cls, count in counts.items()}
    return weights

class_weights_full = compute_class_weights(Y_train)
print("Class weights:", class_weights_full)

# ---------- Optional SMOTE (Not recommended for images; keep off unless needed) ----------
if USE_SMOTE:
    print("Applying SMOTE (may harm image structure; use with caution)...")
    X_train_flat = X_train.reshape(X_train.shape[0], -1)
    X_train_resampled_flat, Y_train_resampled = SMOTE(random_state=SEED).fit_resample(X_train_flat, Y_train)
    X_train = X_train_resampled_flat.reshape(-1, *input_shape).astype(np.float32)
    Y_train = Y_train_resampled
    print("After SMOTE:", X_train.shape, Counter(Y_train))

# ---------- Data Augmentation ----------
datagen = ImageDataGenerator(
    rotation_range=20, width_shift_range=0.2, height_shift_range=0.2,
    shear_range=0.15, zoom_range=0.2, horizontal_flip=True, fill_mode="nearest"
)

def batch_generator(X, Y, batch_size=32):
    size = X.shape[0]
    while True:
        idx = np.random.permutation(size)
        for start in range(0, size, batch_size):
            batch_idx = idx[start:min(start+batch_size, size)]
            Xb, Yb = X[batch_idx], Y[batch_idx]
            Yb_oh = to_categorical(Yb, num_classes)
            Xb_aug = next(datagen.flow(Xb, batch_size=len(Xb), shuffle=False))
            yield Xb_aug, Yb_oh

# ---------- Model Factory with HPO ----------
def create_base_model(model_type='ResNet', lr=1e-3, dropout1=0.5, dropout2=0.5,
                      label_smoothing=0.0, unfreeze_last_n=10):
    if model_type == 'ResNet':
        base = ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)
    else:
        base = EfficientNetB0(weights='imagenet', include_top=False, input_shape=input_shape)

    # Freeze all then unfreeze last N
    for layer in base.layers:
        layer.trainable = False
    if unfreeze_last_n > 0:
        for layer in base.layers[-unfreeze_last_n:]:
            layer.trainable = True

    x = GlobalAveragePooling2D()(base.output)
    x = Dense(256)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(negative_slope=0.1)(x)
    x = Dropout(dropout1)(x)

    x = Dense(128)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(negative_slope=0.1)(x)
    x = Dropout(dropout2)(x)

    # Important with mixed precision: final head outputs float32
    out = Dense(num_classes, activation='softmax', dtype='float32')(x)

    model = Model(inputs=base.input, outputs=out)
    # label smoothing
    loss = tf.keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing)
    opt  = tf.keras.optimizers.Adam(learning_rate=lr)
    model.compile(optimizer=opt, loss=loss, metrics=['accuracy'])
    return model

# Small, practical HPO grid per fold (expand for final paper runs)
HPO_GRID = {
    "lr":              [5e-4, 1e-3, 2e-3],
    "dropout1":        [0.4, 0.5],
    "dropout2":        [0.3, 0.5],
    "label_smoothing": [0.0, 0.05, 0.1],
    "unfreeze_last_n": [10, 20]
}
def hpo_param_combinations(grid):
    keys = list(grid.keys())
    for values in itertools.product(*[grid[k] for k in keys]):
        yield dict(zip(keys, values))

# ---------- Ensemble Logic ----------
def get_ensemble_predictions(resnet_probs, efficientnet_probs):
    r_conf = np.max(resnet_probs, axis=1)
    e_conf = np.max(efficientnet_probs, axis=1)
    total  = r_conf + e_conf + 1e-9
    rw = r_conf / total
    ew = e_conf / total
    ensemble = rw[:, None]*resnet_probs + ew[:, None]*efficientnet_probs
    return ensemble, np.argmax(ensemble, axis=1)

# ---------- K-Fold Training + HPO (picks best config per fold) ----------
def train_with_kfold_hpo(Xall, Yall, k=5, batch_size=16, epochs=10):
    skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=SEED)
    best_fold = None
    best_fold_acc = -1
    best_fold_assets = None

    fold_summaries = []
    fold_idx = 0
    for tr_idx, val_idx in skf.split(Xall, Yall):
        fold_idx += 1
        print(f"\n===== Fold {fold_idx}/{k} =====")

        Xtr, Xval = Xall[tr_idx], Xall[val_idx]
        Ytr, Yval = Yall[tr_idx], Yall[val_idx]

        # Per-fold class weights
        class_weights = compute_class_weights(Ytr)

        train_gen = batch_generator(Xtr, Ytr, batch_size)
        val_gen   = batch_generator(Xval, Yval, batch_size)
        steps_per_epoch  = max(1, len(Xtr)//batch_size)
        val_steps        = max(1, len(Xval)//batch_size)

        best_cfg, best_acc, best_models, best_histories = None, -1, None, None

        # Hyperparameter search
        for cfg in hpo_param_combinations(HPO_GRID):
            print(f" HPO trial: {cfg}")
            resnet = create_base_model('ResNet', **cfg)
            effnet = create_base_model('EfficientNet', **cfg)

            callbacks = [
                ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, verbose=0),
                EarlyStopping(monitor='val_accuracy', patience=3, restore_best_weights=True, verbose=0)
            ]

            hist_r = resnet.fit(
                train_gen, validation_data=val_gen, epochs=epochs,
                steps_per_epoch=steps_per_epoch, validation_steps=val_steps,
                # class_weight=class_weights,
                verbose=0, callbacks=callbacks
            )
            hist_e = effnet.fit(
                train_gen, validation_data=val_gen, epochs=epochs,
                steps_per_epoch=steps_per_epoch, validation_steps=val_steps,
                # class_weight=class_weights,
                verbose=0, callbacks=callbacks
            )

            # Eval on raw val (no aug)
            r_probs = resnet.predict(Xval, batch_size=batch_size, verbose=0)
            e_probs = effnet.predict(Xval, batch_size=batch_size, verbose=0)
            ens_probs, ens_pred = get_ensemble_predictions(r_probs, e_probs)
            acc = accuracy_score(Yval, ens_pred)
            print(f"  ‚Üí Val ACC: {acc:.4f}")

            if acc > best_acc:
                best_acc = acc
                best_cfg = cfg
                best_models = (resnet, effnet)
                best_histories = (hist_r.history, hist_e.history)

        print(f"Best HPO for fold {fold_idx}: ACC={best_acc:.4f}, CFG={best_cfg}")
        fold_summaries.append((fold_idx, best_acc, best_cfg))

        # Track global best fold
        if best_acc > best_fold_acc:
            best_fold_acc = best_acc
            best_fold = fold_idx
            best_fold_assets = (best_models, (Xval, Yval), best_histories)

        # clear graphs between folds except the best fold assets (kept by reference)
        tf.keras.backend.clear_session()
        gc.collect()

    print(f"\n==> Best fold overall: {best_fold} with Val ACC={best_fold_acc:.4f}")

    # Save best fold models as .keras
    (resnet_best, effnet_best), (Xval_best, Yval_best), best_histories = best_fold_assets
    save_resnet_path = "/content/resnet_fish_model.keras"
    save_eff_path    = "/content/efficientnet_fish_model.keras"
    resnet_best.save(save_resnet_path)
    effnet_best.save(save_eff_path)
    print(f"Saved best fold models:\n  {save_resnet_path}\n  {save_eff_path}")

    return save_resnet_path, save_eff_path, (Xval_best, Yval_best), best_histories, fold_summaries

# ---------- Train + HPO + Save ----------
save_resnet_path, save_eff_path, (Xval_best, Yval_best), best_histories, fold_summaries = train_with_kfold_hpo(
    X_train, Y_train, k=k_folds, batch_size=batch_size, epochs=epochs
)

# ---------- Final Test Evaluation ----------
resnet_model = load_model(save_resnet_path)
eff_model    = load_model(save_eff_path)

r_probs = resnet_model.predict(X_test, batch_size=batch_size, verbose=0)
e_probs = eff_model.predict(X_test, batch_size=batch_size, verbose=0)
ens_probs, ens_pred = get_ensemble_predictions(r_probs, e_probs)

test_acc = accuracy_score(Y_test, ens_pred)
test_f1  = f1_score(Y_test, ens_pred, average='weighted')
print(f"\nTEST ‚Äî Ensemble Accuracy: {test_acc:.4f} | F1 (weighted): {test_f1:.4f}")
print("\nClassification Report (Ensemble on Test):\n",
      classification_report(Y_test, ens_pred, target_names=class_labels))

# =========================
# Q1 VISUALIZATIONS
# =========================

# --- 1) Training Curves from Best Fold ---
def plot_training_curves(hist_r, hist_e, title_suffix="(Best Fold)"):
    fig, axes = plt.subplots(1, 2, figsize=(14,5))
    # Accuracy
    axes[0].plot(hist_r.get('accuracy', []), label='ResNet Train')
    axes[0].plot(hist_r.get('val_accuracy', []), label='ResNet Val')
    axes[0].plot(hist_e.get('accuracy', []), label='EffNet Train', linestyle='--')
    axes[0].plot(hist_e.get('val_accuracy', []), label='EffNet Val', linestyle='--')
    axes[0].set_title(f'Accuracy {title_suffix}')
    axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('Accuracy'); axes[0].grid(); axes[0].legend()

    # Loss
    axes[1].plot(hist_r.get('loss', []), label='ResNet Train')
    axes[1].plot(hist_r.get('val_loss', []), label='ResNet Val')
    axes[1].plot(hist_e.get('loss', []), label='EffNet Train', linestyle='--')
    axes[1].plot(hist_e.get('val_loss', []), label='EffNet Val', linestyle='--')
    axes[1].set_title(f'Loss {title_suffix}')
    axes[1].set_xlabel('Epoch'); axes[1].set_ylabel('Loss'); axes[1].grid(); axes[1].legend()
    plt.show()

plot_training_curves(best_histories[0], best_histories[1])

# --- 2) Confusion Matrix, Per-class metrics ---
cm = confusion_matrix(Y_test, ens_pred)
plt.figure(figsize=(7,6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_labels, yticklabels=class_labels)
plt.title('Confusion Matrix ‚Äî Test'); plt.xlabel('Predicted'); plt.ylabel('Actual'); plt.show()

report = classification_report(Y_test, ens_pred, target_names=class_labels, output_dict=True)
df_metrics = pd.DataFrame(report).transpose().loc[class_labels, ['precision','recall','f1-score']]
ax = df_metrics.plot(kind='bar', figsize=(8,5))
plt.title('Per-Class Metrics ‚Äî Test'); plt.ylabel('Score'); plt.grid(axis='y')
for p in ax.patches:
    ax.annotate(f'{p.get_height():.2f}', (p.get_x()+p.get_width()/2, p.get_height()), ha='center', va='bottom', fontsize=9)
plt.xticks(rotation=45); plt.show()

# --- 3) ROC + PR Curves (One-vs-Rest) ---
plt.figure(figsize=(7,6))
for i, cls in enumerate(class_labels):
    fpr, tpr, _ = roc_curve((Y_test==i).astype(int), ens_probs[:, i])
    plt.plot(fpr, tpr, label=f'{cls} (AUC={auc(fpr,tpr):.2f})')
plt.plot([0,1],[0,1],'k--')
plt.title('ROC Curves ‚Äî Test'); plt.xlabel('FPR'); plt.ylabel('TPR'); plt.legend(); plt.grid(); plt.show()

plt.figure(figsize=(7,6))
for i, cls in enumerate(class_labels):
    prec, rec, _ = precision_recall_curve((Y_test==i).astype(int), ens_probs[:, i])
    plt.plot(rec, prec, label=cls)
plt.title('Precision-Recall Curves ‚Äî Test'); plt.xlabel('Recall'); plt.ylabel('Precision'); plt.legend(); plt.grid(); plt.show()

# --- 4) Calibration / Reliability Diagram ---
def plot_reliability(probs, y_true, n_bins=10):
    conf = np.max(probs, axis=1)
    preds = np.argmax(probs, axis=1)
    correct = (preds == y_true).astype(int)
    prob_true, prob_pred = calibration_curve(correct, conf, n_bins=n_bins, strategy='uniform')
    plt.figure(figsize=(6,6))
    plt.plot([0,1],[0,1],'k--', label='Perfectly Calibrated')
    plt.plot(prob_pred, prob_true, marker='o', label='Model')
    plt.title('Reliability Diagram (Top-1 Confidence)')
    plt.xlabel('Predicted probability'); plt.ylabel('Empirical accuracy'); plt.legend(); plt.grid(); plt.show()

plot_reliability(ens_probs, Y_test)

# =========================
# EXPLAINABLE AI (GRAD-CAM)
# =========================
from tensorflow.keras.preprocessing import image

def _last_conv_name_resnet(model):
    try:
        model.get_layer("conv5_block3_out")
        return "conv5_block3_out"
    except:
        # fallback
        for layer in reversed(model.layers):
            try:
                if len(layer.output_shape) == 4:
                    return layer.name
            except: pass
        raise ValueError("No conv layer found in ResNet.")

def _last_conv_name_effnet(model):
    try:
        model.get_layer("top_conv")
        return "top_conv"
    except:
        for layer in reversed(model.layers):
            try:
                if len(layer.output_shape) == 4:
                    return layer.name
            except: pass
        raise ValueError("No conv layer found in EfficientNet.")

def get_gradcam_heatmap(model, img_array, last_conv_layer_name, pred_index=None):
    grad_model = tf.keras.models.Model(
        [model.inputs],
        [model.get_layer(last_conv_layer_name).output, model.output]
    )
    with tf.GradientTape() as tape:
        conv_out, preds = grad_model(img_array)
        if pred_index is None:
            pred_index = tf.argmax(preds[0])
        target = preds[:, pred_index]
    grads = tape.gradient(target, conv_out)
    pooled = tf.reduce_mean(grads, axis=(0,1,2))
    conv_out = conv_out[0]
    heatmap = conv_out @ pooled[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)
    heatmap = np.maximum(heatmap, 0)
    denom = np.max(heatmap) if np.max(heatmap)>0 else 1e-9
    heatmap = heatmap / denom
    return heatmap

def overlay_heatmap(heatmap, original_rgb, alpha=0.4):
    h,w = original_rgb.shape[:2]
    hm = cv2.resize(heatmap, (w,h))
    hm = np.uint8(255*hm)
    hm = cv2.applyColorMap(hm, cv2.COLORMAP_JET)
    out = cv2.addWeighted(original_rgb, 1-alpha, hm, alpha, 0)
    return out

def predict_with_explanations(img_path, resnet_path=save_resnet_path, eff_path=save_eff_path, show=True):
    # Load models
    resnet = load_model(resnet_path)
    effnet = load_model(eff_path)

    # Load & preprocess
    img = image.load_img(img_path, target_size=input_shape[:2])
    img_arr = image.img_to_array(img).astype(np.float32) / 255.0
    img_batch = np.expand_dims(img_arr, axis=0)

    # Predict
    r_probs = resnet.predict(img_batch, verbose=0)
    e_probs = effnet.predict(img_batch, verbose=0)
    ens_probs, _ = get_ensemble_predictions(r_probs, e_probs)

    pred_idx = int(np.argmax(ens_probs[0]))
    pred_label = class_labels[pred_idx]
    probs_dict = {class_labels[i]: float(ens_probs[0, i]) for i in range(num_classes)}

    # Grad-CAM
    res_last = _last_conv_name_resnet(resnet)
    eff_last = _last_conv_name_effnet(effnet)
    r_heat = get_gradcam_heatmap(resnet, img_batch, res_last, pred_index=pred_idx)
    e_heat = get_gradcam_heatmap(effnet, img_batch, eff_last, pred_index=pred_idx)

    # Original image RGB
    orig = cv2.imread(img_path)
    if orig is None:
        orig = (img_arr*255).astype(np.uint8)
    else:
        orig = cv2.cvtColor(orig, cv2.COLOR_BGR2RGB)

    r_overlay = overlay_heatmap(r_heat, orig)
    e_overlay = overlay_heatmap(e_heat, orig)

    if show:
        plt.figure(figsize=(16,6))
        plt.subplot(1,3,1); plt.imshow(orig); plt.axis('off'); plt.title("Original")
        plt.subplot(1,3,2); plt.imshow(r_overlay); plt.axis('off'); plt.title(f"ResNet Grad-CAM ‚Üí {pred_label}")
        plt.subplot(1,3,3); plt.imshow(e_overlay); plt.axis('off'); plt.title(f"EfficientNet Grad-CAM ‚Üí {pred_label}")
        plt.show()

    print("Predicted:", pred_label)
    print("Class Probabilities:")
    for k,v in probs_dict.items():
        print(f"  {k:12s}: {v:.4f}")

    # Simple textual rationale
    print("\nWhy this prediction? Grad-CAM highlights image regions that most influenced the decision.\n"
          "Bright (red/yellow) zones are the most discriminative features for the predicted class.\n"
          "Compare ResNet vs EfficientNet overlays to ensure consistency of highlighted anatomical cues "
          "(e.g., belly patterning, fin edges, head/eye region).")

    return {"predicted_label": pred_label,
            "probabilities": probs_dict,
            "resnet_gradcam": r_overlay,
            "efficientnet_gradcam": e_overlay}

# =========================
# USAGE EXAMPLE:
# test_img_path = "/content/drive/MyDrive/Hilsha/test_ilish.jpg"
# result = predict_with_explanations(test_img_path)
# =========================

# -------------------------
# Notes for Q1 submission:
# - Increase k_folds (5‚Üí10) and epochs (e.g., 50‚Äì100), widen HPO grid.
# - Report mean¬±std across folds on an untouched test split or via nested CV.
# - Consider external validation if available.
# - Optionally add temperature scaling on validation for improved calibration.
# - Archive checkpoints and training logs; ensure full reproducibility (seed, versions).
# -------------------------


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
1 GPU(s) detected.
Loaded X: (8407, 3, 224, 224), Y: (8407,), dtype=uint8
Transposed to HWC: (8407, 224, 224, 3)
Class weights: {np.int32(1): 1.4187763713080168, np.int32(0): 0.5604166666666667, np.int32(3): 4.543918918918919, np.int32(2): 0.579991375592928, np.int32(4): 1.7650918635170603}

===== Fold 1/5 =====
 HPO trial: {'lr': 0.0005, 'dropout1': 0.4, 'dropout2': 0.3, 'label_smoothing': 0.0, 'unfreeze_last_n': 10}


#End