# core

> Fill in a module description here

In [1]:
#| default_exp core

In [2]:
#| hide
from nbdev.showdoc import *

In [None]:
#!pip install git+https://github.com/nathanhubens/fasterai.git

In [3]:
import os
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from fastai.data.core import DataLoaders
from fastai.learner import Learner
from fasterai.sparse.sparsify_callback import SparsifyCallback
import torch.nn.functional as F
from torch.nn import MSELoss
from fastai.vision.all import *
from fasterai.sparse.all import *
from fasterai.prune.all import *
import copy

In [4]:
import sys
sys.path.append("/root/HSI_HypSpecNet11k/hsi-compression/models/")

from sscnet import SpectralSignalsCompressorNetwork

def load_model_weights(model, weight_path):
    """
    Load pre-trained weights into the model.

    Parameters:
        model: PyTorch model object
        weight_path: Path to the pre-trained weight file (.pth.tar)

    Returns:
        model: The model loaded with pre-trained weights
    """
    state_dict = torch.load(weight_path, weights_only=True)
    model.load_state_dict(state_dict['state_dict'])
    return model

model = SpectralSignalsCompressorNetwork()
weight_path = "/root/HSI_HypSpecNet11k/hsi-compression/results/weights/sscnet_2point5bpppc.pth.tar"
model = load_model_weights(model, weight_path)

In [5]:
%pdb on

Automatic pdb calling has been turned ON


In [None]:

base_directory = '/root/HSI_HypSpecNet11k/hsi-compression/datasets/hyspecnet-11k/patches/'
csv_file_path = '/root/HSI_HypSpecNet11k/hsi-compression/datasets/hyspecnet-11k/splits/easy/test.csv'

def load_paths(csv_file):
    df = pd.read_csv(csv_file, header=None)
    file_paths = [os.path.join(base_directory, x.strip()) for x in df[0]]
    print("Paths loaded successfully.")
    return file_paths

class NPYDataset(Dataset):
    def __init__(self, file_paths, transform=None):
        self.file_paths = file_paths
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        sample = np.load(file_path)
        if self.transform:
            sample = self.transform(sample)
        sample = torch.from_numpy(sample).float()
        return sample, sample

def transform_sample(sample):
    return (sample - np.mean(sample)) / np.std(sample)

