In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math

In [35]:
def absmax_quantization(x, bit=8):
    Qb = 2**(bit - 1)
    # find the maximum absolute value in the tensor
    max_val = torch.max(torch.abs(x))
    # using the max values, we can calculate the scaling factor for each value in the tensor to map it to the range appropriate range
    scale_factor = Qb / max_val
    # now we can quantize the tensor, rounding to the nearest integer
    x = torch.round(x * scale_factor)
    # x = (x * scale_factor).round()
    return x.to(torch.int8), max_val

In [36]:
def absmax_dequantization(x, max_val, bit=8):
    Qb = 2**(bit - 1)
    
    reverse_scale_factor = max_val / Qb
    
    x = x * reverse_scale_factor
    
    return x.to(torch.float32) # return to float32 which is original precision

In [37]:
input = torch.randn(1, 25)
print(f"input: {input}")
output, max_val = absmax_quantization(input)
print(f"output: {output}")
print(f"output dtype: {output.dtype}")

dequant_output = absmax_dequantization(output, max_val)
print(f"dequant output: {dequant_output}")
print(f"dequant output dtype: {dequant_output.dtype}")

input: tensor([[-0.8894, -0.5349,  0.6481,  0.3875,  0.0349, -1.7664,  0.5548, -0.4214,
          1.7691,  0.0169,  0.3873,  0.8700,  0.1457,  0.5463, -0.9343, -0.4479,
          1.4549, -0.9986,  2.8610,  0.6735, -0.0676,  0.2125, -0.1026, -0.6561,
          0.0320]])
output: tensor([[ -40,  -24,   29,   17,    2,  -79,   25,  -19,   79,    1,   17,   39,
            7,   24,  -42,  -20,   65,  -45, -128,   30,   -3,   10,   -5,  -29,
            1]], dtype=torch.int8)
output dtype: torch.int8
dequant output: tensor([[-0.8941, -0.5364,  0.6482,  0.3800,  0.0447, -1.7658,  0.5588, -0.4247,
          1.7658,  0.0224,  0.3800,  0.8717,  0.1565,  0.5364, -0.9388, -0.4470,
          1.4528, -1.0058, -2.8610,  0.6705, -0.0671,  0.2235, -0.1118, -0.6482,
          0.0224]])
dequant output dtype: torch.float32


In [30]:
class BitLinear(nn.Module):
    def __init__(self, in_features, out_features, groups=1, bit=8, nl_next=False, bias=True):
        super(BitLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.groups = groups
        print(f"input: {in_features}, output: {out_features}, groups: {groups}")
        
        self.weights = nn.Parameter(torch.Tensor(self.out_features, self.in_features))
        
        print(f"weights: {self.weights.shape}")
        # Upon initialization, the weights will be randomly initialized using the kaiming uniform method
        self.parameter_initialization()
        
    def parameter_initialization(self):
        nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5))
        
    def forward(self, x):
        weights = self.weights.view(self.groups, -1, self.in_features)
        
        # normalize to zero mean
        weights = weights - weights.mean(dim=[1, 2], keepdim=True)
        
        # quantize weights
        weights = torch.sign(weights)
        
        # calculate beta as 1-norm of weights divided by n*m
        beta = (torch.norm(weights, p=1, dim=[1, 2], keepdim=True) / 
                              (weights.shape[1] * weights.shape[2]))
        
        #scale the weights by beta
        weights = weights * beta
        
        #reshape to original shape
        weights = weights.view(self.out_features, self.in_features)
        
        # get quantized inputs
        quantized_input, gamma = absmax_quantization(x)
        
        # forward pass
        print(f"weights: {weights}")
        print(f"quantized input: {quantized_input}")
        output = torch.matmul(quantized_input.float(), weights.t())
        
        print(f"output: {output}")
        output = absmax_dequantization(output, gamma)
        
        return output

In [34]:
input = torch.randn(1, 25)
print(f"input: {input}")
bitlinear = BitLinear(25, 10)
output = bitlinear(input)
print(f"output: {output}")
print(f"output dtype: {output.dtype}")

input: tensor([[-1.0619, -0.2755,  0.9684,  1.3968,  1.9292,  0.2815, -0.9012,  0.8763,
         -0.5202, -0.5438,  0.1914, -1.1180,  1.4784, -1.0089, -0.1801, -3.5963,
         -0.5111,  0.5066,  0.9411,  1.2713,  1.6487, -0.1585, -1.4995,  0.9017,
          1.5501]])
input: 25, output: 10, groups: 1
weights: torch.Size([10, 25])
weights: tensor([[ 1., -1., -1., -1., -1.,  1.,  1., -1., -1., -1., -1., -1., -1., -1.,
         -1., -1.,  1.,  1., -1., -1.,  1.,  1.,  1.,  1.,  1.],
        [ 1.,  1.,  1.,  1., -1.,  1., -1., -1., -1.,  1., -1.,  1.,  1., -1.,
          1.,  1., -1.,  1., -1., -1.,  1.,  1.,  1.,  1., -1.],
        [ 1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1., -1.,  1.,
          1.,  1.,  1., -1.,  1., -1., -1., -1., -1.,  1., -1.],
        [-1.,  1., -1.,  1., -1., -1., -1.,  1.,  1.,  1.,  1.,  1., -1., -1.,
          1.,  1., -1.,  1., -1.,  1., -1., -1., -1., -1., -1.],
        [ 1., -1.,  1., -1.,  1.,  1.,  1., -1.,  1., -1.,  1., -1., -1.,  1.,
   