In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import scipy.io as sio
import os
import requests
from torch.cuda.amp import autocast, GradScaler
import random
from tqdm import tqdm
import kornia.morphology as morph
from einops import rearrange

class MultiScaleBlock(nn.Module):
    def __init__(self, in_channels):
        super(MultiScaleBlock, self).__init__()
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // 4, 1),
            nn.BatchNorm2d(in_channels // 4),
            nn.GELU()
        )
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // 4, 3, padding=1, dilation=1),
            nn.BatchNorm2d(in_channels // 4),
            nn.GELU()
        )
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // 4, 3, padding=2, dilation=2),
            nn.BatchNorm2d(in_channels // 4),
            nn.GELU()
        )
        self.branch4 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // 4, 3, padding=4, dilation=4),
            nn.BatchNorm2d(in_channels // 4),
            nn.GELU()
        )
        self.fusion = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 1),
            nn.BatchNorm2d(in_channels),
            nn.GELU()
        )

    def forward(self, x):
        b1 = self.branch1(x)
        b2 = self.branch2(x)
        b3 = self.branch3(x)
        b4 = self.branch4(x)
        out = torch.cat([b1, b2, b3, b4], dim=1)
        return self.fusion(out)

class EnhancedMorphologicalAttention(nn.Module):
    def __init__(self, in_channels, num_heads=8):
        super(EnhancedMorphologicalAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = in_channels // num_heads
        self.scale = (self.head_dim) ** -0.5
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
        self.qkv = nn.Conv2d(in_channels, in_channels * 3, 1)
        self.proj = nn.Conv2d(in_channels, in_channels, 1)
        self.spatial_gate = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels),
            nn.BatchNorm2d(in_channels),
            nn.GELU(),
            nn.Conv2d(in_channels, in_channels, 1),
            nn.Sigmoid()
        )
        self.channel_mixer = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels // 4, 1),
            nn.GELU(),
            nn.Conv2d(in_channels // 4, in_channels, 1),
            nn.Sigmoid()
        )
        self.morph_process = nn.Sequential(
            nn.Conv2d(in_channels * 2, in_channels, 1),
            nn.BatchNorm2d(in_channels),
            nn.GELU(),
            nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels),
            nn.Conv2d(in_channels, in_channels, 1),
            nn.Sigmoid()
        )
        self.norm1 = nn.BatchNorm2d(in_channels)
        self.norm2 = nn.BatchNorm2d(in_channels)
        self.drop_path = nn.Dropout(0.6)
        self.gamma = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        B, C, H, W = x.shape
        shortcut = x
        x = self.norm1(x)
        qkv = self.qkv(x)
        qkv = qkv.reshape(B, 3, self.num_heads, self.head_dim, H, W)
        qkv = qkv.permute(0, 1, 2, 3, 4, 5).contiguous()
        qkv = qkv.view(B, 3, self.num_heads, self.head_dim, -1)
        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]
        q = q * self.scale
        attn = torch.einsum('bhnm,bhkm->bhnk', q, k)
        attn = attn * self.temperature
        attn = F.softmax(attn, dim=-1)
        x = torch.einsum('bhnk,bhkm->bhnm', attn, v)
        x = x.view(B, self.num_heads * self.head_dim, H, W)
        x = self.proj(x)
        kernel = torch.ones(3, 3).to(x.device)
        dilated = morph.dilation(x, kernel)
        eroded = morph.erosion(x, kernel)
        morph_features = torch.cat([dilated, eroded], dim=1)
        morph_gate = self.morph_process(morph_features)
        spatial_attn = self.spatial_gate(x)
        channel_attn = self.channel_mixer(x)
        x = x * morph_gate * spatial_attn * channel_attn
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.gamma * self.norm2(x))
        return x

class EnhancedResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(EnhancedResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv2 = nn.Conv2d(in_channels, in_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(in_channels)
        self.attention = EnhancedMorphologicalAttention(in_channels)
        self.multiscale = MultiScaleBlock(in_channels)
        self.dropout = nn.Dropout(0.3)
        self.gelu = nn.GELU()

    def forward(self, x):
        residual = x
        out = self.gelu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        out = self.attention(out)
        out = self.multiscale(out)
        out += residual
        return self.gelu(out)

class RobustEnergyFunction(nn.Module):
    def __init__(self, input_channels, num_classes):
        super(RobustEnergyFunction, self).__init__()
        self.initial_conv = nn.Sequential(
            nn.Conv2d(input_channels, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.GELU(),
            nn.Dropout(0.3)
        )
        self.encoder = nn.ModuleList([
            EnhancedResidualBlock(128),
            nn.Sequential(
                nn.Conv2d(128, 256, 3, padding=1),
                nn.BatchNorm2d(256),
                nn.GELU(),
                nn.MaxPool2d(2)
            ),
            EnhancedResidualBlock(256),
            EnhancedResidualBlock(256)
        ])
        self.energy_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(256, 512),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        self.auxiliary_head = nn.Sequential(
            nn.Conv2d(256, 128, 1),
            nn.BatchNorm2d(128),
            nn.GELU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.initial_conv(x)
        for layer in self.encoder:
            x = layer(x)
        energy = self.energy_head(x)
        aux = self.auxiliary_head(x)
        if self.training:
            return energy, aux
        return energy

class RobustHyperspectralEBM(nn.Module):
    def __init__(self, input_channels, num_classes, device):
        super(RobustHyperspectralEBM, self).__init__()
        self.energy_net = RobustEnergyFunction(input_channels, num_classes)
        self.device = device
        self.num_classes = num_classes

    def forward(self, x):
        return self.energy_net(x)

    def sample(self, x, num_steps=15, step_size=0.05):
        x_k = x.clone().requires_grad_(True)
        for _ in range(num_steps):
            energy = self.energy_net(x_k)[0] if self.training else self.energy_net(x_k)
            grad = torch.autograd.grad(energy.sum(), x_k)[0]
            x_k.data += step_size * grad
            x_k.data += torch.randn_like(x_k) * np.sqrt(step_size * 2)
        return x_k.detach()

    def compute_loss(self, x_pos, x_neg, y):
        pos_energy, aux = self.energy_net(x_pos)
        neg_energy = self.energy_net(x_neg)[0] if self.training else self.energy_net(x_neg)
        pos_loss = F.cross_entropy(pos_energy, y, label_smoothing=0.1)
        aux_loss = F.cross_entropy(aux, y, label_smoothing=0.1)
        reg_loss = 0.5 * (pos_energy.pow(2).mean() + neg_energy.pow(2).mean())
        contrastive_loss = torch.mean(torch.relu(1.0 - (pos_energy - neg_energy)))
        return pos_loss + 0.5 * aux_loss + 0.1 * reg_loss + 0.1 * contrastive_loss

def evaluate_model(model, data_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            with autocast():
                output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
    return 100. * correct / total

class EnhancedDataProcessor:
    def __init__(self, patch_size=11):
        self.patch_size = patch_size
        self.scaler = StandardScaler()
        
    def download_dataset(self):
        base_url = "https://www.ehu.eus/ccwintco/uploads"
        files = {
            'Indian_pines.mat': f"{base_url}/2/22/Indian_pines.mat",
            'Indian_pines_gt.mat': f"{base_url}/c/c4/Indian_pines_gt.mat"
        }
        for filename, url in files.items():
            if not os.path.exists(filename):
                response = requests.get(url, allow_redirects=True)
                with open(filename, 'wb') as f:
                    f.write(response.content)

    def load_indian_pines(self):
        self.download_dataset()
        data = sio.loadmat('Indian_pines.mat')['indian_pines']
        labels = sio.loadmat('Indian_pines_gt.mat')['indian_pines_gt']
        return data, labels

    def create_patches(self, data, labels):
        padding = self.patch_size // 2
        padded_data = np.pad(data, ((padding, padding), (padding, padding), (0, 0)), mode='reflect')
        patches = []
        patch_labels = []
        for i in range(padding, padded_data.shape[0] - padding):
            for j in range(padding, padded_data.shape[1] - padding):
                if labels[i-padding, j-padding] != 0:
                    patch = padded_data[i-padding:i+padding+1, j-padding:j+padding+1, :]
                    patches.append(patch)
                    patch_labels.append(labels[i-padding, j-padding] - 1)
        patches = np.array(patches)
        patch_labels = np.array(patch_labels)
        patches = np.transpose(patches, (0, 3, 1, 2))
        patches = self.scaler.fit_transform(patches.reshape(-1, patches.shape[1])).reshape(patches.shape)
        return patches.astype(np.float32), patch_labels

In [2]:
import torch
import numpy as np
from sklearn.metrics import confusion_matrix, cohen_kappa_score
from torch.cuda.amp import autocast
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

def evaluate_model_comprehensive(model, test_loader, device, num_classes):
    model.eval()
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            with autocast():
                output = model(data)
            pred = output.argmax(dim=1)
            all_preds.extend(pred.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    
    # Convert to numpy arrays
    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)
    
    # Calculate confusion matrix
    conf_matrix = confusion_matrix(all_targets, all_preds)
    
    # Calculate Overall Accuracy (OA)
    overall_accuracy = np.sum(np.diag(conf_matrix)) / np.sum(conf_matrix)
    
    # Calculate Per-Class Accuracy
    per_class_accuracy = np.diag(conf_matrix) / np.sum(conf_matrix, axis=1)
    
    # Calculate Average Accuracy (AA)
    average_accuracy = np.mean(per_class_accuracy)
    
    # Calculate Kappa Coefficient
    kappa = cohen_kappa_score(all_targets, all_preds)
    
    return {
        'overall_accuracy': overall_accuracy * 100,
        'average_accuracy': average_accuracy * 100,
        'kappa': kappa,
        'per_class_accuracy': per_class_accuracy * 100,
        'confusion_matrix': conf_matrix
    }

def load_and_evaluate():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load the saved model
    checkpoint = torch.load('ip_best_model.pth', map_location=device)
    
    # Initialize data processor and load data
    processor = EnhancedDataProcessor()
    data, labels = processor.load_indian_pines()
    patches, patch_labels = processor.create_patches(data, labels)
    
    # Create test dataset
    X_train, X_test, y_train, y_test = train_test_split(
        patches, patch_labels, test_size=0.9, random_state=42, stratify=patch_labels
    )
    
    test_dataset = TensorDataset(torch.FloatTensor(X_test), torch.LongTensor(y_test))
    test_loader = DataLoader(
        test_dataset, 
        batch_size=16, 
        num_workers=4, 
        pin_memory=True
    )
    
    # Initialize model
    num_classes = len(np.unique(patch_labels))
    model = RobustHyperspectralEBM(
        input_channels=patches.shape[1],
        num_classes=num_classes,
        device=device
    ).to(device)
    
    # Load the saved state dict
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Evaluate the model
    metrics = evaluate_model_comprehensive(model, test_loader, device, num_classes)
    
    # Print results
    print("\nEvaluation Results:")
    print("-" * 50)
    print(f"Overall Accuracy: {metrics['overall_accuracy']:.2f}%")
    print(f"Average Accuracy: {metrics['average_accuracy']:.2f}%")
    print(f"Kappa Coefficient: {metrics['kappa']:.4f}")
    print("\nPer-Class Accuracy:")
    for i, acc in enumerate(metrics['per_class_accuracy']):
        print(f"Class {i+1}: {acc:.2f}%")
    
    print("\nConfusion Matrix:")
    print(metrics['confusion_matrix'])
    
    return metrics

if __name__ == '__main__':
    metrics = load_and_evaluate()

  checkpoint = torch.load('ip_best_model.pth', map_location=device)
  with autocast():



Evaluation Results:
--------------------------------------------------
Overall Accuracy: 98.62%
Average Accuracy: 98.27%
Kappa Coefficient: 0.9843

Per-Class Accuracy:
Class 1: 92.68%
Class 2: 97.74%
Class 3: 98.80%
Class 4: 93.90%
Class 5: 97.93%
Class 6: 100.00%
Class 7: 100.00%
Class 8: 100.00%
Class 9: 100.00%
Class 10: 99.66%
Class 11: 98.05%
Class 12: 97.19%
Class 13: 100.00%
Class 14: 100.00%
Class 15: 100.00%
Class 16: 96.43%

Confusion Matrix:
[[  38    1    0    0    0    0    0    0    0    2    0    0    0    0
     0    0]
 [   0 1256    0    0    0    7    0    0    0    9   13    0    0    0
     0    0]
 [   0    9  738    0    0    0    0    0    0    0    0    0    0    0
     0    0]
 [   0    0    6  200    0    0    0    0    0    0    0    7    0    0
     0    0]
 [   0    0    3    0  426    0    1    0    0    0    0    0    5    0
     0    0]
 [   0    0    0    0    0  657    0    0    0    0    0    0    0    0
     0    0]
 [   0    0    0    0    0    0 