# Hyperspectral Images Segmentation

## Datasets:
- Indian Pines: 145×145, 200 wavelenghts, 16 classes
- PaviaU: 610×340, 103 wavelenghts, 9 classes
- PaviaC: 1096×715, 102 wavelenghts, 9 classes
- KSC: 512×614, 176 wavelenghts, 13 classes
- Salinas: 512×217, 204 wavelenghts, 16 classes


## Section 1: imports

In [2]:
import os
import urllib.request
import ssl
import scipy.io as sio
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.preprocessing import StandardScaler 
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt
import csv
from tqdm import tqdm
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

try:
    ssl._create_default_https_context = ssl._create_unverified_context
except:
    pass

try:
    import requests
    requests.packages.urllib3.disable_warnings()
except ImportError:
    pass

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Parametry
patch_size = 16
batch_size = 64
epochs = 50
learning_rate = 0.001
val_split = 0.2


Using device: cuda
GPU: NVIDIA GeForce RTX 3060 Laptop GPU


## Section 2: dataset configuration

In [None]:
DATASET_URLS = {
    'Indian': {
        'data': 'https://www.ehu.eus/ccwintco/uploads/6/67/Indian_pines_corrected.mat',
        'gt':   'https://www.ehu.eus/ccwintco/uploads/c/c4/Indian_pines_gt.mat',
    },
    'PaviaU': {
        'data': 'https://www.ehu.eus/ccwintco/uploads/e/ee/PaviaU.mat',
        'gt':   'https://www.ehu.eus/ccwintco/uploads/5/50/PaviaU_gt.mat',
    },
    'PaviaC': {
        'data': 'https://www.ehu.eus/ccwintco/uploads/f/f0/Pavia.mat',
        'gt':   'https://www.ehu.eus/ccwintco/uploads/5/53/Pavia_gt.mat',
    },
    'KSC': {
        'data': 'https://www.ehu.eus/ccwintco/uploads/2/26/KSC.mat',
        'gt':   'https://www.ehu.eus/ccwintco/uploads/a/a6/KSC_gt.mat',
    },
    'Salinas': {
        'data': 'https://www.ehu.eus/ccwintco/uploads/a/a3/Salinas_corrected.mat',
        'gt':   'https://www.ehu.eus/ccwintco/uploads/f/fa/Salinas_gt.mat',
    }
}

DATASET_KEYS = {
    'Indian': {
        'data': ['indian_pines_corrected', 'Indian_pines_corrected'],
        'gt': ['indian_pines_gt', 'Indian_pines_gt']
    },
    'PaviaU': {
        'data': ['paviaU', 'PaviaU', 'pavia_u'],
        'gt': ['paviaU_gt', 'PaviaU_gt', 'pavia_u_gt']
    },
    'PaviaC': {
        'data': ['pavia', 'Pavia', 'paviaC', 'PaviaC'],
        'gt': ['pavia_gt', 'Pavia_gt', 'paviaC_gt', 'PaviaC_gt']
    },
    'KSC': {
        'data': ['KSC', 'ksc'],
        'gt': ['KSC_gt', 'ksc_gt']
    },
    'Salinas': {
        'data': ['salinas_corrected', 'Salinas_corrected', 'salinas'],
        'gt': ['salinas_gt', 'Salinas_gt']
    }
}

DATASET_INFO = {
    'Indian': {'num_classes': 16},
    'PaviaU': {'num_classes': 9},
    'PaviaC': {'num_classes': 9},
    'KSC': {'num_classes': 13},
    'Salinas': {'num_classes': 16}
}


## Section 3: dataloader

In [4]:
def download_file(url, filename):
    if not os.path.exists(filename):
        print(f"Pobieranie {filename}...")
        try:
            urllib.request.urlretrieve(url, filename)
            file_size = os.path.getsize(filename) / 1024 / 1024
            print(f"Pobrano {filename} ({file_size:.1f} MB)")
        except Exception as e:
            try:
                import requests
                response = requests.get(url, verify=False, timeout=60)
                response.raise_for_status()
                with open(filename, 'wb') as f:
                    f.write(response.content)
                file_size = os.path.getsize(filename) / 1024 / 1024
                print(f"Pobrano {filename} ({file_size:.1f} MB) - użyto requests")
            except ImportError:
                print(f"Błąd: requests nie jest zainstalowane.")
                raise
            except Exception as e2:
                print(f"Błąd pobierania {filename}: {str(e2)}")
                raise
    else:
        file_size = os.path.getsize(filename) / 1024 / 1024
        print(f"{filename} już istnieje ({file_size:.1f} MB)")

