In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math

In [None]:
#OLD
def absmax_quantization(x, bit=8):
    Qb = 2**(bit - 1)
    
    # find the maximum absolute value in the tensor
    max_val = torch.max(torch.abs(x))
    
    # using the max values, we can calculate the scaling factor for each value in the tensor to map it to the range appropriate range
    scale_factor = Qb / max_val
    
    # now we can quantize the tensor, rounding to the nearest integer
    x = torch.round(x * scale_factor)
    
    return x.to(torch.int8), max_val

def absmax_dequantization(x, max_val, bit=8):
    Qb = 2**(bit - 1)
    
    reverse_scale_factor = max_val / Qb
    
    x = x * reverse_scale_factor
    
    return x.to(torch.float32) # return to float32 which is original precision

class BitLinear(nn.Module):
    def __init__(self, in_features, out_features, groups=1, bit=8, nl_next=False, bias=True):
        super(BitLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.groups = groups
        # print(f"input: {in_features}, output: {out_features}, groups: {groups}")
        
        self.weights = nn.Parameter(torch.Tensor(self.out_features, self.in_features))
        
        # print(f"weights: {self.weights.shape}")
        # Upon initialization, the weights will be randomly initialized using the kaiming uniform method
        self.parameter_initialization()
        
    def parameter_initialization(self):
        nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5))
        
    def forward(self, x):
        weights = self.weights.view(self.groups, -1, self.in_features)
        
        # normalize to zero mean
        weights = weights - weights.mean(dim=[1, 2], keepdim=True)
        
        # quantize weights
        weights = torch.sign(weights)
        
        # calculate beta as 1-norm of weights divided by n*m
        beta = (torch.norm(weights, p=1, dim=[1, 2], keepdim=True) / 
                              (weights.shape[1] * weights.shape[2]))
        
        #scale the weights by beta
        weights = weights * beta
        
        #reshape to original shape
        weights = weights.view(self.out_features, self.in_features)
        
        # get quantized inputs
        quantized_input, gamma = absmax_quantization(x)
        
        # forward pass
        # print(f"weights: {weights}")
        # print(f"quantized input: {quantized_input}")
        output = torch.matmul(quantized_input.float(), weights.t())
        
        # print(f"output: {output}")
        output = absmax_dequantization(output, gamma)
        
        return output

In [16]:
def absmax_quantization(x, bit=8, nl_next=False):
    Qb = 2**(bit - 1)
    
    # find the maximum absolute value in the tensor
    max_val = torch.max(torch.abs(x))
    min_val = torch.min(x)
    
    print(f"data type before quantization: {x.type()}")
    
    if nl_next:
        shifted_x = x - min_val
        max_val = torch.max(torch.abs(shifted_x))
        
        scale_factor = Qb / max_val
        x = torch.round(shifted_x * scale_factor)
    else:
        # using the max values, we can calculate the scaling factor for each value in the tensor to map it to the range appropriate range
        scale_factor = Qb / max_val
        
        # now we can quantize the tensor, rounding to the nearest integer
        x = torch.round(x * scale_factor)
    
    dequant = max_val / Qb
    
    return x.to(torch.int8), dequant, max_val, min_val

def absmax_dequantization(x, max_val, nl_next=False, min_val=None, bit=8):
    Qb = 2**(bit - 1)
    
    reverse_scale_factor = max_val / Qb
    
    x = x * reverse_scale_factor
    
    return x.to(torch.float32) # return to float32 which is original precision

