In [22]:
import torch 
import torch.nn as nn

# n bits quantization

def float2bin(f, fixed_exp, nbits=8):
    s = torch.sign(f)
    f = f * s
    s = (s * (-1) + 1.) * 0.5
    s= s.unsqueeze(-1)
    f = f/(2**fixed_exp)
    m = integer2bit(f - f % 1,num_bits = nbits-1)
    dtype = f.type()
    out = torch.cat([s, m], dim=-1).type(dtype)
    return out

def integer2bit(integer, num_bits=7):
    dtype = integer.type()
    exponent_bits = -torch.arange(-(num_bits - 1), 1).type(dtype)
    exponent_bits = exponent_bits.repeat(integer.shape + (1,))
    out = integer.unsqueeze(-1) / 2 ** exponent_bits
    return (out - (out % 1)) % 2

def bin2float(b, fixed_exp, nbits=8, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
    dtype = torch.float32
    s = torch.index_select(b, -1, torch.arange(0, 1).to(device))
    m = torch.index_select(b, -1, torch.arange(1,nbits).to(device))
    out = ((-1) ** s).squeeze(-1).type(dtype)
    exponents = -torch.arange(-(nbits - 2.), 1.).to(device)
    exponents = exponents.repeat(b.shape[:-1] + (1,))
    e_decimal = torch.sum(m * 2 ** (exponents), dim=-1)
    out *= e_decimal * 2 ** fixed_exp
    return out

def quantize_weights_nbits(model, nbits):
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            print(f"The current layer is: {name}: ")
            weight = module.weight.data
            # step size
            weight_max = torch.max(weight)
            fixed_exp = torch.ceil(torch.log2(weight_max/(2**(nbits-1)-1))) + 1

            # quantize to binary, and back to floating pt
            binary = float2bin(weight,fixed_exp,nbits).to(torch.int8)
            quantized_f = bin2float(binary,fixed_exp,nbits)

            # update weight
            module.weight.data = quantized_f
            print(f"finished quantized {name} weights to {nbits} BITs")

    return None


In [33]:
import sys
import torch
sys.path.append('../quantization_utils')
sys.path.append('../dataset')

from All_Dataloader import *
from _Loading_All_Model import * 

batch_size = 64
device = torch.device('cpu')
dataset_path = '../../Torch_condaENV/Working_folder/dataset/'

# Load some pretrained model
weight_path = './pre_trained_normal-nmnist_snn_300e.t7'
model = DVS128_Model(batch_size=batch_size).to(device)
checkpoint = torch.load(weight_path,map_location=device)
model.load_state_dict(checkpoint['net'])

# Dataset (TEST and Subset Loader)
_ , test_loader = choose_dataset(target="DVS128_Gesture",batch_size=batch_size,T_BIN=15,dataset_path=dataset_path)


In [16]:
check_accuracy(loader=test_loader, model=model)

Checking on testing data
Got 9799/9984 with accuracy 98.15


tensor(0.9815)

In [34]:
quantize_weights_nbits(model,4)

The current layer is: fc1: 
finished quantized fc1 weights to 4 BITs
The current layer is: fc2: 
finished quantized fc2 weights to 4 BITs


In [35]:
check_accuracy(loader=test_loader, model=model)

Checking on testing data
Got 905/9984 with accuracy 9.06


tensor(0.0906)

In [32]:
print(checkpoint['acc'])

KeyError: 'acc'