In [1]:
#| default_exp SSCNetStaticQuantization

# SSCNet Static Quantization

> - In this we implement the quantization method from fasterai.
> - The documentation are available here https://github.com/nathanhubens/fasterai.git

In [2]:
#| hide
from nbdev.showdoc import *

### Exporting the necessary libraries.

Pre-check installation of the necessary libraries

In [3]:
#| eval: false

"""
# Pre-installation script for required libraries

import subprocess
import sys

# List of required libraries
required_libraries = [
    "os", "sys", "torch", "time", "numpy", "pandas", "fastai", "pathlib"
]

# Function to check and install missing libraries
def check_and_install_libraries(libraries):
    for lib in libraries:
        try:
            # Check if the library can be imported
            __import__(lib)
        except ImportError:
            # Special case for libraries with different pip names
            lib_pip = lib
            if lib == "torch":
                lib_pip = "torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu"
            elif lib == "fastai":
                lib_pip = "fastai"

            print(f"{lib} not found. Installing...")
            try:
                subprocess.check_call(
                    [sys.executable, "-m", "pip", "install", lib_pip]
                )
                print(f"{lib} installed successfully!")
            except subprocess.CalledProcessError:
                print(f"Failed to install {lib}. Please install it manually.")

if __name__ == "__main__":
    check_and_install_libraries(required_libraries)
"""


'\n# Pre-installation script for required libraries\n\nimport subprocess\nimport sys\n\n# List of required libraries\nrequired_libraries = [\n    "os", "sys", "torch", "time", "numpy", "pandas", "fastai", "pathlib"\n]\n\n# Function to check and install missing libraries\ndef check_and_install_libraries(libraries):\n    for lib in libraries:\n        try:\n            # Check if the library can be imported\n            __import__(lib)\n        except ImportError:\n            # Special case for libraries with different pip names\n            lib_pip = lib\n            if lib == "torch":\n                lib_pip = "torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu"\n            elif lib == "fastai":\n                lib_pip = "fastai"\n\n            print(f"{lib} not found. Installing...")\n            try:\n                subprocess.check_call(\n                    [sys.executable, "-m", "pip", "install", lib_pip]\n                )\n                print(f"

In [4]:
#| eval: false
# Required imports
import os
import sys
import torch
import time
import numpy as np
import pandas as pd
from torch import nn
from fastai.vision.all import DataLoader, DataLoaders
from torch.utils.data import Dataset, DataLoader as TorchDataLoader
from pathlib import Path
#from fasterai.quantizer import Quantizer
#from fasterai.quantize_callback import QuantizeCallback


In [5]:
#| eval: false
# Adjust paths for imports
sys.path.append('/root/HSI_HypSpecNet11k/hsi-compression/')
from quantizer import Quantizer
from quantize_callback import QuantizeCallback
sys.path.append('/root/HSI_HypSpecNet11k/hsi-compression/models/')
from sscnet import SpectralSignalsCompressorNetwork

We have pre-trained weights, so we are using that in place of pre-trained model

In [6]:
#| eval: false
# Utility function to load pretrained weights
def load_pretrained_weights(model, pretrained_weights_path):
    print(f"Loading pretrained weights from {pretrained_weights_path}...")
    checkpoint = torch.load(pretrained_weights_path)
    
    if "state_dict" in checkpoint:
        state_dict = checkpoint["state_dict"]
    elif isinstance(checkpoint, dict):
        state_dict = checkpoint
    else:
        raise ValueError("Unsupported checkpoint format.")

    model.load_state_dict(state_dict, strict=False)
    print("Pretrained weights loaded successfully.")


Preparing the dataloaders

In [7]:

#| eval: false
# Base directory for `.npy` files
base_directory = '/root/HSI_HypSpecNet11k/hsi-compression/datasets/hyspecnet-11k/patches/'

# Utility to load paths from a CSV file
def load_paths(csv_file):
    df = pd.read_csv(csv_file, header=None)
    file_paths = [os.path.join(base_directory, x.strip()) for x in df[0]]
    print("Paths loaded successfully.")
    return file_paths


In [8]:

#| eval: false
# Dataset class for `.npy` files
class NPYDataset(Dataset):

    def __init__(self, file_paths, transform=None):
        self.file_paths = file_paths
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        sample = np.load(file_path)
        if self.transform:
            sample = self.transform(sample)
        sample = torch.from_numpy(sample).float()
        return sample, sample

#| eval: false
# Function to standardize samples
def transform_sample(sample):
    return (sample - np.mean(sample)) / np.std(sample)

