# Segmentacja Obraz√≥w Hiperspektralnych - Pe≈Çny Flow (TensorFlow/TPU)

Ten notebook zawiera kompletny pipeline do segmentacji obraz√≥w hiperspektralnych u≈ºywajƒÖc **DBSCAN clustering** z **TensorFlow/Keras** (wsparcie dla TPU).

## Instrukcja u≈ºycia w Google Colab:

1. **W≈ÇƒÖcz TPU**: Runtime ‚Üí Change runtime type ‚Üí TPU (lub GPU je≈õli TPU niedostƒôpne)
2. **Uruchom wszystkie kom√≥rki**: Runtime ‚Üí Run all
3. **Parametry mo≈ºna zmieniƒá** w kom√≥rce z `TARGET_BANDS = [10, 20]`
4. **‚ö† Oszczƒôdzanie RAM**: Notebook automatycznie ≈Çaduje datasety na ≈ºƒÖdanie i czy≈õci pamiƒôƒá

## Flow:
1. **Wczytanie danych** (5 dataset√≥w: Indian, PaviaU, PaviaC, KSC, Salinas)
2. **Preprocessing** - redukcja wymiar√≥w przez filtr Gaussa (10/20/30 kana≈Ç√≥w)
3. **Stworzenie zbioru testowego** - N split√≥w (Train: 3, Test: 1, Validation: 1)
4. **Testy 3 modeli** z DBSCAN clustering do segmentacji
5. **Walidacja i podsumowanie** wynik√≥w

## Modele (TensorFlow/Keras):
- InceptionHSINet (3D CNN)
- SimpleHSINet (2D CNN)
- CNNFromDiagram (2D CNN)

## Segmentacja:
U≈ºywa **DBSCAN clustering** - automatycznie znajduje liczbƒô segment√≥w bez znanych klas z g√≥ry.


In [None]:
# Instalacja pakiet√≥w
%pip install tensorflow scikit-learn scipy matplotlib requests


In [None]:
# Import bibliotek
import os
import urllib.request
import ssl
import scipy.io as sio
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import DBSCAN
from scipy import ndimage
from scipy.stats import mode
import itertools
import json
import matplotlib.pyplot as plt
import time

# Instalacja psutil je≈õli nie ma
try:
    import psutil
except ImportError:
    print("Instalowanie psutil...")
    import subprocess
    subprocess.check_call(['pip', 'install', 'psutil'])
    import psutil

# Obs≈Çuga SSL
try:
    ssl._create_default_https_context = ssl._create_unverified_context
except:
    pass

try:
    import requests
    requests.packages.urllib3.disable_warnings()
except ImportError:
    pass

# Sprawd≈∫ dostƒôpno≈õƒá TPU
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Running on TPU ', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
    print("‚úì TPU znaleziony - u≈ºywam TPU")
    USE_TPU = True
except ValueError:
    print("‚ö† TPU nie znaleziony - sprawdzam GPU...")
    try:
        gpus = tf.config.list_physical_devices('GPU')
        if gpus:
            print(f"‚úì GPU znaleziony: {len(gpus)} GPU(s)")
            USE_TPU = False
        else:
            print("‚ö† Brak GPU - u≈ºywam CPU")
            USE_TPU = False
    except:
        print("‚ö† U≈ºywam CPU")
        USE_TPU = False

print("‚úì Biblioteki zaimportowane")


## KROK 1: Definicje modeli i funkcji pomocniczych


In [None]:
# ========== MODELE TENSORFLOW/KERAS ==========

def create_InceptionHSINet(input_shape, num_classes=16):
    """InceptionHSINet - 3D CNN dla TensorFlow"""
    inputs = keras.Input(shape=input_shape)
    
    # Entry
    x = layers.Conv3D(8, 3, padding='same')(inputs)
    x = layers.SpatialDropout3D(0.3)(x)
    x = layers.ReLU()(x)
    x = layers.MaxPool3D(2)(x)
    
    # Branch 1
    x1 = layers.Conv3D(16, 1, padding='same')(x)
    x1 = layers.SpatialDropout3D(0.3)(x1)
    x1 = layers.ReLU()(x1)
    x1 = layers.Conv3D(16, 3, padding='same')(x1)
    x1 = layers.SpatialDropout3D(0.3)(x1)
    x1 = layers.ReLU()(x1)
    
    # Branch 2
    x2 = layers.Conv3D(16, 3, padding='same')(x)
    x2 = layers.SpatialDropout3D(0.3)(x2)
    x2 = layers.ReLU()(x2)
    x2 = layers.Conv3D(16, 5, padding='same')(x2)
    x2 = layers.SpatialDropout3D(0.3)(x2)
    x2 = layers.ReLU()(x2)
    
    # Branch 3
    x3 = layers.Conv3D(16, 5, padding='same')(x)
    x3 = layers.SpatialDropout3D(0.3)(x3)
    x3 = layers.ReLU()(x3)
    x3 = layers.Conv3D(16, 3, padding='same')(x3)
    x3 = layers.SpatialDropout3D(0.3)(x3)
    x3 = layers.ReLU()(x3)
    
    # Concatenate
    x = layers.Concatenate()([x1, x2, x3])
    x = layers.GlobalAveragePooling3D()(x)
    
    # Feature extraction (bez klasyfikatora)
    features = x
    
    # Classifier
    x = layers.Dropout(0.5)(x)
    x = layers.ReLU()(x)
    outputs = layers.Dense(num_classes)(x)
    
    model = Model(inputs=inputs, outputs=outputs)
    feature_model = Model(inputs=inputs, outputs=features)
    
    return model, feature_model


