In [1]:
import os
import numpy as np
import cv2
import tifffile as tiff
from sklearn.model_selection import train_test_split

# Set image size
IMG_SIZE = 256

# Define your dataset paths
IMAGE_DIR = r"C:\Users\KIIT\Downloads\archive (1)\1_CLOUD_FREE_DATASET\2_SENTINEL2\IMAGE_16_GRID"
MASK_DIR  = r"C:\Users\KIIT\Downloads\archive (1)\3_TRAINING_MASKS\MASK_16_GRID"

def load_images_and_masks(image_dir, mask_dir):
    images = []
    masks = []

    for filename in os.listdir(image_dir):
        if not filename.lower().endswith('.tif'):
            continue

        image_path = os.path.join(image_dir, filename)
        mask_path = os.path.join(mask_dir, filename)

        if not os.path.exists(mask_path):
            print(f"❌ Mask not found for: {filename}")
            continue

        try:
            # Load image and mask using tifffile
            img = tiff.imread(image_path)
            mask = tiff.imread(mask_path)

            # Handle grayscale or multi-band image
            if img.ndim == 2:
                img = np.stack([img]*3, axis=-1)
            elif img.shape[-1] > 3:
                img = img[..., :3]

            # Resize both image and mask
            img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
            mask = cv2.resize(mask, (IMG_SIZE, IMG_SIZE))

            # Normalize to [0, 1]
            img = img / 255.0
            mask = (mask / 255.0 > 0.5).astype(np.float32)  # Binary mask

            images.append(img)
            masks.append(np.expand_dims(mask, axis=-1))  # shape: (H, W, 1)

        except Exception as e:
            print(f"⚠ Failed to process {filename}: {e}")
            continue

    print(f"\n✅ Loaded {len(images)} image-mask pairs.")
    return np.array(images), np.array(masks)

# Load data
X, y = load_images_and_masks(IMAGE_DIR, MASK_DIR)

# Train-validation split
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# Confirm data shapes
print("\n✅ Final Data Shapes:")
print("X_train:", X_train.shape, "y_train:", y_train.shape)
print("X_val:  ", X_val.shape, "y_val:  ", y_val.shape)


✅ Loaded 16 image-mask pairs.

✅ Final Data Shapes:
X_train: (12, 256, 256, 3) y_train: (12, 256, 256, 1)
X_val:   (4, 256, 256, 3) y_val:   (4, 256, 256, 1)
