In [None]:
# === INR-based HSI Compression with SIREN (Patch-based) ===#

########## ===== Hyperspectral Image Compression Using Implicit Neural Representation =====#########
   ######### =========== By Shima Rezasoltani, Faisal Z. Qureshi =================##############
      ############# ============== Ontario Tech University ====================== ##########

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import scipy.io as sio
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error
from torch.quantization import quantize_dynamic
import math
import time
from scipy import linalg

try:
    from skimage.metrics import mean_squared_error as mse
except ImportError:
    from skimage.measure import compare_mse as mse

# Set device
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', DEVICE)

# === Load and Normalize PaviaU ===
data = sio.loadmat('/home/data/Fahad/HSI_datasets/PaviaU/PaviaU.mat')
cube = data['paviaU']  # shape: [610, 340, 103]
mean = cube.mean()
std = cube.std()
cube_norm = (cube - mean) / std  # Z-score normalization
original_cube = cube_norm.copy()

# === Patch Parameters ===
PATCH_SIZE = 64
OVERLAP = 0  # Optional overlap between patches
H, W, C = cube_norm.shape

# === Function to Create Patches ===
def create_patches(image, patch_size, overlap=0):
    patches = []
    patch_coords = []
    
    # Calculate step size
    step = patch_size - overlap
    
    # Pad the image if necessary
    pad_h = (patch_size - H % step) % step
    pad_w = (patch_size - W % step) % step
    
    padded_image = np.pad(image, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
    padded_H, padded_W = padded_image.shape[:2]
    
    # Create patches
    for i in range(0, padded_H - patch_size + 1, step):
        for j in range(0, padded_W - patch_size + 1, step):
            patch = padded_image[i:i+patch_size, j:j+patch_size]
            
            # Normalize patch coordinates to [-1, 1]
            x_coords = np.linspace(-1 + (2*i)/padded_H, -1 + (2*(i+patch_size))/padded_H, patch_size)
            y_coords = np.linspace(-1 + (2*j)/padded_W, -1 + (2*(j+patch_size))/padded_W, patch_size)
            
            grid_x, grid_y = np.meshgrid(x_coords, y_coords, indexing='ij')
            coords = np.stack([grid_x, grid_y], axis=-1).reshape(-1, 2)
            
            patches.append(patch.reshape(-1, C))
            patch_coords.append(coords)
    
    return patches, patch_coords, (pad_h, pad_w)

# Create patches
patches, patch_coords, padding = create_patches(cube_norm, PATCH_SIZE, OVERLAP)
print(f"Created {len(patches)} patches of size {PATCH_SIZE}x{PATCH_SIZE}")

In [None]:
# === Define Sine Activation ===
class Sine(nn.Module):
    def __init__(self, w0=30):
        super().__init__()
        self.w0 = w0

    def forward(self, x):
        return torch.sin(self.w0 * x)

# === SIREN Model ===
class SIREN(nn.Module):
    def __init__(self, in_dim=2, hidden_dim=512, out_dim=103, hidden_layers=4, w0=30):
        super().__init__()
        self.net = [nn.Linear(in_dim, hidden_dim), Sine(w0)]
        for _ in range(hidden_layers):
            self.net.append(nn.Linear(hidden_dim, hidden_dim))
            self.net.append(Sine(w0))
        self.net.append(nn.Linear(hidden_dim, out_dim))
        self.model = nn.Sequential(*self.net)

    def forward(self, x):
        return self.model(x)

# === Custom Weight Initialization ===
def init_weights(m, w0=30, is_first=False):
    if isinstance(m, nn.Linear):
        in_dim = m.weight.shape[1]
        if is_first:
            bound = 1 / in_dim
        else:
            bound = np.sqrt(6 / in_dim) / w0
        with torch.no_grad():
            m.weight.uniform_(-bound, bound)
            m.bias.fill_(0)
        

# === Instantiate & Initialize Model ===
model = SIREN(in_dim=2, hidden_dim=512, out_dim=C, hidden_layers=4).to(DEVICE).float()
for i, layer in enumerate(model.model):
    if isinstance(layer, nn.Linear):
        init_weights(layer, w0=30, is_first=(i == 0))

In [None]:
# === Training Setup ===
optimizer = optim.Adam(model.parameters(), lr=2e-4)
loss_fn = nn.MSELoss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, verbose=True)