def find_key_in_mat(mat_file, possible_keys):
    for key in possible_keys:
        if key in mat_file:
            return key

    keys = [k for k in mat_file.keys() if not k.startswith('__')]
    if keys:
        return keys[0]
    raise ValueError(f"Nie znaleziono klucza w pliku .mat. Możliwe: {possible_keys}")

def load_dataset(dataset_name):
    if dataset_name not in DATASET_URLS:
        raise ValueError(f"Unknown dataset: {dataset_name}. Available: {list(DATASET_URLS.keys())}")
    
    urls = DATASET_URLS[dataset_name]
    keys = DATASET_KEYS[dataset_name]
    info = DATASET_INFO[dataset_name].copy()
    
    # Pobierz pliki
    data_file = f"{dataset_name}_data.mat"
    gt_file = f"{dataset_name}_gt.mat"
    
    download_file(urls['data'], data_file)
    download_file(urls['gt'], gt_file)
    
    # Załaduj dane
    mat_data = sio.loadmat(data_file)
    mat_gt = sio.loadmat(gt_file)
    
    # Znajdź właściwe klucze
    data_key = find_key_in_mat(mat_data, keys['data'])
    gt_key = find_key_in_mat(mat_gt, keys['gt'])
    
    data = mat_data[data_key]
    labels = mat_gt[gt_key]
    
    info['num_bands'] = data.shape[2] if len(data.shape) == 3 else data.shape[-1]
    info['shape'] = data.shape
    
    print(f"Załadowano {dataset_name}: shape={data.shape}, bands={info['num_bands']}, classes={info['num_classes']}")
    
    return data, labels, info

def normalize(data):
    h, w, b = data.shape
    data = data.reshape(-1, b)
    scaler = StandardScaler()
    data = scaler.fit_transform(data)
    return data.reshape(h, w, b)

def pad_with_zeros(data, margin):
    return np.pad(data, ((margin, margin), (margin, margin), (0, 0)), mode='constant')


## Section 4: models


