In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from collections import defaultdict
from itertools import combinations
import seaborn as sns
import cv2
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Dropout, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical

# Define dataset path
dataset_folder_path = 'data/Semantic segmentation dataset'

# Class and color mapping
color_to_class = {
    (226, 169, 41): 0,  # Water
    (132, 41, 246): 1,  # Land
    (110, 193, 228): 2, # Road
    (60, 16, 152): 3,   # Building
    (254, 221, 58): 4,  # Vegetation
    (155, 155, 155): 5  # Unlabeled
}

class_labels = ['Water', 'Land', 'Road', 'Building', 'Vegetation', 'Unlabeled']

# Initialize variables for basic statistics
total_images = 0
total_masks = 0
class_counts = np.zeros(len(class_labels), dtype=int)
image_shapes = []

# Function to count class pixels in a mask
def count_class_pixels(mask_array):
    counts = np.zeros(len(class_labels), dtype=int)
    for rgb, class_idx in color_to_class.items():
        match = np.all(mask_array == rgb, axis=-1)
        counts[class_idx] += np.sum(match)
    return counts

# Function to load images and masks
def load_images_and_masks(dataset_folder_path, image_size=(256, 256)):
    images = []
    masks = []
    for tile in os.listdir(dataset_folder_path):
        tile_path = os.path.join(dataset_folder_path, tile)
        if os.path.isdir(tile_path):
            images_folder = os.path.join(tile_path, 'images')
            masks_folder = os.path.join(tile_path, 'masks')
            for image_file in os.listdir(images_folder):
                img_path = os.path.join(images_folder, image_file)
                mask_file = image_file.replace('.jpg', '.png')
                mask_path = os.path.join(masks_folder, mask_file)
                if os.path.exists(mask_path):
                    image = Image.open(img_path).resize(image_size)
                    mask = Image.open(mask_path).resize(image_size).convert('RGB')
                    images.append(np.array(image) / 255.0)  # Normalize images
                    masks.append(np.array(mask))
    return np.array(images), np.array(masks)

# Function to preprocess masks to one-hot encoding
def preprocess_masks(masks, num_classes=6):
    processed_masks = []
    for mask in masks:
        mask_class_indices = np.zeros(mask.shape[:2], dtype=int)
        for rgb, class_idx in color_to_class.items():
            match = np.all(mask == rgb, axis=-1)
            mask_class_indices[match] = class_idx
        mask_one_hot = to_categorical(mask_class_indices, num_classes=num_classes)
        processed_masks.append(mask_one_hot)
    return np.array(processed_masks)

# Load and preprocess data
print("Loading and preprocessing data...")
train_images_np, train_masks_np = load_images_and_masks(dataset_folder_path)
train_masks_np = preprocess_masks(train_masks_np, num_classes=6)

# Perform train-test split
from sklearn.model_selection import train_test_split
train_images_np, test_images_np, train_masks_np, test_masks_np = train_test_split(train_images_np, train_masks_np, test_size=0.2, random_state=42)

# Basic statistics
print("Performing basic statistics...")
for mask in train_masks_np:
    class_counts += np.sum(mask, axis=(0, 1)).astype(int)

# Display basic statistics
print(f"Total Images: {len(train_images_np)}")
print("Class Distribution:", {label: count for label, count in zip(class_labels, class_counts)})

# Visualize class distribution
plt.figure(figsize=(10, 5))
plt.bar(class_labels, class_counts)
plt.title("Class Distribution in the Dataset")
plt.xlabel("Classes")
plt.ylabel("Pixel Count")
plt.xticks(rotation=45)
plt.show()

# Visualize a few sample images and masks
print("Visualizing sample images and masks...")
plt.figure(figsize=(12, 10))
for i in range(5):
    plt.subplot(2, 5, i+1)
    plt.imshow(train_images_np[i])
    plt.title("Sample Image")
    plt.axis('off')
    
    plt.subplot(2, 5, i+6)
    plt.imshow(np.argmax(train_masks_np[i], axis=-1), cmap='tab20')
    plt.title("Segmentation Mask")
    plt.axis('off')
plt.tight_layout()
plt.show()

# Advanced Data Augmentation
data_gen_args = dict(
    rotation_range=45,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    vertical_flip=True,
    fill_mode='reflect'
)

# ImageDataGenerator for images and masks
image_datagen = ImageDataGenerator(**data_gen_args)
mask_datagen = ImageDataGenerator(**data_gen_args)

# Create data generators
image_generator = image_datagen.flow(train_images_np, batch_size=4, seed=42)
mask_generator = mask_datagen.flow(train_masks_np, batch_size=4, seed=42)
train_generator = zip(image_generator, mask_generator)

# Visualize augmented images and masks
print("Visualizing augmented images and masks...")
augmented_images, augmented_masks = next(train_generator)

plt.figure(figsize=(12, 10))
for i in range(4):
    plt.subplot(2, 4, i+1)
    plt.imshow(augmented_images[i])
    plt.title("Augmented Image")
    plt.axis('off')
    
    plt.subplot(2, 4, i+5)
    plt.imshow(np.argmax(augmented_masks[i], axis=-1), cmap='tab20')
    plt.title("Augmented Mask")
    plt.axis('off')
plt.tight_layout()
plt.show()

# Co-occurrence Analysis
print("Performing class co-occurrence analysis...")
co_occurrence_matrix = np.zeros((len(class_labels), len(class_labels)), dtype=int)

for mask in train_masks_np:
    unique_classes = np.unique(np.argmax(mask, axis=-1))
    for class_pair in combinations(unique_classes, 2):
        co_occurrence_matrix[class_pair[0], class_pair[1]] += 1
        co_occurrence_matrix[class_pair[1], class_pair[0]] += 1

# Plot co-occurrence matrix
plt.figure(figsize=(8, 6))
sns.heatmap(co_occurrence_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=class_labels, yticklabels=class_labels)
plt.title("Class Co-occurrence Matrix")
plt.show()

# Object size and aspect ratio analysis
print("Performing object size and aspect ratio analysis...")
object_sizes = {label: [] for label in class_labels}
aspect_ratios = {label: [] for label in class_labels}

for mask in train_masks_np:
    for class_rgb, class_idx in color_to_class.items():
        binary_mask = np.all(mask[:, :, :3] == class_rgb, axis=-1).astype(np.uint8)
        contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        for contour in contours:
            area = cv2.contourArea(contour)
            if area > 10:  # Ignore very small areas
                x, y, w, h = cv2.boundingRect(contour)
                object_sizes[class_labels[class_idx]].append(area)
                aspect_ratios[class_labels[class_idx]].append(w / h)

# Plot object size distribution
plt.figure(figsize=(12, 6))
for label in class_labels:
    plt.hist(object_sizes[label], bins=20, alpha=0.6, label=label)
plt.title("Object Size Distribution by Class")
plt.xlabel("Object Size (Pixels)")
plt.ylabel("Frequency")
plt.legend()
plt.show()

# Plot aspect ratio distribution
plt.figure(figsize=(12, 6))
for label in class_labels:
    plt.hist(aspect_ratios[label], bins=20, alpha=0.6, label=label)
plt.title("Object Aspect Ratio Distribution by Class")
plt.xlabel("Aspect Ratio (Width/Height)")
plt.ylabel("Frequency")
plt.legend()
plt.show()
