In [108]:
import torch

# Load the .pkl file
file_path = "/shares/bulk/vgandham/SpikeFI/out/net/nmnist-lenet/nmnist-lenet_net2.pt"  # Replace with your actual file path
model_data = torch.load(file_path, map_location=torch.device('cpu'))

# Check the type of the loaded object
print(type(model_data))
print(model_data)

<class 'demo.nets.nmnist.LeNetNetwork'>
LeNetNetwork(
  (slayer): spikeLayer()
  (SC1): _convLayer(2, 6, kernel_size=(7, 7, 1), stride=(1, 1, 1), bias=False)
  (SC2): _convLayer(6, 16, kernel_size=(5, 5, 1), stride=(1, 1, 1), bias=False)
  (SC3): _convLayer(16, 120, kernel_size=(5, 5, 1), stride=(1, 1, 1), bias=False)
  (SP1): _poolLayer(1, 1, kernel_size=(2, 2, 1), stride=(2, 2, 1), bias=False)
  (SP2): _poolLayer(1, 1, kernel_size=(2, 2, 1), stride=(2, 2, 1), bias=False)
  (SF1): _denseLayer(120, 84, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
  (SF2): _denseLayer(84, 10, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
  (SDC): _dropoutLayer(p=0.4, inplace=False)
  (SDF): _dropoutLayer(p=0.2, inplace=False)
)


In [109]:
# Extract the state dictionary (weights)
state_dict = model_data.state_dict()

# Print all available layers
print("Available layers in the model:")
print(state_dict.keys())

Available layers in the model:
odict_keys(['slayer.srmKernel', 'slayer.refKernel', 'SC1.weight', 'SC2.weight', 'SC3.weight', 'SP1.weight', 'SP2.weight', 'SF1.weight', 'SF2.weight'])


In [110]:
for name, param in model_data.named_parameters():
    print(name, param.shape)

SC1.weight torch.Size([6, 2, 7, 7, 1])
SC2.weight torch.Size([16, 6, 5, 5, 1])
SC3.weight torch.Size([120, 16, 5, 5, 1])
SP1.weight torch.Size([1, 1, 2, 2, 1])
SP2.weight torch.Size([1, 1, 2, 2, 1])
SF1.weight torch.Size([84, 120, 1, 1, 1])
SF2.weight torch.Size([10, 84, 1, 1, 1])


In [111]:
print("srmKernel shape:", state_dict["slayer.srmKernel"].shape)
print("refKernel shape:", state_dict["slayer.refKernel"].shape)

srmKernel shape: torch.Size([77])
refKernel shape: torch.Size([11])


In [112]:
# Compute min and max values for each parameter
layer_ranges = {}

for layer_name, weights in state_dict.items():
    if isinstance(weights, torch.Tensor):
        layer_ranges[layer_name] = {
            "min": weights.min().item(),
            "max": weights.max().item()
        }

# Print results
for layer, stats in layer_ranges.items():
    print(f"{layer}: min={stats['min']}, max={stats['max']}")


slayer.srmKernel: min=0.0, max=1.0
slayer.refKernel: min=-20.0, max=-0.0
SC1.weight: min=-13.337733268737793, max=15.416525840759277
SC2.weight: min=-19.57712173461914, max=18.011754989624023
SC3.weight: min=-18.200239181518555, max=17.31533432006836
SP1.weight: min=11.0, max=11.0
SP2.weight: min=11.0, max=11.0
SF1.weight: min=-14.966136932373047, max=11.687015533447266
SF2.weight: min=-19.792848587036133, max=8.437464714050293


In [113]:
# Print the data type of each weight tensor
for layer in state_dict.keys():
    print(f"{layer}: {state_dict[layer].dtype}")

slayer.srmKernel: torch.float32
slayer.refKernel: torch.float32
SC1.weight: torch.float32
SC2.weight: torch.float32
SC3.weight: torch.float32
SP1.weight: torch.float32
SP2.weight: torch.float32
SF1.weight: torch.float32
SF2.weight: torch.float32


In [114]:
# Print the number of weights for each layer
for layer in state_dict.keys():
    num_weights = state_dict[layer].numel()
    print(f"{layer}: {num_weights} weights")

slayer.srmKernel: 77 weights
slayer.refKernel: 11 weights
SC1.weight: 588 weights
SC2.weight: 2400 weights
SC3.weight: 48000 weights
SP1.weight: 4 weights
SP2.weight: 4 weights
SF1.weight: 10080 weights
SF2.weight: 840 weights


In [115]:
print("SP1.weight weights:")
print(state_dict["SP1.weight"])

SP1.weight weights:
tensor([[[[[11.],
           [11.]],

          [[11.],
           [11.]]]]])


In [116]:
print("refKernel weights:")
print(state_dict["slayer.refKernel"])

refKernel weights:
tensor([ -0.0000, -20.0000, -14.7152,  -8.1201,  -3.9830,  -1.8316,  -0.8086,
         -0.3470,  -0.1459,  -0.0604,  -0.0247])


In [None]:
# Saving the quantized weights

from spikefi.utils import quantization as qua


def symmetric_quantize_spikefi(tensor, dtype=torch.qint8):
    """
    Applies symmetric quantization to a tensor while keeping it in float32 format.
    Uses the given function `quant_args_from_range` to calculate scale and zero point.
    """
    # Get min and max values of tensor
    xmin, xmax = tensor.min(), tensor.max()

    # Compute scale and zero point using the provided function
    scale, zero_point, dtype = qua.quant_args_from_range(xmin, xmax, dtype)

    dt_info = torch.iinfo(dtype)
    qmin = dt_info.min
    qmax = dt_info.max

    # Apply quantization (rounding to nearest discrete level) and keep as float32
    quantized_tensor = torch.clamp(((tensor / scale).round() + zero_point),qmin,qmax)

    return quantized_tensor # Keep float32 for compatibility


quantized_state_dict = {}

# Quantize all weights to int8
for layer_name, weights in state_dict.items():
    
    quantized_state_dict[layer_name] = symmetric_quantize_spikefi(weights)

# Print example quantized weights
print("Quantized srmKernel:", quantized_state_dict["slayer.srmKernel"])

# Update the model's weights with the quantized versions
model_data.load_state_dict(quantized_state_dict, strict=False)

save_path = "quantized_model.pt"
torch.save(model_data, save_path)
print(f"Quantized weights saved at {save_path}")



Quantized srmKernel: tensor([-128.,  -65.,  -14.,   26.,   58.,   82.,  100.,  113.,  121.,  126.,
         127.,  126.,  123.,  118.,  111.,  104.,   96.,   87.,   78.,   69.,
          60.,   50.,   41.,   32.,   23.,   14.,    6.,   -2.,  -10.,  -17.,
         -24.,  -31.,  -38.,  -44.,  -49.,  -55.,  -60.,  -65.,  -69.,  -73.,
         -77.,  -81.,  -84.,  -88.,  -91.,  -93.,  -96.,  -98., -101., -103.,
        -105., -106., -108., -110., -111., -112., -114., -115., -116., -117.,
        -118., -119., -119., -120., -121., -121., -122., -122., -123., -123.,
        -124., -124., -124., -125., -125., -125., -125.])
Quantized weights saved at quantized_model.pt


In [117]:
import torch
from spikefi.utils import quantization as qua

def symmetric_quantize_spikefi(tensor, scale, zero_point):
    """
    Applies symmetric quantization to a tensor while keeping it in float32 format.
    """
    dt_info = torch.iinfo(torch.int8)  # Use INT8 range
    qmin, qmax = dt_info.min, dt_info.max

    # Apply quantization (rounding to nearest discrete level) but store as float32
    quantized_tensor = torch.clamp(((tensor / scale).round() + zero_point), qmin, qmax)

    return quantized_tensor  # Keep as float32 for PyTorch compatibility

# Step 1: Find Global Min and Max Across All Weights
global_min = float("inf")
global_max = float("-inf")

for layer_name, weights in state_dict.items():
    if isinstance(weights, torch.Tensor) and "slayer" not in layer_name:  # Exclude neuron parameters
        global_min = min(global_min, weights.min().item())
        global_max = max(global_max, weights.max().item())

# Step 2: Compute a Global Scale and Zero Point
global_scale, global_zero_point, _ = qua.quant_args_from_range(global_min, global_max, torch.qint8)

# Step 3: Apply Global Quantization to All Weights
quantized_state_dict = {}

for layer_name, weights in state_dict.items():
    if isinstance(weights, torch.Tensor):
        quantized_tensor = symmetric_quantize_spikefi(weights, global_scale, global_zero_point)
        quantized_state_dict[layer_name] = quantized_tensor
    if "slayer" in layer_name:
        quantized_state_dict[layer_name] = state_dict[layer_name]

# Step 4: Ensure Neuron Parameters Use the Same Scale as Weights
for neuron_param in ["slayer.srmKernel", "slayer.refKernel"]:
    if neuron_param in state_dict:
        quantized_state_dict[neuron_param] = symmetric_quantize_spikefi(state_dict[neuron_param], global_scale, global_zero_point)

# Save the quantized model

# Update the model's weights with the quantized versions
model_data.load_state_dict(quantized_state_dict, strict=False)

save_path = "quantized_model_global_scale.pt"
torch.save(model_data, save_path)
print(f"Quantized weights saved at {save_path}")


Quantized weights saved at quantized_model_global_scale.pt


In [118]:
print("Quantized test:",  quantized_state_dict["slayer.srmKernel"])

Quantized test: tensor([ 5.,  7.,  8.,  9., 10., 11., 11., 11., 12., 12., 12., 12., 12., 11.,
        11., 11., 11., 11., 10., 10., 10., 10.,  9.,  9.,  9.,  9.,  9.,  8.,
         8.,  8.,  8.,  8.,  7.,  7.,  7.,  7.,  7.,  7.,  7.,  6.,  6.,  6.,
         6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  5.,  5.,  5.,
         5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,
         5.,  5.,  5.,  5.,  5.,  5.,  5.])


In [119]:
print(global_scale)

tensor(0.1483)


In [120]:
print(global_zero_point)

tensor(5, dtype=torch.int32)


In [107]:

def quantize_neuron_params(x):

    dt_info = torch.iinfo(torch.int8)  # Use INT8 range
    qmin, qmax = dt_info.min, dt_info.max

    # Apply quantization (rounding to nearest discrete level) but store as float32
    quantized = torch.clamp(((x / global_scale).round() + global_zero_point), qmin, qmax)

    return quantized


neuron_param_dict = { 'theta':    10, 'tauSr':    10.0, 'tauRef':   1.0,'scaleRef': 2 , 'tauRho':   1  ,'scaleRho': 1 }

quantized_neuron_dict ={}

for name,value in neuron_param_dict.items():
    quantized_neuron_dict[name]=  quantize_neuron_params(value)

print(quantized_neuron_dict)


{'theta': tensor(72.), 'tauSr': tensor(72.), 'tauRef': tensor(12.), 'scaleRef': tensor(18.), 'tauRho': tensor(12.), 'scaleRho': tensor(12.)}