#| eval: false
# Function to create DataLoaders
def create_dataloaders(csv_file_path, batch_size=4, transform=None):
    file_paths = load_paths(csv_file_path)
    dataset = NPYDataset(file_paths, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return DataLoaders(dataloader, dataloader)

Quantization function

In [9]:
#| eval: false
def quantization_pipeline_with_npy(
    model, pretrained_weights_path, csv_file_path, backend="x86", batch_size=4, epochs=2, lr=1e-3, save_path=None
):
    def evaluate_model(model, test_dl, device="cpu"):
        print("Evaluating the model...")
        model.to(device).eval()
        criterion = torch.nn.MSELoss()
        total_loss = 0.0
        with torch.no_grad():
            for xb, _ in tqdm(test_dl, desc="Evaluating Batches", leave=True):
                xb = xb.to(device)
                preds = model(xb)
                loss = criterion(preds, xb)
                total_loss += loss.item()
        avg_loss = total_loss / len(test_dl)
        print(f"Evaluation complete. Average Loss: {avg_loss:.6f}")
        return {"loss": avg_loss}

    print(f"Loading pretrained weights from {pretrained_weights_path}...")
    model.load_state_dict(torch.load(pretrained_weights_path).get("state_dict", torch.load(pretrained_weights_path)), strict=False)

    model.eval()
    print("Model loaded successfully.")

    print(f"Creating DataLoaders using CSV file: {csv_file_path}")
    dls = create_dataloaders(csv_file_path, batch_size=batch_size, transform=transform_sample)

    print("Evaluating the non-quantized model...")
    non_quantized_metrics = evaluate_model(model, dls.valid)

    print("Setting up FastAI Learner with QuantizeCallback...")
    learn = Learner(
        dls,
        model,
        loss_func=torch.nn.MSELoss(),
        cbs=QuantizeCallback(backend=backend),
    )

    print("Callbacks added to the learner:")
    print(learn.cbs)

    print("Starting quantization-aware training...")
    learn.fit_one_cycle(epochs, lr)

    quantized_model = learn.model

    print("\nInspecting quantized model weights...")
    for name, param in quantized_model.named_parameters():
        print(f"Layer: {name}, Data Type: {param.dtype}")

    if save_path:
        print("Saving the quantized model...")
        torch.save(quantized_model, save_path)
        print(f"Quantized model saved to {save_path}")
    else:
        print("Save path not provided; quantized model will not be saved.")

    print("Evaluating the quantized model...")
    quantized_metrics = evaluate_model(quantized_model, dls.valid)

    print("\nQuantization pipeline completed.")
    print(f"Non-Quantized Model Loss: {non_quantized_metrics['loss']:.6f}")
    print(f"Quantized Model Loss: {quantized_metrics['loss']:.6f}")

    return quantized_model, non_quantized_metrics, quantized_metrics, dls


Evaluating KPIs for measuring the performace of the model

In [10]:

#| eval: false
# Performance measurement functions
def measure_inference_time(model, dataloader, device="cpu"):
    """Measure inference time for a model."""
    model.to(device)
    model.eval()
    start = time.time()
    with torch.no_grad():
        for xb, _ in dataloader:
            xb = xb.to(device)
            _ = model(xb)
    end = time.time()
    return end - start

In [11]:

#| eval: false
def measure_vram_usage(model, dataloader, device="cuda"):
    """Simpler VRAM measurement."""
    try:
        model.to(device)
        torch.cuda.reset_peak_memory_stats(device)
        with torch.no_grad():
            for xb, _ in dataloader:
                xb = xb.to(device)
                _ = model(xb)
        vram_peak = torch.cuda.max_memory_allocated(device) / 1e6  # Convert to MB
    except RuntimeError:
        print("VRAM measurement failed. Skipping.")
        vram_peak = -1  # Indicate failure
    return vram_peak


Creating the comparision table

In [12]:
#| eval: false
def generate_comparison_table(
    model, quantized_model, non_quantized_metrics, quantized_metrics, test_dataloader, 
    pretrained_weights_path, quantized_weights_path, device="cpu"
):

    data = []

    # Measure model sizes
    torch.save(model.state_dict(), pretrained_weights)
    torch.save(quantized_model.state_dict(), quantized_weights_path)
    model_size = os.path.getsize(pretrained_weights) / 1e6  # Convert to MB
    quantized_size = os.path.getsize(quantized_weights_path) / 1e6

    # Measure execution speed
    print("Measuring execution speed...")
    non_quantized_speed = measure_inference_time(model, test_dataloader, device)
    quantized_speed = measure_inference_time(quantized_model, test_dataloader, device)

    # Measure VRAM usage
    print("Measuring VRAM usage...")
    non_quantized_vram = measure_vram_usage(model, test_dataloader, device)
    quantized_vram = measure_vram_usage(quantized_model, test_dataloader, device)

    # Collect data
    data.append(["Model Size (MB)", model_size, quantized_size])
    data.append(["Average Loss", non_quantized_metrics["loss"], quantized_metrics["loss"]])
    data.append(["Execution Speed (s)", non_quantized_speed, quantized_speed])
    data.append(["VRAM Usage (MB)", non_quantized_vram, quantized_vram])

    # Generate DataFrame
    df = pd.DataFrame(data, columns=["Metric", "Non-Quantized Model", "Quantized Model"])
    return df


In [13]:
#| eval: false
from tqdm import tqdm  # Import tqdm for progress bars
from sscnet import SpectralSignalsCompressorNetwork  # Import SSCNet model
from quantize_callback import QuantizeCallback  # Import QuantizeCallback
from fastai.learner import Learner  # Import Learner from FastAI



In [14]:
#| eval: false
pretrained_weights = "/root/HSI_HypSpecNet11k/hsi-compression/results/weights/sscnet_2point5bpppc.pth.tar"
csv_file_path = "/root/HSI_HypSpecNet11k/hsi-compression/datasets/hyspecnet-11k/splits/easy/test.csv"
quantized_weights_path = "/root/HSI_HypSpecNet11k/hsi-compression/compressed_model/quantized_sscnet.pth"


In [15]:
#| eval: false
# Initialize the model
model = SpectralSignalsCompressorNetwork()


In [16]:
#| eval: false
# Step 1: Initialize and start the quantization pipeline
print("Running the quantization pipeline...")

# Create a progress bar
with tqdm(total=100, desc="Quantization Pipeline Progress", leave=True) as pbar:
    try:
        # Run the quantization pipeline
        quantized_model, non_quantized_metrics, quantized_metrics, dls = quantization_pipeline_with_npy(
            model=model,
            pretrained_weights_path=pretrained_weights,
            csv_file_path=csv_file_path,
            backend="x86",
            batch_size=4,
            epochs=2,
            lr=1e-3,
        )

        # Update progress bar to reflect pipeline progress (e.g., 70% complete after pipeline)
        pbar.update(70)

        # Step 2: Print the metrics
        # Clear GPU memory to prevent memory leaks
        torch.cuda.empty_cache()
        print("\nPipeline completed. Metrics:")
        print(f"Non-Quantized Model Metrics: Loss = {non_quantized_metrics['loss']:.6f}")
        print(f"Quantized Model Metrics: Loss = {quantized_metrics['loss']:.6f}")

        # Inspect quantized model weights
        print("\nInspecting quantized model weights...")
        for name, param in quantized_model.named_parameters():
            print(f"Layer: {name}, Data Type: {param.dtype}")

        # Update progress bar to completion
        pbar.update(30)

    except Exception as e:
        print(f"Error during quantization pipeline: {e}")
        # Close progress bar to avoid hanging display
        pbar.close()

Running the quantization pipeline...


  model.load_state_dict(torch.load(pretrained_weights_path).get("state_dict", torch.load(pretrained_weights_path)), strict=False)


Loading pretrained weights from /root/HSI_HypSpecNet11k/hsi-compression/results/weights/sscnet_2point5bpppc.pth.tar...
Model loaded successfully.
Creating DataLoaders using CSV file: /root/HSI_HypSpecNet11k/hsi-compression/datasets/hyspecnet-11k/splits/easy/test.csv
Paths loaded successfully.
Evaluating the non-quantized model...
Evaluating the model...



Evaluating Batches:   0%|                                                                                               | 0/61 [00:00<?, ?it/s][A
Evaluating Batches:   2%|█▍                                                                                     | 1/61 [00:00<00:57,  1.04it/s][A
Evaluating Batches:   3%|██▊                                                                                    | 2/61 [00:01<00:56,  1.05it/s][A
Evaluating Batches:   5%|████▎                                                                                  | 3/61 [00:02<00:54,  1.07it/s][A
Evaluating Batches:   7%|█████▋                                                                                 | 4/61 [00:03<00:53,  1.07it/s][A
Evaluating Batches:   8%|███████▏                                                                               | 5/61 [00:04<00:51,  1.09it/s][A
Evaluating Batches:  10%|████████▌                                                                              | 6/6

Evaluation complete. Average Loss: 0.906979
Setting up FastAI Learner with QuantizeCallback...
Callbacks added to the learner:
[TrainEvalCallback, Recorder, CastToTensor, ProgressCallback, QuantizeCallback]
Starting quantization-aware training...


epoch,train_loss,valid_loss,time
0,0.786592,0.767433,04:18
1,0.77029,0.754018,04:18



Inspecting quantized model weights...
Save path not provided; quantized model will not be saved.
Evaluating the quantized model...
Evaluating the model...



Evaluating Batches:   0%|                                                                                               | 0/61 [00:00<?, ?it/s][A
Evaluating Batches:   2%|█▍                                                                                     | 1/61 [00:00<00:30,  1.99it/s][A
Evaluating Batches:   3%|██▊                                                                                    | 2/61 [00:01<00:29,  1.97it/s][A
Evaluating Batches:   5%|████▎                                                                                  | 3/61 [00:01<00:29,  1.96it/s][A
Evaluating Batches:   7%|█████▋                                                                                 | 4/61 [00:02<00:29,  1.95it/s][A
Evaluating Batches:   8%|███████▏                                                                               | 5/61 [00:02<00:28,  1.95it/s][A
Evaluating Batches:  10%|████████▌                                                                              | 6/6

Evaluation complete. Average Loss: 0.757528

Quantization pipeline completed.
Non-Quantized Model Loss: 0.906979
Quantized Model Loss: 0.757528

Pipeline completed. Metrics:
Non-Quantized Model Metrics: Loss = 0.906979
Quantized Model Metrics: Loss = 0.757528

Inspecting quantized model weights...





In [17]:

    print("\nQuantized Model Architecture:")
    print(quantized_model)


Quantized Model Architecture:
GraphModule(
  (encoder): Module(
    (0): QuantizedConv2d(202, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.856876015663147, zero_point=67, padding=(1, 1))
    (1): QuantizedPReLU()
    (2): QuantizedConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=2.5710415840148926, zero_point=65, padding=(1, 1))
    (3): QuantizedPReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): QuantizedConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=10.29415225982666, zero_point=64, padding=(1, 1))
    (6): QuantizedPReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): QuantizedConv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), scale=46.11467361450195, zero_point=64, padding=(1, 1))
    (9): QuantizedPReLU()
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): QuantizedConv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), s

In [18]:
import torch
from metrics.psnr import PeakSignalToNoiseRatio
from tqdm import tqdm

# Path to the original model's weights
original_model_path = "/root/HSI_HypSpecNet11k/hsi-compression/results/weights/sscnet_2point5bpppc.pth.tar"

@torch.no_grad()
def calculate_psnr(model, dataloader, device="cpu"):
    psnr_metric = PeakSignalToNoiseRatio()
    model.to(device).eval()

    total_psnr = 0.0
    num_samples = 0

    for xb, _ in tqdm(dataloader, desc="Calculating PSNR", leave=True):
        xb = xb.to(device)
        preds = model(xb)
        psnr = psnr_metric(preds, xb)
        total_psnr += psnr.item() * xb.size(0)*100
        num_samples += xb.size(0)

    avg_psnr = total_psnr / num_samples
    return avg_psnr

# Initialize the original model
original_model = SpectralSignalsCompressorNetwork()

# Load the original model weights
print(f"Loading original model weights from {original_model_path}...")
original_model.load_state_dict(
    torch.load(original_model_path).get("state_dict", torch.load(original_model_path)),
    strict=False
)

print("Original model loaded successfully.")

# Calculate PSNR for the original model
psnr_original = calculate_psnr(original_model, dls.valid, device="cpu")

# Calculate PSNR for the quantized model (directly, no saving or loading)
quantized_model.eval()  # Ensure the quantized model is in evaluation mode
psnr_quantized = calculate_psnr(quantized_model, dls.valid, device="cpu")

# Print PSNR results
print(f"PSNR for Original Model: {psnr_original:.2f} dB")
print(f"PSNR for Quantized Model: {psnr_quantized:.2f} dB")


  torch.load(original_model_path).get("state_dict", torch.load(original_model_path)),


Loading original model weights from /root/HSI_HypSpecNet11k/hsi-compression/results/weights/sscnet_2point5bpppc.pth.tar...
Original model loaded successfully.


Calculating PSNR: 100%|████████████████████████████████████████████████████████████████████████████████████████| 61/61 [00:57<00:00,  1.06it/s]
Calculating PSNR: 100%|████████████████████████████████████████████████████████████████████████████████████████| 61/61 [00:31<00:00,  1.95it/s]

PSNR for Original Model: 43.14 dB
PSNR for Quantized Model: 121.24 dB





In [20]:
import torch

#model_path = quantized_weights_path  # Replace with your file path
loaded_object = torch.load(quantized_weights_path, map_location="cpu", weights_only=True)
print(type(loaded_object))  # Prints the class of the loaded object


<class 'collections.OrderedDict'>


In [21]:

from metrics.ssim import StructuralSimilarity
from tqdm import tqdm

# Path to the original model's weights
original_model_path = "/root/HSI_HypSpecNet11k/hsi-compression/results/weights/sscnet_2point5bpppc.pth.tar"

@torch.no_grad()
def calculate_ssim(model, dataloader, device="cpu"):
    ssim_metric = StructuralSimilarity(data_range=1.0, channels=202)
    model.to(device).eval()

    total_ssim = 0.0
    num_samples = 0

    for xb, _ in tqdm(dataloader, desc="Calculating SSIM", leave=True):
        xb = xb.to(device)
        preds = model(xb)
        ssim = ssim_metric(preds, xb)
        total_ssim += ssim.item() * xb.size(0)
        num_samples += xb.size(0)

    avg_ssim = total_ssim / num_samples
    return avg_ssim

# Initialize the original model
original_model = SpectralSignalsCompressorNetwork()

# Load the original model weights
print(f"Loading original model weights from {original_model_path}...")
original_model.load_state_dict(torch.load(original_model_path)["state_dict"], strict=False)
print("Original model loaded successfully.")

# Calculate SSIM for the original model
ssim_original = calculate_ssim(original_model, dls.valid, device="cpu")

# Calculate SSIM for the quantized model (directly, no saving or loading)
quantized_model.eval()
ssim_quantized = calculate_ssim(quantized_model, dls.valid, device="cpu")

# Print SSIM results
print(f"SSIM for Original Model: {ssim_original:.4f}")
print(f"SSIM for Quantized Model: {ssim_quantized:.4f}")


  original_model.load_state_dict(torch.load(original_model_path)["state_dict"], strict=False)


Loading original model weights from /root/HSI_HypSpecNet11k/hsi-compression/results/weights/sscnet_2point5bpppc.pth.tar...
Original model loaded successfully.


Calculating SSIM: 100%|████████████████████████████████████████████████████████████████████████████████████████| 61/61 [01:23<00:00,  1.36s/it]
Calculating SSIM: 100%|████████████████████████████████████████████████████████████████████████████████████████| 61/61 [00:51<00:00,  1.18it/s]

SSIM for Original Model: 0.0146
SSIM for Quantized Model: 0.0138





In [22]:
import torch
from torch.fx import GraphModule
from sscnet import SpectralSignalsCompressorNetwork  # Assuming this is your model class
from metrics.sa import SpectralAngle
from tqdm import tqdm

quantized_model_path = "/root/HSI_HypSpecNet11k/hsi-compression/compressed_model/static_quant_fastrai_model.pth"

@torch.no_grad()
def calculate_spectral_angle(model, dataloader, device="cpu"):
    sa_metric = SpectralAngle()
    model.to(device).eval()

    total_sa = 0.0
    num_samples = 0

    for xb, _ in tqdm(dataloader, desc="Calculating Spectral Angle", leave=True):
        xb = xb.to(device)
        preds = model(xb)
        sa = sa_metric(preds, xb)
        total_sa += sa.item() * xb.size(0)
        num_samples += xb.size(0)

    avg_sa = total_sa / num_samples
    return avg_sa

# Load quantized model
print(f"Loading quantized model weights from {quantized_model_path}...")
quantized_model = torch.load(quantized_model_path, map_location="cpu")

# Check the type of the quantized model and handle accordingly
if isinstance(quantized_model, GraphModule):
    print("Quantized model is a GraphModule. Using it directly.")
elif isinstance(quantized_model, dict) and "state_dict" in quantized_model:
    print("Quantized model contains state_dict. Loading weights into model instance...")
    quantized_model_instance = SpectralSignalsCompressorNetwork()
    quantized_model_instance.load_state_dict(quantized_model["state_dict"])
    quantized_model = quantized_model_instance
else:
    raise ValueError("Unexpected quantized model format. Inspect the file.")

quantized_model.eval()

# Calculate Spectral Angle for the quantized model
sa_quantized = calculate_spectral_angle(quantized_model, dls.valid, device="cpu")
print(f"Spectral Angle for Quantized Model: {sa_quantized:.2f} degrees")


Loading quantized model weights from /root/HSI_HypSpecNet11k/hsi-compression/compressed_model/static_quant_fastrai_model.pth...


  quantized_model = torch.load(quantized_model_path, map_location="cpu")


AttributeError: 'Conv2d' object has no attribute '_modules'

In [23]:
import torch
from torch.fx import GraphModule
from metrics.sa import SpectralAngle
from tqdm import tqdm

quantized_model_path = "/root/HSI_HypSpecNet11k/hsi-compression/compressed_model/static_quant_fastrai_model.pth"

@torch.no_grad()
def calculate_spectral_angle(model, dataloader, device="cpu"):
    sa_metric = SpectralAngle()
    model.to(device).eval()

    total_sa = 0.0
    num_samples = 0

    for xb, _ in tqdm(dataloader, desc="Calculating Spectral Angle", leave=True):
        xb = xb.to(device)
        preds = model(xb)
        sa = sa_metric(preds, xb)
        total_sa += sa.item() * xb.size(0)
        num_samples += xb.size(0)

    avg_sa = total_sa / num_samples
    return avg_sa

# Load quantized model
print(f"Loading quantized model weights from {quantized_model_path}...")
quantized_model = torch.load(quantized_model_path, map_location="cpu", weights_only=True)


# Check if it's a GraphModule
if isinstance(quantized_model, GraphModule):
    print("Quantized model is a GraphModule. Using it directly.")
elif isinstance(quantized_model, dict) and "state_dict" in quantized_model:
    print("Quantized model contains state_dict. Loading weights...")
    quantized_model_instance = SpectralSignalsCompressorNetwork()
    quantized_model_instance.load_state_dict(quantized_model["state_dict"])
    quantized_model = quantized_model_instance
else:
    raise ValueError("Unexpected quantized model format. Inspect the file.")

quantized_model.eval()

# Calculate Spectral Angle for the quantized model
sa_quantized = calculate_spectral_angle(quantized_model, dls.valid, device="cpu")
print(f"Spectral Angle for Quantized Model: {sa_quantized:.2f} degrees")


Loading quantized model weights from /root/HSI_HypSpecNet11k/hsi-compression/compressed_model/static_quant_fastrai_model.pth...


UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL torch.fx.graph_module.reduce_graph_module was not an allowed global by default. Please use `torch.serialization.add_safe_globals([reduce_graph_module])` to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

In [36]:
#|eval:false
# Test model size measurement
model_size = os.path.getsize(pretrained_weights) / 1e6
quantized_size = os.path.getsize(quantized_weights_path) / 1e6
print(f"Model Size: {model_size} MB, Quantized Size: {quantized_size} MB")
# Calculate compression percentage
compression_percentage = ((model_size - quantized_size) / model_size) * 100
print(f"Compression Percentage: {compression_percentage:.2f}%")



Model Size: 55.141518 MB, Quantized Size: 13.844594 MB
Compression Percentage: 74.89%


In [37]:
#|eval:false
try:
    print(f"Validation DataLoader size: {len(dls.valid)}")
except NameError:
    print("DataLoaders not found. Recreating...")
    dls = create_dataloaders(csv_file_path, batch_size=4, transform=transform_sample)

Validation DataLoader size: 61


In [38]:
#|eval:false
# Test execution speed
speed = measure_inference_time(model, dls.valid)
print(f"Inference Time: {speed:.2f}s")


Inference Time: 56.93s


In [39]:
# Test VRAM usage
# Non-Quantized Model
non_quantized_vram_usage = measure_vram_usage(model, dls.valid)
print(f"Non-Quantized Model VRAM Usage: {non_quantized_vram_usage:.2f} MB")


Non-Quantized Model VRAM Usage: 410.94 MB


In [40]:
import torch
from tqdm import tqdm

def measure_vram_usage_on_cpu(model, dataloader):
    """Measure VRAM usage on CPU for the quantized model."""
    torch.cuda.reset_peak_memory_stats()
    model.eval()  # Ensure the model is in evaluation mode

    # Simulate inference to check memory usage
    for xb, _ in tqdm(dataloader, desc="Measuring VRAM Usage for Quantized Model on CPU", leave=True):
        with torch.no_grad():
            preds = model(xb)  # Perform inference

    # Since we're on CPU, use memory stats for CPU usage measurement
    vram_usage = torch.cuda.max_memory_allocated() / 1e6  # Convert bytes to MB
    return vram_usage

# Measure VRAM usage
quantized_vram_usage = measure_vram_usage_on_cpu(quantized_model, dls.valid)
print(f"Quantized Model VRAM Usage: {quantized_vram_usage:.2f} MB")


Measuring VRAM Usage for Quantized Model on CPU: 100%|█████████████████████████████████████████████████████████| 61/61 [00:56<00:00,  1.07it/s]

Quantized Model VRAM Usage: 111.26 MB





In [51]:
import torch
from collections import defaultdict
from metrics.psnr import PeakSignalToNoiseRatio
from metrics.ssim import StructuralSimilarity
from metrics.sa import SpectralAngle
from tqdm import tqdm
from models.sscnet import SpectralSignalsCompressorNetwork

# Paths
original_model_path = "/root/HSI_HypSpecNet11k/hsi-compression/results/weights/sscnet_2point5bpppc.pth.tar"
batch_size = 4
device = "cpu"  # Change to "cuda" if your environment allows

# Function to Calculate PSNR from MSE
def calculate_psnr_from_mse(mse, max_pixel_value=1.0):
    psnr = 20 * torch.log10(torch.tensor(max_pixel_value)) - 10 * torch.log10(torch.tensor(mse))
    return psnr.item()

# Metric Inference Functions
def inference_ssim(model, batch):
    ssim_metric = StructuralSimilarity()
    preds = model(batch)
    preds = preds.clamp(0, 1)  # Clamp predictions to valid range
    ssim = ssim_metric(preds, batch).item()
    return {"SSIM": ssim}

def inference_sa(model, batch):
    sa_metric = SpectralAngle()
    preds = model(batch)
    epsilon = 1e-8
    sa = sa_metric(preds + epsilon, batch + epsilon).item()
    return {"SA": sa}

# Initialize Models
original_model = SpectralSignalsCompressorNetwork()
print(f"Loading original model weights from {original_model_path}...")
original_model.load_state_dict(torch.load(original_model_path)["state_dict"], strict=False)
original_model.to(device).eval()
print("Original model loaded successfully.")

# Simulate quantized model (already in memory)
quantized_model.eval()

# Use Actual Loss Values
mse_original = 0.906979
mse_quantized = 0.740657

# Calculate Metrics
print(f"Calculating PSNR using Actual MSE Loss...")
psnr_original = calculate_psnr_from_mse(mse_original)
psnr_quantized = calculate_psnr_from_mse(mse_quantized)
'''
# Evaluate SSIM and SA for Original and Quantized Models
print("Calculating Metrics for Original Model...")
metrics_original_ssim = eval_model_batch_wise(original_model, dls.valid, inference_ssim)
metrics_original_sa = eval_model_batch_wise(original_model, dls.valid, inference_sa)

print("Calculating Metrics for Quantized Model...")
metrics_quantized_ssim = eval_model_batch_wise(quantized_model, dls.valid, inference_ssim)
metrics_quantized_sa = eval_model_batch_wise(quantized_model, dls.valid, inference_sa)
'''
# Print Results
print(f"PSNR for Original Model: {psnr_original:.2f} dB")
print(f"PSNR for Quantized Model: {psnr_quantized:.2f} dB")
'''
print(f"SSIM for Original Model: {metrics_original_ssim['SSIM']:.4f}")
print(f"SSIM for Quantized Model: {metrics_quantized_ssim['SSIM']:.4f}")

print(f"Spectral Angle for Original Model: {metrics_original_sa['SA']:.2f} degrees")
print(f"Spectral Angle for Quantized Model: {metrics_quantized_sa['SA']:.2f} degrees")
'''

Loading original model weights from /root/HSI_HypSpecNet11k/hsi-compression/results/weights/sscnet_2point5bpppc.pth.tar...
Original model loaded successfully.
Calculating PSNR using Actual MSE Loss...
PSNR for Original Model: 0.42 dB
PSNR for Quantized Model: 1.30 dB


  original_model.load_state_dict(torch.load(original_model_path)["state_dict"], strict=False)


'\nprint(f"SSIM for Original Model: {metrics_original_ssim[\'SSIM\']:.4f}")\nprint(f"SSIM for Quantized Model: {metrics_quantized_ssim[\'SSIM\']:.4f}")\n\nprint(f"Spectral Angle for Original Model: {metrics_original_sa[\'SA\']:.2f} degrees")\nprint(f"Spectral Angle for Quantized Model: {metrics_quantized_sa[\'SA\']:.2f} degrees")\n'

In [63]:
#| eval:false
#| export
import json
import os

def analyze_results(json_file="/root/HSI_HypSpecNet11k/hsi-compression/results/tests/weights.json"):
    """
    Analyzes the compression results from a given JSON file and prints key metrics.

    Args:
        json_file (str): Path to the JSON file containing the compression results.
    """
    if not os.path.exists(json_file):
        print(f"File not found: {json_file}")
        return

    with open(json_file, 'r') as f:
        data = json.load(f)

    # Extract values
    name = data.get("name", "N/A")
    description = data.get("description", "No description")
    bpppc = data["results"].get("bpppc", [None])[0]
    psnr = data["results"].get("psnr", [None])[0]
    ssim = data["results"].get("ssim", [None])[0]
    sa = data["results"].get("sa", [None])[0]
    encoding_time = data["results"].get("encoding_time", [None])[0]
    decoding_time = data["results"].get("decoding_time", [None])[0]

    # Print extracted values
    print(f"Name: {name}")
    print(f"Description: {description}")
    print(f"Bits Per Pixel Per Channel (bpppc): {bpppc:.2f}")
    print(f"PSNR: {psnr:.2f} dB")
    print(f"SSIM: {ssim:.4f}")
    if sa is not None:
        print(f"Spectral Angle (SA): {sa:.2f}")
    else:
        print("Spectral Angle (SA): Not available")

    print(f"Encoding Time: {encoding_time:.4f} seconds")
    print(f"Decoding Time: {decoding_time:.4f} seconds")

    # Inference Analysis
    if psnr > 40 and ssim > 0.95:
        print("Inference: The compression maintains high image quality.")
    else:
        print("Inference: The compression may have degraded image quality.")

    if encoding_time < 0.01:
        print("Inference: The encoding process is fast.")
    else:
        print("Inference: The encoding process is relatively slow.")

    if decoding_time < 0.01:
        print("Inference: The decoding process is fast.")
    else:
        print("Inference: The decoding process is relatively slow.")


#| eval:false
analyze_results()

Name: sscnet
Description: Test
Bits Per Pixel Per Channel (bpppc): 2.53
PSNR: 43.37 dB
SSIM: 0.9748
Spectral Angle (SA): 1.84
Encoding Time: 0.0025 seconds
Decoding Time: 0.0015 seconds
Inference: The compression maintains high image quality.
Inference: The encoding process is fast.
Inference: The decoding process is fast.


In [64]:
#| eval:false
#| export
import json
import os

def analyze_results(json_file="/root/HSI_HypSpecNet11k/hsi-compression/results/tests/compressed_model.json"):
    """
    Analyzes the compression results from a given JSON file and prints key metrics.

    Args:
        json_file (str): Path to the JSON file containing the compression results.
    """
    if not os.path.exists(json_file):
        print(f"File not found: {json_file}")
        return

    with open(json_file, 'r') as f:
        data = json.load(f)

    # Extract values
    name = data.get("name", "N/A")
    description = data.get("description", "No description")
    bpppc = data["results"].get("bpppc", [None])[0]
    psnr = data["results"].get("psnr", [None])[0]
    ssim = data["results"].get("ssim", [None])[0]
    sa = data["results"].get("sa", [None])[0]
    encoding_time = data["results"].get("encoding_time", [None])[0]
    decoding_time = data["results"].get("decoding_time", [None])[0]

    # Print extracted values
    print(f"Name: {name}")
    print(f"Description: {description}")
    print(f"Bits Per Pixel Per Channel (bpppc): {bpppc:.2f}")
    print(f"PSNR: {psnr:.2f} dB")
    print(f"SSIM: {ssim:.4f}")
    if sa is not None:
        print(f"Spectral Angle (SA): {sa:.2f}")
    else:
        print("Spectral Angle (SA): Not available")

    print(f"Encoding Time: {encoding_time:.4f} seconds")
    print(f"Decoding Time: {decoding_time:.4f} seconds")

    # Inference Analysis
    if psnr > 40 and ssim > 0.95:
        print("Inference: The compression maintains high image quality.")
    else:
        print("Inference: The compression may have degraded image quality.")

    if encoding_time < 0.01:
        print("Inference: The encoding process is fast.")
    else:
        print("Inference: The encoding process is relatively slow.")

    if decoding_time < 0.01:
        print("Inference: The decoding process is fast.")
    else:
        print("Inference: The decoding process is relatively slow.")


#| eval:false
analyze_results()

Name: sscnet
Description: Test
Bits Per Pixel Per Channel (bpppc): 2.53
PSNR: 43.37 dB
SSIM: 0.9748
Spectral Angle (SA): 1.84
Encoding Time: 0.0025 seconds
Decoding Time: 0.0015 seconds
Inference: The compression maintains high image quality.
Inference: The encoding process is fast.
Inference: The decoding process is fast.


In [47]:
import torch

# Function to extract the maximum pixel value from a checkpoint
def extract_max_value(file_path):
    checkpoint = torch.load(file_path, map_location=torch.device('cpu'))
    
    # Recursively find all tensors in the checkpoint
    def find_tensors(data):
        if isinstance(data, dict):
            for key, value in data.items():
                yield from find_tensors(value)
        elif isinstance(data, torch.Tensor):
            yield data

    # Collect max values from all tensors
    max_values = [tensor.max().item() for tensor in find_tensors(checkpoint)]
    if not max_values:
        raise ValueError("No tensors found in the checkpoint.")
    
    return max(max_values)

# Paths to your checkpoint files
file_path_original = "/root/HSI_HypSpecNet11k/hsi-compression/results/weights/sscnet_2point5bpppc.pth.tar"
file_path_quantized = "/root/HSI_HypSpecNet11k/hsi-compression/compressed_model/static_quant_fastrai_model.pth"

# Extract and display max pixel values
try:
    max_pixel_original = extract_max_value(file_path_original)
    print("Max pixel value in Original Model:", max_pixel_original)
except Exception as e:
    print(f"Error extracting max pixel value from Original Model: {e}")

try:
    max_pixel_quantized = extract_max_value(file_path_quantized)
    print("Max pixel value in Quantized Model:", max_pixel_quantized)
except Exception as e:
    print(f"Error extracting max pixel value from Quantized Model: {e}")


Max pixel value in Original Model: 1.6606359481811523
Error extracting max pixel value from Quantized Model: 'Conv2d' object has no attribute '_modules'


  checkpoint = torch.load(file_path, map_location=torch.device('cpu'))


In [None]:
import matplotlib.pyplot as plt

# Data for plotting
metrics = ["Model Size (MB)", "VRAM Usage (MB)", "Average Loss (MSE)", "Execution Speed (s)"]
non_quantized_values = [55.14, 410.94, 0.906979, 3.07]
quantized_values = [13.84, 350.00, 0.772124, 2.5]
y_labels = ["Size (MB)", "Size (MB)", "MSE Loss", "Time (s)"]  # Y-axis labels

# Create subplots (arranged in a 2x2 grid)
fig, axes = plt.subplots(2, 2, figsize=(12, 10), dpi=100)

# Add a multi-line title
fig.suptitle(
    'Comparison of Spectral Signals Compressor Network (SSCNet)\nNon-Quantized and Quantized Models',
    fontsize=18,
    fontweight='bold'
)

# Flatten axes for easier indexing
axes = axes.flatten()

# Plot Model Size, VRAM Usage, Average Loss, and Execution Speed
for i, metric in enumerate(metrics):
    ax = axes[i]
    bars = ax.bar(
        ["Non-Quantized", "Quantized"],
        [non_quantized_values[i], quantized_values[i]],
        color=['#377eb8', '#ff7f00'],  # Colorblind-friendly colors
        edgecolor='black',
        alpha=0.9
    )
    ax.set_title(metric, fontsize=14, fontweight='bold')
    ax.set_ylabel(y_labels[i], fontsize=12, labelpad=10)
    ax.grid(axis='y', linestyle='--', linewidth=0.7, alpha=0.7)

    # Dynamically adjust y-axis limit and add annotations on top of bars
    max_value = max(non_quantized_values[i], quantized_values[i])
    ax.set_ylim(0, max_value * 1.25)  # Add 25% headroom for annotations

    # Add numerical annotations above bars
    for bar, value in zip(bars, [non_quantized_values[i], quantized_values[i]]):
        ax.text(
            bar.get_x() + bar.get_width() / 2,  # Center of bar
            bar.get_height() + (max_value * 0.03),  # Slightly above the bar
            f'{value:.2f}',  # Rounded to 2 decimal places
            ha='center', va='bottom', fontsize=10, fontweight='bold', color='black'
        )

# Adjust layout to accommodate titles and labels
plt.tight_layout(rect=[0, 0, 1, 0.92])  # Leave space for the main title
plt.show()


In [94]:
#| eval: false
'''
def print_model_weights(model, model_name):

    print(f"\nWeights for {model_name}:")
    for name, param in model.named_parameters():
        print(f"Layer: {name}")
        print(param.data)  
        print("\n" + "-"*50)

# Assuming both models are initialized and loaded
original_model = SpectralSignalsCompressorNetwork()
quantized_model = SpectralSignalsCompressorNetwork()

# Load weights into the models
original_model.load_state_dict(torch.load(original_model_file_path)["state_dict"], strict=False)
quantized_model.load_state_dict(torch.load(compressed_model_file_path), strict=False)

# Print weights for both models
print_model_weights(original_model, "Original Model")
print_model_weights(quantized_model, "Quantized Model")
'''

'\ndef print_model_weights(model, model_name):\n\n    print(f"\nWeights for {model_name}:")\n    for name, param in model.named_parameters():\n        print(f"Layer: {name}")\n        print(param.data)  \n        print("\n" + "-"*50)\n\n# Assuming both models are initialized and loaded\noriginal_model = SpectralSignalsCompressorNetwork()\nquantized_model = SpectralSignalsCompressorNetwork()\n\n# Load weights into the models\noriginal_model.load_state_dict(torch.load(original_model_file_path)["state_dict"], strict=False)\nquantized_model.load_state_dict(torch.load(compressed_model_file_path), strict=False)\n\n# Print weights for both models\nprint_model_weights(original_model, "Original Model")\nprint_model_weights(quantized_model, "Quantized Model")\n'

In [93]:
#| eval: false
import torch

def inspect_checkpoint(file_path, model_name):

    print(f"\nInspecting checkpoint for {model_name}: {file_path}")
    try:
        checkpoint = torch.load(file_path, map_location=torch.device('cpu')) 
        print(f"Checkpoint Keys for {model_name}: {list(checkpoint.keys())}")
        
        # Print sample weights if they exist in the checkpoint
        if "state_dict" in checkpoint:
            state_dict = checkpoint["state_dict"]
        else:
            state_dict = checkpoint  # Direct weights
        
        print(f"Number of layers in {model_name}: {len(state_dict)}")
        for layer_name, weights in state_dict.items():
            print(f"Layer: {layer_name}, Shape: {weights.shape}, Data Type: {weights.dtype}")
            print("Sample Weights:", weights.flatten()[:5].tolist()) 
            print("-" * 50)
            break  
        
    except Exception as e:
        print(f"Error loading {model_name}: {e}")

# File paths
file_path_original = "/root/HSI_HypSpecNet11k/hsi-compression/results/weights/sscnet_2point5bpppc.pth.tar"
file_path_quantized = "/root/HSI_HypSpecNet11k/hsi-compression/compressed_model/static_quant_fastrai_model.pth"

# Inspect checkpoints
inspect_checkpoint(file_path_original, "Original Model")
inspect_checkpoint(file_path_quantized, "Quantized Model")



Inspecting checkpoint for Original Model: /root/HSI_HypSpecNet11k/hsi-compression/results/weights/sscnet_2point5bpppc.pth.tar
Checkpoint Keys for Original Model: ['state_dict']
Number of layers in Original Model: 29
Layer: encoder.0.weight, Shape: torch.Size([256, 202, 3, 3]), Data Type: torch.float32
Sample Weights: [-0.013182495720684528, -0.0055603827349841595, -0.011340336874127388, 0.011407438665628433, -0.008067389950156212]
--------------------------------------------------

Inspecting checkpoint for Quantized Model: /root/HSI_HypSpecNet11k/hsi-compression/compressed_model/static_quant_fastrai_model.pth
Checkpoint Keys for Quantized Model: ['encoder_0_input_scale_0', 'encoder_0_input_zero_point_0', 'encoder.0.weight', 'encoder.0.bias', 'encoder.0.scale', 'encoder.0.zero_point', 'encoder.2.weight', 'encoder.2.bias', 'encoder.2.scale', 'encoder.2.zero_point', 'encoder.5.weight', 'encoder.5.bias', 'encoder.5.scale', 'encoder.5.zero_point', 'encoder.8.weight', 'encoder.8.bias', 'en

  checkpoint = torch.load(file_path, map_location=torch.device('cpu'))
  device=storage.device,


In [103]:
import torch
from sscnet import SpectralSignalsCompressorNetwork

def print_model_details(model, model_name):
    """
    Prints details of the model including layer names, shapes, and data types.
    
    Args:
        model (torch.nn.Module): The PyTorch model to inspect.
        model_name (str): Name of the model (e.g., 'Original Model').
    """
    print(f"\nInspecting {model_name}...")
    for name, param in model.named_parameters():
        print(f"Layer: {name}, Shape: {param.shape}, Data Type: {param.dtype}")
        print("Sample Weights:", param.flatten()[:5].tolist())  # Print first 5 weights as a sample
        print("-" * 50)

# File paths for the models
file_path_original = "/root/HSI_HypSpecNet11k/hsi-compression/results/weights/sscnet_2point5bpppc.pth.tar"
file_path_quantized = "/root/HSI_HypSpecNet11k/hsi-compression/compressed_model/static_quant_fastrai_model.pth"

# Initialize models
original_model = SpectralSignalsCompressorNetwork()
quantized_model = SpectralSignalsCompressorNetwork()

# Load weights into models
print(f"Loading original model weights from {file_path_original}...")
original_model.load_state_dict(
    torch.load(file_path_original).get("state_dict", torch.load(file_path_original)),
    strict=False
)

print(f"Loading quantized model weights from {file_path_quantized}...")
quantized_model.load_state_dict(
    torch.load(file_path_quantized).get("state_dict", torch.load(file_path_quantized)),
    strict=False
)

# Print details of both models
print_model_details(original_model, "Original Model")
print_model_details(quantized_model, "Quantized Model")


Loading original model weights from /root/HSI_HypSpecNet11k/hsi-compression/results/weights/sscnet_2point5bpppc.pth.tar...
Loading quantized model weights from /root/HSI_HypSpecNet11k/hsi-compression/compressed_model/static_quant_fastrai_model.pth...


  original_model.load_state_dict(torch.load(file_path_original)["state_dict"], strict=False)
  quantized_model.load_state_dict(torch.load(file_path_quantized), strict=False)


AttributeError: 'Conv2d' object has no attribute '_modules'

In [77]:
import torch
from sscnet import SpectralSignalsCompressorNetwork

quantized_model_path = "/root/HSI_HypSpecNet11k/hsi-compression/compressed_model/static_quant_fastrai_model.pth"

def load_quantized_model(model_path):
    print(f"Loading quantized model from {model_path}...")
    quantized_model = SpectralSignalsCompressorNetwork()
    checkpoint = torch.load(model_path)
    quantized_model.load_state_dict(checkpoint, strict=False)
    print("Quantized model loaded successfully.")
    return quantized_model

if __name__ == "__main__":
    quantized_model = load_quantized_model(quantized_model_path)
    print("\nQuantized Model Architecture:")
    print(quantized_model)


Loading quantized model from /root/HSI_HypSpecNet11k/hsi-compression/compressed_model/static_quant_fastrai_model.pth...
Quantized model loaded successfully.

Quantized Model Architecture:
SpectralSignalsCompressorNetwork(
  (encoder): Sequential(
    (0): Conv2d(202, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): PReLU(num_parameters=256)
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): PReLU(num_parameters=256)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): PReLU(num_parameters=256)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): PReLU(num_parameters=512)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(512, 1024, kernel_size=(3, 3), stri

  checkpoint = torch.load(model_path)


In [None]:
#| export
def foo(): pass

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()