# === Training Loop ===
epochs = 800
batch_size = 18384  # Batch size within each patch
best_loss = float('inf')

for epoch in range(epochs):
    epoch_loss = 0
    processed_patches = 0
    
    # Shuffle patch order
    patch_order = np.random.permutation(len(patches))
    
    for patch_idx in patch_order:
        # Get current patch data
        current_coords = torch.tensor(patch_coords[patch_idx], dtype=torch.float32, device=DEVICE)
        current_values = torch.tensor(patches[patch_idx], dtype=torch.float32, device=DEVICE)
        
        # Shuffle within patch
        permutation = torch.randperm(current_coords.shape[0])
        patch_loss = 0
        
        for i in range(0, current_coords.shape[0], batch_size):
            idx = permutation[i:i+batch_size]
            batch_coords = current_coords[idx]
            batch_values = current_values[idx]

            optimizer.zero_grad()
            preds = model(batch_coords)
            loss = loss_fn(preds, batch_values)
            loss.backward()
            optimizer.step()
            
            patch_loss += loss.item()
        
        avg_patch_loss = patch_loss / (current_coords.shape[0] // batch_size + 1)
        epoch_loss += avg_patch_loss
        processed_patches += 1
    
    avg_epoch_loss = epoch_loss / processed_patches
    scheduler.step(avg_epoch_loss)
    
    # Save best model
    if avg_epoch_loss < best_loss:
        best_loss = avg_epoch_loss
        torch.save(model.state_dict(), '/home/data/Fahad/models/best_INRpatch_model_weights4.pth')
        print(f"Saved new best model with loss: {best_loss:.6f}")
    
    print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_epoch_loss:.6f} | LR: {optimizer.param_groups[0]['lr']:.2e}")

In [None]:

# Load best model
model.load_state_dict(torch.load('/home/data/Fahad/models/best_INRpatch_model_weights4.pth'))
model.eval()


In [None]:
# === Model Size Calculation ===
def get_model_size(model, quantization_bits=32):
    """Calculate model size in bits using state_dict (works for quantized models)"""
    total_bits = 0
    for name, param in model.state_dict().items():
        if isinstance(param, torch.Tensor):
            total_bits += param.numel() * quantization_bits
    return total_bits


# Calculate sizes
original_size_bits = get_model_size(model, 32)
# quantized_size_bits = get_model_size(quantized_model, 8)

print(f"Original model size: {original_size_bits / 8 / 1024:.2f} KB")
# print(f"Quantized model size: {quantized_size_bits / 8 / 1024:.2f} KB")

In [None]:
# === Compression Ratio and Bitrate in bpppc ===
original_cube_size = H * W * C * 32  # Original uncompressed size in bits (float32)
compression_ratio_original = original_cube_size / original_size_bits
# compression_ratio_quantized = original_cube_size / quantized_size_bits

# Bitrate in bpppc (bits per pixel per channel)
bitrate_original = original_size_bits / (H * W * C)
# bitrate_quantized = quantized_size_bits / (H * W * C)

print(f"\nCompression Ratios:")
print(f"Original model: {compression_ratio_original:.2f}:1")
# print(f"Quantized model: {compression_ratio_quantized:.2f}:1")

print(f"\nBitrates (bpppc):")
print(f"Original model: {bitrate_original:.4f} bpppc")
# print(f"Quantized model: {bitrate_quantized:.6f} bpppc")

In [None]:
# === Reconstruction Function ===
def reconstruct_image(model, original_shape, patch_size, overlap, padding):
    # Initialize reconstructed image
    rec_image = np.zeros(original_shape)
    count = np.zeros(original_shape[:2])  # To handle overlapping regions
    
    padded_H = original_shape[0] + padding[0]
    padded_W = original_shape[1] + padding[1]
    step = patch_size - overlap
    
    with torch.no_grad():
        for i in range(0, padded_H - patch_size + 1, step):
            for j in range(0, padded_W - patch_size + 1, step):
                # Create coordinates for this patch
                x_coords = np.linspace(-1 + (2*i)/padded_H, -1 + (2*(i+patch_size))/padded_H, patch_size)
                y_coords = np.linspace(-1 + (2*j)/padded_W, -1 + (2*(j+patch_size))/padded_W, patch_size)
                
                grid_x, grid_y = np.meshgrid(x_coords, y_coords, indexing='ij')
                coords = np.stack([grid_x, grid_y], axis=-1).reshape(-1, 2)
                coords = torch.tensor(coords, dtype=torch.float32, device=DEVICE)
                
                # Predict patch values
                pred = model(coords).cpu().numpy().reshape(patch_size, patch_size, -1)
                
                # Place in reconstructed image (handling overlap with averaging)
                end_i = min(i+patch_size, original_shape[0])
                end_j = min(j+patch_size, original_shape[1])
                
                rec_image[i:end_i, j:end_j] += pred[:end_i-i, :end_j-j]
                count[i:end_i, j:end_j] += 1
    
    # Average overlapping regions
    rec_image /= count[..., np.newaxis]
    return rec_image