In [5]:
#model1
class InceptionHSINet(nn.Module):
    def __init__(self, in_channels=1, num_classes=16):
        super(InceptionHSINet, self).__init__()
        self.entry = nn.Sequential(
            nn.Conv3d(in_channels, 8, kernel_size=3, padding=1),
            nn.Dropout3d(0.3),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2)
        )
        self.branch1 = nn.Sequential(
            nn.Conv3d(8, 16, kernel_size=1),
            nn.Dropout3d(0.3),
            nn.ReLU(),
            nn.Conv3d(16, 16, kernel_size=3, padding=1),
            nn.Dropout3d(0.3),
            nn.ReLU()
        )
        self.branch2 = nn.Sequential(
            nn.Conv3d(8, 16, kernel_size=3, padding=1),
            nn.Dropout3d(0.3),
            nn.ReLU(),
            nn.Conv3d(16, 16, kernel_size=5, padding=2),
            nn.Dropout3d(0.3),
            nn.ReLU()
        )
        self.branch3 = nn.Sequential(
            nn.Conv3d(8, 16, kernel_size=5, padding=2),
            nn.Dropout3d(0.3),
            nn.ReLU(),
            nn.Conv3d(16, 16, kernel_size=3, padding=1),
            nn.Dropout3d(0.3),
            nn.ReLU()
        )
        self.pool = nn.AdaptiveAvgPool3d(1)
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(16 * 3, num_classes)
        )

    def forward(self, x):
        x = self.entry(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        x = torch.cat([x1, x2, x3], dim=1)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

#model2
class SimpleHSINet(nn.Module):
    def __init__(self, input_channels=30, num_classes=16):
        super(SimpleHSINet, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 90, kernel_size=1)
        self.conv2 = nn.Conv2d(90, 270, kernel_size=3)
        self.dropout1 = nn.Dropout2d(0.3)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(270, 180)
        self.dropout2 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(180, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.dropout1(x)
        x = F.adaptive_avg_pool2d(x, 1)
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

#model3
class CNNFromDiagram(nn.Module):
    def __init__(self, input_channels=200, num_classes=16, patch_size=16):
        super(CNNFromDiagram, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=100, kernel_size=3, padding=0)
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(in_channels=100, out_channels=100, kernel_size=3, padding=0)
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        
        # Oblicz rozmiar wyjścia po konwolucjach
        dummy_input = torch.zeros(1, input_channels, patch_size, patch_size)
        x = self.pool1(F.relu(self.conv1(dummy_input)))
        x = self.pool2(F.relu(self.conv2(x)))
        flatten_dim = x.view(1, -1).shape[1]
        
        self.fc1 = nn.Linear(flatten_dim, 84)
        self.fc2 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

print("Załadowano modele")


Załadowano modele


## Section 5: dataset classes


In [6]:
class HSI_Dataset(Dataset):
    def __init__(self, dataset_name, patch_size=16, model_type='2d'):
        self.dataset_name = dataset_name
        self.patch_size = patch_size
        self.model_type = model_type
        
        data, labels, self.info = load_dataset(dataset_name)
        
        # StandardScaler
        data = normalize(data)
        
        # Padding
        margin = patch_size // 2
        padded_data = pad_with_zeros(data, margin)
        
        # patch extraction
        h, w, _ = data.shape
        self.patches = []
        self.targets = []
        
        for i in range(h):
            for j in range(w):
                label = labels[i, j]
                if label == 0:
                    continue
                patch = padded_data[i:i+patch_size, j:j+patch_size, :]
                self.patches.append(patch)
                self.targets.append(label - 1) 
        
        self.patches = np.array(self.patches)
        self.targets = np.array(self.targets)
        
        
        if model_type == '3d':
            self.patches = np.transpose(self.patches, (0, 3, 1, 2))  # (N, B, H, W)
            self.patches = np.expand_dims(self.patches, axis=1)  # (N, 1, B, H, W)
        else:
            self.patches = np.transpose(self.patches, (0, 3, 1, 2))  # (N, B, H, W)
        
        print(f"Dataset {dataset_name}: {len(self)} samples, shape={self.patches.shape}")
    
    def __len__(self):
        return len(self.patches)
    
    def __getitem__(self, idx):
        return torch.tensor(self.patches[idx], dtype=torch.float32), torch.tensor(self.targets[idx], dtype=torch.long)


## Section 6: trening

In [None]:
def get_loaders(dataset_name, patch_size=16, batch_size=64, val_split=0.2, model_type='2d'):
    dataset = HSI_Dataset(dataset_name, patch_size, model_type)
    val_len = int(len(dataset) * val_split)
    train_len = len(dataset) - val_len
    train_set, val_set = random_split(dataset, [train_len, val_len])
    return DataLoader(train_set, batch_size=batch_size, shuffle=True), DataLoader(val_set, batch_size=batch_size), dataset.info

def create_model(model_name, num_bands, num_classes, patch_size=16):
    if model_name == 'InceptionHSINet':
        # Model 3D - in_channels=1, num_classes
        return InceptionHSINet(in_channels=1, num_classes=num_classes)
    elif model_name == 'SimpleHSINet':
        # Model 2D - input_channels=num_bands, num_classes
        return SimpleHSINet(input_channels=num_bands, num_classes=num_classes)
    elif model_name == 'CNNFromDiagram':
        # Model 2D - input_channels=num_bands, num_classes, patch_size
        return CNNFromDiagram(input_channels=num_bands, num_classes=num_classes, patch_size=patch_size)
    else:
        raise ValueError(f"Unknown model: {model_name}")

def train(model, train_loader, val_loader, epochs=50, lr=0.001, device=None, 
          model_name="", dataset_name=""):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    log_file = f"training_log_{model_name}_{dataset_name}.csv"
    
    best_val_acc = 0.0
    train_history = []
    
    with open(log_file, mode="w", newline="") as file:
        writer = csv.writer(file)
        writer.writerow(["Epoch", "Train Loss", "Train Accuracy", "Validation Accuracy"])
        
        for epoch in range(epochs):
            # Training
            model.train()
            total_loss, correct = 0, 0
            
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
                correct += (outputs.argmax(1) == labels).sum().item()
            
            train_acc = 100.0 * correct / len(train_loader.dataset)
            
            # Validation
            model.eval()
            val_correct = 0
            with torch.no_grad():
                for val_inputs, val_labels in val_loader:
                    val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
                    val_outputs = model(val_inputs)
                    val_correct += (val_outputs.argmax(1) == val_labels).sum().item()
            
            val_acc = 100.0 * val_correct / len(val_loader.dataset)
            
            # Save results
            writer.writerow([epoch + 1, total_loss, train_acc, val_acc])
            train_history.append({
                'epoch': epoch + 1,
                'train_loss': total_loss,
                'train_acc': train_acc,
                'val_acc': val_acc
            })
            
            # Best model
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(), f"best_model_{model_name}_{dataset_name}.pth")
            
            if (epoch + 1) % 10 == 0 or epoch == 0:
                print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}, Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")
    
    print(f"Trening zakończony. Best Val Acc: {best_val_acc:.2f}%")
    return train_history, best_val_acc