def create_dataloaders(csv_file_path, batch_size=4, transform=None):
    file_paths = load_paths(csv_file_path)
    dataset = NPYDataset(file_paths, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return DataLoaders(dataloader, dataloader)

dls = create_dataloaders(
    csv_file_path=csv_file_path,
    batch_size=4,
    transform=transform_sample
)

class GradualPruningSchedule:
    def __init__(self, start_pct=0.0, end_pct=1.0, n_steps=100):
        self.start_pct = start_pct
        self.end_pct = end_pct
        self.n_steps = n_steps

    def __call__(self, target_sparsity, pct_train):
        if pct_train < self.start_pct:
            return 0.0
        elif pct_train > self.end_pct:
            return target_sparsity
        else:
            sparsity = (pct_train - self.start_pct) / (self.end_pct - self.start_pct)
            return sparsity * target_sparsity

sparsity = 50
granularity = "layer"
context = "global"
criteria = "l1"
schedule = GradualPruningSchedule(start_pct=0.0, end_pct=0.5, n_steps=100)

def dummy_loss(output, target):
    return F.mse_loss(output, target)

import sys
sys.path.append("/root/HSI_HypSpecNet11k/hsi-compression/models/")
from sscnet import SpectralSignalsCompressorNetwork

model = SpectralSignalsCompressorNetwork()
learn = Learner(dls=dls, model=model, loss_func=dummy_loss)

sp_cb = SparsifyCallback(
    sparsity=sparsity,
    granularity=granularity,
    context=context,
    criteria=criteria,
    schedule=schedule
)

learn.unfreeze()
learn.fit_one_cycle(1, cbs=sp_cb)
m_sp = copy.deepcopy(learn.model)

print("Sparsification Results:")
for name, param in m_sp.named_parameters():
    if param.requires_grad:
        sparsity_level = (param.data == 0).float().mean().item() * 100
        print(f"Sparsity for {name}: {sparsity_level:.2f}%")


Paths loaded successfully.
Pruning of layer until a sparsity of [50]%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,time


TypeError: Exception occured in `SparsifyCallback` when calling event `before_batch`:
	can't multiply sequence by non-int of type 'float'

> [0;32m/tmp/ipykernel_222247/3958308271.py[0m(54)[0;36m__call__[0;34m()[0m
[0;32m     52 [0;31m        [0;32melse[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     53 [0;31m            [0msparsity[0m [0;34m=[0m [0;34m([0m[0mpct_train[0m [0;34m-[0m [0mself[0m[0;34m.[0m[0mstart_pct[0m[0;34m)[0m [0;34m/[0m [0;34m([0m[0mself[0m[0;34m.[0m[0mend_pct[0m [0;34m-[0m [0mself[0m[0;34m.[0m[0mstart_pct[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 54 [0;31m            [0;32mreturn[0m [0msparsity[0m [0;34m*[0m [0mtarget_sparsity[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     55 [0;31m[0;34m[0m[0m
[0m[0;32m     56 [0;31m[0msparsity[0m [0;34m=[0m [0;36m50[0m[0;34m[0m[0;34m[0m[0m
[0m


In [25]:
import fasterai.prune
print(dir(fasterai.prune))


['__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'all', 'prune_callback', 'pruner']


In [18]:
# 3. Apply Pruning to the Model
def apply_pruning(model, sparsity, context, criteria, schedule):
    """
    Apply pruning to a model.

    Parameters:
        model: PyTorch model object
        sparsity: Desired sparsity level
        context: Context (e.g., filter)
        criteria: Pruning criteria (e.g., filter L1 norm)
        schedule: Schedule for pruning (e.g., fixed)

    Returns:
        pruned_model: A copy of the pruned model
    """
    learn = Learner(dls=None, model=model)  # No dataloaders needed for just pruning
    cbs = PruneCallback(sparsity=sparsity, context=context, criteria=criteria, schedule=schedule)
    learn.fit(1, cbs=cbs)  # Dummy training loop
    pruned_model = copy.deepcopy(learn.model)
    return pruned_model

In [12]:
# 4. Sparsify and Then Prune
def sparsify_then_prune(model, sparsity_sparse, sparsity_prune, context_sparse, context_prune, criteria_sparse, criteria_prune, schedule_sparse, schedule_prune):
    """
    Apply sparsification followed by pruning to a model.

    Parameters:
        model: PyTorch model object
        sparsity_sparse: Sparsity level for sparsification
        sparsity_prune: Sparsity level for pruning
        context_sparse: Context for sparsification
        context_prune: Context for pruning
        criteria_sparse: Sparsification criteria
        criteria_prune: Pruning criteria
        schedule_sparse: Schedule for sparsification
        schedule_prune: Schedule for pruning

    Returns:
        final_model: A model that has been sparsified and pruned
    """
    sparsified_model = apply_sparsification(
        model, sparsity_sparse, context_sparse, criteria_sparse, schedule_sparse
    )
    final_model = apply_pruning(
        sparsified_model, sparsity_prune, context_prune, criteria_prune, schedule_prune
    )
    return final_model

In [13]:
# 5. Model Integrity Check
def check_model_integrity(model1, model2):
    """
    Check the integrity of two models by comparing their input and output channels.

    Parameters:
        model1: PyTorch model object (original or sparsified)
        model2: PyTorch model object (pruned or sparsified + pruned)

    Returns:
        bool: True if integrity is maintained, False otherwise
    """
    # Example check: matching input-output dimensions for layers
    for layer1, layer2 in zip(model1.modules(), model2.modules()):
        if isinstance(layer1, nn.Conv2d) and isinstance(layer2, nn.Conv2d):
            if layer1.out_channels != layer2.out_channels or layer1.in_channels != layer2.in_channels:
                return False
    return True

In [None]:
#| export
def foo(): pass

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()