def create_SimpleHSINet(input_shape, num_classes=16):
    """SimpleHSINet - 2D CNN dla TensorFlow"""
    inputs = keras.Input(shape=input_shape)
    
    x = layers.Conv2D(90, 1, padding='same')(inputs)
    x = layers.ReLU()(x)
    x = layers.Conv2D(270, 3, padding='same')(x)
    x = layers.ReLU()(x)
    x = layers.SpatialDropout2D(0.3)(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Flatten()(x)
    
    # Feature extraction
    features = layers.Dense(180, activation='relu')(x)
    
    # Classifier
    x = layers.Dropout(0.3)(features)
    outputs = layers.Dense(num_classes)(x)
    
    model = Model(inputs=inputs, outputs=outputs)
    feature_model = Model(inputs=inputs, outputs=features)
    
    return model, feature_model


def create_CNNFromDiagram(input_shape, num_classes=16):
    """CNNFromDiagram - 2D CNN dla TensorFlow"""
    inputs = keras.Input(shape=input_shape)
    
    x = layers.Conv2D(100, 3, padding='same')(inputs)
    x = layers.ReLU()(x)
    x = layers.MaxPooling2D(2)(x)
    x = layers.Conv2D(100, 3, padding='same')(x)
    x = layers.ReLU()(x)
    x = layers.MaxPooling2D(2)(x)
    x = layers.Flatten()(x)
    
    # Feature extraction
    features = layers.Dense(84, activation='relu')(x)
    
    # Classifier
    outputs = layers.Dense(num_classes)(features)
    
    model = Model(inputs=inputs, outputs=outputs)
    feature_model = Model(inputs=inputs, outputs=features)
    
    return model, feature_model


# Funkcja do tworzenia modeli
def create_model(model_name, input_shape, num_classes):
    """Tworzy model i feature extractor"""
    if model_name == 'InceptionHSINet':
        return create_InceptionHSINet(input_shape, num_classes)
    elif model_name == 'SimpleHSINet':
        return create_SimpleHSINet(input_shape, num_classes)
    elif model_name == 'CNNFromDiagram':
        return create_CNNFromDiagram(input_shape, num_classes)
    else:
        raise ValueError(f"Unknown model: {model_name}")

MODELS = ['InceptionHSINet', 'SimpleHSINet', 'CNNFromDiagram']

print("‚úì Modele TensorFlow/Keras zdefiniowane")


In [None]:
# ========== FUNKCJE POMOCNICZE ==========

DATASET_URLS = {
    'Indian': {
        'data': 'https://www.ehu.eus/ccwintco/uploads/6/67/Indian_pines_corrected.mat',
        'gt':   'https://www.ehu.eus/ccwintco/uploads/c/c4/Indian_pines_gt.mat',
    },
    'PaviaU': {
        'data': 'https://www.ehu.eus/ccwintco/uploads/e/ee/PaviaU.mat',
        'gt':   'https://www.ehu.eus/ccwintco/uploads/5/50/PaviaU_gt.mat',
    },
    'PaviaC': {
        'data': 'https://www.ehu.eus/ccwintco/uploads/e/e3/Pavia.mat',
        'gt':   'https://www.ehu.eus/ccwintco/uploads/5/53/Pavia_gt.mat',
    },
    'KSC': {
        'data': 'http://www.ehu.es/ccwintco/uploads/2/26/KSC.mat',
        'gt':   'http://www.ehu.es/ccwintco/uploads/a/a6/KSC_gt.mat',
    },
    'Salinas': {
        'data': 'https://www.ehu.eus/ccwintco/uploads/a/a3/Salinas_corrected.mat',
        'gt':   'https://www.ehu.eus/ccwintco/uploads/f/fa/Salinas_gt.mat',
    }
}

DATASET_KEYS = {
    'Indian': {'data': ['indian_pines_corrected', 'Indian_pines_corrected'], 'gt': ['indian_pines_gt', 'Indian_pines_gt']},
    'PaviaU': {'data': ['paviaU', 'PaviaU', 'pavia_u'], 'gt': ['paviaU_gt', 'PaviaU_gt', 'pavia_u_gt']},
    'PaviaC': {'data': ['pavia', 'Pavia', 'paviaC', 'PaviaC'], 'gt': ['pavia_gt', 'Pavia_gt', 'paviaC_gt', 'PaviaC_gt']},
    'KSC': {'data': ['KSC', 'ksc'], 'gt': ['KSC_gt', 'ksc_gt']},
    'Salinas': {'data': ['salinas_corrected', 'Salinas_corrected', 'salinas'], 'gt': ['salinas_gt', 'Salinas_gt']}
}

DATASET_INFO = {
    'Indian': {'num_classes': 16, 'num_bands': 200},
    'PaviaU': {'num_classes': 9, 'num_bands': 103},
    'PaviaC': {'num_classes': 9, 'num_bands': 102},
    'KSC': {'num_classes': 13, 'num_bands': 176},
    'Salinas': {'num_classes': 16, 'num_bands': 204}
}

def download_file(url, filename):
    if not os.path.exists(filename):
        print(f"Pobieranie {filename}...")
        urllib.request.urlretrieve(url, filename)
        print(f"‚úì Pobrano {filename}")

def find_key_in_mat(mat_file, possible_keys):
    if isinstance(possible_keys, str):
        possible_keys = [possible_keys]
    for key in possible_keys:
        if key in mat_file:
            return key
    keys = [k for k in mat_file.keys() if not k.startswith('__')]
    if keys:
        return keys[0]
    raise ValueError(f"Nie znaleziono klucza w pliku .mat")

def load_data(dataset_name):
    urls = DATASET_URLS[dataset_name]
    keys = DATASET_KEYS[dataset_name]
    data_file = f"{dataset_name}_data.mat"
    gt_file = f"{dataset_name}_gt.mat"
    download_file(urls['data'], data_file)
    download_file(urls['gt'], gt_file)
    mat_data = sio.loadmat(data_file)
    mat_gt = sio.loadmat(gt_file)
    data_key = find_key_in_mat(mat_data, keys['data'])
    gt_key = find_key_in_mat(mat_gt, keys['gt'])
    data = mat_data[data_key]
    labels = mat_gt[gt_key]
    print(f"‚úì Za≈Çadowano {dataset_name}: shape={data.shape}, bands={data.shape[2]}")
    return data, labels

def normalize(data):
    h, w, b = data.shape
    data = data.reshape(-1, b)
    data = StandardScaler().fit_transform(data)
    return data.reshape(h, w, b)

def pad_with_zeros(data, margin):
    return np.pad(data, ((margin, margin), (margin, margin), (0, 0)), mode='constant')

print("‚úì Funkcje pomocnicze zdefiniowane")


In [None]:
# Parametry - ZOPTYMALIZOWANE DLA PAMIƒòCI
TARGET_BANDS = [10, 20]  # Zmniejszono z [10, 20, 30] - mniej pamiƒôci
PATCH_SIZE = 16
BATCH_SIZE = 64  # Zmniejszono z 128 - mniej pamiƒôci na batch
EPOCHS = 30  # Zmniejszono z 50 - szybsze treningi
LR = 0.001
N_SPLITS = 1  # Liczba split√≥w: 1=szybko (1 test), 3=≈õrednio, 5=pe≈Çne testy (d≈Çu≈ºej)

# OPCJE OPTYMALIZACJI PAMIƒòCI
LOAD_DATASETS_ON_DEMAND = True  # ≈Åaduj datasety tylko gdy potrzebne (oszczƒôdza RAM)
CLEAR_MEMORY_BETWEEN_SPLITS = True  # Czy≈õƒá pamiƒôƒá miƒôdzy splitami
PROCESS_ONE_DATASET_AT_TIME = True  # Przetwarzaj jeden dataset na raz

DATASET_NAMES = ['Indian', 'PaviaU', 'PaviaC', 'KSC', 'Salinas']

import gc  # Garbage collection

# ≈Åadowanie wszystkich dataset√≥w
print("=" * 80)
print("KROK 1: Wczytanie danych")
print("=" * 80)

datasets = {}
for dataset_name in DATASET_NAMES:
    print(f"\n≈Åadowanie {dataset_name}...")
    data, labels = load_data(dataset_name)
    datasets[dataset_name] = {
        'data': data,
        'labels': labels,
        'info': DATASET_INFO[dataset_name]
    }

print(f"\n‚úì Wczytano wszystkie {len(datasets)} dataset√≥w")


In [None]:
# Preprocessing - redukcja wymiar√≥w przez filtr Gaussa
def gaussian_band_reduction(data, target_bands, sigma=1.0):
    H, W, B = data.shape
    if B <= target_bands:
        return data
    step = B / target_bands
    indices = np.round(np.arange(0, B, step)).astype(int)
    indices = indices[:target_bands]
    selected_bands = data[:, :, indices]
    filtered_bands = np.zeros_like(selected_bands)
    for i in range(target_bands):
        filtered_bands[:, :, i] = ndimage.gaussian_filter(selected_bands[:, :, i], sigma=sigma)
    return filtered_bands

print("=" * 80)
print("KROK 2: Preprocessing - redukcja wymiar√≥w (Gauss)")
print("=" * 80)

preprocessed_data = {}
for dataset_name, dataset_data in datasets.items():
    data = dataset_data['data']
    labels = dataset_data['labels']
    info = dataset_data['info']
    print(f"\nPreprocessing {dataset_name}...")
    data_normalized = normalize(data)
    original_bands = data.shape[2]
    
    for target_bands in TARGET_BANDS:
        if target_bands >= original_bands:
            data_reduced = data_normalized
            print(f"  {target_bands} kana≈Ç√≥w: {original_bands} (oryginalne)")
        else:
            data_reduced = gaussian_band_reduction(data_normalized, target_bands, sigma=1.0)
            print(f"  {target_bands} kana≈Ç√≥w: {original_bands} -> {target_bands} (Gauss)")
        
        key = (dataset_name, target_bands)
        preprocessed_data[key] = {
            'data': data_reduced,
            'labels': labels,
            'info': {**info, 'num_bands': data_reduced.shape[2], 'original_bands': original_bands}
        }

print(f"\n‚úì Preprocessing zako≈Ñczony")


In [None]:
# Generowanie podzia≈Ç√≥w dataset√≥w
print("=" * 80)
print("KROK 3: Stworzenie zbioru testowego")
print("=" * 80)
print(f"Generowanie {N_SPLITS} split√≥w (mo≈ºesz zmieniƒá N_SPLITS w parametrach)")

train_combinations = list(itertools.combinations(DATASET_NAMES, 3))
splits = []

for split_id, train_datasets in enumerate(train_combinations, 1):
    train_list = list(train_datasets)
    remaining = [d for d in DATASET_NAMES if d not in train_list]
    test_combinations = list(itertools.permutations(remaining, 2))
    
    for test_idx, (test_dataset, validation_dataset) in enumerate(test_combinations):
        split = {
            'split_id': split_id * 10 + test_idx + 1,
            'train_datasets': train_list,
            'test_dataset': test_dataset,
            'validation_dataset': validation_dataset
        }
        splits.append(split)

# Wybierz N_SPLITS pierwszych
if len(splits) > N_SPLITS:
    selected_splits = []
    seen_train_combos = set()
    for split in splits:
        train_key = tuple(sorted(split['train_datasets']))
        if train_key not in seen_train_combos or len(selected_splits) < N_SPLITS:
            selected_splits.append(split)
            seen_train_combos.add(train_key)
            if len(selected_splits) >= N_SPLITS:
                break
    splits = selected_splits[:N_SPLITS]

print(f"\nWygenerowano {len(splits)} podzia≈Ç√≥w:")
for split in splits:
    print(f"  Split {split['split_id']}: Train={'+'.join(split['train_datasets'])}, Test={split['test_dataset']}, Val={split['validation_dataset']}")


## KROK 3: Funkcje treningu i clusteringu


In [None]:
# Funkcja do przygotowania danych dla TensorFlow
def prepare_dataset(data_dict, dataset_names, patch_size=16, model_type='2d'):
    """Przygotowuje dane jako numpy arrays dla TensorFlow"""
    patches_list = []
    targets_list = []
    all_num_classes = set()
    
    target_bands_list = [data_dict[k]['data'].shape[2] for k in data_dict.keys() if k[0] in dataset_names]
    if not target_bands_list:
        raise ValueError(f"Brak danych dla dataset√≥w: {dataset_names}")
    
    target_bands = target_bands_list[0]
    
    for dataset_name in dataset_names:
        dataset_keys = [k for k in data_dict.keys() if k[0] == dataset_name and data_dict[k]['data'].shape[2] == target_bands]
        if not dataset_keys:
            dataset_keys = [k for k in data_dict.keys() if k[0] == dataset_name]
            if not dataset_keys:
                continue
        
        key = dataset_keys[0]
        data = data_dict[key]['data']
        labels = data_dict[key]['labels']
        info = data_dict[key]['info']
        all_num_classes.add(info['num_classes'])
        
        margin = patch_size // 2
        padded_data = pad_with_zeros(data, margin)
        
        h, w, _ = data.shape
        for i in range(h):
            for j in range(w):
                label = labels[i, j]
                if label == 0:
                    continue
                patch = padded_data[i:i+patch_size, j:j+patch_size, :]
                patches_list.append(patch)
                targets_list.append(label - 1)
    
    if len(patches_list) == 0:
        raise ValueError(f"Brak danych dla dataset√≥w: {dataset_names}")
    
    patches = np.array(patches_list)
    targets = np.array(targets_list)
    
    if len(patches.shape) == 4:
        current_bands = patches.shape[-1]
        if current_bands != target_bands:
            if current_bands < target_bands:
                padding = np.zeros((patches.shape[0], patches.shape[1], patches.shape[2], target_bands - current_bands))
                patches = np.concatenate([patches, padding], axis=-1)
            else:
                patches = patches[:, :, :, :target_bands]
    
    if model_type == '3d':
        # Conv3D: (N, H, W, B) -> (N, H, W, B, 1) dla TensorFlow
        patches = np.expand_dims(patches, axis=-1)  # (N, H, W, B, 1)
    else:
        # Conv2D: (N, H, W, B) -> (N, H, W, B) - ju≈º OK
        patches = patches
    
    num_bands = patches.shape[-2] if model_type == '3d' else patches.shape[-1]
    num_classes = max(all_num_classes) if all_num_classes else 16
    
    print(f"Dataset {dataset_names}: {len(patches)} samples, bands={num_bands}, classes={num_classes}")
    
    return patches, targets, num_bands, num_classes

print("‚úì Funkcje przygotowania danych zdefiniowane")


In [None]:
# Funkcja treningu TensorFlow
def train_model(model, x_train, y_train, x_val, y_val, epochs=50, lr=0.001, batch_size=128, model_name="model"):
    """Trenowanie modelu TensorFlow/Keras"""
    
    # Kompilacja modelu
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=lr),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy']
    )
    
    # Callbacks
    callbacks = [
        keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=10, restore_best_weights=True),
        keras.callbacks.ModelCheckpoint(
            f'best_model_{model_name}.h5',
            monitor='val_accuracy',
            save_best_only=True,
            verbose=0
        )
    ]
    
    # Trening
    history = model.fit(
        x_train, y_train,
        validation_data=(x_val, y_val),
        epochs=epochs,
        batch_size=batch_size,
        callbacks=callbacks,
        verbose=1 if epochs <= 20 else 2
    )
    
    best_val_acc = max(history.history['val_accuracy']) * 100
    print(f"Best validation accuracy: {best_val_acc:.2f}%")
    
    return model