class BitLinear(nn.Module):
    def __init__(self, in_features, out_features, groups=1, bit=8, nl_next=False, bias=True):
        super(BitLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.groups = groups
        self.nl_next = nl_next
        
        if bias is True:
            self.bias = nn.Parameter(torch.randn(self.out_features))
        else:
            self.register_parameter("bias", None)
        
        self.weights = nn.Parameter(torch.randn(self.out_features, self.in_features))
        
        # # print(f"weights: {self.weights.shape}")
        # # Upon initialization, the weights will be randomly initialized using the kaiming uniform method
        # self.parameter_initialization()
        
    # def parameter_initialization(self):
    #     nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5))
        
    def forward(self, x):
        
        input_norm = F.layer_norm(x, (self.in_features,))
        
        input_quant, dequant, gamma, eta = absmax_quantization(input_norm, nl_next=self.nl_next)
        
        print(f"data type after quantization: {input_quant.type()}")
        
        weight_quant = torch.sign(self.weights)
        
        print(f"weight quant: {weight_quant}")
        
        output = torch.matmul(input_quant.float(), weight_quant.t())
        
        if self.bias is not None:
            output = output + self.bias.unsqueeze(0).expand_as(output)
            
        beta = torch.norm(self.weights, p=1) / (self.in_features * self.out_features)
        
        output = output * dequant * beta
        
        return output

In [17]:
# make a test input of values ranging from -1000 to 1000
input = torch.rand(1, 16) * 2000 - 1000
print(f"x: {input}")
output, dequant, max_val, min_val = absmax_quantization(input)
print(f"output: {output}, {output.type()}")
recon_input = dequant * output
print(f"recon_input: {recon_input}, {recon_input.type()}")

x: tensor([[ 953.1671,  -93.2529, -873.9128,  526.3042,  212.9384,  716.6527,
          397.1938, -655.5541, -177.6257,  196.9402,  722.6052,   18.9512,
         -131.3445,  375.2721,  221.9041, -706.2743]])
data type before quantization: torch.FloatTensor
output: tensor([[ 127,  -13, -117,   71,   29,   96,   53,  -88,  -24,   26,   97,    3,
          -18,   50,   30,  -95]], dtype=torch.int8), torch.CharTensor
recon_input: tensor([[ 945.7205,  -96.8060, -871.2543,  528.7099,  215.9519,  714.8754,
          394.6707, -655.3024, -178.7188,  193.6121,  722.3220,   22.3399,
         -134.0391,  372.3309,  223.3985, -707.4287]]), torch.FloatTensor


In [18]:
test_layer = BitLinear(16, 16, bit=8)
test_layer(input)

data type before quantization: torch.FloatTensor
data type after quantization: torch.CharTensor
weight quant: tensor([[ 1.,  1., -1., -1., -1., -1., -1.,  1., -1., -1., -1., -1., -1.,  1.,
          1., -1.],
        [-1.,  1., -1., -1., -1., -1.,  1., -1., -1., -1.,  1., -1.,  1., -1.,
          1., -1.],
        [ 1.,  1.,  1.,  1., -1.,  1., -1.,  1.,  1., -1., -1.,  1.,  1., -1.,
         -1.,  1.],
        [ 1.,  1.,  1.,  1., -1., -1.,  1.,  1.,  1., -1., -1., -1.,  1.,  1.,
          1., -1.],
        [-1.,  1.,  1.,  1., -1., -1.,  1.,  1., -1.,  1., -1.,  1.,  1., -1.,
         -1., -1.],
        [ 1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1., -1.,  1., -1., -1.,
         -1., -1.],
        [ 1.,  1., -1., -1.,  1.,  1., -1., -1.,  1.,  1., -1., -1.,  1., -1.,
         -1., -1.],
        [ 1., -1.,  1., -1., -1., -1.,  1.,  1.,  1., -1., -1., -1.,  1., -1.,
         -1., -1.],
        [-1.,  1., -1., -1.,  1.,  1.,  1., -1.,  1.,  1., -1.,  1., -1.,  1.,
         -1., -1.],

tensor([[ 0.8729,  1.8419, -4.7915, -1.7068, -4.7331,  0.8427,  3.0279, -3.6637,
          2.5646,  6.4944,  5.7991,  1.8301,  0.7808, -5.5449,  5.2603,  2.9647]],
       grad_fn=<MulBackward0>)