# Reconstruct with both models
print("\nReconstructing with original model...")
reconstructed_original = reconstruct_image(model, cube_norm.shape, PATCH_SIZE, OVERLAP, padding)

# print("Reconstructing with quantized model...")
# reconstructed_quantized = reconstruct_image(quantized_model, cube_norm.shape, PATCH_SIZE, OVERLAP, padding)

In [None]:
# For MS-SSIM
try:
    from skimage.metrics import structural_similarity as ms_ssim
except ImportError:
    # Fallback for older versions of skimage
    try:
        from skimage.measure import compare_msssim as ms_ssim
    except ImportError:
        ms_ssim = None

# === Image Quality Metrics ===
def calculate_psnr(original, reconstructed, data_range=1.0):
    """Calculate PSNR for each band and return average"""
    psnrs = []
    for c in range(original.shape[2]):
        psnrs.append(psnr(original[..., c], reconstructed[..., c], data_range=data_range))
    return np.mean(psnrs)


def calculate_ms_ssim(original, reconstructed, data_range=1.0):
    """Calculate MS-SSIM for each band and return average"""
    if ms_ssim is None:
        raise ImportError("MS-SSIM not available in your skimage version. Requires scikit-image >= 0.19")
    
    msssims = []
    for c in range(original.shape[2]):
        msssims.append(ms_ssim(original[..., c], reconstructed[..., c], data_range=data_range,
                             channel_axis=None, win_size=7))
    return np.mean(msssims)

def calculate_sam(original, reconstructed):
    """Calculate Spectral Angle Mapper (SAM)"""
    # Flatten spatial dimensions
    orig_flat = original.reshape(-1, original.shape[2])
    rec_flat = reconstructed.reshape(-1, reconstructed.shape[2])
    
    # Calculate dot product and magnitudes
    dot_product = np.sum(orig_flat * rec_flat, axis=1)
    orig_mag = np.sqrt(np.sum(orig_flat**2, axis=1))
    rec_mag = np.sqrt(np.sum(rec_flat**2, axis=1))
    
    # Avoid division by zero
    mask = (orig_mag * rec_mag) > 0
    cos_theta = np.zeros_like(dot_product)
    cos_theta[mask] = dot_product[mask] / (orig_mag[mask] * rec_mag[mask])
    
    # Clamp to avoid numerical errors
    cos_theta = np.clip(cos_theta, -1.0 + 1e-10, 1.0 - 1e-10)
    angles = np.arccos(cos_theta)
    
    # Convert to degrees and return mean
    return np.mean(np.rad2deg(angles))

In [None]:
# Denormalize
reconstructed_original = reconstructed_original * std + mean
# reconstructed_quantized = reconstructed_quantized * std + mean

# === Calculate Metrics ===
print("\n=== Quality Metrics for Original Model ===")
psnr_val = calculate_psnr(cube, reconstructed_original, data_range=cube.max()-cube.min())
sam_val = calculate_sam(cube, reconstructed_original)
ms_ssim_val = calculate_ms_ssim(cube, reconstructed_original)
print(f"PSNR: {psnr_val:.2f} dB")
print(f"MS_SSIM: {ms_ssim_val:.4f}")
print(f"SAM: {sam_val:.4f} degrees")

# print("\n=== Quality Metrics for Quantized Model ===")
# psnr_val_q = calculate_psnr(cube, reconstructed_quantized, data_range=cube.max()-cube.min())
# ssim_val_q = calculate_ssim(cube, reconstructed_quantized, data_range=cube.max()-cube.min())
# sam_val_q = calculate_sam(cube, reconstructed_quantized)
# print(f"PSNR: {psnr_val_q:.2f} dB")
# print(f"SSIM: {ssim_val_q:.4f}")
# print(f"SAM: {sam_val_q:.2f} degrees")