## Section 7: main loop for experiments

In [None]:
MODELS = ['InceptionHSINet', 'SimpleHSINet', 'CNNFromDiagram']
DATASETS = ['Indian', 'PaviaU', 'PaviaC', 'KSC', 'Salinas']

# dictionary to store results
results = {}

# Model type mapping (which models use 3D, which 2D)
MODEL_TYPES = {
    'InceptionHSINet': '3d',
    'SimpleHSINet': '2d',
    'CNNFromDiagram': '2d'
}

total_experiments = len(MODELS) * len(DATASETS)
experiment_num = 0

print(f"Starting {total_experiments} experiments: {len(MODELS)} models x {len(DATASETS)} datasets")
print("=" * 80)

for model_name in MODELS:
    for dataset_name in DATASETS:
        experiment_num += 1
        print(f"\n{'='*80}")
        print(f"EXPERIMENT {experiment_num}/{total_experiments}: {model_name} on {dataset_name}")
        print(f"{'='*80}")
        
        try:
            # Get model type
            model_type = MODEL_TYPES[model_name]
            
            # Load data and create loaders
            train_loader, val_loader, dataset_info = get_loaders(
                dataset_name=dataset_name,
                patch_size=patch_size,
                batch_size=batch_size,
                val_split=val_split,
                model_type=model_type
            )
            
            # Create model with corresponding parameters
            num_bands = dataset_info['num_bands']
            num_classes = dataset_info['num_classes']
            
            model = create_model(model_name, num_bands, num_classes, patch_size)
            print(f"Model: {model_name}, Input: {num_bands} bands, {num_classes} classes")
            
            # Train
            history, best_val_acc = train(
                model, train_loader, val_loader, 
                epochs=epochs, lr=learning_rate, device=device,
                model_name=model_name, dataset_name=dataset_name
            )
            
            # Save results
            results[f"{model_name}_{dataset_name}"] = {
                'model': model_name,
                'dataset': dataset_name,
                'best_val_acc': best_val_acc,
                'num_bands': num_bands,
                'num_classes': num_classes,
                'num_samples': len(train_loader.dataset) + len(val_loader.dataset)
            }
            
            print(f"Finished: {model_name} on {dataset_name} - Val Acc: {best_val_acc:.2f}%")
            
        except Exception as e:
            print(f"Error in experiment {model_name} on {dataset_name}: {str(e)}")
            results[f"{model_name}_{dataset_name}"] = {
                'model': model_name,
                'dataset': dataset_name,
                'error': str(e)
            }
        
        print()

print(f"\n{'='*80}")
print("ALL EXPERIMENTS FINISHED")
print(f"{'='*80}")


Starting 15 experiments: 3 models x 5 datasets

