In [None]:
import torch

# Load the .pkl file
file_path = "/shares/bulk/vgandham/SpikeFI/out/net/nmnist-lenet/nmnist-lenet_net2.pt"  # Replace with your actual file path
model_data = torch.load(file_path, map_location=torch.device('cpu'))

# Check the type of the loaded object
print(type(model_data))
print(model_data)

<class 'demo.nets.nmnist.LeNetNetwork'>


In [3]:
# Extract the state dictionary (weights)
state_dict = model_data.state_dict()

# Print all available layers
print("Available layers in the model:")
print(state_dict.keys())

Available layers in the model:
odict_keys(['slayer.srmKernel', 'slayer.refKernel', 'SC1.weight', 'SC2.weight', 'SC3.weight', 'SP1.weight', 'SP2.weight', 'SF1.weight', 'SF2.weight'])


In [4]:
# Print the weights of the first convolutional layer
print("SC1 Weights:")
print(state_dict['SC1.weight'])

# Print the weights of the last fully connected layer
print("SF2 Weights:")
print(state_dict['SF2.weight'])

SC1 Weights:
tensor([[[[[ -5.1831],
           [  6.1414],
           [  7.0232],
           [ -1.1213],
           [  8.4697],
           [  6.8815],
           [  6.3043]],

          [[  2.4414],
           [ -3.1192],
           [  3.0804],
           [ -0.3441],
           [ -0.4664],
           [ -4.1038],
           [  2.0989]],

          [[ -6.9814],
           [ -5.1634],
           [ -0.1439],
           [  6.8133],
           [ 10.4784],
           [  6.2772],
           [  2.8187]],

          [[ -1.7098],
           [  5.7819],
           [  6.3244],
           [  5.2780],
           [ -1.4829],
           [ -7.2240],
           [ -2.7214]],

          [[  6.5970],
           [ -5.5290],
           [  0.3880],
           [ -2.8634],
           [  9.5725],
           [  8.3570],
           [-10.4210]],

          [[  3.4205],
           [ -4.9066],
           [  8.4940],
           [ 10.1891],
           [  5.2176],
           [-11.6931],
           [-13.0708]],

         

In [5]:
# Print the data type of each weight tensor
for layer in state_dict.keys():
    print(f"{layer}: {state_dict[layer].dtype}")

slayer.srmKernel: torch.float32
slayer.refKernel: torch.float32
SC1.weight: torch.float32
SC2.weight: torch.float32
SC3.weight: torch.float32
SP1.weight: torch.float32
SP2.weight: torch.float32
SF1.weight: torch.float32
SF2.weight: torch.float32


In [6]:
# Print the number of weights for each layer
for layer in state_dict.keys():
    num_weights = state_dict[layer].numel()
    print(f"{layer}: {num_weights} weights")

slayer.srmKernel: 77 weights
slayer.refKernel: 11 weights
SC1.weight: 588 weights
SC2.weight: 2400 weights
SC3.weight: 48000 weights
SP1.weight: 4 weights
SP2.weight: 4 weights
SF1.weight: 10080 weights
SF2.weight: 840 weights


In [7]:
print("srmKernel weights:")
print(state_dict["slayer.srmKernel"])

srmKernel weights:
tensor([0.0000, 0.2460, 0.4451, 0.6041, 0.7288, 0.8244, 0.8951, 0.9449, 0.9771,
        0.9947, 1.0000, 0.9953, 0.9825, 0.9631, 0.9384, 0.9098, 0.8781, 0.8442,
        0.8088, 0.7725, 0.7358, 0.6990, 0.6626, 0.6268, 0.5918, 0.5578, 0.5249,
        0.4932, 0.4628, 0.4337, 0.4060, 0.3796, 0.3546, 0.3309, 0.3084, 0.2873,
        0.2674, 0.2487, 0.2311, 0.2146, 0.1991, 0.1847, 0.1712, 0.1586, 0.1468,
        0.1359, 0.1257, 0.1162, 0.1074, 0.0992, 0.0916, 0.0845, 0.0780, 0.0719,
        0.0663, 0.0611, 0.0563, 0.0518, 0.0477, 0.0439, 0.0404, 0.0372, 0.0342,
        0.0314, 0.0289, 0.0266, 0.0244, 0.0224, 0.0206, 0.0189, 0.0174, 0.0159,
        0.0146, 0.0134, 0.0123, 0.0113, 0.0103])


In [8]:
print("refKernel weights:")
print(state_dict["slayer.refKernel"])

refKernel weights:
tensor([ -0.0000, -20.0000, -14.7152,  -8.1201,  -3.9830,  -1.8316,  -0.8086,
         -0.3470,  -0.1459,  -0.0604,  -0.0247])


In [None]:
# Saving the quantized weights

from spikefi.utils import quantization as qua 

def quantize_to_int8(tensor):
    """
    Quantizes a given float32 tensor into int8 format.
    The values are scaled and converted to torch.int8.
    """
    xmin, xmax = tensor.min(), tensor.max()

    # Compute scale and zero point
    scale, zero_point, dtype = qua.quant_args_from_range(xmin, xmax, torch.int8)  # Use qint8 to get the scale

    print("Converted to: ",dtype)
    
    # Normalize and convert to int8
    int8_tensor = torch.clamp((tensor / scale).round() + zero_point, -128, 127).to(torch.int8)

    return int8_tensor, scale, zero_point


quantized_state_dict = {}

# Quantize all weights to int8
for layer_name, weights in state_dict.items():
    quantized_weights, scale, zero_point = quantize_to_int8(weights)  # Using previously defined function
    quantized_state_dict[layer_name] = quantized_weights

# Print example quantized weights
print("Quantized srmKernel:", quantized_state_dict["SC1.weight"])

# Update the model's weights with the quantized versions
model_data.load_state_dict(quantized_state_dict, strict=False)

save_path = "quantized_model_int8.pt"
torch.save(model_data, save_path)
print(f"Quantized weights saved at {save_path}")
print("✅ Quantized model saved in int8 format!")



Quantized srmKernel: tensor([[[[[ -55],
           [  45],
           [  53],
           [ -19],
           [  66],
           [  52],
           [  47]],

          [[  13],
           [ -37],
           [  18],
           [ -12],
           [ -13],
           [ -45],
           [  10]],

          [[ -71],
           [ -55],
           [ -10],
           [  51],
           [  84],
           [  47],
           [  16]],

          [[ -24],
           [  42],
           [  47],
           [  38],
           [ -22],
           [ -73],
           [ -33]],

          [[  50],
           [ -58],
           [  -6],
           [ -34],
           [  76],
           [  65],
           [-101]],

          [[  21],
           [ -53],
           [  66],
           [  81],
           [  37],
           [-113],
           [-125]],

          [[ -64],
           [ -26],
           [ -41],
           [ -79],
           [ -98],
           [-122],
           [ -20]]],


         [[[   9],
           [ 