In [1]:
#Trial 13
# fixes applied
# Random val split
# increases training and val pathes

In [None]:
# Complete Memory-Optimized Enhanced HSI Denoising Pipeline for PSNR > 40 dB
# Comprehensive training with PSNR, SSIM, SAM metrics and visualization

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import glob
import matplotlib.pyplot as plt
from torch.utils.checkpoint import checkpoint
import json
from skimage.metrics import structural_similarity as compare_ssim
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch_ema import ExponentialMovingAverage
from torchvision.models import vgg19
from torchvision.models import VGG19_Weights
import torch.fft as fft
import h5py
try:
    import scipy.io as sio
except ImportError:
    sio = None
import scipy.ndimage
# ----------------------------
# Memory-Efficient Utilities
# ----------------------------
class LayerNormChannel3d(nn.Module):
    """Lightweight channel normalization"""
    def __init__(self, num_channels: int = None, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.num_channels = num_channels
        self.gn = None
        if num_channels is not None:
            self.gn = nn.GroupNorm(1, num_channels, eps=eps, affine=True)

    def forward(self, x):
        C = x.shape[1]
        if self.gn is None or C != self.num_channels:
            self.num_channels = C
            self.gn = nn.GroupNorm(1, C, eps=self.eps, affine=True).to(x.device)
        return self.gn(x)

def depthwise_conv3d(channels: int, kernel_size: int = 3, stride: int = 1, padding: int = 1, dilation: int = 1):
    return nn.Conv3d(channels, channels, kernel_size=kernel_size, stride=stride,
                     padding=padding, groups=channels, bias=True, dilation=dilation)

class EfficientChannelAttention(nn.Module):
    """Memory-efficient channel attention"""
    def __init__(self, channels, reduction=16):
        super().__init__()
        hidden = max(4, channels // reduction)
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, hidden, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(hidden, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        B, C, D, H, W = x.shape
        y = self.avg_pool(x).view(B, C)
        y = self.fc(y).view(B, C, 1, 1, 1)
        return x * y.expand_as(x)

class AdaptiveDropout3d(nn.Module):
    """Adaptive dropout that adjusts rate based on a factor (e.g., layer depth)"""
    def __init__(self, base_drop=0.25, factor=1.0):
        super().__init__()
        self.dropout = nn.Dropout3d(base_drop * factor)

    def forward(self, x):
        return self.dropout(x)

# ----------------------------
# Memory-Optimized Blocks
# ----------------------------
class PatchMerging3D(nn.Module):
    """
    3D Patch Merging Layer (downsampling)
    """
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(8 * dim)

    def forward(self, x):
        """
        Args:
            x: (B, C, D, H, W)
        """
        B, C, D, H, W = x.shape

        # Pad if needed
        pad_d = (2 - D % 2) % 2
        pad_h = (2 - H % 2) % 2
        pad_w = (2 - W % 2) % 2

        if pad_d > 0 or pad_h > 0 or pad_w > 0:
            x = F.pad(x, (0, pad_w, 0, pad_h, 0, pad_d))
            D, H, W = x.shape[2:]

        # Convert to (B, D, H, W, C)
        x = x.permute(0, 2, 3, 4, 1).contiguous()

        # Downsample by merging 2x2x2 patches
        x0 = x[:, 0::2, 0::2, 0::2, :]  # (B, D/2, H/2, W/2, C)
        x1 = x[:, 1::2, 0::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, 0::2, :]
        x3 = x[:, 1::2, 1::2, 0::2, :]
        x4 = x[:, 0::2, 0::2, 1::2, :]
        x5 = x[:, 1::2, 0::2, 1::2, :]
        x6 = x[:, 0::2, 1::2, 1::2, :]
        x7 = x[:, 1::2, 1::2, 1::2, :]

        x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)  # (B, D/2, H/2, W/2, 8*C)
        x = self.norm(x)
        x = self.reduction(x)  # (B, D/2, H/2, W/2, 2*C)

        # Convert back to (B, C, D, H, W)
        x = x.permute(0, 4, 1, 2, 3).contiguous()
        return x

class PatchMerging3D(nn.Module):
    """3D Patch Merging with better handling"""
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(8 * dim)

    def forward(self, x):
        B, C, D, H, W = x.shape
        
        # Pad to multiples of 2
        pad_d = (2 - D % 2) % 2
        pad_h = (2 - H % 2) % 2
        pad_w = (2 - W % 2) % 2
        
        if pad_d > 0 or pad_h > 0 or pad_w > 0:
            x = F.pad(x, (0, pad_w, 0, pad_h, 0, pad_d))
            D, H, W = x.shape[2:]

        x = x.permute(0, 2, 3, 4, 1).contiguous()
        
        x0 = x[:, 0::2, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, 0::2, :]
        x3 = x[:, 1::2, 1::2, 0::2, :]
        x4 = x[:, 0::2, 0::2, 1::2, :]
        x5 = x[:, 1::2, 0::2, 1::2, :]
        x6 = x[:, 0::2, 1::2, 1::2, :]
        x7 = x[:, 1::2, 1::2, 1::2, :]
        
        x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)
        x = self.norm(x)
        x = self.reduction(x)
        
        x = x.permute(0, 4, 1, 2, 3).contiguous()
        return x

class PatchExpanding3D(nn.Module):
    """3D Patch Expanding for decoder"""
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.expand = nn.Linear(dim, 4 * dim, bias=False)
        self.norm = norm_layer(dim)

    def forward(self, x):
        B, C, D, H, W = x.shape
        
        x = x.permute(0, 2, 3, 4, 1).contiguous()
        x = self.norm(x)
        x = self.expand(x)
        
        x = rearrange(x, 'b d h w (p1 p2 p3 c) -> b (d p1) (h p2) (w p3) c', 
                     p1=2, p2=2, p3=2, c=C//2)
        
        x = x.permute(0, 4, 1, 2, 3).contiguous()
        return x

class SpectralAttentionModule(nn.Module):
    """Dedicated spectral attention for bottleneck - HYBRID ATTENTION"""
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Conv3d(dim, dim * 3, 1)
        self.proj = nn.Conv3d(dim, dim, 1)
        self.norm = nn.GroupNorm(1, dim)
        
    def forward(self, x):
        B, C, D, H, W = x.shape
        x_norm = self.norm(x)
        
        qkv = self.qkv(x_norm)
        q, k, v = torch.chunk(qkv, 3, dim=1)
        
        # Reshape for spectral attention: (B*H*W, num_heads, D, head_dim)
        q = q.permute(0, 3, 4, 2, 1).reshape(B * H * W, D, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.permute(0, 3, 4, 2, 1).reshape(B * H * W, D, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = v.permute(0, 3, 4, 2, 1).reshape(B * H * W, D, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        
        # Spectral attention
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        
        out = (attn @ v).permute(0, 2, 1, 3).reshape(B, H, W, D, C)
        out = out.permute(0, 4, 3, 1, 2).contiguous()
        
        out = self.proj(out)
        return x + out

# ----------------------------
# SST Blocks
# ----------------------------

class SpectralSelfAttention(nn.Module):
    """
    Spectral Self-Attention: Attends along the spectral dimension (bands)
    Treats each spatial position independently and attends across all bands
    """
    def __init__(self, dim, num_bands, num_heads=8, qkv_bias=True, attn_drop=0., proj_drop=0.):
            super().__init__()
            self.dim = dim
            self.num_bands = num_bands  # Expected bands from config
            self.num_heads = num_heads
            self.head_dim = dim // num_heads
            self.scale = self.head_dim ** -0.5
            
            assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}"
            
            # Linear projections for Q, K, V
            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
            self.attn_drop = nn.Dropout(attn_drop)
            self.proj = nn.Linear(dim, dim)
            self.proj_drop = nn.Dropout(proj_drop)
            
            # FIXED: Initialize buffer with expected bands, will expand automatically if needed
            spectral_pos_embed = torch.zeros(1, num_bands, dim)
            nn.init.trunc_normal_(spectral_pos_embed, std=0.02)
            self.register_buffer('spectral_pos_embed', spectral_pos_embed)
        
    def forward(self, x):
        """
        Args:
            x: (B, C, D, H, W) - Input feature map
        Returns:
            x: (B, C, D, H, W) - Output feature map
        """
        B, C, D, H, W = x.shape
        
        # Reshape to (B, H, W, D, C) for spectral attention
        x = x.permute(0, 3, 4, 2, 1).contiguous()  # (B, H, W, D, C)
        
        # Flatten spatial dimensions: (B*H*W, D, C)
        x_flat = x.reshape(B * H * W, D, C)
        
        # FIXED: Dynamically expand buffer if we encounter more bands than initialized
        if D > self.spectral_pos_embed.shape[1]:
            # This happens rarely (only when encountering new max bands)
            old_size = self.spectral_pos_embed.shape[1]
            new_size = D
            
            # Create expanded buffer on same device
            expanded = torch.zeros(1, new_size, self.dim, 
                                  device=self.spectral_pos_embed.device,
                                  dtype=self.spectral_pos_embed.dtype)
            
            # Copy existing learned embeddings
            expanded[:, :old_size, :] = self.spectral_pos_embed
            
            # Initialize new bands with small random values
            nn.init.trunc_normal_(expanded[:, old_size:, :], std=0.02)
            
            # Update the buffer in-place
            self.spectral_pos_embed.resize_(expanded.shape)
            self.spectral_pos_embed.copy_(expanded)
            
            #print(f"[SpectralPosEmbed] Auto-expanded from {old_size} to {new_size} bands")
        
        # Use only the bands we need (safe slicing - buffer is always >= D now)
        pos_embed = self.spectral_pos_embed[:, :D, :]  # (1, D, dim)
        
        # Add spectral positional encoding
        x_flat = x_flat + pos_embed
        
        # Generate Q, K, V
        qkv = self.qkv(x_flat).reshape(B * H * W, D, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # Each: (B*H*W, num_heads, D, head_dim)
        
        # Scaled dot-product attention across spectral dimension
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))  # (B*H*W, num_heads, D, D)
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_drop(attn)
        
        # Apply attention to values
        x_attn = (attn @ v).transpose(1, 2).reshape(B * H * W, D, C)  # (B*H*W, D, C)
        
        # Project and reshape back
        x_attn = self.proj(x_attn)
        x_attn = self.proj_drop(x_attn)
        
        # Reshape back to (B, H, W, D, C)
        x_attn = x_attn.reshape(B, H, W, D, C)
        
        # Reshape to original format (B, C, D, H, W)
        x_attn = x_attn.permute(0, 4, 3, 1, 2).contiguous()
        
        return x_attn

class SpatialSelfAttention(nn.Module):
    """
    FIXED Spatial Self-Attention - SAME NAME, IMPROVED IMPLEMENTATION
    Drop-in replacement - no API changes, just better internals
    """
    def __init__(self, dim, num_heads=8, window_size=8, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.window_size = window_size
        
        assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}"
        
        # FIXED: Full 3D convolution instead of depthwise
        # OLD: self.dwconv = depthwise_conv3d(dim, kernel_size=(1, 3, 3), padding=(0, 1, 1))
        # NEW: Full conv for cross-spectral information flow
        self.dwconv = nn.Conv3d(dim, dim, kernel_size=(1, 3, 3), padding=(0, 1, 1), bias=True)
        
        # Keep same API as before
        self.qkv = nn.Conv3d(dim, dim * 3, kernel_size=1, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Conv3d(dim, dim, kernel_size=1)
        self.proj_drop = nn.Dropout(proj_drop)
        
        # NEW: Relative position bias (optional, improves performance)
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)
        )
        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
        
        coords_h = torch.arange(window_size)
        coords_w = torch.arange(window_size)
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij'))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += window_size - 1
        relative_coords[:, :, 1] += window_size - 1
        relative_coords[:, :, 0] *= 2 * window_size - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index)
        
    def forward(self, x):
        """
        SAME API: (B, C, D, H, W) -> (B, C, D, H, W)
        Just improved internals
        """
        B, C, D, H, W = x.shape
        
        # Apply full 3D conv (cross-spectral enabled)
        x_local = self.dwconv(x)
        
        qkv = self.qkv(x_local)
        
        if H * W > self.window_size ** 2:
            # Efficient attention for large spatial dims
            qkv = rearrange(qkv, 'b (three head c) d h w -> three b d head c (h w)', 
                           three=3, head=self.num_heads)
            q, k, v = qkv[0], qkv[1], qkv[2]
            
            q_flat = q.reshape(B * D * self.num_heads, self.head_dim, H * W)
            k_flat = k.reshape(B * D * self.num_heads, self.head_dim, H * W)
            v_flat = v.reshape(B * D * self.num_heads, self.head_dim, H * W)
            
            k_global = k_flat.mean(dim=-1, keepdim=True)
            v_global = v_flat.mean(dim=-1, keepdim=True)
            
            q_flat = q_flat * self.scale
            attn = torch.bmm(q_flat.transpose(1, 2), k_global)
            attn = F.softmax(attn, dim=1)
            attn = self.attn_drop(attn)
            
            x_attn = v_global * attn.transpose(1, 2)
            x_attn = x_attn.reshape(B, D, self.num_heads, self.head_dim, H * W)
            
        else:
            # Full attention for small spatial dims
            qkv = rearrange(qkv, 'b (three head c) d h w -> three b d head c (h w)', 
                           three=3, head=self.num_heads)
            q, k, v = qkv[0], qkv[1], qkv[2]
            
            q_flat = q.reshape(B * D * self.num_heads, self.head_dim, H * W)
            k_flat = k.reshape(B * D * self.num_heads, self.head_dim, H * W)
            v_flat = v.reshape(B * D * self.num_heads, self.head_dim, H * W)
            
            q_flat = q_flat * self.scale
            attn = torch.bmm(q_flat.transpose(1, 2), k_flat)
            
            if H == self.window_size and W == self.window_size:
                relative_position_bias = self.relative_position_bias_table[
                    self.relative_position_index.view(-1)
                ].view(self.window_size * self.window_size, self.window_size * self.window_size, -1)
                relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
                attn = attn.view(B * D, self.num_heads, H * W, H * W) + relative_position_bias.unsqueeze(0)
                attn = attn.view(B * D * self.num_heads, H * W, H * W)
            
            attn = F.softmax(attn, dim=-1)
            attn = self.attn_drop(attn)
            
            x_attn = torch.bmm(attn, v_flat.transpose(1, 2))
            x_attn = x_attn.transpose(1, 2)
            x_attn = x_attn.reshape(B, D, self.num_heads, self.head_dim, H * W)
        
        x_attn = rearrange(x_attn, 'b d head c (h w) -> b (head c) d h w', 
                          head=self.num_heads, h=H, w=W)
        
        x_attn = self.proj(x_attn)
        x_attn = self.proj_drop(x_attn)
        
        return x_attn


class SSTBlock(nn.Module):
    """
    Spectral-Spatial Transformer Block
    Combines spectral and spatial self-attention with feed-forward network
    """
    def __init__(self, dim, num_bands, num_heads=8, window_size=8, 
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., 
                 drop_path=0., norm_layer=None):
        super().__init__()
        norm_layer = norm_layer or LayerNormChannel3d
        
        self.dim = dim
        self.num_bands = num_bands
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio
        
        # Normalization layers
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        self.norm3 = norm_layer(dim)
        
        # Spectral attention
        self.spectral_attn = SpectralSelfAttention(
            dim=dim,
            num_bands=num_bands,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop
        )
        
        # Spatial attention
        self.spatial_attn = SpatialSelfAttention(
            dim=dim,
            num_heads=num_heads,
            window_size=window_size,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop
        )
        
        # Drop path for stochastic depth
        self.drop_path = nn.Dropout(drop_path) if drop_path > 0. else nn.Identity()
        
        # Feed-forward network (reuse existing GDFN)
        self.ffn = GDFN(dim, ffn_expansion_factor=mlp_ratio, bias=False)
        
    def forward(self, x):
        """
        Args:
            x: (B, C, D, H, W)
        Returns:
            x: (B, C, D, H, W)
        """
        # Spectral attention with residual
        x = x + self.drop_path(self.spectral_attn(self.norm1(x)))
        
        # Spatial attention with residual
        x = x + self.drop_path(self.spatial_attn(self.norm2(x)))
        
        # Feed-forward with residual
        x = x + self.drop_path(self.ffn(self.norm3(x)))
        
        return x


class SSTStage(nn.Module):
    """
    SST Stage: Multiple SST blocks with optional downsampling
    Replaces SwinTransformerStage3D
    """
    def __init__(self, dim, num_bands, depth, num_heads=8, window_size=8,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
                 drop_path_rate=0., norm_layer=None, downsample=None):
        super().__init__()
        norm_layer = norm_layer or LayerNormChannel3d
        
        self.dim = dim
        self.depth = depth
        
        # Stochastic depth decay rule
        dpr = [drop_path_rate * (i / (depth - 1)) if depth > 1 else drop_path_rate 
               for i in range(depth)]
        
        # Build SST blocks
        self.blocks = nn.ModuleList([
            SSTBlock(
                dim=dim,
                num_bands=num_bands,
                num_heads=num_heads,
                window_size=window_size,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=dpr[i],
                norm_layer=norm_layer
            )
            for i in range(depth)
        ])
        
        # Downsampling layer
        if downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=nn.LayerNorm)
        else:
            self.downsample = None
    
    def forward(self, x):
        """
        Args:
            x: (B, C, D, H, W)
        Returns:
            x: (B, C, D, H, W) - features before downsampling
            x_down: (B, 2*C, D/2, H/2, W/2) - downsampled features (if downsample exists)
        """
        # Pass through SST blocks
        for blk in self.blocks:
            if self.training:
                x = checkpoint(blk, x, use_reentrant=False)
            else:
                x = blk(x)
        
        # Downsample if needed
        if self.downsample is not None:
            x_down = self.downsample(x)
            return x, x_down
        else:
            return x, x

# ----------------------------
# SST Blocks
# ----------------------------

class SpectralSelfModulatingResidualBlock(nn.Module):
    """Spectral Self-Modulating Residual Block (SSMRB) for adaptive feature transformation."""
    def __init__(self, dim, ffn_expand=2, drop=0.25, drop_factor=1.0):
        super().__init__()
        hidden = dim * ffn_expand

        # Main FFN path (vanilla-like)
        self.ffn_pw1 = nn.Conv3d(dim, hidden * 2, kernel_size=1, bias=True)
        self.ffn_dw = depthwise_conv3d(hidden * 2, kernel_size=3, padding=1)
        self.act = nn.GELU()
        self.ffn_pw2 = nn.Conv3d(hidden, dim, kernel_size=1, bias=True)

        # Self-modulation branch: Generate gamma (scale) and beta (shift) using spectral-adjacent conv
        # Kernel (3,1,1) captures adjacent spectral bands; depthwise for efficiency
        self.mod_gamma = nn.Sequential(
            depthwise_conv3d(dim, kernel_size=(3,1,1), padding=(1,0,0)),
            nn.Sigmoid()  # For scaling (0 to 1)
        )
        self.mod_beta = nn.Sequential(
            depthwise_conv3d(dim, kernel_size=(3,1,1), padding=(1,0,0)),
            nn.Tanh()  # For shifting (-1 to 1)
        )

        # Normalization and dropout
        self.norm = LayerNormChannel3d(dim)
        self.dropout = AdaptiveDropout3d(drop, factor=drop_factor)  # Adaptive dropout

        # Layer scale for residual contribution
        self.gamma_res = nn.Parameter(torch.ones(1, dim, 1, 1, 1) * 1e-4)

    def _ffn_path(self, x):
        """Vanilla FFN computation."""
        x2 = self.ffn_pw1(x)
        x2 = self.ffn_dw(x2)
        a, b = torch.chunk(x2, 2, dim=1)
        x2 = self.act(a) * b
        x2 = self.ffn_pw2(x2)
        return x2

    def forward(self, x):
        # Normalize input
        x_norm = self.norm(x)

        # Compute main FFN path
        ffn_out = self._ffn_path(x_norm)

        # Compute self-modulation parameters from input (using adjacent spectral info)
        gamma = self.mod_gamma(x_norm)  # Scale
        beta = self.mod_beta(x_norm)    # Shift

        # Apply modulation: gamma * ffn_out + beta
        modulated = gamma * ffn_out + beta

        # Apply dropout and scale
        modulated = self.dropout(modulated) * self.gamma_res

        # Residual connection: x + modulated
        return x + modulated


