In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math

In [None]:
#OLD
def absmax_quantization(x, bit=8):
    Qb = 2**(bit - 1)
    
    # find the maximum absolute value in the tensor
    max_val = torch.max(torch.abs(x))
    
    # using the max values, we can calculate the scaling factor for each value in the tensor to map it to the range appropriate range
    scale_factor = Qb / max_val
    
    # now we can quantize the tensor, rounding to the nearest integer
    x = torch.round(x * scale_factor)
    
    return x.to(torch.int8), max_val

def absmax_dequantization(x, max_val, bit=8):
    Qb = 2**(bit - 1)
    
    reverse_scale_factor = max_val / Qb
    
    x = x * reverse_scale_factor
    
    return x.to(torch.float32) # return to float32 which is original precision

class BitLinear(nn.Module):
    def __init__(self, in_features, out_features, groups=1, bit=8, nl_next=False, bias=True):
        super(BitLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.groups = groups
        # print(f"input: {in_features}, output: {out_features}, groups: {groups}")
        
        self.weights = nn.Parameter(torch.Tensor(self.out_features, self.in_features))
        
        # print(f"weights: {self.weights.shape}")
        # Upon initialization, the weights will be randomly initialized using the kaiming uniform method
        self.parameter_initialization()
        
    def parameter_initialization(self):
        nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5))
        
    def forward(self, x):
        weights = self.weights.view(self.groups, -1, self.in_features)
        
        # normalize to zero mean
        weights = weights - weights.mean(dim=[1, 2], keepdim=True)
        
        # quantize weights
        weights = torch.sign(weights)
        
        # calculate beta as 1-norm of weights divided by n*m
        beta = (torch.norm(weights, p=1, dim=[1, 2], keepdim=True) / 
                              (weights.shape[1] * weights.shape[2]))
        
        #scale the weights by beta
        weights = weights * beta
        
        #reshape to original shape
        weights = weights.view(self.out_features, self.in_features)
        
        # get quantized inputs
        quantized_input, gamma = absmax_quantization(x)
        
        # forward pass
        # print(f"weights: {weights}")
        # print(f"quantized input: {quantized_input}")
        output = torch.matmul(quantized_input.float(), weights.t())
        
        # print(f"output: {output}")
        output = absmax_dequantization(output, gamma)
        
        return output

In [28]:
def absmax_quantization(x, bit=8, nl_next=False):
    Qb = 2**(bit - 1)
    
    # find the maximum absolute value in the tensor
    max_val = torch.max(torch.abs(x))
    min_val = torch.min(x)
    
    if nl_next:
        shifted_x = x - min_val
        max_val = torch.max(torch.abs(shifted_x))
        
        scale_factor = Qb / max_val
        x = torch.round(shifted_x * scale_factor)
    else:
        # using the max values, we can calculate the scaling factor for each value in the tensor to map it to the range appropriate range
        scale_factor = Qb / max_val
        
        # now we can quantize the tensor, rounding to the nearest integer
        x = torch.round(x * scale_factor)
    
    return x.to(torch.int8), max_val, min_val

def absmax_dequantization(x, max_val, nl_next=False, min_val=None, bit=8):
    Qb = 2**(bit - 1)
    
    reverse_scale_factor = max_val / Qb
    
    x = x * reverse_scale_factor
    
    return x.to(torch.float32) # return to float32 which is original precision

class BitLinear(nn.Module):
    def __init__(self, in_features, out_features, groups=1, bit=8, nl_next=False, bias=True):
        super(BitLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.groups = groups
        # print(f"input: {in_features}, output: {out_features}, groups: {groups}")
        
        self.weights = nn.Parameter(torch.Tensor(self.out_features, self.in_features))
        
        # print(f"weights: {self.weights.shape}")
        # Upon initialization, the weights will be randomly initialized using the kaiming uniform method
        self.parameter_initialization()
        
    def parameter_initialization(self):
        nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5))
        
    def forward(self, x):
        weights = self.weights.view(self.groups, -1, self.in_features)
        
        # normalize to zero mean
        weights = weights - weights.mean(dim=[1, 2], keepdim=True)
        
        # quantize weights
        weights = torch.sign(weights)
        
        # calculate beta as 1-norm of weights divided by n*m
        beta = (torch.norm(weights, p=1, dim=[1, 2], keepdim=True) / 
                              (weights.shape[1] * weights.shape[2]))
        
        #scale the weights by beta
        weights = weights * beta
        
        #reshape to original shape
        weights = weights.view(self.out_features, self.in_features)
        
        # get quantized inputs
        quantized_input, gamma = absmax_quantization(x)
        
        # forward pass
        # print(f"weights: {weights}")
        # print(f"quantized input: {quantized_input}")
        output = torch.matmul(quantized_input.float(), weights.t())
        
        # print(f"output: {output}")
        output = absmax_dequantization(output, gamma)
        
        return output

In [32]:
# make a test input of values ranging from -1000 to 1000
input = torch.rand(1, 16) * 2000 - 1000
print(f"x: {input}")
output, max_val, min_val = absmax_quantization(input)
print(f"output: {output}")
recon_input = absmax_dequantization(output, max_val)
print(f"recon_input: {recon_input}")
output, max_val, min_val = absmax_quantization(input, nl_next=True)
print(f"output: {output}")
recon_input = absmax_dequantization(output, max_val)
print(f"recon input: {recon_input}")
recon_input = recon_input + min_val
print(f"recon input: {recon_input}")

x: tensor([[-394.6103,  970.2345, -608.3282, -295.1694, -837.0201,  673.8638,
           42.3396,  -14.7834, -921.3655, -161.1737, -599.8037,  768.5757,
         -322.7930,  949.8514,  105.5886, -375.3987]])
output: tensor([[ -52,  127,  -80,  -39, -110,   89,    6,   -2, -122,  -21,  -79,  101,
          -43,  125,   14,  -50]], dtype=torch.int8)
recon_input: tensor([[-394.1578,  962.6545, -606.3965, -295.6183, -833.7953,  674.6161,
           45.4797,  -15.1599, -924.7548, -159.1791, -598.8166,  765.5757,
         -325.9381,  947.4946,  106.1194, -378.9979]])
output: tensor([[ 36, 127,  21,  42,   6, 108,  65,  61,   0,  51,  22, 114,  41, 127,
          69,  37]], dtype=torch.int8)
recon input: tensor([[ 532.0125, 1876.8219,  310.3406,  620.6812,   88.6687, 1596.0375,
          960.5781,  901.4656,    0.0000,  753.6844,  325.1187, 1684.7062,
          605.9031, 1876.8219, 1019.6906,  546.7906]])
recon input: tensor([[-389.3530,  955.4564, -611.0249, -300.6843, -832.6967,  674.6720,