print("‚úì Funkcja treningu TensorFlow zdefiniowana")


In [None]:
# Funkcje clusteringu dla TensorFlow
def extract_features_for_image(feature_model, data, labels, patch_size=16, model_type='2d', batch_size=64):
    """Ekstrahuje embeddingi u≈ºywajƒÖc TensorFlow model"""
    margin = patch_size // 2
    padded_data = pad_with_zeros(data, margin)
    
    h, w, _ = data.shape
    patches = []
    pixel_coords = []
    
    for i in range(h):
        for j in range(w):
            if labels[i, j] > 0:
                patch = padded_data[i:i+patch_size, j:j+patch_size, :]
                patches.append(patch)
                pixel_coords.append((i, j))
    
    if len(patches) == 0:
        raise ValueError("Brak pikseli z danymi")
    
    patches = np.array(patches)
    
    if model_type == '3d':
        # Conv3D: (N, H, W, B) -> (N, H, W, B, 1) dla TensorFlow
        patches = np.expand_dims(patches, axis=-1)
    else:
        # Conv2D: (N, H, W, B) - ju≈º OK
        pass
    
    # Ekstrahuj embeddingi w batchach
    embeddings_list = []
    for i in range(0, len(patches), batch_size):
        batch = patches[i:i+batch_size]
        batch_embeddings = feature_model.predict(batch, verbose=0)
        embeddings_list.append(batch_embeddings)
    
    embeddings = np.concatenate(embeddings_list, axis=0)
    
    feature_dim = embeddings.shape[1]
    embedding_map = np.zeros((h, w, feature_dim))
    
    for idx, (i, j) in enumerate(pixel_coords):
        embedding_map[i, j] = embeddings[idx]
    
    return embedding_map, pixel_coords