# ----------------------------
# New Classes Added (for Restormer Integration in Architecture)
# ----------------------------
class LayerNorm3d(nn.Module):
    """3D LayerNorm for channel normalization"""
    def __init__(self, num_channels, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps

    def forward(self, x):
        mu = x.mean(dim=1, keepdim=True)
        var = x.var(dim=1, keepdim=True, unbiased=False)
        inv_std = torch.rsqrt(var + self.eps)
        out = (x - mu) * inv_std * self.weight.view(1, -1, 1, 1, 1) + self.bias.view(1, -1, 1, 1, 1)
        return out

class GDFN(nn.Module):
    """Gated-Dconv Feed-Forward Network adapted to 3D, focusing on spatial"""
    def __init__(self, dim, ffn_expansion_factor=2.66, bias=False):
        super().__init__()
        hidden = int(dim * ffn_expansion_factor)
        self.project_in = nn.Conv3d(dim, hidden * 2, kernel_size=1, bias=bias)
        self.dwconv = nn.Conv3d(hidden * 2, hidden * 2, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), groups=hidden * 2, bias=bias)  # Spatial focus
        self.project_out = nn.Conv3d(hidden, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        x = F.gelu(x1) * x2
        x = self.project_out(x)
        return x

class MDTA(nn.Module):
    """Multi-Dconv Head Transposed Attention adapted to 3D, focusing on spatial"""
    def __init__(self, dim, num_heads, bias=False):
        super().__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
        self.qkv = nn.Conv3d(dim, dim * 3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv3d(dim * 3, dim * 3, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), groups=dim * 3, bias=bias)  # Spatial focus
        self.project_out = nn.Conv3d(dim, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        qkv = self.qkv_dwconv(self.qkv(x))
        q, k, v = qkv.chunk(3, dim=1)
        b, c, d, h, w = q.shape
        q = rearrange(q, 'b (head cc) d h w -> b head cc (d h w)', head=self.num_heads, cc=c // self.num_heads)
        k = rearrange(k, 'b (head cc) d h w -> b head cc (d h w)', head=self.num_heads)
        v = rearrange(v, 'b (head cc) d h w -> b head cc (d h w)', head=self.num_heads)
        q = F.normalize(q, dim=-1)
        k = F.normalize(k, dim=-1)
        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)
        out = (attn @ v)
        out = rearrange(out, 'b head cc (d h w) -> b (head cc) d h w', head=self.num_heads, d=d, h=h, w=w)
        out = self.project_out(out)
        return out

class RestormerBlock(nn.Module):
    """Restormer Transformer Block adapted to 3D"""
    def __init__(self, dim, num_heads=4, ffn_expansion_factor=2.66, bias=False):
        super().__init__()
        self.norm1 = LayerNorm3d(dim)
        self.attn = MDTA(dim, num_heads, bias)
        self.norm2 = LayerNorm3d(dim)
        self.ffn = GDFN(dim, ffn_expansion_factor, bias)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x
# ----------------------------
# New Classes Added (for Restormer Integration in Architecture)
# ----------------------------

# ----------------------------
# Modified FusedBottleneck (for Restormer Integration in Architecture)
# ----------------------------
# ============================================================
# INSERT THESE NEW CLASSES BEFORE FusedBottleneck (around line ~680)
# These are ADDITIONS, not replacements
# ============================================================

class PositionalEncoding3D(nn.Module):
    """
    FIXED: Lightweight 3D Positional Encoding
    
    Instead of storing full (C, D, H, W) tensor, we use:
    1. Separate 1D embeddings for each dimension (much smaller)
    2. Broadcast and combine at runtime
    
    Memory: O(C*D + C*H + C*W) instead of O(C*D*H*W)
    Example: 64*128 + 64*256 + 64*256 = 41K params vs 537M params!
    """
    def __init__(self, channels, max_d=128, max_h=256, max_w=256):
        super().__init__()
        self.channels = channels
        self.max_d = max_d
        self.max_h = max_h
        self.max_w = max_w
        
        # Separate 1D positional embeddings for each dimension
        # These will be broadcast and combined
        self.pos_embed_d = nn.Parameter(torch.zeros(1, channels, max_d, 1, 1))
        self.pos_embed_h = nn.Parameter(torch.zeros(1, channels, 1, max_h, 1))
        self.pos_embed_w = nn.Parameter(torch.zeros(1, channels, 1, 1, max_w))
        
        # Initialize with small random values
        nn.init.trunc_normal_(self.pos_embed_d, std=0.02)
        nn.init.trunc_normal_(self.pos_embed_h, std=0.02)
        nn.init.trunc_normal_(self.pos_embed_w, std=0.02)
        
        # Learnable scaling factors for each dimension
        self.scale_d = nn.Parameter(torch.ones(1))
        self.scale_h = nn.Parameter(torch.ones(1))
        self.scale_w = nn.Parameter(torch.ones(1))
        
    def forward(self, x):
        """
        Args:
            x: (B, C, D, H, W)
        Returns:
            x + positional encoding: (B, C, D, H, W)
        """
        B, C, D, H, W = x.shape
        
        # Slice and broadcast each dimension
        pe_d = self.pos_embed_d[:, :, :D, :, :] * self.scale_d
        pe_h = self.pos_embed_h[:, :, :, :H, :] * self.scale_h
        pe_w = self.pos_embed_w[:, :, :, :, :W] * self.scale_w
        
        # Combine positional encodings (broadcasting happens automatically)
        pe = pe_d + pe_h + pe_w  # Shape: (1, C, D, H, W)
        
        return x + pe


class CrossSpectralSpatialAttention(nn.Module):
    """
    NEW CLASS: Cross-attention between spectral and spatial features
    Enables joint spectral-spatial modeling instead of separate processing
    """
    def __init__(self, dim, num_heads=8, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        # Separate projections for cross-attention paths
        self.q_spectral = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv_spatial = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.q_spatial = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv_spectral = nn.Linear(dim, dim * 2, bias=qkv_bias)
        
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_spectral = nn.Linear(dim, dim)
        self.proj_spatial = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        
        # Gated fusion
        self.gate = nn.Sequential(
            nn.Linear(dim * 2, dim),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        """(B, C, D, H, W) -> (B, C, D, H, W)"""
        B, C, D, H, W = x.shape
        
        # Path 1: Spectral features with spatial context
        x_spectral = x.permute(0, 3, 4, 2, 1).reshape(B * H * W, D, C)
        q_spec = self.q_spectral(x_spectral)
        q_spec = q_spec.reshape(B * H * W, D, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        
        q_spec_scaled = q_spec * self.scale
        attn_spec = torch.matmul(q_spec_scaled, q_spec_scaled.transpose(-2, -1))
        attn_spec = F.softmax(attn_spec, dim=-1)
        attn_spec = self.attn_drop(attn_spec)
        
        out_spec = torch.matmul(attn_spec, q_spec)
        out_spec = out_spec.transpose(1, 2).reshape(B * H * W, D, C)
        out_spec = self.proj_spectral(out_spec)
        out_spec = self.proj_drop(out_spec)
        out_spectral = out_spec.reshape(B, H, W, D, C).permute(0, 4, 3, 1, 2)
        
        # Path 2: Spatial features with spectral context
        x_spatial_q = x.permute(0, 2, 3, 4, 1).reshape(B * D, H * W, C)
        q_spat = self.q_spatial(x_spatial_q)
        q_spat = q_spat.reshape(B * D, H * W, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        
        q_spat_scaled = q_spat * self.scale
        attn_spat = torch.matmul(q_spat_scaled, q_spat_scaled.transpose(-2, -1))
        attn_spat = F.softmax(attn_spat, dim=-1)
        attn_spat = self.attn_drop(attn_spat)
        
        out_spat = torch.matmul(attn_spat, q_spat)
        out_spat = out_spat.transpose(1, 2).reshape(B * D, H * W, C)
        out_spat = self.proj_spatial(out_spat)
        out_spat = self.proj_drop(out_spat)
        out_spatial = out_spat.reshape(B, D, H, W, C).permute(0, 4, 1, 2, 3)
        
        # Gated fusion
        concat_features = torch.cat([out_spectral, out_spatial], dim=1)
        gate_input = F.adaptive_avg_pool3d(concat_features, 1).squeeze(-1).squeeze(-1).squeeze(-1)
        gate = self.gate(gate_input).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        
        fused = gate * out_spectral + (1 - gate) * out_spatial
        return fused


class EnhancedBottleneck(nn.Module):
    """
    NEW CLASS: SOTA-level bottleneck with cross-attention and joint modeling
    Will be used in the main architecture
    """
    def __init__(self, dim, num_heads=8, mlp_ratio=4.):
        super().__init__()
        self.dim = dim
        
        self.norm1 = LayerNormChannel3d(dim)
        self.norm2 = LayerNormChannel3d(dim)
        self.norm3 = LayerNormChannel3d(dim)
        
        # Cross-attention for spectral-spatial joint modeling
        self.cross_attn = CrossSpectralSpatialAttention(
            dim=dim,
            num_heads=num_heads,
            qkv_bias=True,
            attn_drop=0.1,
            proj_drop=0.1
        )
        
        # Non-local attention for global context
        self.non_local = nn.Sequential(
            nn.Conv3d(dim, dim // 2, 1),
            nn.GELU(),
            nn.Conv3d(dim // 2, dim, 1),
            nn.Sigmoid()
        )
        
        # Keep using existing GDFN and SSMRB
        self.ffn = GDFN(dim, ffn_expansion_factor=mlp_ratio, bias=False)
        self.spectral_refine = SpectralSelfModulatingResidualBlock(
            dim, ffn_expand=2, drop=0.1, drop_factor=1.0
        )
        
        # Gated fusion
        self.fusion_gate = nn.Sequential(
            nn.Conv3d(dim * 3, dim, 1),
            nn.GELU(),
            nn.Conv3d(dim, dim, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        """(B, C, D, H, W) -> (B, C, D, H, W)"""
        identity = x
        
        # Path 1: Cross-attention
        cross_out = self.cross_attn(self.norm1(x))
        x = x + cross_out
        
        # Path 2: Non-local
        non_local_weight = self.non_local(self.norm2(x))
        non_local_out = x * non_local_weight
        x = x + non_local_out
        
        # Path 3: FFN + Spectral refinement
        ffn_out = self.ffn(self.norm3(x))
        spectral_out = self.spectral_refine(x + ffn_out)
        
        # Gated fusion
        fusion_input = torch.cat([cross_out, non_local_out, spectral_out], dim=1)
        fusion_weight = self.fusion_gate(fusion_input)
        
        out = identity + fusion_weight * (cross_out + non_local_out + spectral_out) / 3.0
        
        return out

# ----------------------------
# Modified FusedBottleneck (for Restormer Integration in Architecture)
# ----------------------------
        
class FusedBottleneck(nn.Module):
    """
    IMPROVED FusedBottleneck - SAME NAME, better implementation
    Now uses stacked EnhancedBottleneck blocks for SOTA performance
    Drop-in replacement - same API, just calls different internals
    """
    def __init__(self, base_dim, window_sizes=[2, 4]):
        super().__init__()
        # Calculate actual dim from base_dim (maintains compatibility)
        # Your original: dim = base_dim * 4
        # But you call it with base_dim * 2, so actual dim is base_dim * 8
        dim = base_dim * 4  # This gives base_dim * 8 when called with base_dim * 2
        
        # Use stacked enhanced bottleneck blocks instead of old approach
        self.blocks = nn.ModuleList([
            EnhancedBottleneck(dim, num_heads=8, mlp_ratio=4.),
            #EnhancedBottleneck(dim, num_heads=8, mlp_ratio=4.)
        ])
    
    def forward(self, x):
        """
        SAME API: (B, C, D, H, W) -> (B, C, D, H, W)
        """
        for block in self.blocks:
            x = block(x)
        return x

# ----------------------------
# Modified FusedBottleneck (for Restormer Integration in Architecture)
# ----------------------------

# ----------------------------
# Efficient Loss Function
# ----------------------------

# ----------------------------
# Modified MemoryEfficientLoss (for Spatial-Focused Loss Improvements)
# ----------------------------
class MemoryEfficientLoss(nn.Module):
    """Lightweight but effective loss function with FIXED weights and tensor handling"""
    def __init__(self, device='cuda', mse_weight=1.0, l1_weight=1.0, sam_weight=0.5, edge_weight=0.2):
        super().__init__()
        self.mse = nn.MSELoss()
        self.l1 = nn.L1Loss()
        self.device = device
        self.mse_weight = mse_weight
        self.l1_weight = l1_weight
        self.sam_weight = sam_weight
        self.edge_weight = edge_weight

    def forward(self, pred, target, epoch=None):
        # FIXED: Ensure both tensors have same shape
        if pred.shape != target.shape:
            # If shapes don't match, interpolate pred to match target
            if pred.dim() == 5 and target.dim() == 5:
                pred = F.interpolate(pred, size=target.shape[2:], mode='trilinear', align_corners=False)
            elif pred.dim() == 4 and target.dim() == 4:
                pred = F.interpolate(pred, size=target.shape[2:], mode='bilinear', align_corners=False)

        # Main losses
        mse_loss = self.mse(pred, target)
        l1_loss = self.l1(pred, target)

        # FIXED: SAM calculation with proper tensor handling
        eps = 1e-8

        # Handle both 4D and 5D tensors
        if pred.dim() == 5:  # (B, C, D, H, W)
            B, C, D, H, W = pred.shape
            pred_flat = pred.reshape(B, C, D * H * W)  # (B, C, D*H*W)
            target_flat = target.reshape(B, C, D * H * W)  # (B, C, D*H*W)

            # Normalize along channel dimension (spectral bands)
            pred_norm = F.normalize(pred_flat, dim=1, eps=eps)
            target_norm = F.normalize(target_flat, dim=1, eps=eps)

            # Compute cosine similarity along spectral dimension
            cos_sim = torch.sum(pred_norm * target_norm, dim=1)  # (B, D*H*W)

        elif pred.dim() == 4:  # (B, D, H, W) - spectral first
            B, D, H, W = pred.shape
            pred_flat = pred.reshape(B, D, H * W)  # (B, D, H*W)
            target_flat = target.reshape(B, D, H * W)  # (B, D, H*W)

            # Normalize along spectral dimension
            pred_norm = F.normalize(pred_flat, dim=1, eps=eps)
            target_norm = F.normalize(target_flat, dim=1, eps=eps)

            # Compute cosine similarity along spectral dimension
            cos_sim = torch.sum(pred_norm * target_norm, dim=1)  # (B, H*W)

        else:
            # Fallback for other dimensions
            pred_flat = pred.flatten(start_dim=1)
            target_flat = target.flatten(start_dim=1)
            pred_norm = F.normalize(pred_flat, dim=1, eps=eps)
            target_norm = F.normalize(target_flat, dim=1, eps=eps)
            cos_sim = torch.sum(pred_norm * target_norm, dim=1)

        cos_sim = torch.clamp(cos_sim, -1 + eps, 1 - eps)
        sam_loss = torch.mean(1 - cos_sim)

        # FIXED: Edge loss with proper spatial dimension handling
        def spatial_gradient(x):
            if x.dim() == 5:  # (B, C, D, H, W)
                grad_h = torch.abs(x[:, :, :, 1:, :] - x[:, :, :, :-1, :])
                grad_w = torch.abs(x[:, :, :, :, 1:] - x[:, :, :, :, :-1])
            elif x.dim() == 4:  # (B, D, H, W)
                grad_h = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :])
                grad_w = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])
            else:
                return 0, 0
            return grad_h.mean(), grad_w.mean()

        pred_grad_h, pred_grad_w = spatial_gradient(pred)
        target_grad_h, target_grad_w = spatial_gradient(target)
        edge_loss = abs(pred_grad_h - target_grad_h) + abs(pred_grad_w - target_grad_w)

        # Static combination
        total_loss = (
            self.mse_weight * mse_loss +
            self.l1_weight * l1_loss +
            self.sam_weight * sam_loss +
            self.edge_weight * edge_loss
        )

        return total_loss
# ----------------------------
# Modified MemoryEfficientLoss (for Spatial-Focused Loss Improvements)
# ----------------------------

# ----------------------------
# Memory-Efficient U-Net
# ----------------------------


# ----------------------------
# Memory-Efficient U-Net
# ----------------------------
class MemoryOptimizedUNet(nn.Module):
    """
    SST-based U-Net for HSI Denoising (MODIFIED)
    - 4 hierarchical stages with SST blocks instead of Swin
    - Spectral-aware attention at all levels
    - Deep supervision at each decoder stage
    """
    def __init__(self, in_channels=1, base_dim=48, window_sizes=[4, 8, 16], num_bands=64):
        super().__init__()
        self.base_dim = base_dim
        self.in_channels = in_channels
        self.num_bands = num_bands
        
        # Initial projection
        self.patch_embed = nn.Conv3d(in_channels, base_dim, kernel_size=3, padding=1)
        self.pos_embed_init = PositionalEncoding3D(base_dim, 128, 256, 256)
        
        # ENCODER: 4 SST stages with [2, 2, 6, 2] depth
        # Stage 1: base_dim, shallow features
        self.enc_stage1 = SSTStage(
            dim=base_dim,
            num_bands=num_bands,
            depth=2,
            num_heads=8,
            window_size=8,
            mlp_ratio=4.,
            drop=0.0,
            attn_drop=0.0,
            drop_path_rate=0.05,
            downsample=PatchMerging3D
        )
        
        # Stage 2: base_dim*2, intermediate features
        self.enc_stage2 = SSTStage(
            dim=base_dim * 2,
            num_bands=num_bands // 2,  # Bands halved after merging
            depth=2,
            num_heads=8,
            window_size=8,
            mlp_ratio=4.,
            drop=0.0,
            attn_drop=0.0,
            drop_path_rate=0.1,
            downsample=PatchMerging3D
        )
        
        # Stage 3: base_dim*4, deep features (6 blocks)
        self.enc_stage3 = SSTStage(
            dim=base_dim * 4,
            num_bands=num_bands // 4,
            depth=6,
            num_heads=16,
            window_size=4,
            mlp_ratio=4.,
            drop=0.0,
            attn_drop=0.0,
            drop_path_rate=0.15,
            downsample=PatchMerging3D
        )
        
        # Stage 4 (deepest): base_dim*8, bottleneck
        self.enc_stage4 = SSTStage(
            dim=base_dim * 8,
            num_bands=num_bands // 8,
            depth=2,
            num_heads=16,
            window_size=2,
            mlp_ratio=4.,
            drop=0.0,
            attn_drop=0.0,
            drop_path_rate=0.2,
            downsample=None
        )

        self.pe_enc1 = PositionalEncoding3D(base_dim, 128, 256, 256)
        self.pe_enc2 = PositionalEncoding3D(base_dim * 2, 64, 128, 128)
        self.pe_enc3 = PositionalEncoding3D(base_dim * 4, 32, 64, 64)
        self.pe_enc4 = PositionalEncoding3D(base_dim * 8, 16, 32, 32)
        
        # BOTTLENECK: Keep your custom FusedBottleneck (works well with SST)
        self.spectral_attention = SpectralAttentionModule(base_dim * 8, num_heads=8)
        self.bottleneck_fusion = FusedBottleneck(base_dim * 2, window_sizes=window_sizes)
        
        # DECODER: 4 SST stages matching encoder
        # Stage 3 decoder
        self.dec_stage3_up = PatchExpanding3D(dim=base_dim * 8)
        self.dec_stage3 = SSTStage(
            dim=base_dim * 4,
            num_bands=num_bands // 4,
            depth=6,
            num_heads=16,
            window_size=4,
            mlp_ratio=4.,
            drop=0.0,
            attn_drop=0.0,
            drop_path_rate=0.1,
            downsample=None
        )
        
        # Stage 2 decoder
        self.dec_stage2_up = PatchExpanding3D(dim=base_dim * 4)
        self.dec_stage2 = SSTStage(
            dim=base_dim * 2,
            num_bands=num_bands // 2,
            depth=2,
            num_heads=8,
            window_size=8,
            mlp_ratio=4.,
            drop=0.0,
            attn_drop=0.0,
            drop_path_rate=0.08,
            downsample=None
        )
        
        # Stage 1 decoder
        self.dec_stage1_up = PatchExpanding3D(dim=base_dim * 2)
        self.dec_stage1 = SSTStage(
            dim=base_dim,
            num_bands=num_bands,
            depth=2,
            num_heads=8,
            window_size=8,
            mlp_ratio=4.,
            drop=0.0,
            attn_drop=0.0,
            drop_path_rate=0.02,
            downsample=None
        )
        
        # DEEP SUPERVISION: Auxiliary outputs at each decoder stage
        self.deep_sup3 = nn.Conv3d(base_dim * 4, in_channels, 1)
        self.deep_sup2 = nn.Conv3d(base_dim * 2, in_channels, 1)
        self.deep_sup1 = nn.Conv3d(base_dim, in_channels, 1)
        
        # Final reconstruction
        self.final_conv = nn.Sequential(
            nn.Conv3d(base_dim, base_dim // 2, 3, padding=1),
            nn.GELU(),
            nn.Conv3d(base_dim // 2, in_channels, 1),
        )
        
        # Global residual
        self.global_residual = nn.Conv3d(in_channels, in_channels, 1)
        
        # Deep supervision flag
        self.use_deep_supervision = True

    def _align_tensors(self, x, target_size):
        if x.shape[2:] != target_size:
            x = F.interpolate(x, size=target_size, mode='trilinear', align_corners=False)
        return x

    def forward(self, x, return_deep_sup=False):
        # Handle input shape
        original_was_4d = False
        if x.dim() == 4:
            original_was_4d = True
            x = x.unsqueeze(1)
        elif x.dim() == 5 and x.shape[1] != 1:
            if x.shape[2] == 1:
                x = x.transpose(1, 2)
        
        original_size = x.shape[2:]
        input_residual = self.global_residual(x)
        
        # Initial embedding
        x = self.patch_embed(x)
        x = self.pos_embed_init(x)
        
        # ENCODER (4 SST stages)
        e1, e1_down = self.enc_stage1(self.pe_enc1(x))       # Skip 1
        e2, e2_down = self.enc_stage2(self.pe_enc2(e1_down))  # Skip 2
        e3, e3_down = self.enc_stage3(self.pe_enc3(e2_down))  # Skip 3
        e4, _ = self.enc_stage4(self.pe_enc4(e3_down))        # Deepest features
        
        # BOTTLENECK: Hybrid attention
        b = self.spectral_attention(e4)  # Add spectral attention
        b = self.bottleneck_fusion(b)   # Your custom fusion
        
        # DECODER with deep supervision
        deep_outputs = []
        
        # Decoder stage 3
        d3 = self.dec_stage3_up(b)
        d3 = self._align_tensors(d3, e3.shape[2:])
        d3 = d3 + e3  # Skip connection
        d3, _ = self.dec_stage3(d3)
        if self.training and self.use_deep_supervision:
            sup3 = self.deep_sup3(d3)
            sup3 = self._align_tensors(sup3, original_size)
            deep_outputs.append(sup3)
        
        # Decoder stage 2
        d2 = self.dec_stage2_up(d3)
        d2 = self._align_tensors(d2, e2.shape[2:])
        d2 = d2 + e2
        d2, _ = self.dec_stage2(d2)
        if self.training and self.use_deep_supervision:
            sup2 = self.deep_sup2(d2)
            sup2 = self._align_tensors(sup2, original_size)
            deep_outputs.append(sup2)
        
        # Decoder stage 1
        d1 = self.dec_stage1_up(d2)
        d1 = self._align_tensors(d1, e1.shape[2:])
        d1 = d1 + e1
        d1, _ = self.dec_stage1(d1)
        if self.training and self.use_deep_supervision:
            sup1 = self.deep_sup1(d1)
            sup1 = self._align_tensors(sup1, original_size)
            deep_outputs.append(sup1)
        
        # Final reconstruction
        out = self.final_conv(d1)
        out = self._align_tensors(out, original_size)
        input_residual = self._align_tensors(input_residual, original_size)
        out = out + input_residual
        
        # Return format handling
        if original_was_4d and out.shape[1] == 1:
            out = out.squeeze(1)
            if self.training and self.use_deep_supervision:
                deep_outputs = [o.squeeze(1) for o in deep_outputs]
        
        if return_deep_sup and self.training:
            return out, deep_outputs
        return out
# ----------------------------
# Efficient Data Loading
# ----------------------------
# ----------------------------
# Modified domain_shift_augment (for Spatial Augmentations in Data)
# ----------------------------
def domain_shift_augment(cube):
    """Simulate sensor differences for better generalization"""
    # Simulate sensor differences: random spectral scaling and offset
    if random.random() < 0.5:
        scale = random.uniform(0.9, 1.1)
        offset = random.uniform(-0.05, 0.05)
        cube = np.clip(cube * scale + offset, 0, 1)
    # Simulate band misalignment: slight band shuffling with probability
    if random.random() < 0.3:
        perm = np.arange(cube.shape[0])
        np.random.shuffle(perm)
        shift = random.randint(1, 3)
        perm = np.roll(perm, shift)
        cube = cube[perm]
    # Added spatial-heavy: rotate on spatial axes
    if random.random() < 0.4:
        angle = random.uniform(-15, 15)
        for i in range(cube.shape[0]):
            cube[i] = scipy.ndimage.rotate(cube[i], angle, reshape=False, mode='reflect')
    return cube

# ----------------------------
# Modified domain_shift_augment (for Spatial Augmentations in Data)
# ----------------------------

# ----------------------------
# Modified MemoryEfficientHSIDataset (for Spatial Augmentations and Curriculum in Data)
# ----------------------------
class MemoryEfficientHSIDataset(Dataset):
    def __init__(self, files, patch_size, noise_level=30,
                 patches_per_file=200, target_bands=None, augment=True, 
                 dataset_type="unknown", train_crop_size=1024, scales=None):
        super().__init__()
        self.files = files
        self.patch_size = patch_size
        self.noise_level = noise_level / 255.0
        self.patches_per_file = patches_per_file
        self.augment = augment
        self.target_bands = target_bands if target_bands else 31  # ICVL default: 31 bands
        self.dataset_type = dataset_type
        self.synthetic_count = 0
        self.real_data_count = 0
        self.train_crop_size = train_crop_size  # ICVL: center crop size
        self.scales = scales if scales else [64, 32, 32]  # ICVL: multi-scale strides

        # Keep existing file loading logic
        self.file_info = []
        print(f"\n=== Initializing {dataset_type.upper()} Dataset ===")
        print(f"Noise Level (σ): {noise_level} -> {self.noise_level:.4f} (0-1 scale)")
        print(f"Center crop size: {train_crop_size}x{train_crop_size}")
        print(f"Patch size: {patch_size}x{patch_size}")
        print(f"Multi-scale strides: {self.scales}")

        if len(files) > 0 and files[0] != 'synthetic':
            print(f"Attempting to load {len(files)} files...")
            successful_files = []
            failed_files = []

            for i, file_path in enumerate(files):
                if os.path.exists(file_path):
                    try:
                        print(f"  [{i+1}/{len(files)}] Loading: {os.path.basename(file_path)}")

                        # Try h5py first (for MATLAB v7.3)
                        h5py_success = False
                        try:
                            with h5py.File(file_path, 'r') as f:
                                # FIXED: Look for 'rad' key specifically for ICVL data
                                if 'rad' in f:
                                    key = 'rad'
                                    data = f[key]
                                    cube_shape = data.shape
                                    print(f"    ✓ Shape: {cube_shape}, Key: '{key}' (HDF5)")
                                    self.file_info.append({'path': file_path, 'key': key, 'shape': cube_shape, 'format': 'h5py'})
                                    successful_files.append(os.path.basename(file_path))
                                    h5py_success = True
                                else:
                                    raise Exception("No 'rad' key found in file")
                        except Exception as h5_error:
                            # Fallback to scipy.io for older MAT files
                            if sio and not h5py_success:
                                try:
                                    mat = sio.loadmat(file_path)
                                    key = [k for k in mat.keys() if not k.startswith('__')][0]
                                    cube_shape = mat[key].shape
                                    print(f"    ✓ Shape: {cube_shape}, Key: '{key}' (scipy)")
                                    self.file_info.append({'path': file_path, 'key': key, 'shape': cube_shape, 'format': 'scipy'})
                                    successful_files.append(os.path.basename(file_path))
                                except Exception as scipy_error:
                                    raise Exception(f"Both h5py ({str(h5_error)[:50]}) and scipy ({str(scipy_error)[:50]}) failed")
                    except Exception as e:
                        print(f"    ✗ Failed: {e}")
                        failed_files.append(os.path.basename(file_path))

            print(f"\n{dataset_type.upper()} Dataset Summary:")
            print(f"  ✓ Successfully loaded: {len(successful_files)} files")
            print(f"  → Patches per file: {patches_per_file}")
            print(f"  → Total iterations per epoch: {len(successful_files) * patches_per_file}")
            if not self.file_info:
                print(f"  → Will use SYNTHETIC data for {dataset_type}")

    def __len__(self):
        return max(1, len(self.file_info)) * self.patches_per_file

    def _generate_synthetic_data(self):
        """Generate BETTER synthetic HSI data with proper spectral correlation"""
        self.synthetic_count += 1
    
        D, H, W = self.target_bands, self.patch_size, self.patch_size
    
        clean = torch.zeros(D, H, W)
    
        # Create different material signatures
        n_materials = 3
        for mat in range(n_materials):
            signature = torch.randn(D)
            signature = torch.softmax(signature, dim=0)
    
            center_h, center_w = random.randint(H//4, 3*H//4), random.randint(W//4, 3*W//4)
            for h in range(H):
                for w in range(W):
                    dist = ((h - center_h)**2 + (w - center_w)**2) ** 0.5
                    weight = torch.exp(torch.tensor(-dist / (H/4)))  # FIX: Convert to tensor
                    clean[:, h, w] += weight * signature * random.uniform(0.3, 1.0)
    
        # Normalize to [0, 1]
        clean = (clean - clean.min()) / (clean.max() - clean.min() + 1e-8)
    
        # Add noise in [0,1] range
        noise = torch.randn_like(clean) * self.noise_level
        noisy = torch.clamp(clean + noise, 0, 1)

        return noisy, clean

    def __getitem__(self, idx):
        # REMOVED: Synthetic data fallback - force real data only
        if not self.file_info:
            raise RuntimeError(
                f"CRITICAL ERROR: No files loaded in {self.dataset_type} dataset!\n"
                f"file_info is empty. Check __init__ file loading."
            )
    
        file_idx = idx // self.patches_per_file
        file_info = self.file_info[file_idx % len(self.file_info)]
    
        # Load based on file format
        if file_info.get('format') == 'h5py':
            with h5py.File(file_info['path'], 'r') as f:
                cube = np.array(f[file_info['key']]).astype(np.float32)
        else:
            mat = sio.loadmat(file_info['path'])
            cube = mat[file_info['key']].astype(np.float32)
        
        # Ensure proper shape (H, W, D)
        if cube.ndim == 3:
            # ICVL data is in (D, H, W) format, need to transpose to (H, W, D)
            if cube.shape[0] < min(cube.shape[1:]):  # Shape is (D, H, W)
                cube = cube.transpose(1, 2, 0)  # Convert to (H, W, D)
            elif cube.shape[2] > min(cube.shape[:2]):  # Shape is likely (H, W, D) already
                pass
            else:
                raise ValueError(
                    f"ERROR: Unexpected cube shape {cube.shape} from file {file_info['path']}\n"
                    f"Cannot determine if format is (D,H,W) or (H,W,D)"
                )
        else:
            raise ValueError(
                f"ERROR: cube has {cube.ndim} dimensions, expected 3\n"
                f"Shape: {cube.shape}, File: {file_info['path']}"
            )
                
        H, W, D = cube.shape
    
        # ICVL: Center crop to train_crop_size (1024x1024)
        if H > self.train_crop_size or W > self.train_crop_size:
            start_h = (H - self.train_crop_size) // 2
            start_w = (W - self.train_crop_size) // 2
            cube = cube[start_h:start_h+self.train_crop_size,
                       start_w:start_w+self.train_crop_size, :]
            H, W = self.train_crop_size, self.train_crop_size
    
        # Check minimum size
        if H < self.patch_size or W < self.patch_size:
            raise ValueError(
                f"ERROR: Image too small after cropping\n"
                f"Size: {H}x{W}, Required: {self.patch_size}x{self.patch_size}\n"
                f"File: {file_info['path']}"
            )
        
        if D < 4:
            raise ValueError(
                f"ERROR: Too few spectral bands: {D}\n"
                f"File: {file_info['path']}"
            )
    
        # ICVL: Randomly choose stride from multi-scale strides
        stride = random.choice(self.scales)
        
        # Random patch extraction
        start_h = random.randint(0, H - self.patch_size)
        start_w = random.randint(0, W - self.patch_size)
        cube = cube[start_h:start_h+self.patch_size,
                   start_w:start_w+self.patch_size,
                   :min(self.target_bands, D)]
    
        # ICVL: Normalize to [0, 1]
        cube_min = cube.min()
        cube_max = cube.max()
        if cube_max > cube_min:
            cube = (cube - cube_min) / (cube_max - cube_min)
        else:
            cube = np.clip(cube / (np.max(cube) + 1e-8), 0, 1)
    
        # Convert to (D, H, W) format for model
        cube = cube.transpose(2, 0, 1)
    
        # ============================================================
        # AUGMENTATION PIPELINE (single coherent block)
        # ============================================================
        if self.augment:
            # Geometric augmentations (on numpy arrays)
            # Random rotation (90° increments)
            if random.random() < 0.5:
                k = random.choice([1, 2, 3])
                cube = np.rot90(cube, k, axes=(1, 2)).copy()
            
            # Horizontal flip
            if random.random() < 0.5:
                cube = np.flip(cube, axis=1).copy()
            
            # Vertical flip
            if random.random() < 0.5:
                cube = np.flip(cube, axis=2).copy()
            
            # Gaussian blur (spatial only) - applied to numpy before tensor conversion
            if random.random() < 0.3:
                sigma = random.uniform(0.3, 0.8)
                for i in range(cube.shape[0]):
                    cube[i] = scipy.ndimage.gaussian_filter(cube[i], sigma=sigma)
    
        # Convert to tensor AFTER all numpy augmentations
        clean_tensor = torch.from_numpy(cube.copy())  # Extra safety: ensure contiguous
    
        # Add noise in [0,1] range
        noise = torch.randn_like(clean_tensor) * self.noise_level
        noisy_tensor = torch.clamp(clean_tensor + noise, 0, 1)
    
        # Spectral augmentation (on tensors, applied to both clean and noisy)
        if self.augment and random.random() < 0.4:
            scale = torch.tensor(
                np.random.uniform(0.95, 1.05, size=(clean_tensor.shape[0], 1, 1)), 
                dtype=torch.float32
            )
            clean_tensor = torch.clamp(clean_tensor * scale, 0, 1)
            noisy_tensor = torch.clamp(noisy_tensor * scale, 0, 1)
    
        self.real_data_count += 1
        return noisy_tensor.float(), clean_tensor.float()

    def get_usage_stats(self):
        """Return statistics about data usage"""
        total = self.real_data_count + self.synthetic_count
        if total == 0:
            return "No data accessed yet"
        real_percent = (self.real_data_count / total) * 100
        synthetic_percent = (self.synthetic_count / total) * 100
        return f"Real data: {self.real_data_count} ({real_percent:.1f}%), Synthetic: {self.synthetic_count} ({synthetic_percent:.1f}%)"

# ----------------------------
# Metric Calculation Functions
# ----------------------------
def calculate_psnr(pred, target):
    mse = torch.mean((pred - target) ** 2)
    return 10 * torch.log10(1.0 / (mse + 1e-8))

def calculate_ssim(pred_np, target_np):
    """Calculate SSIM across all spectral bands"""
    if pred_np.ndim == 3:  # (D, H, W)
        D, H, W = pred_np.shape
        ssim_vals = []
        for d in range(D):
            try:
                ssim_val = compare_ssim(pred_np[d], target_np[d], data_range=1.0)
                ssim_vals.append(ssim_val)
            except Exception:
                ssim_vals.append(0.5)  # Fallback
        return np.mean(ssim_vals)
    else:
        return compare_ssim(pred_np, target_np, data_range=1.0)

def calculate_sam(pred_np, target_np):
    """Calculate Spectral Angle Mapper"""
    eps = 1e-8
    if pred_np.ndim == 3:  # (D, H, W)
        pred_flat = pred_np.reshape(pred_np.shape[0], -1)  # (D, H*W)
        target_flat = target_np.reshape(target_np.shape[0], -1)

        dot = np.sum(pred_flat * target_flat, axis=0)
        norm_pred = np.linalg.norm(pred_flat, axis=0) + eps
        norm_target = np.linalg.norm(target_flat, axis=0) + eps
        cos_angle = np.clip(dot / (norm_pred * norm_target), -1, 1)
        angles = np.arccos(cos_angle)
        return np.mean(angles)
    else:
        dot = np.sum(pred_np * target_np)
        norm_pred = np.linalg.norm(pred_np) + eps
        norm_target = np.linalg.norm(target_np) + eps
        cos_angle = np.clip(dot / (norm_pred * norm_target), -1, 1)
        return np.arccos(cos_angle)

# ----------------------------
# Visualizations Functions
# ----------------------------
def create_training_visualizations(train_losses, val_losses, val_psnrs, val_ssims, val_sams, learning_rates, save_dir, best_psnr):
    """Create comprehensive training visualizations and save them"""
    import matplotlib.pyplot as plt
    plt.style.use('default')

    # Create figure with subplots
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle(f'HSI Denoising Training Progress (Best PSNR: {best_psnr:.4f} dB)', fontsize=16)

    # Calculate validation epochs
    val_epochs = [i*5 for i in range(1, len(val_losses)+1)] if val_losses else []

    # 1. Training and Validation Loss
    axes[0, 0].plot(range(1, len(train_losses)+1), train_losses, 'b-', label='Training Loss', alpha=0.8)
    if val_losses:
        axes[0, 0].plot(val_epochs, val_losses, 'r-', label='Validation Loss', marker='o', markersize=3)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training and Validation Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    axes[0, 0].set_yscale('log')

    # 2. PSNR over Epochs
    if val_psnrs:
        axes[0, 1].plot(val_epochs, val_psnrs, 'r-', label='PSNR', marker='o', markersize=4)
        axes[0, 1].axhline(y=40, color='purple', linestyle='--', label='Target (40 dB)', alpha=0.7)
        axes[0, 1].axhline(y=best_psnr, color='orange', linestyle=':', label=f'Best ({best_psnr:.2f} dB)', alpha=0.7)
        axes[0, 1].legend()
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('PSNR (dB)')
    axes[0, 1].set_title('PSNR over Epochs')
    axes[0, 1].grid(True, alpha=0.3)

    # 3. SSIM over Epochs
    if val_ssims:
        axes[0, 2].plot(val_epochs, val_ssims, 'r-', label='SSIM', marker='o', markersize=4)
        axes[0, 2].axhline(y=1.0, color='purple', linestyle='--', label='Perfect (1.0)', alpha=0.7)
        axes[0, 2].legend()
    axes[0, 2].set_xlabel('Epoch')
    axes[0, 2].set_ylabel('SSIM')
    axes[0, 2].set_title('SSIM over Epochs')
    axes[0, 2].grid(True, alpha=0.3)
    axes[0, 2].set_ylim([0, 1.1])

    # 4. SAM over Epochs
    if val_sams:
        axes[1, 0].plot(val_epochs, val_sams, 'r-', label='SAM', marker='o', markersize=4)
        axes[1, 0].axhline(y=0, color='purple', linestyle='--', label='Perfect (0.0)', alpha=0.7)
        axes[1, 0].legend()
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('SAM (radians)')
    axes[1, 0].set_title('SAM over Epochs')
    axes[1, 0].grid(True, alpha=0.3)

    # 5. Learning Rate Changes
    axes[1, 1].plot(range(1, len(learning_rates)+1), learning_rates, 'purple', linewidth=2)
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Learning Rate')
    axes[1, 1].set_title('Learning Rate Schedule')
    axes[1, 1].grid(True, alpha=0.3)
    axes[1, 1].set_yscale('log')

    # 6. Combined Metrics Summary
    if val_psnrs and val_ssims and val_sams:
        # Normalize metrics for combined view
        norm_psnr = np.array(val_psnrs) / 50.0  # Normalize to ~1
        norm_ssim = np.array(val_ssims)  # Already 0-1
        norm_sam = 1 - np.array(val_sams)  # Invert so higher is better

        axes[1, 2].plot(val_epochs, norm_psnr, 'b-', label='PSNR (normalized)', alpha=0.7, linewidth=2)
        axes[1, 2].plot(val_epochs, norm_ssim, 'r-', label='SSIM', alpha=0.7, linewidth=2)
        axes[1, 2].plot(val_epochs, norm_sam, 'g-', label='1-SAM', alpha=0.7, linewidth=2)
        axes[1, 2].legend()

    axes[1, 2].set_xlabel('Epoch')
    axes[1, 2].set_ylabel('Normalized Metrics')
    axes[1, 2].set_title('Combined Metrics Overview')
    axes[1, 2].grid(True, alpha=0.3)
    axes[1, 2].set_ylim([0, 1.1])

    plt.tight_layout()

    # Save the plot
    plot_path = os.path.join(save_dir, 'training_progress.png')
    plt.savefig(plot_path, dpi=300, bbox_inches='tight', facecolor='white')
    print(f"Training plots saved to: {plot_path}")

    # Also save as PDF for high quality
    plot_path_pdf = os.path.join(save_dir, 'training_progress.pdf')
    plt.savefig(plot_path_pdf, bbox_inches='tight', facecolor='white')

    plt.show()
    plt.close()

def print_final_scores(val_psnrs, val_ssims, val_sams, best_psnr, train_files, val_files, save_dir):
    """Print comprehensive final scores"""
    print("\n" + "="*80)
    print(f"Best Validation PSNR: {best_psnr:.4f} dB")
    print("Training finished. Models saved with comprehensive metadata.")
    print(f"Results saved to: {save_dir}")

    print("\n=== FINAL SCORES ===")
    if val_psnrs and val_ssims and val_sams:
        # Handle single validation set
        final_psnr = val_psnrs[-1]
        final_ssim = val_ssims[-1]
        final_sam = val_sams[-1]

        print(f"Final Validation PSNR: {final_psnr:.4f} dB")
        print(f"Final Validation SSIM: {final_ssim:.4f}")
        print(f"Final Validation SAM: {final_sam:.4f} radians")
        
        # Additional statistics
        print(f"\nValidation History:")
        print(f"  Mean PSNR: {np.mean(val_psnrs):.4f} ± {np.std(val_psnrs):.4f} dB")
        print(f"  Mean SSIM: {np.mean(val_ssims):.4f} ± {np.std(val_ssims):.4f}")
        print(f"  Mean SAM:  {np.mean(val_sams):.4f} ± {np.std(val_sams):.4f} radians")
        print(f"  Best PSNR: {max(val_psnrs):.4f} dB")
        print(f"  Worst PSNR: {min(val_psnrs):.4f} dB")

    # Count actual files (filter out 'synthetic')
    real_train_files = [f for f in train_files if f != 'synthetic']
    real_val_files = [f for f in val_files if f != 'synthetic']

    print(f"\nDataset Information:")
    print(f"  Training files: {len(real_train_files)}")
    print(f"  Validation files: {len(real_val_files)}")
    print("="*80)

#WarmupCosineScheduler
class WarmupCosineScheduler:
    """
    IMPROVED: LR scheduler with warmup + peak plateau + cosine annealing
    Stays at peak LR longer for better exploration before decay
    """
    def __init__(self, optimizer, warmup_epochs, peak_epochs, total_epochs, lr_max, lr_min):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs      # Warmup period
        self.peak_epochs = peak_epochs          # NEW: How long to stay at peak
        self.total_epochs = total_epochs
        self.lr_max = lr_max
        self.lr_min = lr_min
        self.current_epoch = 0
    
    def step(self):
        """Update learning rate and return current LR"""
        self.current_epoch += 1
        
        if self.current_epoch <= self.warmup_epochs:
            # Phase 1: Warmup (linear increase from 0 to lr_max)
            lr = self.lr_max * (self.current_epoch / self.warmup_epochs)
            
        elif self.current_epoch <= self.peak_epochs:
            # Phase 2: Peak plateau (stay at lr_max for exploration)
            lr = self.lr_max
            
        else:
            # Phase 3: Cosine annealing (smooth decay from lr_max to lr_min)
            progress = (self.current_epoch - self.peak_epochs) / (self.total_epochs - self.peak_epochs)
            lr = self.lr_min + (self.lr_max - self.lr_min) * 0.5 * (1 + math.cos(math.pi * progress))
        
        # Update optimizer learning rate
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        
        return lr
    
    def state_dict(self):
        """For checkpoint saving compatibility"""
        return {
            'current_epoch': self.current_epoch,
            'warmup_epochs': self.warmup_epochs,
            'peak_epochs': self.peak_epochs,  # Save peak_epochs
            'total_epochs': self.total_epochs,
            'lr_max': self.lr_max,
            'lr_min': self.lr_min
        }
    
    def load_state_dict(self, state_dict):
        """For checkpoint loading compatibility"""
        self.current_epoch = state_dict['current_epoch']
        self.warmup_epochs = state_dict['warmup_epochs']
        self.peak_epochs = state_dict.get('peak_epochs', self.warmup_epochs)  # Backward compatibility
        self.total_epochs = state_dict['total_epochs']
        self.lr_max = state_dict['lr_max']
        self.lr_min = state_dict['lr_min']

# ----------------------------
# Memory-Optimized Main Function with Comprehensive Training
# ----------------------------
# Complete main() function with all modifications integrated
# Keeps all original components, adds deep supervision and gradient accumulation

def main():
    print("=== SOTA-Enhanced Memory-Optimized HSI Denoising ===")
    print("4-stage SST Transformer with hybrid attention and deep supervision")
    print("Target: PSNR > 40 dB with efficient memory usage")

    # Enable TF32 for RTX 40-series (faster matrix operations)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    
    # Enable cuDNN benchmarking for consistent input sizes
    torch.backends.cudnn.benchmark = True
    
    # UPDATED CONFIG for 4-stage
    config = {
        'patch_size': 64,           
        'batch_size': 2,             
        'base_dim': 64,              
        'noise_level': 30,           
        'lr_max': 1.5e-4,          
        'lr_min': 1e-6,              
        'total_epochs': 300,         
        'patience': 20,             
        'weight_decay': 2e-4,       
        'val_split': 0.1,            
        'seed': 42,                  
        'target_bands': 31,
        'gradient_accumulation': 8,   # Simulate batch_size= batch_size * gradient_accumulation
        'train_crop_size': 1024,
        'scales': [64, 32, 32], 
        'warmup_epochs': 10,
    }

    # Set seeds
    torch.manual_seed(config['seed'])
    random.seed(config['seed'])
    np.random.seed(config['seed'])

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Device: {device}")

    # Clear GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

    data_dir = '/workspace/icvl_part/train'
    save_dir = './HSI_denoising_ICVL_resultsV15_noise30'
    os.makedirs(save_dir, exist_ok=True)

    print(f"\n=== DATA DISCOVERY ===")
    print(f"Looking for HSI data in: {data_dir}")

    # Dataset preparation with detailed logging
    if os.path.exists(data_dir):
        try:
            mat_files = glob.glob(os.path.join(data_dir, '*.mat'))
            print(f"Found {len(mat_files)} .mat files in directory")

            if len(mat_files) >= 10:
                n_val = 10
                train_files = mat_files[:-n_val]  # First 90 for training
                val_files = mat_files[-n_val:]     # Last 10 for validation
                print(f"ICVL Split: {len(train_files)} training files, {len(val_files)} validation files")
            else:
                print(f"Warning: Found only {len(mat_files)} files. Need at least 10 for ICVL setup.")
                print("Using what's available with proportional split...")
                n_val = max(1, len(mat_files) // 10)
                train_files = mat_files[:-n_val]
                val_files = mat_files[-n_val:]
        except Exception as e:
            print(f"Error accessing data directory: {e}")
            train_files, val_files = ['synthetic'], ['synthetic']
    else:
        print(f"Data directory not found: {data_dir}")
        train_files, val_files = ['synthetic'], ['synthetic']

    # Create datasets with enhanced logging
    train_dataset = MemoryEfficientHSIDataset(
        train_files,
        patch_size=config['patch_size'],
        noise_level=config['noise_level'],
        patches_per_file=15,
        target_bands=config['target_bands'],
        augment=True,
        dataset_type="training",
        train_crop_size=config['train_crop_size'],
        scales=config['scales']
    )

    val_dataset_synthetic = MemoryEfficientHSIDataset(
        val_files,
        patch_size=config['patch_size'],
        noise_level=config['noise_level'],
        patches_per_file=10,
        target_bands=config['target_bands'],
        augment=False,
        dataset_type="validation_synthetic",
        train_crop_size=config['train_crop_size'],
        scales=[config['patch_size']]  # No multi-scale for validation
    )


    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'],
                             shuffle=True, num_workers=4, pin_memory=True)
    val_loader_synthetic = DataLoader(val_dataset_synthetic, batch_size=1,
                           shuffle=False, num_workers=2, pin_memory=True)

    # Pre-training data verification
    print("\n=== PRE-TRAINING DATA VERIFICATION ===")
    print("Testing first few training samples...")

    test_count = 3
    samples_verified = 0
    
    for i, (noisy, clean) in enumerate(train_loader):
        if i >= test_count:
            break
        print(f"Sample {i+1}: Noisy shape: {noisy.shape}, Clean shape: {clean.shape}")
        samples_verified += noisy.shape[0]
    
    if train_dataset.file_info:
        print(f"✓ Successfully verified {samples_verified} samples from {len(train_dataset.file_info)} real data files")
    else:
        print(f"⚠ Using synthetic data (no real files loaded)")
    
    print("=" * 50)

    # Model setup - NEW: SOTA 4-stage Swin Transformer
    model = MemoryOptimizedUNet(
        in_channels=1,
        base_dim=config['base_dim'],
        window_sizes=[4,8,16],
        num_bands=config['target_bands']
    ).to(device)

    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {total_params / 1e6:.2f}M")
    #print(f"Architecture: 4-stage SST Transformer + Hybrid Attention + Deep Supervision")
    print(f"Architecture: 4-stage SST Transformer + Hybrid Attention")

    # Initialize EMA
    ema = ExponentialMovingAverage(model.parameters(), decay=0.999)
    print("EMA initialized with decay=0.999")

    # Loss and optimizer
    criterion = MemoryEfficientLoss(device=device)
    optimizer = optim.AdamW(model.parameters(), lr=config['lr_max'],
                            betas=(0.9, 0.999), eps=1e-8,
                           weight_decay=config['weight_decay'])

    # Scheduler
    #from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
    #scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2, eta_min=config['lr_min'])
    #scheduler = CosineAnnealingLR(optimizer, T_max=config['total_epochs'], eta_min=config['lr_min'])
    scheduler = WarmupCosineScheduler(
        optimizer=optimizer,
        warmup_epochs=config['warmup_epochs'],
        peak_epochs=50,
        total_epochs=config['total_epochs'],
        lr_max=config['lr_max'],
        lr_min=config['lr_min']
    )
    print(f"Scheduler: Warmup ({config['warmup_epochs']} epochs) + Cosine Annealing")
    # Mixed precision
    scaler = torch.amp.GradScaler('cuda')

    # Training tracking
    best_psnr = 0
    patience_counter = 0
    train_losses, val_losses, val_psnrs, val_ssims, val_sams = [], [], [], [], []
    learning_rates = []

    # NEW: Gradient accumulation setup
    accumulation_steps = config.get('gradient_accumulation', 1)
    print(f"Using gradient accumulation: {accumulation_steps} steps (effective batch size: {config['batch_size'] * accumulation_steps})")

    print(f"Starting training for {config['total_epochs']} epochs...")
    print("-" * 80)

    for epoch in range(1, config['total_epochs'] + 1):
        # ============================================================
        # TRAINING PHASE - WITH DEEP SUPERVISION & GRADIENT ACCUMULATION
        # ============================================================
        model.train()
        epoch_loss = 0
        epoch_main_loss = 0
        epoch_deep_loss = 0
        num_batches = 0

        optimizer.zero_grad()  # Initialize outside loop

        for batch_idx, (noisy, clean) in enumerate(train_loader):
            noisy = noisy.float().to(device, non_blocking=True)
            clean = clean.float().to(device, non_blocking=True)

            with torch.amp.autocast('cuda'):
                # NEW: Get main output and deep supervision outputs
                output, deep_outputs = model(noisy, return_deep_sup=True)
                
                # Main loss
                main_loss = criterion(output, clean)
                
                # NEW: Deep supervision losses (weighted progressively)
                deep_loss = 0
                if len(deep_outputs) > 0:
                    weights = [0.4, 0.3, 0.3]  # Weights for 3 auxiliary outputs
                    for i, deep_out in enumerate(deep_outputs):
                        deep_loss += weights[i] * criterion(deep_out, clean)
                
                # Combined loss with gradient accumulation scaling
                #deep supervision loss disabled
                #loss = (main_loss + 0.3 * deep_loss) / accumulation_steps
                loss = main_loss / accumulation_steps

            scaler.scale(loss).backward()

            # NEW: Step optimizer only every accumulation_steps batches
            if (batch_idx + 1) % accumulation_steps == 0:
                # Gradient clipping for stability
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
                
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                #EMA
                ema.update()

            # Logging (unscale loss for accurate reporting)
            epoch_loss += loss.item() * accumulation_steps
            epoch_main_loss += main_loss.item()
            epoch_deep_loss += deep_loss.item() if isinstance(deep_loss, torch.Tensor) else deep_loss
            num_batches += 1

            # Memory management
            if num_batches % 10 == 0:
                torch.cuda.empty_cache()



        current_lr = scheduler.step()
        learning_rates.append(current_lr)
        train_loss = epoch_loss / num_batches
        train_main_loss = epoch_main_loss / num_batches
        train_deep_loss = epoch_deep_loss / num_batches
        train_losses.append(train_loss)

        # Data usage statistics every 50 epochs
        #if epoch % 50 == 0:
            #print(f"  Data usage after {epoch} epochs: {train_dataset.get_usage_stats()}")
            #print(f"  Loss breakdown - Main: {train_main_loss:.6f}, Deep: {train_deep_loss:.6f}")

        # ============================================================
        # VALIDATION PHASE - WITHOUT DEEP SUPERVISION
        # ============================================================
        if epoch % 5 == 0:
            model.eval()

            #EMA
            with ema.average_parameters():
                val_loss = 0
                total_psnr = 0
                total_ssim = 0
                total_sam = 0
                num_val_batches = 0
            
                with torch.no_grad():
                    for noisy, clean in val_loader_synthetic:
                        noisy, clean = noisy.to(device, non_blocking=True), clean.to(device, non_blocking=True)
            
                        with torch.amp.autocast('cuda'):
                            output = model(noisy, return_deep_sup=False)
                            loss = criterion(output, clean)
            
                        val_loss += loss.item()
                        psnr = calculate_psnr(output, clean)
                        total_psnr += psnr.item()
            
                        output_np = output.squeeze(0).cpu().numpy() if output.dim() == 4 else output.squeeze(0).squeeze(0).cpu().numpy()
                        clean_np = clean.squeeze(0).cpu().numpy() if clean.dim() == 4 else clean.squeeze(0).squeeze(0).cpu().numpy()
            
                        ssim_val = calculate_ssim(output_np, clean_np)
                        sam_val = calculate_sam(output_np, clean_np)
            
                        total_ssim += ssim_val
                        total_sam += sam_val
                        num_val_batches += 1
            
                val_loss /= num_val_batches
                val_psnr = total_psnr / num_val_batches
                val_ssim = total_ssim / num_val_batches
                val_sam = total_sam / num_val_batches
            
                # **FIX: Append metrics to tracking lists**
                val_losses.append(val_loss)
                val_psnrs.append(val_psnr)
                val_ssims.append(val_ssim)
                val_sams.append(val_sam)
            
                # Check improvement
                if val_psnr > best_psnr:
                    best_psnr = val_psnr
                    patience_counter = 0
                    torch.save({
                        'model_state_dict': model.state_dict(),
                        #EMA
                        'ema_state_dict': ema.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'epoch': epoch,
                        'best_psnr': best_psnr,
                        'config': config
                    }, os.path.join(save_dir, 'best_model.pth'))
                    print(f"  *** New best model saved! PSNR: {best_psnr:.4f} dB ***")
                else:
                    patience_counter += 1
            
                print(f"Epoch {epoch:4d} | LR: {current_lr:.2e} | Train Loss: {train_loss:.6f} | "
                      f"Val Loss: {val_loss:.6f} | Val PSNR: {val_psnr:.4f} | "
                      f"Val SSIM: {val_ssim:.4f} | Val SAM: {val_sam:.4f} | Best: {best_psnr:.4f}")
            
            # Early stopping
            if patience_counter >= config['patience']:
                print(f"Early stopping at epoch {epoch}. Best PSNR: {best_psnr:.4f} dB")
                break
        else:
            print(f"Epoch {epoch:4d} | LR: {current_lr:.2e} | Train Loss: {train_loss:.6f}")

        # Clear memory
        torch.cuda.empty_cache()


    # Save complete model with all metadata
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
        'epoch': epoch,
        'best_psnr': best_psnr,
        'config': config,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_psnrs': val_psnrs,
        'val_ssims': val_ssims,
        'val_sams': val_sams,
        'learning_rates': learning_rates,
        'architecture': '4-stage SST Transformer + Hybrid Attention + Deep Supervision',
        'total_params': total_params
    }, os.path.join(save_dir, 'enhanced_denoising_pipeline_full.pth'))

    # Create comprehensive visualizations
    print("\nCreating training visualizations...")
    create_training_visualizations(
        train_losses=train_losses,
        val_losses=val_losses,
        val_psnrs=val_psnrs,
        val_ssims=val_ssims,
        val_sams=val_sams,
        learning_rates=learning_rates,
        save_dir=save_dir,
        best_psnr=best_psnr
    )

    # Print final comprehensive scores
    print_final_scores(
        val_psnrs=val_psnrs,
        val_ssims=val_ssims,
        val_sams=val_sams,
        best_psnr=best_psnr,
        train_files=train_files,
        val_files=val_files,
        save_dir=save_dir
    )

    print("\n" + "="*80)
    print("FINAL ARCHITECTURE SUMMARY")
    print("="*80)
    print(f"Model: 4-stage SST Transformer U-Net")
    print(f"Hybrid attention: SST + Spectral + FusedBottleneck")
    print(f"Deep supervision: 3 auxiliary losses")
    print(f"Total parameters: {total_params / 1e6:.2f}M")
    print(f"Best validation PSNR: {best_psnr:.4f} dB")
    print(f"Training epochs: {epoch}")
    print(f"Results directory: {save_dir}")
    print("="*80)

    return model, best_psnr

if __name__ == "__main__":
    main()

=== SOTA-Enhanced Memory-Optimized HSI Denoising ===
4-stage SST Transformer with hybrid attention and deep supervision
Target: PSNR > 40 dB with efficient memory usage
Device: cuda
GPU Memory: 12.5 GB

=== DATA DISCOVERY ===
Looking for HSI data in: /home/habib/Documents/workspace/icvl_part/train
Found 100 .mat files in directory
ICVL Split: 90 training files, 10 validation files

=== Initializing TRAINING Dataset ===
Noise Level (σ): 10 -> 0.0392 (0-1 scale)
Center crop size: 1024x1024
Patch size: 64x64
Multi-scale strides: [64, 32, 32]
Attempting to load 90 files...
  [1/90] Loading: prk_0328-1034.mat
    ✓ Shape: (31, 1392, 1300), Key: 'rad' (HDF5)
  [2/90] Loading: lst_0408-1012.mat
    ✓ Shape: (31, 1392, 1300), Key: 'rad' (HDF5)
  [3/90] Loading: eve_0331-1601.mat
    ✓ Shape: (31, 1392, 1300), Key: 'rad' (HDF5)
  [4/90] Loading: mor_0328-1209-2.mat
    ✓ Shape: (31, 1392, 1027), Key: 'rad' (HDF5)
  [5/90] Loading: bgu_0403-1439.mat
    ✓ Shape: (31, 1392, 1300), Key: 'rad' (HDF

In [None]:
#Testing

In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import glob
import matplotlib.pyplot as plt
import seaborn as sns
from skimage.metrics import structural_similarity as compare_ssim
from torch.utils.checkpoint import checkpoint
from tqdm import tqdm
import warnings
from einops import rearrange
warnings.filterwarnings('ignore')
import h5py
try:
    import scipy.io as sio
except ImportError:
    sio = None

# ----------------------------
# Memory-Efficient Utilities
# ----------------------------
# ----------------------------
# Memory-Efficient Utilities
# ----------------------------
class LayerNormChannel3d(nn.Module):
    """Lightweight channel normalization"""
    def __init__(self, num_channels: int = None, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.num_channels = num_channels
        self.gn = None
        if num_channels is not None:
            self.gn = nn.GroupNorm(1, num_channels, eps=eps, affine=True)

    def forward(self, x):
        C = x.shape[1]
        if self.gn is None or C != self.num_channels:
            self.num_channels = C
            self.gn = nn.GroupNorm(1, C, eps=self.eps, affine=True).to(x.device)
        return self.gn(x)

def depthwise_conv3d(channels: int, kernel_size: int = 3, stride: int = 1, padding: int = 1, dilation: int = 1):
    return nn.Conv3d(channels, channels, kernel_size=kernel_size, stride=stride,
                     padding=padding, groups=channels, bias=True, dilation=dilation)

class EfficientChannelAttention(nn.Module):
    """Memory-efficient channel attention"""
    def __init__(self, channels, reduction=16):
        super().__init__()
        hidden = max(4, channels // reduction)
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, hidden, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(hidden, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        B, C, D, H, W = x.shape
        y = self.avg_pool(x).view(B, C)
        y = self.fc(y).view(B, C, 1, 1, 1)
        return x * y.expand_as(x)

class AdaptiveDropout3d(nn.Module):
    """Adaptive dropout that adjusts rate based on a factor (e.g., layer depth)"""
    def __init__(self, base_drop=0.25, factor=1.0):
        super().__init__()
        self.dropout = nn.Dropout3d(base_drop * factor)

    def forward(self, x):
        return self.dropout(x)

# ----------------------------
# Memory-Optimized Blocks
# ----------------------------
class PatchMerging3D(nn.Module):
    """
    3D Patch Merging Layer (downsampling)
    """
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(8 * dim)

    def forward(self, x):
        """
        Args:
            x: (B, C, D, H, W)
        """
        B, C, D, H, W = x.shape

        # Pad if needed
        pad_d = (2 - D % 2) % 2
        pad_h = (2 - H % 2) % 2
        pad_w = (2 - W % 2) % 2

        if pad_d > 0 or pad_h > 0 or pad_w > 0:
            x = F.pad(x, (0, pad_w, 0, pad_h, 0, pad_d))
            D, H, W = x.shape[2:]

        # Convert to (B, D, H, W, C)
        x = x.permute(0, 2, 3, 4, 1).contiguous()

        # Downsample by merging 2x2x2 patches
        x0 = x[:, 0::2, 0::2, 0::2, :]  # (B, D/2, H/2, W/2, C)
        x1 = x[:, 1::2, 0::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, 0::2, :]
        x3 = x[:, 1::2, 1::2, 0::2, :]
        x4 = x[:, 0::2, 0::2, 1::2, :]
        x5 = x[:, 1::2, 0::2, 1::2, :]
        x6 = x[:, 0::2, 1::2, 1::2, :]
        x7 = x[:, 1::2, 1::2, 1::2, :]

        x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)  # (B, D/2, H/2, W/2, 8*C)
        x = self.norm(x)
        x = self.reduction(x)  # (B, D/2, H/2, W/2, 2*C)

        # Convert back to (B, C, D, H, W)
        x = x.permute(0, 4, 1, 2, 3).contiguous()
        return x

class PatchMerging3D(nn.Module):
    """3D Patch Merging with better handling"""
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(8 * dim)

    def forward(self, x):
        B, C, D, H, W = x.shape
        
        # Pad to multiples of 2
        pad_d = (2 - D % 2) % 2
        pad_h = (2 - H % 2) % 2
        pad_w = (2 - W % 2) % 2
        
        if pad_d > 0 or pad_h > 0 or pad_w > 0:
            x = F.pad(x, (0, pad_w, 0, pad_h, 0, pad_d))
            D, H, W = x.shape[2:]

        x = x.permute(0, 2, 3, 4, 1).contiguous()
        
        x0 = x[:, 0::2, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, 0::2, :]
        x3 = x[:, 1::2, 1::2, 0::2, :]
        x4 = x[:, 0::2, 0::2, 1::2, :]
        x5 = x[:, 1::2, 0::2, 1::2, :]
        x6 = x[:, 0::2, 1::2, 1::2, :]
        x7 = x[:, 1::2, 1::2, 1::2, :]
        
        x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)
        x = self.norm(x)
        x = self.reduction(x)
        
        x = x.permute(0, 4, 1, 2, 3).contiguous()
        return x

class PatchExpanding3D(nn.Module):
    """3D Patch Expanding for decoder"""
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.expand = nn.Linear(dim, 4 * dim, bias=False)
        self.norm = norm_layer(dim)

    def forward(self, x):
        B, C, D, H, W = x.shape
        
        x = x.permute(0, 2, 3, 4, 1).contiguous()
        x = self.norm(x)
        x = self.expand(x)
        
        x = rearrange(x, 'b d h w (p1 p2 p3 c) -> b (d p1) (h p2) (w p3) c', 
                     p1=2, p2=2, p3=2, c=C//2)
        
        x = x.permute(0, 4, 1, 2, 3).contiguous()
        return x

class SpectralAttentionModule(nn.Module):
    """Dedicated spectral attention for bottleneck - HYBRID ATTENTION"""
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Conv3d(dim, dim * 3, 1)
        self.proj = nn.Conv3d(dim, dim, 1)
        self.norm = nn.GroupNorm(1, dim)
        
    def forward(self, x):
        B, C, D, H, W = x.shape
        x_norm = self.norm(x)
        
        qkv = self.qkv(x_norm)
        q, k, v = torch.chunk(qkv, 3, dim=1)
        
        # Reshape for spectral attention: (B*H*W, num_heads, D, head_dim)
        q = q.permute(0, 3, 4, 2, 1).reshape(B * H * W, D, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.permute(0, 3, 4, 2, 1).reshape(B * H * W, D, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = v.permute(0, 3, 4, 2, 1).reshape(B * H * W, D, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        
        # Spectral attention
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        
        out = (attn @ v).permute(0, 2, 1, 3).reshape(B, H, W, D, C)
        out = out.permute(0, 4, 3, 1, 2).contiguous()
        
        out = self.proj(out)
        return x + out

# ----------------------------
# SST Blocks
# ----------------------------

class SpectralSelfAttention(nn.Module):
    """
    Spectral Self-Attention: Attends along the spectral dimension (bands)
    Treats each spatial position independently and attends across all bands
    """
    def __init__(self, dim, num_bands, num_heads=8, qkv_bias=True, attn_drop=0., proj_drop=0.):
            super().__init__()
            self.dim = dim
            self.num_bands = num_bands  # Expected bands from config
            self.num_heads = num_heads
            self.head_dim = dim // num_heads
            self.scale = self.head_dim ** -0.5
            
            assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}"
            
            # Linear projections for Q, K, V
            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
            self.attn_drop = nn.Dropout(attn_drop)
            self.proj = nn.Linear(dim, dim)
            self.proj_drop = nn.Dropout(proj_drop)
            
            # FIXED: Initialize buffer with expected bands, will expand automatically if needed
            spectral_pos_embed = torch.zeros(1, num_bands, dim)
            nn.init.trunc_normal_(spectral_pos_embed, std=0.02)
            self.register_buffer('spectral_pos_embed', spectral_pos_embed)
        
    def forward(self, x):
        """
        Args:
            x: (B, C, D, H, W) - Input feature map
        Returns:
            x: (B, C, D, H, W) - Output feature map
        """
        B, C, D, H, W = x.shape
        
        # Reshape to (B, H, W, D, C) for spectral attention
        x = x.permute(0, 3, 4, 2, 1).contiguous()  # (B, H, W, D, C)
        
        # Flatten spatial dimensions: (B*H*W, D, C)
        x_flat = x.reshape(B * H * W, D, C)
        
        # FIXED: Dynamically expand buffer if we encounter more bands than initialized
        if D > self.spectral_pos_embed.shape[1]:
            # This happens rarely (only when encountering new max bands)
            old_size = self.spectral_pos_embed.shape[1]
            new_size = D
            
            # Create expanded buffer on same device
            expanded = torch.zeros(1, new_size, self.dim, 
                                  device=self.spectral_pos_embed.device,
                                  dtype=self.spectral_pos_embed.dtype)
            
            # Copy existing learned embeddings
            expanded[:, :old_size, :] = self.spectral_pos_embed
            
            # Initialize new bands with small random values
            nn.init.trunc_normal_(expanded[:, old_size:, :], std=0.02)
            
            # Update the buffer in-place
            self.spectral_pos_embed.resize_(expanded.shape)
            self.spectral_pos_embed.copy_(expanded)
            
            #print(f"[SpectralPosEmbed] Auto-expanded from {old_size} to {new_size} bands")
        
        # Use only the bands we need (safe slicing - buffer is always >= D now)
        pos_embed = self.spectral_pos_embed[:, :D, :]  # (1, D, dim)
        
        # Add spectral positional encoding
        x_flat = x_flat + pos_embed
        
        # Generate Q, K, V
        qkv = self.qkv(x_flat).reshape(B * H * W, D, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # Each: (B*H*W, num_heads, D, head_dim)
        
        # Scaled dot-product attention across spectral dimension
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))  # (B*H*W, num_heads, D, D)
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_drop(attn)
        
        # Apply attention to values
        x_attn = (attn @ v).transpose(1, 2).reshape(B * H * W, D, C)  # (B*H*W, D, C)
        
        # Project and reshape back
        x_attn = self.proj(x_attn)
        x_attn = self.proj_drop(x_attn)
        
        # Reshape back to (B, H, W, D, C)
        x_attn = x_attn.reshape(B, H, W, D, C)
        
        # Reshape to original format (B, C, D, H, W)
        x_attn = x_attn.permute(0, 4, 3, 1, 2).contiguous()
        
        return x_attn

class SpatialSelfAttention(nn.Module):
    """
    FIXED Spatial Self-Attention - SAME NAME, IMPROVED IMPLEMENTATION
    Drop-in replacement - no API changes, just better internals
    """
    def __init__(self, dim, num_heads=8, window_size=8, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.window_size = window_size
        
        assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}"
        
        # FIXED: Full 3D convolution instead of depthwise
        # OLD: self.dwconv = depthwise_conv3d(dim, kernel_size=(1, 3, 3), padding=(0, 1, 1))
        # NEW: Full conv for cross-spectral information flow
        self.dwconv = nn.Conv3d(dim, dim, kernel_size=(1, 3, 3), padding=(0, 1, 1), bias=True)
        
        # Keep same API as before
        self.qkv = nn.Conv3d(dim, dim * 3, kernel_size=1, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Conv3d(dim, dim, kernel_size=1)
        self.proj_drop = nn.Dropout(proj_drop)
        
        # NEW: Relative position bias (optional, improves performance)
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)
        )
        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
        
        coords_h = torch.arange(window_size)
        coords_w = torch.arange(window_size)
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij'))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += window_size - 1
        relative_coords[:, :, 1] += window_size - 1
        relative_coords[:, :, 0] *= 2 * window_size - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index)
        
    def forward(self, x):
        """
        SAME API: (B, C, D, H, W) -> (B, C, D, H, W)
        Just improved internals
        """
        B, C, D, H, W = x.shape
        
        # Apply full 3D conv (cross-spectral enabled)
        x_local = self.dwconv(x)
        
        qkv = self.qkv(x_local)
        
        if H * W > self.window_size ** 2:
            # Efficient attention for large spatial dims
            qkv = rearrange(qkv, 'b (three head c) d h w -> three b d head c (h w)', 
                           three=3, head=self.num_heads)
            q, k, v = qkv[0], qkv[1], qkv[2]
            
            q_flat = q.reshape(B * D * self.num_heads, self.head_dim, H * W)
            k_flat = k.reshape(B * D * self.num_heads, self.head_dim, H * W)
            v_flat = v.reshape(B * D * self.num_heads, self.head_dim, H * W)
            
            k_global = k_flat.mean(dim=-1, keepdim=True)
            v_global = v_flat.mean(dim=-1, keepdim=True)
            
            q_flat = q_flat * self.scale
            attn = torch.bmm(q_flat.transpose(1, 2), k_global)
            attn = F.softmax(attn, dim=1)
            attn = self.attn_drop(attn)
            
            x_attn = v_global * attn.transpose(1, 2)
            x_attn = x_attn.reshape(B, D, self.num_heads, self.head_dim, H * W)
            
        else:
            # Full attention for small spatial dims
            qkv = rearrange(qkv, 'b (three head c) d h w -> three b d head c (h w)', 
                           three=3, head=self.num_heads)
            q, k, v = qkv[0], qkv[1], qkv[2]
            
            q_flat = q.reshape(B * D * self.num_heads, self.head_dim, H * W)
            k_flat = k.reshape(B * D * self.num_heads, self.head_dim, H * W)
            v_flat = v.reshape(B * D * self.num_heads, self.head_dim, H * W)
            
            q_flat = q_flat * self.scale
            attn = torch.bmm(q_flat.transpose(1, 2), k_flat)
            
            if H == self.window_size and W == self.window_size:
                relative_position_bias = self.relative_position_bias_table[
                    self.relative_position_index.view(-1)
                ].view(self.window_size * self.window_size, self.window_size * self.window_size, -1)
                relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
                attn = attn.view(B * D, self.num_heads, H * W, H * W) + relative_position_bias.unsqueeze(0)
                attn = attn.view(B * D * self.num_heads, H * W, H * W)
            
            attn = F.softmax(attn, dim=-1)
            attn = self.attn_drop(attn)
            
            x_attn = torch.bmm(attn, v_flat.transpose(1, 2))
            x_attn = x_attn.transpose(1, 2)
            x_attn = x_attn.reshape(B, D, self.num_heads, self.head_dim, H * W)
        
        x_attn = rearrange(x_attn, 'b d head c (h w) -> b (head c) d h w', 
                          head=self.num_heads, h=H, w=W)
        
        x_attn = self.proj(x_attn)
        x_attn = self.proj_drop(x_attn)
        
        return x_attn


class SSTBlock(nn.Module):
    """
    Spectral-Spatial Transformer Block
    Combines spectral and spatial self-attention with feed-forward network
    """
    def __init__(self, dim, num_bands, num_heads=8, window_size=8, 
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., 
                 drop_path=0., norm_layer=None):
        super().__init__()
        norm_layer = norm_layer or LayerNormChannel3d
        
        self.dim = dim
        self.num_bands = num_bands
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio
        
        # Normalization layers
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        self.norm3 = norm_layer(dim)
        
        # Spectral attention
        self.spectral_attn = SpectralSelfAttention(
            dim=dim,
            num_bands=num_bands,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop
        )
        
        # Spatial attention
        self.spatial_attn = SpatialSelfAttention(
            dim=dim,
            num_heads=num_heads,
            window_size=window_size,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop
        )
        
        # Drop path for stochastic depth
        self.drop_path = nn.Dropout(drop_path) if drop_path > 0. else nn.Identity()
        
        # Feed-forward network (reuse existing GDFN)
        self.ffn = GDFN(dim, ffn_expansion_factor=mlp_ratio, bias=False)
        
    def forward(self, x):
        """
        Args:
            x: (B, C, D, H, W)
        Returns:
            x: (B, C, D, H, W)
        """
        # Spectral attention with residual
        x = x + self.drop_path(self.spectral_attn(self.norm1(x)))
        
        # Spatial attention with residual
        x = x + self.drop_path(self.spatial_attn(self.norm2(x)))
        
        # Feed-forward with residual
        x = x + self.drop_path(self.ffn(self.norm3(x)))
        
        return x


class SSTStage(nn.Module):
    """
    SST Stage: Multiple SST blocks with optional downsampling
    Replaces SwinTransformerStage3D
    """
    def __init__(self, dim, num_bands, depth, num_heads=8, window_size=8,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
                 drop_path_rate=0., norm_layer=None, downsample=None):
        super().__init__()
        norm_layer = norm_layer or LayerNormChannel3d
        
        self.dim = dim
        self.depth = depth
        
        # Stochastic depth decay rule
        dpr = [drop_path_rate * (i / (depth - 1)) if depth > 1 else drop_path_rate 
               for i in range(depth)]
        
        # Build SST blocks
        self.blocks = nn.ModuleList([
            SSTBlock(
                dim=dim,
                num_bands=num_bands,
                num_heads=num_heads,
                window_size=window_size,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=dpr[i],
                norm_layer=norm_layer
            )
            for i in range(depth)
        ])
        
        # Downsampling layer
        if downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=nn.LayerNorm)
        else:
            self.downsample = None
    
    def forward(self, x):
        """
        Args:
            x: (B, C, D, H, W)
        Returns:
            x: (B, C, D, H, W) - features before downsampling
            x_down: (B, 2*C, D/2, H/2, W/2) - downsampled features (if downsample exists)
        """
        # Pass through SST blocks
        for blk in self.blocks:
            if self.training:
                x = checkpoint(blk, x, use_reentrant=False)
            else:
                x = blk(x)
        
        # Downsample if needed
        if self.downsample is not None:
            x_down = self.downsample(x)
            return x, x_down
        else:
            return x, x

# ----------------------------
# SST Blocks
# ----------------------------

class SpectralSelfModulatingResidualBlock(nn.Module):
    """Spectral Self-Modulating Residual Block (SSMRB) for adaptive feature transformation."""
    def __init__(self, dim, ffn_expand=2, drop=0.25, drop_factor=1.0):
        super().__init__()
        hidden = dim * ffn_expand

        # Main FFN path (vanilla-like)
        self.ffn_pw1 = nn.Conv3d(dim, hidden * 2, kernel_size=1, bias=True)
        self.ffn_dw = depthwise_conv3d(hidden * 2, kernel_size=3, padding=1)
        self.act = nn.GELU()
        self.ffn_pw2 = nn.Conv3d(hidden, dim, kernel_size=1, bias=True)

        # Self-modulation branch: Generate gamma (scale) and beta (shift) using spectral-adjacent conv
        # Kernel (3,1,1) captures adjacent spectral bands; depthwise for efficiency
        self.mod_gamma = nn.Sequential(
            depthwise_conv3d(dim, kernel_size=(3,1,1), padding=(1,0,0)),
            nn.Sigmoid()  # For scaling (0 to 1)
        )
        self.mod_beta = nn.Sequential(
            depthwise_conv3d(dim, kernel_size=(3,1,1), padding=(1,0,0)),
            nn.Tanh()  # For shifting (-1 to 1)
        )

        # Normalization and dropout
        self.norm = LayerNormChannel3d(dim)
        self.dropout = AdaptiveDropout3d(drop, factor=drop_factor)  # Adaptive dropout

        # Layer scale for residual contribution
        self.gamma_res = nn.Parameter(torch.ones(1, dim, 1, 1, 1) * 1e-4)

    def _ffn_path(self, x):
        """Vanilla FFN computation."""
        x2 = self.ffn_pw1(x)
        x2 = self.ffn_dw(x2)
        a, b = torch.chunk(x2, 2, dim=1)
        x2 = self.act(a) * b
        x2 = self.ffn_pw2(x2)
        return x2

    def forward(self, x):
        # Normalize input
        x_norm = self.norm(x)

        # Compute main FFN path
        ffn_out = self._ffn_path(x_norm)

        # Compute self-modulation parameters from input (using adjacent spectral info)
        gamma = self.mod_gamma(x_norm)  # Scale
        beta = self.mod_beta(x_norm)    # Shift

        # Apply modulation: gamma * ffn_out + beta
        modulated = gamma * ffn_out + beta

        # Apply dropout and scale
        modulated = self.dropout(modulated) * self.gamma_res

        # Residual connection: x + modulated
        return x + modulated


# ----------------------------
# New Classes Added (for Restormer Integration in Architecture)
# ----------------------------
class LayerNorm3d(nn.Module):
    """3D LayerNorm for channel normalization"""
    def __init__(self, num_channels, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps

    def forward(self, x):
        mu = x.mean(dim=1, keepdim=True)
        var = x.var(dim=1, keepdim=True, unbiased=False)
        inv_std = torch.rsqrt(var + self.eps)
        out = (x - mu) * inv_std * self.weight.view(1, -1, 1, 1, 1) + self.bias.view(1, -1, 1, 1, 1)
        return out

class GDFN(nn.Module):
    """Gated-Dconv Feed-Forward Network adapted to 3D, focusing on spatial"""
    def __init__(self, dim, ffn_expansion_factor=2.66, bias=False):
        super().__init__()
        hidden = int(dim * ffn_expansion_factor)
        self.project_in = nn.Conv3d(dim, hidden * 2, kernel_size=1, bias=bias)
        self.dwconv = nn.Conv3d(hidden * 2, hidden * 2, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), groups=hidden * 2, bias=bias)  # Spatial focus
        self.project_out = nn.Conv3d(hidden, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        x = F.gelu(x1) * x2
        x = self.project_out(x)
        return x

class MDTA(nn.Module):
    """Multi-Dconv Head Transposed Attention adapted to 3D, focusing on spatial"""
    def __init__(self, dim, num_heads, bias=False):
        super().__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
        self.qkv = nn.Conv3d(dim, dim * 3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv3d(dim * 3, dim * 3, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), groups=dim * 3, bias=bias)  # Spatial focus
        self.project_out = nn.Conv3d(dim, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        qkv = self.qkv_dwconv(self.qkv(x))
        q, k, v = qkv.chunk(3, dim=1)
        b, c, d, h, w = q.shape
        q = rearrange(q, 'b (head cc) d h w -> b head cc (d h w)', head=self.num_heads, cc=c // self.num_heads)
        k = rearrange(k, 'b (head cc) d h w -> b head cc (d h w)', head=self.num_heads)
        v = rearrange(v, 'b (head cc) d h w -> b head cc (d h w)', head=self.num_heads)
        q = F.normalize(q, dim=-1)
        k = F.normalize(k, dim=-1)
        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)
        out = (attn @ v)
        out = rearrange(out, 'b head cc (d h w) -> b (head cc) d h w', head=self.num_heads, d=d, h=h, w=w)
        out = self.project_out(out)
        return out

class RestormerBlock(nn.Module):
    """Restormer Transformer Block adapted to 3D"""
    def __init__(self, dim, num_heads=4, ffn_expansion_factor=2.66, bias=False):
        super().__init__()
        self.norm1 = LayerNorm3d(dim)
        self.attn = MDTA(dim, num_heads, bias)
        self.norm2 = LayerNorm3d(dim)
        self.ffn = GDFN(dim, ffn_expansion_factor, bias)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x
# ----------------------------
# New Classes Added (for Restormer Integration in Architecture)
# ----------------------------

# ----------------------------
# Modified FusedBottleneck (for Restormer Integration in Architecture)
# ----------------------------
# ============================================================
# INSERT THESE NEW CLASSES BEFORE FusedBottleneck (around line ~680)
# These are ADDITIONS, not replacements
# ============================================================

class PositionalEncoding3D(nn.Module):
    """
    FIXED: Lightweight 3D Positional Encoding
    
    Instead of storing full (C, D, H, W) tensor, we use:
    1. Separate 1D embeddings for each dimension (much smaller)
    2. Broadcast and combine at runtime
    
    Memory: O(C*D + C*H + C*W) instead of O(C*D*H*W)
    Example: 64*128 + 64*256 + 64*256 = 41K params vs 537M params!
    """
    def __init__(self, channels, max_d=128, max_h=256, max_w=256):
        super().__init__()
        self.channels = channels
        self.max_d = max_d
        self.max_h = max_h
        self.max_w = max_w
        
        # Separate 1D positional embeddings for each dimension
        # These will be broadcast and combined
        self.pos_embed_d = nn.Parameter(torch.zeros(1, channels, max_d, 1, 1))
        self.pos_embed_h = nn.Parameter(torch.zeros(1, channels, 1, max_h, 1))
        self.pos_embed_w = nn.Parameter(torch.zeros(1, channels, 1, 1, max_w))
        
        # Initialize with small random values
        nn.init.trunc_normal_(self.pos_embed_d, std=0.02)
        nn.init.trunc_normal_(self.pos_embed_h, std=0.02)
        nn.init.trunc_normal_(self.pos_embed_w, std=0.02)
        
        # Learnable scaling factors for each dimension
        self.scale_d = nn.Parameter(torch.ones(1))
        self.scale_h = nn.Parameter(torch.ones(1))
        self.scale_w = nn.Parameter(torch.ones(1))
        
    def forward(self, x):
        """
        Args:
            x: (B, C, D, H, W)
        Returns:
            x + positional encoding: (B, C, D, H, W)
        """
        B, C, D, H, W = x.shape
        
        # Slice and broadcast each dimension
        pe_d = self.pos_embed_d[:, :, :D, :, :] * self.scale_d
        pe_h = self.pos_embed_h[:, :, :, :H, :] * self.scale_h
        pe_w = self.pos_embed_w[:, :, :, :, :W] * self.scale_w
        
        # Combine positional encodings (broadcasting happens automatically)
        pe = pe_d + pe_h + pe_w  # Shape: (1, C, D, H, W)
        
        return x + pe


class CrossSpectralSpatialAttention(nn.Module):
    """
    NEW CLASS: Cross-attention between spectral and spatial features
    Enables joint spectral-spatial modeling instead of separate processing
    """
    def __init__(self, dim, num_heads=8, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        # Separate projections for cross-attention paths
        self.q_spectral = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv_spatial = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.q_spatial = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv_spectral = nn.Linear(dim, dim * 2, bias=qkv_bias)
        
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_spectral = nn.Linear(dim, dim)
        self.proj_spatial = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        
        # Gated fusion
        self.gate = nn.Sequential(
            nn.Linear(dim * 2, dim),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        """(B, C, D, H, W) -> (B, C, D, H, W)"""
        B, C, D, H, W = x.shape
        
        # Path 1: Spectral features with spatial context
        x_spectral = x.permute(0, 3, 4, 2, 1).reshape(B * H * W, D, C)
        q_spec = self.q_spectral(x_spectral)
        q_spec = q_spec.reshape(B * H * W, D, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        
        q_spec_scaled = q_spec * self.scale
        attn_spec = torch.matmul(q_spec_scaled, q_spec_scaled.transpose(-2, -1))
        attn_spec = F.softmax(attn_spec, dim=-1)
        attn_spec = self.attn_drop(attn_spec)
        
        out_spec = torch.matmul(attn_spec, q_spec)
        out_spec = out_spec.transpose(1, 2).reshape(B * H * W, D, C)
        out_spec = self.proj_spectral(out_spec)
        out_spec = self.proj_drop(out_spec)
        out_spectral = out_spec.reshape(B, H, W, D, C).permute(0, 4, 3, 1, 2)
        
        # Path 2: Spatial features with spectral context
        x_spatial_q = x.permute(0, 2, 3, 4, 1).reshape(B * D, H * W, C)
        q_spat = self.q_spatial(x_spatial_q)
        q_spat = q_spat.reshape(B * D, H * W, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        
        q_spat_scaled = q_spat * self.scale
        attn_spat = torch.matmul(q_spat_scaled, q_spat_scaled.transpose(-2, -1))
        attn_spat = F.softmax(attn_spat, dim=-1)
        attn_spat = self.attn_drop(attn_spat)
        
        out_spat = torch.matmul(attn_spat, q_spat)
        out_spat = out_spat.transpose(1, 2).reshape(B * D, H * W, C)
        out_spat = self.proj_spatial(out_spat)
        out_spat = self.proj_drop(out_spat)
        out_spatial = out_spat.reshape(B, D, H, W, C).permute(0, 4, 1, 2, 3)
        
        # Gated fusion
        concat_features = torch.cat([out_spectral, out_spatial], dim=1)
        gate_input = F.adaptive_avg_pool3d(concat_features, 1).squeeze(-1).squeeze(-1).squeeze(-1)
        gate = self.gate(gate_input).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        
        fused = gate * out_spectral + (1 - gate) * out_spatial
        return fused


class EnhancedBottleneck(nn.Module):
    """
    NEW CLASS: SOTA-level bottleneck with cross-attention and joint modeling
    Will be used in the main architecture
    """
    def __init__(self, dim, num_heads=8, mlp_ratio=4.):
        super().__init__()
        self.dim = dim
        
        self.norm1 = LayerNormChannel3d(dim)
        self.norm2 = LayerNormChannel3d(dim)
        self.norm3 = LayerNormChannel3d(dim)
        
        # Cross-attention for spectral-spatial joint modeling
        self.cross_attn = CrossSpectralSpatialAttention(
            dim=dim,
            num_heads=num_heads,
            qkv_bias=True,
            attn_drop=0.1,
            proj_drop=0.1
        )
        
        # Non-local attention for global context
        self.non_local = nn.Sequential(
            nn.Conv3d(dim, dim // 2, 1),
            nn.GELU(),
            nn.Conv3d(dim // 2, dim, 1),
            nn.Sigmoid()
        )
        
        # Keep using existing GDFN and SSMRB
        self.ffn = GDFN(dim, ffn_expansion_factor=mlp_ratio, bias=False)
        self.spectral_refine = SpectralSelfModulatingResidualBlock(
            dim, ffn_expand=2, drop=0.1, drop_factor=1.0
        )
        
        # Gated fusion
        self.fusion_gate = nn.Sequential(
            nn.Conv3d(dim * 3, dim, 1),
            nn.GELU(),
            nn.Conv3d(dim, dim, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        """(B, C, D, H, W) -> (B, C, D, H, W)"""
        identity = x
        
        # Path 1: Cross-attention
        cross_out = self.cross_attn(self.norm1(x))
        x = x + cross_out
        
        # Path 2: Non-local
        non_local_weight = self.non_local(self.norm2(x))
        non_local_out = x * non_local_weight
        x = x + non_local_out
        
        # Path 3: FFN + Spectral refinement
        ffn_out = self.ffn(self.norm3(x))
        spectral_out = self.spectral_refine(x + ffn_out)
        
        # Gated fusion
        fusion_input = torch.cat([cross_out, non_local_out, spectral_out], dim=1)
        fusion_weight = self.fusion_gate(fusion_input)
        
        out = identity + fusion_weight * (cross_out + non_local_out + spectral_out) / 3.0
        
        return out

# ----------------------------
# Modified FusedBottleneck (for Restormer Integration in Architecture)
# ----------------------------
        
class FusedBottleneck(nn.Module):
    """
    IMPROVED FusedBottleneck - SAME NAME, better implementation
    Now uses stacked EnhancedBottleneck blocks for SOTA performance
    Drop-in replacement - same API, just calls different internals
    """
    def __init__(self, base_dim, window_sizes=[2, 4]):
        super().__init__()
        # Calculate actual dim from base_dim (maintains compatibility)
        # Your original: dim = base_dim * 4
        # But you call it with base_dim * 2, so actual dim is base_dim * 8
        dim = base_dim * 4  # This gives base_dim * 8 when called with base_dim * 2
        
        # Use stacked enhanced bottleneck blocks instead of old approach
        self.blocks = nn.ModuleList([
            EnhancedBottleneck(dim, num_heads=8, mlp_ratio=4.),
            #EnhancedBottleneck(dim, num_heads=8, mlp_ratio=4.)
        ])
    
    def forward(self, x):
        """
        SAME API: (B, C, D, H, W) -> (B, C, D, H, W)
        """
        for block in self.blocks:
            x = block(x)
        return x

# ----------------------------
# Modified FusedBottleneck (for Restormer Integration in Architecture)
# ----------------------------

# ----------------------------
# Efficient Loss Function
# ----------------------------

# ----------------------------
# Modified MemoryEfficientLoss (for Spatial-Focused Loss Improvements)
# ----------------------------
class MemoryEfficientLoss(nn.Module):
    """Lightweight but effective loss function with FIXED weights and tensor handling"""
    def __init__(self, device='cuda', mse_weight=1.0, l1_weight=1.0, sam_weight=0.5, edge_weight=0.2):
        super().__init__()
        self.mse = nn.MSELoss()
        self.l1 = nn.L1Loss()
        self.device = device
        self.mse_weight = mse_weight
        self.l1_weight = l1_weight
        self.sam_weight = sam_weight
        self.edge_weight = edge_weight

    def forward(self, pred, target, epoch=None):
        # FIXED: Ensure both tensors have same shape
        if pred.shape != target.shape:
            # If shapes don't match, interpolate pred to match target
            if pred.dim() == 5 and target.dim() == 5:
                pred = F.interpolate(pred, size=target.shape[2:], mode='trilinear', align_corners=False)
            elif pred.dim() == 4 and target.dim() == 4:
                pred = F.interpolate(pred, size=target.shape[2:], mode='bilinear', align_corners=False)

        # Main losses
        mse_loss = self.mse(pred, target)
        l1_loss = self.l1(pred, target)

        # FIXED: SAM calculation with proper tensor handling
        eps = 1e-8

        # Handle both 4D and 5D tensors
        if pred.dim() == 5:  # (B, C, D, H, W)
            B, C, D, H, W = pred.shape
            pred_flat = pred.reshape(B, C, D * H * W)  # (B, C, D*H*W)
            target_flat = target.reshape(B, C, D * H * W)  # (B, C, D*H*W)

            # Normalize along channel dimension (spectral bands)
            pred_norm = F.normalize(pred_flat, dim=1, eps=eps)
            target_norm = F.normalize(target_flat, dim=1, eps=eps)

            # Compute cosine similarity along spectral dimension
            cos_sim = torch.sum(pred_norm * target_norm, dim=1)  # (B, D*H*W)

        elif pred.dim() == 4:  # (B, D, H, W) - spectral first
            B, D, H, W = pred.shape
            pred_flat = pred.reshape(B, D, H * W)  # (B, D, H*W)
            target_flat = target.reshape(B, D, H * W)  # (B, D, H*W)

            # Normalize along spectral dimension
            pred_norm = F.normalize(pred_flat, dim=1, eps=eps)
            target_norm = F.normalize(target_flat, dim=1, eps=eps)

            # Compute cosine similarity along spectral dimension
            cos_sim = torch.sum(pred_norm * target_norm, dim=1)  # (B, H*W)

        else:
            # Fallback for other dimensions
            pred_flat = pred.flatten(start_dim=1)
            target_flat = target.flatten(start_dim=1)
            pred_norm = F.normalize(pred_flat, dim=1, eps=eps)
            target_norm = F.normalize(target_flat, dim=1, eps=eps)
            cos_sim = torch.sum(pred_norm * target_norm, dim=1)

        cos_sim = torch.clamp(cos_sim, -1 + eps, 1 - eps)
        sam_loss = torch.mean(1 - cos_sim)

        # FIXED: Edge loss with proper spatial dimension handling
        def spatial_gradient(x):
            if x.dim() == 5:  # (B, C, D, H, W)
                grad_h = torch.abs(x[:, :, :, 1:, :] - x[:, :, :, :-1, :])
                grad_w = torch.abs(x[:, :, :, :, 1:] - x[:, :, :, :, :-1])
            elif x.dim() == 4:  # (B, D, H, W)
                grad_h = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :])
                grad_w = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])
            else:
                return 0, 0
            return grad_h.mean(), grad_w.mean()

        pred_grad_h, pred_grad_w = spatial_gradient(pred)
        target_grad_h, target_grad_w = spatial_gradient(target)
        edge_loss = abs(pred_grad_h - target_grad_h) + abs(pred_grad_w - target_grad_w)

        # Static combination
        total_loss = (
            self.mse_weight * mse_loss +
            self.l1_weight * l1_loss +
            self.sam_weight * sam_loss +
            self.edge_weight * edge_loss
        )

        return total_loss
# ----------------------------
# Modified MemoryEfficientLoss (for Spatial-Focused Loss Improvements)
# ----------------------------

# ----------------------------
# Memory-Efficient U-Net
# ----------------------------


# ----------------------------
# Memory-Efficient U-Net
# ----------------------------
class MemoryOptimizedUNet(nn.Module):
    """
    SST-based U-Net for HSI Denoising (MODIFIED)
    - 4 hierarchical stages with SST blocks instead of Swin
    - Spectral-aware attention at all levels
    - Deep supervision at each decoder stage
    """
    def __init__(self, in_channels=1, base_dim=48, window_sizes=[4, 8, 16], num_bands=64):
        super().__init__()
        self.base_dim = base_dim
        self.in_channels = in_channels
        self.num_bands = num_bands
        
        # Initial projection
        self.patch_embed = nn.Conv3d(in_channels, base_dim, kernel_size=3, padding=1)
        self.pos_embed_init = PositionalEncoding3D(base_dim, 128, 256, 256)
        
        # ENCODER: 4 SST stages with [2, 2, 6, 2] depth
        # Stage 1: base_dim, shallow features
        self.enc_stage1 = SSTStage(
            dim=base_dim,
            num_bands=num_bands,
            depth=2,
            num_heads=8,
            window_size=8,
            mlp_ratio=4.,
            drop=0.0,
            attn_drop=0.0,
            drop_path_rate=0.05,
            downsample=PatchMerging3D
        )
        
        # Stage 2: base_dim*2, intermediate features
        self.enc_stage2 = SSTStage(
            dim=base_dim * 2,
            num_bands=num_bands // 2,  # Bands halved after merging
            depth=2,
            num_heads=8,
            window_size=8,
            mlp_ratio=4.,
            drop=0.0,
            attn_drop=0.0,
            drop_path_rate=0.1,
            downsample=PatchMerging3D
        )
        
        # Stage 3: base_dim*4, deep features (6 blocks)
        self.enc_stage3 = SSTStage(
            dim=base_dim * 4,
            num_bands=num_bands // 4,
            depth=6,
            num_heads=16,
            window_size=4,
            mlp_ratio=4.,
            drop=0.0,
            attn_drop=0.0,
            drop_path_rate=0.15,
            downsample=PatchMerging3D
        )
        
        # Stage 4 (deepest): base_dim*8, bottleneck
        self.enc_stage4 = SSTStage(
            dim=base_dim * 8,
            num_bands=num_bands // 8,
            depth=2,
            num_heads=16,
            window_size=2,
            mlp_ratio=4.,
            drop=0.0,
            attn_drop=0.0,
            drop_path_rate=0.2,
            downsample=None
        )

        self.pe_enc1 = PositionalEncoding3D(base_dim, 128, 256, 256)
        self.pe_enc2 = PositionalEncoding3D(base_dim * 2, 64, 128, 128)
        self.pe_enc3 = PositionalEncoding3D(base_dim * 4, 32, 64, 64)
        self.pe_enc4 = PositionalEncoding3D(base_dim * 8, 16, 32, 32)
        
        # BOTTLENECK: Keep your custom FusedBottleneck (works well with SST)
        self.spectral_attention = SpectralAttentionModule(base_dim * 8, num_heads=8)
        self.bottleneck_fusion = FusedBottleneck(base_dim * 2, window_sizes=window_sizes)
        
        # DECODER: 4 SST stages matching encoder
        # Stage 3 decoder
        self.dec_stage3_up = PatchExpanding3D(dim=base_dim * 8)
        self.dec_stage3 = SSTStage(
            dim=base_dim * 4,
            num_bands=num_bands // 4,
            depth=6,
            num_heads=16,
            window_size=4,
            mlp_ratio=4.,
            drop=0.0,
            attn_drop=0.0,
            drop_path_rate=0.1,
            downsample=None
        )
        
        # Stage 2 decoder
        self.dec_stage2_up = PatchExpanding3D(dim=base_dim * 4)
        self.dec_stage2 = SSTStage(
            dim=base_dim * 2,
            num_bands=num_bands // 2,
            depth=2,
            num_heads=8,
            window_size=8,
            mlp_ratio=4.,
            drop=0.0,
            attn_drop=0.0,
            drop_path_rate=0.08,
            downsample=None
        )
        
        # Stage 1 decoder
        self.dec_stage1_up = PatchExpanding3D(dim=base_dim * 2)
        self.dec_stage1 = SSTStage(
            dim=base_dim,
            num_bands=num_bands,
            depth=2,
            num_heads=8,
            window_size=8,
            mlp_ratio=4.,
            drop=0.0,
            attn_drop=0.0,
            drop_path_rate=0.02,
            downsample=None
        )
        
        # DEEP SUPERVISION: Auxiliary outputs at each decoder stage
        self.deep_sup3 = nn.Conv3d(base_dim * 4, in_channels, 1)
        self.deep_sup2 = nn.Conv3d(base_dim * 2, in_channels, 1)
        self.deep_sup1 = nn.Conv3d(base_dim, in_channels, 1)
        
        # Final reconstruction
        self.final_conv = nn.Sequential(
            nn.Conv3d(base_dim, base_dim // 2, 3, padding=1),
            nn.GELU(),
            nn.Conv3d(base_dim // 2, in_channels, 1),
        )
        
        # Global residual
        self.global_residual = nn.Conv3d(in_channels, in_channels, 1)
        
        # Deep supervision flag
        self.use_deep_supervision = True

    def _align_tensors(self, x, target_size):
        if x.shape[2:] != target_size:
            x = F.interpolate(x, size=target_size, mode='trilinear', align_corners=False)
        return x

    def forward(self, x, return_deep_sup=False):
        # Handle input shape
        original_was_4d = False
        if x.dim() == 4:
            original_was_4d = True
            x = x.unsqueeze(1)
        elif x.dim() == 5 and x.shape[1] != 1:
            if x.shape[2] == 1:
                x = x.transpose(1, 2)
        
        original_size = x.shape[2:]
        input_residual = self.global_residual(x)
        
        # Initial embedding
        x = self.patch_embed(x)
        x = self.pos_embed_init(x)
        
        # ENCODER (4 SST stages)
        e1, e1_down = self.enc_stage1(self.pe_enc1(x))       # Skip 1
        e2, e2_down = self.enc_stage2(self.pe_enc2(e1_down))  # Skip 2
        e3, e3_down = self.enc_stage3(self.pe_enc3(e2_down))  # Skip 3
        e4, _ = self.enc_stage4(self.pe_enc4(e3_down))        # Deepest features
        
        # BOTTLENECK: Hybrid attention
        b = self.spectral_attention(e4)  # Add spectral attention
        b = self.bottleneck_fusion(b)   # Your custom fusion
        
        # DECODER with deep supervision
        deep_outputs = []
        
        # Decoder stage 3
        d3 = self.dec_stage3_up(b)
        d3 = self._align_tensors(d3, e3.shape[2:])
        d3 = d3 + e3  # Skip connection
        d3, _ = self.dec_stage3(d3)
        if self.training and self.use_deep_supervision:
            sup3 = self.deep_sup3(d3)
            sup3 = self._align_tensors(sup3, original_size)
            deep_outputs.append(sup3)
        
        # Decoder stage 2
        d2 = self.dec_stage2_up(d3)
        d2 = self._align_tensors(d2, e2.shape[2:])
        d2 = d2 + e2
        d2, _ = self.dec_stage2(d2)
        if self.training and self.use_deep_supervision:
            sup2 = self.deep_sup2(d2)
            sup2 = self._align_tensors(sup2, original_size)
            deep_outputs.append(sup2)
        
        # Decoder stage 1
        d1 = self.dec_stage1_up(d2)
        d1 = self._align_tensors(d1, e1.shape[2:])
        d1 = d1 + e1
        d1, _ = self.dec_stage1(d1)
        if self.training and self.use_deep_supervision:
            sup1 = self.deep_sup1(d1)
            sup1 = self._align_tensors(sup1, original_size)
            deep_outputs.append(sup1)
        
        # Final reconstruction
        out = self.final_conv(d1)
        out = self._align_tensors(out, original_size)
        input_residual = self._align_tensors(input_residual, original_size)
        out = out + input_residual
        
        # Return format handling
        if original_was_4d and out.shape[1] == 1:
            out = out.squeeze(1)
            if self.training and self.use_deep_supervision:
                deep_outputs = [o.squeeze(1) for o in deep_outputs]
        
        if return_deep_sup and self.training:
            return out, deep_outputs
        return out

# ----------------------------
# Memory-Efficient U-Net
# ----------------------------

# ----------------------------
# Metric Calculation Functions
# ----------------------------
def calculate_psnr(pred, target):
    """Calculate Peak Signal-to-Noise Ratio"""
    mse = torch.mean((pred - target) ** 2)
    return 10 * torch.log10(1.0 / (mse + 1e-8)).item()

def calculate_ssim(pred_np, target_np):
    """Calculate SSIM across all spectral bands"""
    if pred_np.ndim == 3:  # (D, H, W)
        D, H, W = pred_np.shape
        ssim_vals = []
        for d in range(D):
            try:
                ssim_val = compare_ssim(pred_np[d], target_np[d], data_range=1.0)
                ssim_vals.append(ssim_val)
            except Exception:
                ssim_vals.append(0.5)  # Fallback
        return np.mean(ssim_vals)
    return compare_ssim(pred_np, target_np, data_range=1.0)

def calculate_sam(pred_np, target_np):
    """Calculate Spectral Angle Mapper"""
    eps = 1e-8
    if pred_np.ndim == 3:  # (D, H, W)
        pred_flat = pred_np.reshape(pred_np.shape[0], -1)
        target_flat = target_np.reshape(target_np.shape[0], -1)
        dot = np.sum(pred_flat * target_flat, axis=0)
        norm_pred = np.linalg.norm(pred_flat, axis=0) + eps
        norm_target = np.linalg.norm(target_flat, axis=0) + eps
        cos_angle = np.clip(dot / (norm_pred * norm_target), -1, 1)
        angles = np.arccos(cos_angle)
        return np.mean(angles)
    dot = np.sum(pred_np * target_np)
    norm_pred = np.linalg.norm(pred_np) + eps
    norm_target = np.linalg.norm(target_np) + eps
    cos_angle = np.clip(dot / (norm_pred * norm_target), -1, 1)
    return np.arccos(cos_angle)

# ----------------------------
# Test Data Loading
# ----------------------------
def diagnose_dataset_structure(file_path):
    """Diagnose the structure of a .mat file to understand data organization"""
    if not sio:
        print("scipy not available")
        return None
        
    try:
        mat = sio.loadmat(file_path)
        print(f"\n=== Diagnosing: {os.path.basename(file_path)} ===")
        
        # Print all keys
        keys = [k for k in mat.keys() if not k.startswith('__')]
        print(f"Available keys: {keys}")
        
        for key in keys:
            data = mat[key]
            print(f"Key '{key}': shape={data.shape}, dtype={data.dtype}")
            print(f"  Value range: [{data.min():.4f}, {data.max():.4f}]")
            print(f"  Mean: {data.mean():.4f}, Std: {data.std():.4f}")
            
            # Check if it looks like spectral data
            if len(data.shape) == 3:
                print(f"  3D data - possible interpretations:")
                print(f"    As (H, W, Bands): {data.shape}")
                print(f"    As (Bands, H, W): {data.shape}")
                
                # Check which dimension might be spectral
                dim_variances = [data.var(axis=i).mean() for i in range(3)]
                spectral_dim = np.argmax(dim_variances)
                print(f"  Dimension variances: {dim_variances}")
                print(f"  Likely spectral dimension: {spectral_dim} (highest variance)")
                
        return mat, keys
    except Exception as e:
        print(f"Error diagnosing {file_path}: {e}")
        return None, []
        
def load_test_data(test_dir, target_bands=31, max_files=None, test_crop_size=512):
    """Load test data with ICVL-specific preprocessing: 512×512×31 crops"""
    print(f"Loading test data from: {test_dir}")

    if not os.path.exists(test_dir):
        print(f"Test directory not found: {test_dir}")
        return []

    mat_files = glob.glob(os.path.join(test_dir, '*.mat'))
    if max_files:
        mat_files = mat_files[:max_files]

    print(f"Found {len(mat_files)} test files\n")

    test_data = []
    failed_files = []
    
    for file_path in tqdm(mat_files, desc="Loading test files", ncols=80):
        try:
            # Try loading with h5py first
            with h5py.File(file_path, 'r') as f:
                if 'rad' in f:
                    data_key = 'rad'
                    cube_raw = np.array(f[data_key]).astype(np.float32)
                else:
                    keys = [k for k in f.keys() if not k.startswith('#')]
                    if keys:
                        data_key = keys[0]
                        cube_raw = np.array(f[data_key]).astype(np.float32)
                    else:
                        failed_files.append((os.path.basename(file_path), "No valid keys found"))
                        continue
                        
        except Exception as h5_error:
            if sio:
                try:
                    mat = sio.loadmat(file_path)
                    keys = [k for k in mat.keys() if not k.startswith('__')]
                    data_key = keys[0] if keys else None
                    if data_key:
                        cube_raw = mat[data_key].astype(np.float32)
                    else:
                        failed_files.append((os.path.basename(file_path), "No valid keys in scipy load"))
                        continue
                except Exception as scipy_error:
                    failed_files.append((os.path.basename(file_path), f"Both loaders failed"))
                    continue
            else:
                failed_files.append((os.path.basename(file_path), "h5py failed, scipy unavailable"))
                continue

        try:
            cube_original = cube_raw.copy()

            # ICVL format handling: (D, H, W) -> transpose to (H, W, D)
            if cube_raw.ndim == 3:
                if cube_raw.shape[0] < min(cube_raw.shape[1:]):  # (D, H, W)
                    cube_raw = cube_raw.transpose(1, 2, 0)  # -> (H, W, D)
                    cube_original = cube_original.transpose(1, 2, 0)
            elif cube_raw.ndim == 2:
                cube_raw = cube_raw[np.newaxis, np.newaxis, ...]
                cube_original = cube_original[np.newaxis, np.newaxis, ...]

            if cube_raw.ndim != 3:
                failed_files.append((os.path.basename(file_path), f"Invalid dimensions: {cube_raw.ndim}"))
                continue

            H, W, D = cube_raw.shape

            # CRITICAL: Center crop to 512×512 for testing (ICVL protocol)
            if H > test_crop_size or W > test_crop_size:
                start_h = (H - test_crop_size) // 2
                start_w = (W - test_crop_size) // 2
                cube_raw = cube_raw[start_h:start_h+test_crop_size,
                                   start_w:start_w+test_crop_size, :]
                cube_original = cube_original[start_h:start_h+test_crop_size,
                                             start_w:start_w+test_crop_size, :]
                H, W = test_crop_size, test_crop_size
                #print(f"  Cropped {os.path.basename(file_path)} to {test_crop_size}×{test_crop_size}")
            elif H < test_crop_size or W < test_crop_size:
                # Pad if smaller than 512
                pad_h = max(0, test_crop_size - H)
                pad_w = max(0, test_crop_size - W)
                cube_raw = np.pad(cube_raw, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
                cube_original = np.pad(cube_original, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
                H, W = test_crop_size, test_crop_size
                print(f"  Padded {os.path.basename(file_path)} to {test_crop_size}×{test_crop_size}")

            # Select exactly 31 bands (ICVL standard)
            if D >= target_bands:
                cube_raw = cube_raw[:, :, :target_bands]
                cube_original = cube_original[:, :, :target_bands]
            else:
                failed_files.append((os.path.basename(file_path), f"Insufficient bands: {D} < {target_bands}"))
                continue

            # Store original range
            original_min = cube_original.min()
            original_max = cube_original.max()

            # MATCH TRAINING: Global min-max normalization to [0,1]
            cube = (cube_raw - cube_raw.min()) / (cube_raw.max() - cube_raw.min() + 1e-8)

            # Convert to (D, H, W) format for model
            cube = cube.transpose(2, 0, 1)

            test_data.append({
                'clean': cube,
                'original_range': (original_min, original_max),
                'filename': os.path.basename(file_path),
                'shape': cube.shape,
                'dataset_type': 'icvl'
            })
            
        except Exception as e:
            failed_files.append((os.path.basename(file_path), f"Processing error: {str(e)[:50]}"))
            continue

    # Summary
    print(f"\n{'='*60}")
    print(f"DATA LOADING SUMMARY (ICVL Protocol: {test_crop_size}×{test_crop_size}×{target_bands})")
    print(f"{'='*60}")
    print(f"✓ Successfully loaded: {len(test_data)}/{len(mat_files)} files")
    
    if failed_files:
        print(f"✗ Failed to load: {len(failed_files)} files")
        for fname, reason in failed_files[:5]:
            print(f"    - {fname}: {reason}")
        if len(failed_files) > 5:
            print(f"    ... and {len(failed_files) - 5} more")
    
    print(f"{'='*60}\n")

    return test_data
    
def add_static_noise_to_data(clean_data, noise_level=50, batch_size=5):
    """Add Gaussian noise matching TRAINING protocol exactly"""
    import gc
    
    test_samples = []
    total_samples = len(clean_data)
    num_batches = (total_samples + batch_size - 1) // batch_size
    failed_samples = []
    
    print(f"Adding Gaussian noise (σ={noise_level}) to {total_samples} samples...")
    print(f"Noise protocol: MATCH TRAINING (scale [0,1] -> [0,255], add noise, clip, rescale)")
    
    for batch_idx in range(num_batches):
        batch_start = batch_idx * batch_size
        batch_end = min(batch_start + batch_size, total_samples)
        batch_data = clean_data[batch_start:batch_end]
        
        for local_idx, data_item in enumerate(batch_data):
            global_idx = batch_start + local_idx
            filename = data_item['filename']
            
            try:
                clean_raw = data_item['clean']
                
                if np.isnan(clean_raw).any() or np.isinf(clean_raw).any():
                    raise ValueError(f"Invalid values in {filename}")

                with torch.no_grad():
                    clean_tensor = torch.from_numpy(clean_raw).float().unsqueeze(0)
                    
                    # EXACT MATCH WITH TRAINING: Paper-standard Gaussian noise
                    clean_255 = clean_tensor * 255.0
                    noise = torch.randn_like(clean_255) * noise_level
                    noisy_255 = torch.clamp(clean_255 + noise, 0, 255)
                    noisy_normalized = noisy_255 / 255.0

                    clean_np = clean_tensor[0].cpu().numpy()
                    noisy_np = noisy_normalized[0].cpu().numpy()
                    
                    del clean_tensor, clean_255, noise, noisy_255, noisy_normalized

                test_samples.append({
                    'clean': clean_np,
                    'noisy': noisy_np,
                    'noise_level': noise_level / 255.0,
                    'original_range': data_item['original_range'],
                    'filename': filename,
                    'shape': clean_np.shape
                })
                
            except Exception as e:
                failed_samples.append((global_idx + 1, filename, str(e)))
                continue
        
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print(f"✓ Successfully processed: {len(test_samples)}/{total_samples} samples\n")
    
    return test_samples

    
# ----------------------------
# Testing and Visualization Functions
# ----------------------------
def test_model_comprehensive(model, test_samples, device, patch_size=64):
    """Comprehensive model testing with proper patch handling - SILENT VERSION"""
    model.eval()
    results = []

    print("\n" + "="*60)
    print("TESTING MODEL ON ALL SAMPLES")
    print("="*60)

    with torch.no_grad():
        for i, sample in enumerate(tqdm(test_samples, desc="Testing samples", ncols=80)):
            try:
                clean_np = sample['clean']
                noisy_np = sample['noisy']

                # Create 4D tensors (B, D, H, W)
                clean_tensor = torch.from_numpy(clean_np).unsqueeze(0).float().to(device)
                noisy_tensor = torch.from_numpy(noisy_np).unsqueeze(0).float().to(device)

                B, D, H, W = clean_tensor.shape

                # Process based on image size
                if H <= patch_size and W <= patch_size:
                    with torch.cuda.amp.autocast():
                        output_tensor = model(noisy_tensor)
                else:
                    # Patch-based processing for large images
                    output_tensor = torch.zeros_like(clean_tensor)
                    
                    for h_start in range(0, H, patch_size):
                        for w_start in range(0, W, patch_size):
                            h_end = min(h_start + patch_size, H)
                            w_end = min(w_start + patch_size, W)
                            
                            patch_noisy = noisy_tensor[:, :, h_start:h_end, w_start:w_end]
                            
                            # Pad to patch_size if needed
                            pad_h = patch_size - (h_end - h_start)
                            pad_w = patch_size - (w_end - w_start)
                            if pad_h > 0 or pad_w > 0:
                                patch_noisy = F.pad(patch_noisy, (0, pad_w, 0, pad_h))
                            
                            with torch.cuda.amp.autocast():
                                patch_output = model(patch_noisy)
                            
                            # Remove padding
                            if pad_h > 0 or pad_w > 0:
                                patch_output = patch_output[:, :, :h_end-h_start, :w_end-w_start]
                            
                            output_tensor[:, :, h_start:h_end, w_start:w_end] = patch_output

                output_np = output_tensor[0].cpu().numpy()

                # Calculate metrics
                loss = torch.mean((output_tensor - clean_tensor) ** 2).item()
                psnr = calculate_psnr(output_tensor, clean_tensor)
                ssim = calculate_ssim(output_np, clean_np)
                sam = calculate_sam(output_np, clean_np)
                
                input_psnr = calculate_psnr(noisy_tensor, clean_tensor)
                input_ssim = calculate_ssim(noisy_np, clean_np)
                input_sam = calculate_sam(noisy_np, clean_np)

                results.append({
                    'filename': sample['filename'],
                    'noise_level': sample['noise_level'],
                    'shape': sample['shape'],
                    'input_psnr': input_psnr,
                    'output_psnr': psnr,
                    'psnr_improvement': psnr - input_psnr,
                    'input_ssim': input_ssim,
                    'ssim': ssim,
                    'ssim_improvement': ssim - input_ssim,
                    'input_sam': input_sam,
                    'sam': sam,
                    'sam_improvement': sam - input_sam     
                    'loss': loss,
                    'clean': clean_np,
                    'noisy': noisy_np,
                    'denoised': output_np,
                    'original_range': sample['original_range']
                })

                torch.cuda.empty_cache()

            except Exception as e:
                print(f"\nError processing sample {i+1} ({sample['filename']}): {e}")
                continue

    # Print summary statistics
    if results:
        print(f"\n{'='*60}")
        print(f"TESTING COMPLETED - AVERAGE SCORES")
        print(f"{'='*60}")
        print(f"Total samples tested: {len(results)}/{len(test_samples)}")
        print(f"\nAverage Metrics:")
        print(f"  • Loss:            {np.mean([r['loss'] for r in results]):.6f} ± {np.std([r['loss'] for r in results]):.6f}")
        print(f"  • Input PSNR:      {np.mean([r['input_psnr'] for r in results]):.2f} ± {np.std([r['input_psnr'] for r in results]):.2f} dB")
        print(f"  • Output PSNR:     {np.mean([r['output_psnr'] for r in results]):.2f} ± {np.std([r['output_psnr'] for r in results]):.2f} dB")
        print(f"  • PSNR Improvement: {np.mean([r['psnr_improvement'] for r in results]):.2f} ± {np.std([r['psnr_improvement'] for r in results]):.2f} dB")
        print(f"  • Input SSIM:            {np.mean([r['input_ssim'] for r in results]):.4f} ± {np.std([r['input_ssim'] for r in results]):.4f}")
        print(f"  • Output SSIM:            {np.mean([r['ssim_improvement'] for r in results]):.4f} ± {np.std([r['ssim_improvement'] for r in results]):.4f}")
        print(f"  • SSIM Improvement:            {np.mean([r['ssim'] for r in results]):.4f} ± {np.std([r['ssim'] for r in results]):.4f}")
        print(f"  • Input SAM:             {np.mean([r['input_sam'] for r in results]):.4f} ± {np.std([r['input_sam'] for r in results]):.4f}")
        print(f"  • Output SAM:             {np.mean([r['sam'] for r in results]):.4f} ± {np.std([r['sam'] for r in results]):.4f}")
        print(f"  • SAM Improvement:             {np.mean([r['sam_improvement'] for r in results]):.4f} ± {np.std([r['sam_improvement'] for r in results]):.4f}")
        
        # Best and worst cases
        sorted_by_psnr = sorted(results, key=lambda x: x['output_psnr'])
        print(f"\nPerformance Range:")
        print(f"  • Best PSNR:  {sorted_by_psnr[-1]['output_psnr']:.2f} dB ({sorted_by_psnr[-1]['filename']})")
        print(f"  • Worst PSNR: {sorted_by_psnr[0]['output_psnr']:.2f} dB ({sorted_by_psnr[0]['filename']})")
        print(f"{'='*60}\n")
    else:
        print("\nNo results generated!")

    return results
    
def create_comprehensive_visualizations(results, save_dir):
    """Create comprehensive test result visualizations"""
    print("Creating comprehensive visualizations...")

    plt.style.use('default')
    sns.set_palette("husl")

    fig, axes = plt.subplots(2, 3, figsize=(20, 12))

    noise_level = results[0]['noise_level']
    input_psnrs = [r['input_psnr'] for r in results]
    output_psnrs = [r['output_psnr'] for r in results]
    psnr_improvements = [r['psnr_improvement'] for r in results]
    ssims = [r['ssim'] for r in results]
    sams = [r['sam'] for r in results]

    # PSNR comparison
    axes[0, 0].scatter(input_psnrs, output_psnrs, alpha=0.7, s=60, color='blue')
    axes[0, 0].plot([min(input_psnrs), max(output_psnrs)], [min(input_psnrs), max(output_psnrs)], 'r--', alpha=0.5)
    axes[0, 0].axhline(y=40, color='red', linestyle='-', alpha=0.7, label='Target (40 dB)')
    axes[0, 0].set_xlabel('Input PSNR (dB)')
    axes[0, 0].set_ylabel('Output PSNR (dB)')
    axes[0, 0].set_title(f'PSNR: Input vs Output (σ={noise_level})')
    axes[0, 0].grid(True, alpha=0.3)
    axes[0, 0].legend()

    # PSNR improvement distribution
    axes[0, 1].hist(psnr_improvements, bins=15, alpha=0.7, color='green', edgecolor='black')
    axes[0, 1].axvline(x=np.mean(psnr_improvements), color='red', linestyle='--',
                      label=f'Mean: {np.mean(psnr_improvements):.2f} dB')
    axes[0, 1].set_xlabel('PSNR Improvement (dB)')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].set_title('PSNR Improvement Distribution')
    axes[0, 1].grid(True, alpha=0.3)
    axes[0, 1].legend()

    # SSIM distribution
    axes[0, 2].hist(ssims, bins=20, alpha=0.7, color='green', edgecolor='black')
    axes[0, 2].axvline(x=np.mean(ssims), color='red', linestyle='--',
                      label=f'Mean: {np.mean(ssims):.3f}')
    axes[0, 2].set_xlabel('SSIM')
    axes[0, 2].set_ylabel('Frequency')
    axes[0, 2].set_title('SSIM Distribution')
    axes[0, 2].grid(True, alpha=0.3)
    axes[0, 2].legend()

    # SAM distribution
    axes[1, 0].hist(sams, bins=20, alpha=0.7, color='orange', edgecolor='black')
    axes[1, 0].axvline(x=np.mean(sams), color='red', linestyle='--',
                      label=f'Mean: {np.mean(sams):.3f}')
    axes[1, 0].set_xlabel('SAM (radians)')
    axes[1, 0].set_ylabel('Frequency')
    axes[1, 0].set_title('SAM Distribution')
    axes[1, 0].grid(True, alpha=0.3)
    axes[1, 0].legend()

    # Performance vs Image Size
    image_sizes = [r['shape'][1] * r['shape'][2] for r in results]
    axes[1, 1].scatter(image_sizes, output_psnrs, alpha=0.7, s=60, color='purple')
    axes[1, 1].set_xlabel('Image Size (pixels)')
    axes[1, 1].set_ylabel('Output PSNR (dB)')
    axes[1, 1].set_title('Performance vs Image Size')
    axes[1, 1].grid(True, alpha=0.3)
    axes[1, 1].axhline(y=40, color='red', linestyle='--', alpha=0.7, label='Target (40 dB)')
    axes[1, 1].legend()

    # Summary statistics
    target_achieved = sum(1 for p in output_psnrs if p > 40)
    success_rate = target_achieved / len(output_psnrs) * 100
    axes[1, 2].axis('off')
    stats_text = f"""
TEST RESULTS SUMMARY
{'='*25}
Static Noise Testing (σ={noise_level})
Total Samples: {len(results)}
Unique Images: {len(set([r['filename'] for r in results]))}

PSNR STATISTICS:
- Mean Output PSNR: {np.mean(output_psnrs):.2f} ± {np.std(output_psnrs):.2f} dB
- Mean Input PSNR: {np.mean(input_psnrs):.2f} ± {np.std(input_psnrs):.2f} dB
- Mean Improvement: {np.mean(psnr_improvements):.2f} ± {np.std(psnr_improvements):.2f} dB
- Target Achievement: {target_achieved} / {len(output_psnrs)} ({success_rate:.1f}%)
- Max PSNR: {max(output_psnrs):.2f} dB
- Min PSNR: {min(output_psnrs):.2f} dB

OTHER METRICS:
- Mean SSIM: {np.mean(ssims):.4f} ± {np.std(ssims):.4f}
- Mean SAM: {np.mean(sams):.4f} ± {np.std(sams):.4f}

CONSISTENCY:
- PSNR Std Dev: {np.std(output_psnrs):.2f} dB
- Performance: {'Excellent' if success_rate > 80 else 'Good' if success_rate > 50 else 'Needs Improvement'}
    """

    axes[1, 2].text(0.05, 0.95, stats_text, transform=axes[1, 2].transAxes,
                    fontsize=11, verticalalignment='top', fontfamily='monospace',
                    bbox=dict(boxstyle="round,pad=0.5", facecolor="lightblue", alpha=0.8))

    plt.suptitle(f'Comprehensive HSI Denoising Test Results (Static noise σ={noise_level})',
                 fontsize=16, y=0.98)
    plt.tight_layout()
    plt.subplots_adjust(top=0.94)
    plt.savefig(os.path.join(save_dir, 'comprehensive_test_results.png'), dpi=150, bbox_inches='tight')
    plt.close()

    return {
        'mean_psnr': np.mean(output_psnrs),
        'std_psnr': np.std(output_psnrs),
        'mean_improvement': np.mean(psnr_improvements),
        'std_improvement': np.std(psnr_improvements),
        'mean_ssim': np.mean(ssims),
        'std_ssim': np.std(ssims),
        'mean_sam': np.mean(sams),
        'std_sam': np.std(sams),
        'mean_loss': np.mean([r['loss'] for r in results]),
        'std_loss': np.std([r['loss'] for r in results]),
        'target_achievement_rate': success_rate / 100,
        'max_psnr': max(output_psnrs),
        'min_psnr': min(output_psnrs),
        'noise_level': noise_level
    }

def create_sample_visualization(result, title, save_dir):
    """Create detailed visualization with dataset-specific RGB band selection"""
    clean = result['clean']
    noisy = result['noisy'] 
    denoised = result['denoised']
    D, H, W = clean.shape

    # Dataset-specific RGB band selection
    filename_lower = result['filename'].lower()
    
    print(f"Creating visualization for {result['filename']}")
    
    # FIXED: Correct RGB band selections for common datasets
    if 'indian' in filename_lower:
        # Indian Pines (220 bands): Standard RGB approximation
        if D >= 150:
            rgb_bands = [49, 26, 16]  # Red: ~630nm, Green: ~550nm, Blue: ~470nm
            print(f"  Using Indian Pines standard RGB bands: {[b+1 for b in rgb_bands]}")
        elif D >= 100:
            rgb_bands = [int(D*0.22), int(D*0.12), int(D*0.07)]
            print(f"  Using scaled Indian bands: {[b+1 for b in rgb_bands]}")
        else:
            rgb_bands = [min(D-1, 22), min(D-1, 12), min(D-1, 7)]
            print(f"  Using minimal Indian bands: {[b+1 for b in rgb_bands]}")
    
    elif 'pavia' in filename_lower:
        # Pavia University/Centre (103 bands)
        if D >= 80:
            rgb_bands = [55, 41, 12]
        else:
            rgb_bands = [int(D*0.7), int(D*0.5), int(D*0.15)]
        print(f"  Using Pavia RGB bands: {[b+1 for b in rgb_bands]}")
    
    elif 'washington' in filename_lower or 'dc' in filename_lower:
        # Washington DC Mall (191 bands)
        if D >= 150:
            rgb_bands = [56, 26, 16]
        else:
            rgb_bands = [int(D*0.35), int(D*0.15), int(D*0.08)]
        print(f"  Using Washington RGB bands: {[b+1 for b in rgb_bands]}")
    
    else:
        # Generic approach: use spread across spectrum
        rgb_bands = [
            min(D-1, int(D*0.7)),   # Near-infrared/Red
            min(D-1, int(D*0.4)),   # Green/Yellow
            min(D-1, int(D*0.15))   # Blue
        ]
        print(f"  Using generic RGB bands: {[b+1 for b in rgb_bands]}")

    # Ensure bands are valid
    # rgb_bands = [min(max(0, b), D-1) for b in rgb_bands]
    rgb_bands = [22,13,5]
    
    # Create RGB composites with enhanced contrast
    rgb_images = {}
    for data_type, data in [('clean', clean), ('noisy', noisy), ('denoised', denoised)]:
        rgb_image = np.zeros((H, W, 3))
        for i, band in enumerate(rgb_bands):
            band_data = data[band]
            
            rgb_image[:, :, i] = np.clip(band_data, 0, 1)
        
        # Slight gamma correction for visual appeal
        rgb_image = np.power(rgb_image, 0.95)
        rgb_images[data_type] = rgb_image

    # Create figure
    fig, axes = plt.subplots(1, 3, figsize=(19, 8))

    axes[0].imshow(rgb_images['clean'], interpolation='nearest')
    axes[0].set_title(f'Clean (Bands {rgb_bands[0]+1}, {rgb_bands[1]+1}, {rgb_bands[2]+1})')
    axes[0].axis('off')

    axes[1].imshow(rgb_images['noisy'], interpolation='nearest')
    axes[1].set_title(f'Noisy (Bands {rgb_bands[0]+1}, {rgb_bands[1]+1}, {rgb_bands[2]+1})')
    axes[1].axis('off')

    axes[2].imshow(rgb_images['denoised'], interpolation='nearest')
    axes[2].set_title(f'Denoised (Bands {rgb_bands[0]+1}, {rgb_bands[1]+1}, {rgb_bands[2]+1})')
    axes[2].axis('off')

    orig_min, orig_max = result['original_range']
    fig.suptitle(f'{title}\n'
                f'File: {result["filename"]}, Noise: σ={result["noise_level"]:.2f}, '
                f'Original Range: [{orig_min:.2f}, {orig_max:.2f}]\n'
                f'PSNR: {result["input_psnr"]:.2f}→{result["output_psnr"]:.2f} dB (+{result["psnr_improvement"]:.2f}), '
                f'SSIM: {result["ssim"]:.3f}, SAM: {result["sam"]:.3f}',
                fontsize=14, y=0.98)
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.88)
    safe_title = title.replace(' ', '_').lower()
    safe_filename = result['filename'].replace('.mat', '').replace(' ', '_')
    plt.savefig(os.path.join(save_dir, f'sample_{safe_title}_{safe_filename}.png'),
                dpi=150, bbox_inches='tight')
    plt.close()

    print(f"  Saved visualization: sample_{safe_title}_{safe_filename}.png")

def create_detailed_analysis_plots(results, save_dir):
    """Create detailed analysis plots"""
    print("Creating detailed analysis plots...")

    fig, axes = plt.subplots(2, 2, figsize=(16, 12))

    input_psnrs = [r['input_psnr'] for r in results]
    output_psnrs = [r['output_psnr'] for r in results]
    psnr_improvements = [r['psnr_improvement'] for r in results]
    ssims = [r['ssim'] for r in results]
    sams = [r['sam'] for r in results]
    losses = [r['loss'] for r in results]
    noise_level = results[0]['noise_level']
    image_sizes = [r['shape'][1] * r['shape'][2] for r in results]

    # Correlation matrix
    try:
        import pandas as pd
        metrics_data = {
            'Input_PSNR': input_psnrs,
            'Output_PSNR': output_psnrs,
            'PSNR_Improvement': psnr_improvements,
            'SSIM': ssims,
            'SAM': sams,
            'Loss': losses,
            'Image_Size': image_sizes
        }
        df = pd.DataFrame(metrics_data)
        correlation_matrix = df.corr()

        im = axes[0, 0].imshow(correlation_matrix.values, cmap='RdYlBu', vmin=-1, vmax=1)
        axes[0, 0].set_xticks(range(len(correlation_matrix.columns)))
        axes[0, 0].set_yticks(range(len(correlation_matrix.columns)))
        axes[0, 0].set_xticklabels(correlation_matrix.columns, rotation=45, ha='right')
        axes[0, 0].set_yticklabels(correlation_matrix.columns)
        axes[0, 0].set_title('Metrics Correlation Matrix')

        for i in range(len(correlation_matrix.columns)):
            for j in range(len(correlation_matrix.columns)):
                axes[0, 0].text(j, i, f'{correlation_matrix.iloc[i, j]:.2f}',
                               ha="center", va="center", color="black", fontsize=9)

        plt.colorbar(im, ax=axes[0, 0], shrink=0.8)
    except ImportError:
        axes[0, 0].text(0.5, 0.5, 'pandas not available\nfor correlation matrix',
                       ha='center', va='center', transform=axes[0, 0].transAxes)

    # PSNR improvement distribution
    axes[0, 1].hist(psnr_improvements, bins=15, alpha=0.7, color='green', edgecolor='black')
    axes[0, 1].axvline(x=np.mean(psnr_improvements), color='red', linestyle='--', linewidth=2,
                      label=f'Mean: {np.mean(psnr_improvements):.2f} dB')
    axes[0, 1].axvline(x=np.median(psnr_improvements), color='blue', linestyle='--', linewidth=2,
                      label=f'Median: {np.median(psnr_improvements):.2f} dB')
    axes[0, 1].set_xlabel('PSNR Improvement (dB)')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].set_title(f'PSNR Improvement Distribution (σ={noise_level})')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    # Performance vs Image Size
    scatter = axes[1, 0].scatter(image_sizes, output_psnrs, alpha=0.7, s=60,
                                c=psnr_improvements, cmap='RdYlGn')
    axes[1, 0].set_xlabel('Image Size (pixels)')
    axes[1, 0].set_ylabel('Output PSNR (dB)')
    axes[1, 0].set_title('PSNR vs Image Size (colored by improvement)')
    axes[1, 0].grid(True, alpha=0.3)
    axes[1, 0].axhline(y=40, color='red', linestyle='--', alpha=0.7, label='Target (40 dB)')
    axes[1, 0].legend()
    plt.colorbar(scatter, ax=axes[1, 0], label='PSNR Improvement (dB)', shrink=0.8)

    # Performance summary
    target_achieved = sum(1 for p in output_psnrs if p > 40)
    success_rate = target_achieved / len(output_psnrs) * 100
    consistency_metric = np.std(psnr_improvements) / np.mean(psnr_improvements) if np.mean(psnr_improvements) > 0 else 0
    stats_text = f"""
DETAILED PERFORMANCE ANALYSIS
{'='*35}
Static Noise Testing (σ={noise_level})

PSNR PERFORMANCE:
- Mean Output PSNR: {np.mean(output_psnrs):.2f} ± {np.std(output_psnrs):.2f} dB
- Mean Input PSNR: {np.mean(input_psnrs):.2f} ± {np.std(input_psnrs):.2f} dB
- Mean Improvement: {np.mean(psnr_improvements):.2f} ± {np.std(psnr_improvements):.2f} dB
- Max Improvement: {max(psnr_improvements):.2f} dB
- Min Improvement: {min(psnr_improvements):.2f} dB

TARGET ACHIEVEMENT:
- Success Rate: {success_rate:.1f}% ({target_achieved}/{len(results)})
- Above 40 dB: {target_achieved} samples
- Above 35 dB: {sum(1 for p in output_psnrs if p > 35)} samples
- Above 30 dB: {sum(1 for p in output_psnrs if p > 30)} samples

CONSISTENCY ANALYSIS:
- Coefficient of Variation: {consistency_metric:.3f}
- Performance Stability: {'High' if consistency_metric < 0.2 else 'Moderate' if consistency_metric < 0.5 else 'Low'}

QUALITY METRICS:
- SSIM: {np.mean(ssims):.4f} ± {np.std(ssims):.4f}
- SAM: {np.mean(sams):.4f} ± {np.std(sams):.4f}

IMAGE SIZE ANALYSIS:
- Min Size: {min(image_sizes)} pixels
- Max Size: {max(image_sizes)} pixels
- Mean Size: {np.mean(image_sizes):.0f} pixels
    """

    axes[1, 1].axis('off')
    axes[1, 1].text(0.05, 0.95, stats_text, transform=axes[1, 1].transAxes,
                    fontsize=10, verticalalignment='top', fontfamily='monospace',
                    bbox=dict(boxstyle="round,pad=0.4", facecolor="lightblue", alpha=0.8))

    plt.suptitle(f'Detailed Analysis - Static Noise Level (σ={noise_level})',
                 fontsize=16, y=0.98)
    plt.tight_layout()
    plt.subplots_adjust(top=0.94)
    plt.savefig(os.path.join(save_dir, 'detailed_analysis_plots.png'), dpi=150, bbox_inches='tight')
    plt.close()

    return {
        'mean_psnr': np.mean(output_psnrs),
        'std_psnr': np.std(output_psnrs),
        'mean_improvement': np.mean(psnr_improvements),
        'std_improvement': np.std(psnr_improvements),
        'mean_ssim': np.mean(ssims),
        'std_ssim': np.std(ssims),
        'mean_sam': np.mean(sams),
        'std_sam': np.std(sams),
        'mean_loss': np.mean([r['loss'] for r in results]),
        'std_loss': np.std([r['loss'] for r in results]),
        'target_achievement_rate': success_rate / 100,
        'max_psnr': max(output_psnrs),
        'min_psnr': min(output_psnrs),
        'noise_level': noise_level
    }

def create_spectral_analysis(results, save_dir, num_samples=3):
    """Create spectral signature analysis"""
    print("Creating spectral analysis...")

    sorted_results = sorted(results, key=lambda x: x['output_psnr'])
    sample_indices = [0, len(sorted_results)//2, len(sorted_results)-1]
    selected_results = [sorted_results[i] for i in sample_indices]

    fig, axes = plt.subplots(len(selected_results), 1, figsize=(15, 4*len(selected_results)))
    if len(selected_results) == 1:
        axes = [axes]

    for idx, result in enumerate(selected_results):
        clean = result['clean']
        noisy = result['noisy']
        denoised = result['denoised']

        D, H, W = clean.shape
        center_h, center_w = H//2, W//2

        clean_spectrum = clean[:, center_h, center_w]
        noisy_spectrum = noisy[:, center_h, center_w]
        denoised_spectrum = denoised[:, center_h, center_w]

        bands = np.arange(D)

        axes[idx].plot(bands, clean_spectrum, 'g-', linewidth=2, label='Clean', alpha=0.8)
        axes[idx].plot(bands, noisy_spectrum, 'r--', linewidth=1.5, label='Noisy', alpha=0.7)
        axes[idx].plot(bands, denoised_spectrum, 'b-', linewidth=2, label='Denoised', alpha=0.8)

        axes[idx].set_xlabel('Spectral Band')
        axes[idx].set_ylabel('Reflectance')
        axes[idx].set_title(f'Spectral Signature - {result["filename"]}\n'
                           f'PSNR: {result["output_psnr"]:.2f} dB, SSIM: {result["ssim"]:.3f}, '
                           f'SAM: {result["sam"]:.4f}')
        axes[idx].legend()
        axes[idx].grid(True, alpha=0.3)

        ax2 = axes[idx].twinx()
        denoising_improvement = np.abs(denoised_spectrum - clean_spectrum) - np.abs(noisy_spectrum - clean_spectrum)
        ax2.fill_between(bands, 0, denoising_improvement, alpha=0.3, color='purple', label='Denoising Effect')
        ax2.set_ylabel('Denoising Effect', color='purple')
        ax2.tick_params(axis='y', labelcolor='purple')

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'spectral_analysis.png'), dpi=150, bbox_inches='tight')
    plt.close()

def save_detailed_results(results, summary_stats, save_dir, model_info=None):
    """Save detailed results as visualizations"""
    print("Saving visualization results...")

    print(f"Results saved to {save_dir}")
    print(f"• Comprehensive plots: comprehensive_test_results.png")
    print(f"• Detailed analysis: detailed_analysis_plots.png")
    print(f"• Spectral analysis: spectral_analysis.png")
    print(f"• Sample visualizations: sample_*.png")
    print(f"• Testing Mode: Static noise (σ={summary_stats['noise_level']})")


def main():
    """Main testing function with ICVL protocol"""
    print("="*80)
    print("HSI DENOISING MODEL TESTING - ICVL Protocol")
    print("Test crops: 512×512×31 with training-matched normalization")
    print("="*80)

    # Configuration
    TEST_DIR = '/workspace/icvl_part/test_gauss'
    MODEL_PATH = './HSI_denoising_ICVL_resultsV15_noise30'
    RESULTS_DIR = './HSI_denoising_ICVL_resultsV15_noise30/test_results'
    TRAINING_NOISE_LEVEL = 10
    
    config = {
        'base_dim': 64,
        'target_bands': 31,  # ICVL standard
        'patch_size': 64,    # For patch-based processing if needed
        'test_crop_size': 512,  # ICVL test protocol
        'noise_level': TRAINING_NOISE_LEVEL,
    }

    os.makedirs(RESULTS_DIR, exist_ok=True)
    # Device setup
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")

    if torch.cuda.is_available():
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
        torch.cuda.empty_cache()

    # Load model
    print("\nLoading trained model...")
    full_model_path = os.path.join(MODEL_PATH, 'enhanced_denoising_pipeline_full.pth')
    best_model_path = os.path.join(MODEL_PATH, 'best_memory_optimized_model.pth')

    model_info = None
    if os.path.exists(full_model_path):
        print("Loading full model with metadata...")
        checkpoint = torch.load(full_model_path, map_location=device, weights_only=False)
        model_info = checkpoint
        config = checkpoint.get('config', config)
    elif os.path.exists(best_model_path):
        print("Loading best model checkpoint...")
        checkpoint = torch.load(best_model_path, map_location=device, weights_only=False)
        config = checkpoint.get('config', config)
    else:
        raise FileNotFoundError(f"No model found in {MODEL_PATH}")

    # Initialize model
    model = MemoryOptimizedUNet(
        in_channels=1,
        base_dim=config['base_dim'],
        window_sizes=[4, 8, 16],
        num_bands=config['target_bands']
    ).to(device)

    # Escape strict model parameter matching here
    #model.load_state_dict(checkpoint['model_state_dict'])
    #model.load_state_dict(checkpoint['model_state_dict'], strict=False)

    # ===================================================================
    # FIXED: Handle spectral_pos_embed size mismatches
    # ===================================================================
    # Get the saved state dict and current model's state dict
    state_dict = checkpoint['model_state_dict']
    model_dict = model.state_dict()
    
    # Prepare filtered state dict
    filtered_state_dict = {}
    adjusted_keys = []
    skipped_keys = []
    
    for k, v in state_dict.items():
        if k in model_dict:
            if v.shape == model_dict[k].shape:
                # Shapes match - load directly
                filtered_state_dict[k] = v
            else:
                # Size mismatch detected
                if 'spectral_pos_embed' in k:
                    # Handle spectral positional embedding mismatches
                    needed_bands = model_dict[k].shape[1]
                    available_bands = v.shape[1]
                    
                    if needed_bands <= available_bands:
                        # Checkpoint has more bands than we need - slice it
                        filtered_state_dict[k] = v[:, :needed_bands, :].clone()
                        adjusted_keys.append(f"{k}: {v.shape} -> {filtered_state_dict[k].shape}")
                    else:
                        # Checkpoint has fewer bands than we need - pad it
                        padded = torch.zeros_like(model_dict[k])
                        padded[:, :available_bands, :] = v
                        filtered_state_dict[k] = padded
                        adjusted_keys.append(f"{k}: {v.shape} -> {filtered_state_dict[k].shape} (padded)")
                else:
                    # Non-spectral parameter with mismatch - skip it
                    skipped_keys.append(f"{k}: checkpoint {v.shape} vs model {model_dict[k].shape}")
        else:
            # Key not in model - likely from different architecture
            skipped_keys.append(f"{k}: not found in current model")
    
    # Report what was adjusted
    if adjusted_keys:
        print(f"\n✓ Adjusted {len(adjusted_keys)} spectral_pos_embed parameters:")
        for key in adjusted_keys[:5]:  # Show first 5
            print(f"    {key}")
        if len(adjusted_keys) > 5:
            print(f"    ... and {len(adjusted_keys) - 5} more")
    
    if skipped_keys:
        print(f"\n⚠ Skipped {len(skipped_keys)} mismatched parameters:")
        for key in skipped_keys[:3]:  # Show first 3
            print(f"    {key}")
        if len(skipped_keys) > 3:
            print(f"    ... and {len(skipped_keys) - 3} more")
    
    # Load the filtered state dict (strict=False allows missing keys)
    missing_keys, unexpected_keys = model.load_state_dict(filtered_state_dict, strict=True)
    
    if missing_keys:
        print(f"\n⚠ Missing keys (will use random initialization): {len(missing_keys)}")
        if len(missing_keys) <= 5:
            for key in missing_keys:
                print(f"    {key}")
    
    if unexpected_keys:
        print(f"\n⚠ Unexpected keys (ignored): {len(unexpected_keys)}")
    
    print("\n✓ Model weights loaded successfully!")
    # ===================================================================
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model loaded successfully!")
    print(f"Parameters: {total_params / 1e6:.2f}M")
    print(f"Base dimension: {config['base_dim']}")
    print(f"Target bands: {config['target_bands']}")
    print(f"Training noise level: σ={config['noise_level']}")

    if model_info and 'best_psnr' in model_info:
        print(f"Training best PSNR: {model_info['best_psnr']:.2f} dB")

    # Load test data
    test_data = load_test_data(
        TEST_DIR, 
        target_bands=config['target_bands'],
        test_crop_size=config.get('test_crop_size', 512)
    )

    if not test_data:
        print("ERROR: No test data loaded!")
        return
        

    print(f"Loaded {len(test_data)} test images (512×512×31)")

    # Add static noise
    test_samples = add_static_noise_to_data(test_data, noise_level=config['noise_level'])

    print("\nNoise Analysis:")
    noise_types = {}
    actual_noise_levels = []
    for sample in test_samples:
        noise_type = sample.get('noise_type', 'unknown')
        noise_types[noise_type] = noise_types.get(noise_type, 0) + 1
        actual_noise_levels.append(sample.get('actual_noise_level', 0))
    
    print(f"Noise types distribution: {noise_types}")
    print(f"Actual noise levels - Mean: {np.mean(actual_noise_levels):.4f}, "
          f"Std: {np.std(actual_noise_levels):.4f}, "
          f"Range: [{np.min(actual_noise_levels):.4f}, {np.max(actual_noise_levels):.4f}]")
    
    # Verify noise is visible
    sample_noise = test_samples[0]['noisy'] - test_samples[0]['clean']
    print(f"Sample noise verification - Mean: {np.mean(sample_noise):.4f}, "
          f"Std: {np.std(sample_noise):.4f}, "
          f"Max: {np.max(np.abs(sample_noise)):.4f}")

    print(f"Created {len(test_samples)} test samples with static noise σ={config['noise_level']}")

    # Test model
    results = test_model_comprehensive(model, test_samples, device, patch_size=config['patch_size'])

    if not results:
        print("No test results generated!")
        return

    print(f"\nSuccessfully tested {len(results)} samples")

    # Summary statistics
    print("\n" + "="*60)
    print("SUMMARY STATISTICS:")
    print("="*60)
    print(f"Number of test samples: {len(results)}")
    print(f"Average Test Loss: {np.mean([r['loss'] for r in results]):.6f} ± {np.std([r['loss'] for r in results]):.6f}")
    print(f"Average PSNR:      {np.mean([r['output_psnr'] for r in results]):.3f} ± {np.std([r['output_psnr'] for r in results]):.3f} dB")
    print(f"Average SSIM:      {np.mean([r['ssim'] for r in results]):.4f} ± {np.std([r['ssim'] for r in results]):.4f}")
    print(f"Average SAM:       {np.mean([r['sam'] for r in results]):.4f} ± {np.std([r['sam'] for r in results]):.4f}")

    sorted_results = sorted(results, key=lambda x: x['output_psnr'])
    best_result = sorted_results[-1]
    worst_result = sorted_results[0]
    print(f"Best PSNR:  {best_result['filename']} ({best_result['output_psnr']:.3f} dB)")
    print(f"Worst PSNR: {worst_result['filename']} ({worst_result['output_psnr']:.3f} dB)")
    print("="*60)

    # Visualizations
    summary_stats = create_comprehensive_visualizations(results, RESULTS_DIR)

    sorted_results = sorted(results, key=lambda x: x['output_psnr'])
    best_result = sorted_results[-1]
    worst_result = sorted_results[0]
    median_result = sorted_results[len(sorted_results)//2]

    for title, result in [('Best Performance', best_result),
                         ('Median Performance', median_result),
                         ('Challenging Case', worst_result)]:
        create_sample_visualization(result, title, RESULTS_DIR)

    create_spectral_analysis(results, RESULTS_DIR)
    save_detailed_results(results, summary_stats, RESULTS_DIR, model_info)

    # Final summary
    print("\n" + "="*80)
    print("COMPREHENSIVE TEST RESULTS SUMMARY")
    print("="*80)

    print(f"Test Configuration:")
    print(f"  • Test Directory: {TEST_DIR}")
    print(f"  • Model Path: {MODEL_PATH}")
    print(f"  • Total Samples: {len(results)}")
    print(f"  • Unique Images: {len(set([r['filename'] for r in results]))}")
    print(f"  • Noise Level: σ={summary_stats['noise_level']}")
    print(f"  • Device: {device}")

    print(f"\nPerformance Metrics:")
    print(f"  • Mean PSNR: {summary_stats['mean_psnr']:.2f} ± {summary_stats['std_psnr']:.2f} dB")
    print(f"  • Mean Improvement: {summary_stats['mean_improvement']:.2f} dB")
    print(f"  • Max PSNR: {summary_stats['max_psnr']:.2f} dB")
    print(f"  • Min PSNR: {summary_stats['min_psnr']:.2f} dB")
    print(f"  • Mean SSIM: {summary_stats['mean_ssim']:.4f} ± {summary_stats['std_ssim']:.4f}")
    print(f"  • Mean SAM: {summary_stats['mean_sam']:.4f} ± {summary_stats['std_sam']:.4f}")

    print(f"\nTarget Achievement (PSNR > 40 dB):")
    achievement_rate = summary_stats['target_achievement_rate'] * 100
    target_count = int(achievement_rate * len(results) / 100)
    print(f"  • Success Rate: {achievement_rate:.1f}% ({target_count} / {len(results)} samples)")
    print(f"  • Status: {'EXCELLENT' if achievement_rate > 80 else 'GOOD' if achievement_rate > 50 else 'NEEDS IMPROVEMENT'}")

    print(f"\nModel Architecture:")
    print(f"  • Parameters: {total_params / 1e6:.2f}M")
    print(f"  • Base Dimension: {config['base_dim']}")
    print(f"  • Input Bands: {config['target_bands']}")
    print(f"  • Patch Size: {config['patch_size']}")

    if model_info and 'best_psnr' in model_info:
        training_psnr = model_info['best_psnr']
        test_psnr = summary_stats['mean_psnr']
        generalization = "Good" if abs(training_psnr - test_psnr) < 5 else "Moderate" if abs(training_psnr - test_psnr) < 10 else "Poor"
        print(f"\nGeneralization Analysis:")
        print(f"  • Training PSNR: {training_psnr:.2f} dB")
        print(f"  • Test PSNR: {test_psnr:.2f} dB")
        print(f"  • Difference: {abs(training_psnr - test_psnr):.2f} dB")
        print(f"  • Generalization: {generalization}")

    print(f"\n" + "="*80)
    print(f"Final Average PSNR: {summary_stats['mean_psnr']:.3f} ± {summary_stats['std_psnr']:.3f} dB")
    print(f"Final Average SSIM: {summary_stats['mean_ssim']:.4f} ± {summary_stats['std_ssim']:.4f}")
    print(f"Final Average SAM:  {summary_stats['mean_sam']:.4f} ± {summary_stats['std_sam']:.4f}")
    print(f"\nResults saved to: {RESULTS_DIR}")
    print("="*80)
    print("TESTING COMPLETED SUCCESSFULLY!")
    print("="*80)

    return results, summary_stats, model_info

if __name__ == '__main__':
    main()

HSI DENOISING MODEL TESTING - ICVL Protocol
Test crops: 512×512×31 with training-matched normalization
Using device: cuda
GPU Memory: 12.5 GB

Loading trained model...
Loading full model with metadata...

✓ Adjusted 18 spectral_pos_embed parameters:
    enc_stage2.blocks.0.spectral_attn.spectral_pos_embed: torch.Size([1, 16, 128]) -> torch.Size([1, 15, 128])
    enc_stage2.blocks.1.spectral_attn.spectral_pos_embed: torch.Size([1, 16, 128]) -> torch.Size([1, 15, 128])
    enc_stage3.blocks.0.spectral_attn.spectral_pos_embed: torch.Size([1, 8, 256]) -> torch.Size([1, 7, 256])
    enc_stage3.blocks.1.spectral_attn.spectral_pos_embed: torch.Size([1, 8, 256]) -> torch.Size([1, 7, 256])
    enc_stage3.blocks.2.spectral_attn.spectral_pos_embed: torch.Size([1, 8, 256]) -> torch.Size([1, 7, 256])
    ... and 13 more

✓ Model weights loaded successfully!
Model loaded successfully!
Parameters: 53.67M
Base dimension: 64
Target bands: 31
Training noise level: σ=50
Training best PSNR: 37.28 dB
Loadi

Loading test files: 100%|███████████████████████| 50/50 [00:45<00:00,  1.10it/s]



DATA LOADING SUMMARY (ICVL Protocol: 512×512×31)
✓ Successfully loaded: 50/50 files

Loaded 50 test images (512×512×31)
Adding Gaussian noise (σ=50) to 50 samples...
Noise protocol: MATCH TRAINING (scale [0,1] -> [0,255], add noise, clip, rescale)
✓ Successfully processed: 50/50 samples


Noise Analysis:
Noise types distribution: {'unknown': 50}
Actual noise levels - Mean: 0.0000, Std: 0.0000, Range: [0.0000, 0.0000]
Sample noise verification - Mean: 0.0278, Std: 0.1653, Max: 0.9747
Created 50 test samples with static noise σ=50

TESTING MODEL ON ALL SAMPLES


Testing samples: 100%|██████████████████████████| 50/50 [05:03<00:00,  6.06s/it]



TESTING COMPLETED - AVERAGE SCORES
Total samples tested: 50/50

Average Metrics:
  • Loss:            0.000108 ± 0.000038
  • Input PSNR:      15.75 ± 0.76 dB
  • Output PSNR:     39.93 ± 1.56 dB
  • PSNR Improvement: 24.17 ± 1.66 dB
  • SSIM:            0.9517 ± 0.0107
  • SAM:             0.0571 ± 0.0613

Performance Range:
  • Best PSNR:  44.55 dB (Labtest_0910-1504.mat)
  • Worst PSNR: 36.53 dB (nachal_0823-1149.mat)


Successfully tested 50 samples

SUMMARY STATISTICS:
Number of test samples: 50
Average Test Loss: 0.000108 ± 0.000038
Average PSNR:      39.928 ± 1.561 dB
Average SSIM:      0.9517 ± 0.0107
Average SAM:       0.0571 ± 0.0613
Best PSNR:  Labtest_0910-1504.mat (44.554 dB)
Worst PSNR: nachal_0823-1149.mat (36.532 dB)
Creating comprehensive visualizations...
Creating visualization for Labtest_0910-1504.mat
  Using generic RGB bands: [22, 13, 5]
  Saved visualization: sample_best_performance_Labtest_0910-1504.png
Creating visualization for objects_0924-1557.mat
  Using g