EXPERIMENT 1/15: InceptionHSINet on Indian
Indian_data.mat już istnieje (5.7 MB)
Indian_gt.mat już istnieje (0.0 MB)
Załadowano Indian: shape=(145, 145, 200), bands=200, classes=16
Dataset Indian: 10249 samples, shape=(10249, 1, 200, 16, 16)
Model: InceptionHSINet, Input: 200 bands, 16 classes
Epoch 1/50, Loss: 299.6764, Train Acc: 23.84%, Val Acc: 35.29%
Epoch 10/50, Loss: 214.0621, Train Acc: 42.12%, Val Acc: 54.12%
Epoch 20/50, Loss: 194.5109, Train Acc: 47.62%, Val Acc: 57.00%
Epoch 30/50, Loss: 178.6502, Train Acc: 50.66%, Val Acc: 62.27%


## Section 8: summary

In [None]:
#dataframe for results
results_list = []
for key, value in results.items():
    if 'error' not in value:
        results_list.append(value)

if results_list:
    df_results = pd.DataFrame(results_list)
    
    pivot_results = df_results.pivot(index='model', columns='dataset', values='best_val_acc')
    
    print("\n" + "="*80)
    print("Validation Accuracy (%)")
    print("="*80)
    print(pivot_results.round(2))
    
    # Save to CSV
    df_results.to_csv('all_results.csv', index=False)
    pivot_results.to_csv('results_pivot.csv')
    print(f"\nResults saved to all_results.csv and results_pivot.csv")
    
    # Find best results
    print("\n" + "="*80)
    print("BEST RESULTS:")
    print("="*80)
    best_overall = df_results.loc[df_results['best_val_acc'].idxmax()]
    print(f"Best result: {best_overall['model']} on {best_overall['dataset']}: {best_overall['best_val_acc']:.2f}%")
    
    # Average per model
    print("\nAverage accuracy per model:")
    print(df_results.groupby('model')['best_val_acc'].mean().round(2).sort_values(ascending=False))
    
    # Average per dataset
    print("\nAverage accuracy per dataset:")
    print(df_results.groupby('dataset')['best_val_acc'].mean().round(2).sort_values(ascending=False))
else:
    print("No results to display")


## Data visualization


In [None]:
if 'results_list' not in locals():
    results_list = []
    for key, value in results.items():
        if 'error' not in value:
            results_list.append(value)

if results_list:
    df_results = pd.DataFrame(results_list)
    
    # Heatmap of results
    pivot_results = df_results.pivot(index='model', columns='dataset', values='best_val_acc')
    
    plt.figure(figsize=(12, 6))
    plt.imshow(pivot_results.values, aspect='auto', cmap='YlOrRd')
    plt.colorbar(label='Validation Accuracy (%)')
    plt.xticks(range(len(pivot_results.columns)), pivot_results.columns, rotation=45)
    plt.yticks(range(len(pivot_results.index)), pivot_results.index)
    plt.xlabel('Dataset')
    plt.ylabel('Model')
    plt.title('Validation Accuracy per Model-Dataset Combination')
    
    # Add values to heatmap
    for i in range(len(pivot_results.index)):
        for j in range(len(pivot_results.columns)):
            text = plt.text(j, i, f'{pivot_results.iloc[i, j]:.1f}',
                           ha="center", va="center", color="black", fontsize=9)
    
    plt.tight_layout()
    plt.savefig('results_heatmap.png', dpi=150)
    plt.show()
    print("Heatmap saved to results_heatmap.png")
    
    # Bar plot - średnie per model
    plt.figure(figsize=(10, 6))
    mean_per_model = df_results.groupby('model')['best_val_acc'].mean().sort_values(ascending=False)
    mean_per_model.plot(kind='bar', color=['#1f77b4', '#ff7f0e', '#2ca02c'])
    plt.ylabel('Average Validation Accuracy (%)')
    plt.xlabel('Model')
    plt.title('Average Performance per Model Across All Datasets')
    plt.xticks(rotation=45)
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.savefig('results_per_model.png', dpi=150)
    plt.show()
    print("Plot of results per model saved to results_per_model.png")