def segment_with_dbscan(feature_model, data, labels, patch_size=16, model_type='2d', eps=0.5, min_samples=5, batch_size=64):
    print(f"Segmentacja DBSCAN (eps={eps}, min_samples={min_samples})...")
    
    embedding_map, pixel_coords = extract_features_for_image(feature_model, data, labels, patch_size, model_type, batch_size)
    
    embeddings_list = []
    coords_list = []
    for i, j in pixel_coords:
        embeddings_list.append(embedding_map[i, j])
        coords_list.append((i, j))
    
    embeddings_array = np.array(embeddings_list)
    scaler = StandardScaler()
    embeddings_normalized = scaler.fit_transform(embeddings_array)
    
    dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric='euclidean')
    cluster_labels = dbscan.fit_predict(embeddings_normalized)
    
    h, w = data.shape[:2]
    segmentation_map = np.zeros((h, w), dtype=np.int32)
    
    for idx, (i, j) in enumerate(coords_list):
        cluster_id = cluster_labels[idx]
        if cluster_id >= 0:
            segmentation_map[i, j] = cluster_id + 1
        else:
            segmentation_map[i, j] = 0
    
    n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)
    n_outliers = np.sum(cluster_labels == -1)
    print(f"  Znaleziono {n_clusters} klastr√≥w, {n_outliers} outliers")
    
    return segmentation_map, n_clusters

def evaluate_clustering(segmentation_map, ground_truth):
    mask = ground_truth > 0
    if np.sum(mask) == 0:
        return 0.0
    
    segments = segmentation_map[mask]
    classes = ground_truth[mask]
    
    unique_segments = np.unique(segments)
    unique_segments = unique_segments[unique_segments > 0]
    
    if len(unique_segments) == 0:
        return 0.0
    
    cluster_to_class = {}
    for seg in unique_segments:
        seg_mask = segments == seg
        if np.sum(seg_mask) > 0:
            most_common_class = mode(classes[seg_mask], keepdims=True)[0][0]
            cluster_to_class[seg] = most_common_class
    
    mapped_segments = np.zeros_like(segments)
    for seg, cls in cluster_to_class.items():
        mapped_segments[segments == seg] = cls
    
    correct = np.sum(mapped_segments == classes)
    total = len(classes)
    
    return 100.0 * correct / total if total > 0 else 0.0

print("‚úì Funkcje clusteringu TensorFlow zdefiniowane")


## KROK 4: Trening i testowanie modeli


In [None]:
# G≈Ç√≥wna pƒôtla treningu i testowania TensorFlow
print(f"Device: {'TPU' if USE_TPU else 'GPU/CPU'}")

all_results = {}

for target_bands in TARGET_BANDS:
    print(f"\n{'#'*80}")
    print(f"# Testowanie dla {target_bands} kana≈Ç√≥w")
    print(f"{'#'*80}")
    
    results = []
    
    for split in splits:
        print(f"\n{'='*60}")
        print(f"Split {split['split_id']}:")
        print(f"  Train: {', '.join(split['train_datasets'])}")
        print(f"  Test: {split['test_dataset']}")
        print(f"  Validation: {split['validation_dataset']}")
        print(f"{'='*60}")
        
        # Za≈Çaduj dane na ≈ºƒÖdanie
        train_data_dict = {}
        for name in split['train_datasets']:
            key = (name, target_bands)
            if LOAD_DATASETS_ON_DEMAND:
                if key not in preprocessed_data:
                    # Za≈Çaduj i preprocessuj
                    preprocessed_data[key] = load_and_preprocess_dataset(name, target_bands)
                train_data_dict[key] = preprocessed_data[key]
            else:
                if key in preprocessed_data:
                    train_data_dict[key] = preprocessed_data[key]
        
        if not train_data_dict:
            print(f"  ‚ö† Brak danych dla train datasets")
            continue
        
        for model_name in MODELS:
            print(f"\n  Model: {model_name}")
            model_type = '3d' if model_name == 'InceptionHSINet' else '2d'
            
            try:
                train_keys_filtered = {k: v for k, v in train_data_dict.items() if k[1] == target_bands}
                if not train_keys_filtered:
                    continue
                
                # Przygotuj dane
                x_train_full, y_train_full, num_bands, num_classes = prepare_dataset(
                    train_keys_filtered, split['train_datasets'], patch_size=PATCH_SIZE, model_type=model_type
                )
                
                # Podzia≈Ç na train/val
                n_train = int(len(x_train_full) * 0.8)
                indices = np.random.permutation(len(x_train_full))
                train_indices = indices[:n_train]
                val_indices = indices[n_train:]
                
                x_train = x_train_full[train_indices]
                y_train = y_train_full[train_indices]
                x_val = x_train_full[val_indices]
                y_val = y_train_full[val_indices]
                
                # Okre≈õl input_shape
                if model_type == '3d':
                    input_shape = (PATCH_SIZE, PATCH_SIZE, num_bands, 1)
                else:
                    input_shape = (PATCH_SIZE, PATCH_SIZE, num_bands)
                
                # Utw√≥rz model w strategii TPU je≈õli dostƒôpne
                if USE_TPU:
                    with strategy.scope():
                        model, feature_model = create_model(model_name, input_shape, num_classes)
                else:
                    model, feature_model = create_model(model_name, input_shape, num_classes)
                
                print(f"    Trenowanie... (bands={num_bands}, classes={num_classes})")
                trained_model = train_model(model, x_train, y_train, x_val, y_val, epochs=EPOCHS, lr=LR, batch_size=BATCH_SIZE, model_name=f"{model_name}_{split['split_id']}")
                
                # Za≈Çaduj najlepszy model
                try:
                    trained_model.load_weights(f'best_model_{model_name}_{split["split_id"]}.h5')
                except:
                    pass
                
                # Skopiuj wagi z trained_model do feature_model
                # Znajd≈∫ warstwƒô przed klasyfikatorem (features)
                for i, layer in enumerate(trained_model.layers):
                    if i < len(feature_model.layers):
                        try:
                            feature_model.layers[i].set_weights(layer.get_weights())
                        except:
                            pass
                
                # Test na test dataset - DBSCAN (≈Çadowanie na ≈ºƒÖdanie)
                test_key = (split['test_dataset'], target_bands)
                test_acc = 0.0
                test_n_clusters = 0
                test_n_samples = 0
                
                if LOAD_DATASETS_ON_DEMAND:
                    if test_key not in preprocessed_data:
                        preprocessed_data[test_key] = load_and_preprocess_dataset(split['test_dataset'], target_bands)
                
                if test_key in preprocessed_data:
                    test_data = preprocessed_data[test_key]['data']
                    test_labels = preprocessed_data[test_key]['labels']
                    segmentation_map, n_clusters = segment_with_dbscan(feature_model, test_data, test_labels, patch_size=PATCH_SIZE, model_type=model_type, eps=0.5, min_samples=5, batch_size=BATCH_SIZE)
                    test_acc = evaluate_clustering(segmentation_map, test_labels)
                    test_n_clusters = n_clusters
                    test_n_samples = np.sum(test_labels > 0)
                    print(f"    Test DBSCAN: accuracy={test_acc:.2f}%, clusters={n_clusters}")
                    # Wyczy≈õƒá dane testowe z pamiƒôci je≈õli nie sƒÖ potrzebne
                    if CLEAR_MEMORY_BETWEEN_SPLITS and PROCESS_ONE_DATASET_AT_TIME:
                        del test_data, test_labels, segmentation_map
                
                # Test na validation dataset - DBSCAN (≈Çadowanie na ≈ºƒÖdanie)
                validation_dataset_name = split['validation_dataset']
                final_test_key = (validation_dataset_name, target_bands)
                final_test_acc = 0.0
                final_test_n_clusters = 0
                final_test_n_samples = 0
                
                if LOAD_DATASETS_ON_DEMAND:
                    if final_test_key not in preprocessed_data:
                        preprocessed_data[final_test_key] = load_and_preprocess_dataset(validation_dataset_name, target_bands)
                
                if final_test_key in preprocessed_data:
                    validation_data = preprocessed_data[final_test_key]['data']
                    validation_labels = preprocessed_data[final_test_key]['labels']
                    segmentation_map, n_clusters = segment_with_dbscan(feature_model, validation_data, validation_labels, patch_size=PATCH_SIZE, model_type=model_type, eps=0.5, min_samples=5, batch_size=BATCH_SIZE)
                    final_test_acc = evaluate_clustering(segmentation_map, validation_labels)
                    final_test_n_clusters = n_clusters
                    final_test_n_samples = np.sum(validation_labels > 0)
                    print(f"    Validation DBSCAN: accuracy={final_test_acc:.2f}%, clusters={n_clusters}")
                    # Wyczy≈õƒá dane walidacyjne z pamiƒôci je≈õli nie sƒÖ potrzebne
                    if CLEAR_MEMORY_BETWEEN_SPLITS and PROCESS_ONE_DATASET_AT_TIME:
                        del validation_data, validation_labels, segmentation_map
                
                result = {
                    'split_id': split['split_id'],
                    'model_name': model_name,
                    'target_bands': target_bands,
                    'train_datasets': split['train_datasets'],
                    'test_dataset': split['test_dataset'],
                    'validation_dataset': validation_dataset_name,
                    'test_accuracy': test_acc,
                    'test_n_clusters': test_n_clusters,
                    'validation_accuracy': final_test_acc,
                    'validation_n_clusters': final_test_n_clusters,
                    'test_n_samples': test_n_samples,
                    'validation_n_samples': final_test_n_samples
                }
                results.append(result)
                
                # Wyczy≈õƒá pamiƒôƒá
                del model, feature_model, trained_model, x_train, y_train, x_val, y_val, x_train_full, y_train_full
                tf.keras.backend.clear_session()
                gc.collect()
                
                # Je≈õli przetwarzamy jeden dataset na raz, wyczy≈õƒá preprocessed_data
                if CLEAR_MEMORY_BETWEEN_SPLITS and PROCESS_ONE_DATASET_AT_TIME:
                    # Usu≈Ñ dane treningowe z pamiƒôci
                    for key in list(train_data_dict.keys()):
                        if key in preprocessed_data:
                            del preprocessed_data[key]['data']
                            preprocessed_data[key]['data'] = None  # Zachowaj strukturƒô
                    gc.collect()
                
            except Exception as e:
                print(f"    ‚úó B≈ÇƒÖd: {e}")
                import traceback
                traceback.print_exc()
                # Wyczy≈õƒá pamiƒôƒá nawet przy b≈Çƒôdzie
                tf.keras.backend.clear_session()
                gc.collect()
                continue
        
        # Wyczy≈õƒá pamiƒôƒá miƒôdzy splitami
        if CLEAR_MEMORY_BETWEEN_SPLITS:
            print(f"  üßπ Czyszczenie pamiƒôci po split {split['split_id']}...")
            tf.keras.backend.clear_session()
            gc.collect()
    
    all_results[target_bands] = results
    print(f"\n‚úì Zapisano wyniki dla {target_bands} kana≈Ç√≥w")

print("\n‚úì Wszystkie testy zako≈Ñczone")


In [None]:
# Podsumowanie wynik√≥w
print("=" * 80)
print("PODSUMOWANIE WYNIK√ìW")
print("=" * 80)

for target_bands, results in all_results.items():
    print(f"\n{target_bands} kana≈Ç√≥w:")
    for result in results:
        print(f"  {result['model_name']} - Split {result['split_id']}:")
        print(f"    Test: {result['test_accuracy']:.2f}% ({result['test_n_clusters']} clusters)")
        print(f"    Validation: {result['validation_accuracy']:.2f}% ({result['validation_n_clusters']} clusters)")

# Znajd≈∫ najlepszy model
best_model = None
best_score = -1
for target_bands, results in all_results.items():
    for result in results:
        avg_score = (result['test_accuracy'] + result['validation_accuracy']) / 2
        if avg_score > best_score:
            best_score = avg_score
            best_model = result

if best_model:
    print(f"\n{'='*80}")
    print(f"NAJLEPSZY MODEL:")
    print(f"  Model: {best_model['model_name']}")
    print(f"  Bands: {best_model['target_bands']}")
    print(f"  Split: {best_model['split_id']}")
    print(f"  ≈örednia accuracy: {best_score:.2f}%")
    print(f"{'='*80}")

# Zapis wynik√≥w
print("\nZapisujƒô wyniki do pliku...")
with open('results.json', 'w') as f:
    json.dump(all_results, f, indent=2, default=str)
print("‚úì Wyniki zapisane do results.json")


## KROK 6: Wizualizacja wynik√≥w (opcjonalne)


In [None]:
# Wizualizacja wynik√≥w segmentacji (przyk≈Çad dla pierwszego wyniku)
if all_results:
    target_bands = list(all_results.keys())[0]
    result = all_results[target_bands][0] if all_results[target_bands] else None
    
    if result:
        # Za≈Çaduj model i dane
        model_name = result['model_name']
        model_class = MODELS[model_name]
        model_type = '3d' if model_name == 'InceptionHSINet' else '2d'
        
        # Tutaj mo≈ºesz dodaƒá kod do wizualizacji
        # (wymaga ponownego treningu lub zapisania modelu)
        print(f"Wizualizacja dla {model_name} - wymaga ponownego treningu lub zapisanego modelu")
        print("Mo≈ºesz dodaƒá kod wizualizacji tutaj")
