# Quantization

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [2]:
x_f = torch.randn((4, 4)); x_f

tensor([[-2.2066,  0.3966,  0.5587,  1.1892],
        [ 0.3932,  1.6257, -0.6623, -0.0903],
        [-0.8323, -0.0078,  1.2867, -0.1683],
        [ 3.2285,  1.6256,  0.6275, -1.0032]])

In [3]:
def quantize(x, n_bits):
    
    n_levels = 2**n_bits
    x_min, x_max = x.min(), x.max()
    q_min, q_max = 0, n_levels-1
    
    scale = (x_max - x_min) / (q_max - q_min)
    print(scale)
    zp = q_min - x_min / scale
    
    x_int = torch.round(x/scale) + zp
    return torch.clamp(x_int, q_min, q_max), scale, zp
    

def dequantize(x, scale, zp):
    return (x - zp)*scale

In [4]:
x_f

tensor([[-2.2066,  0.3966,  0.5587,  1.1892],
        [ 0.3932,  1.6257, -0.6623, -0.0903],
        [-0.8323, -0.0078,  1.2867, -0.1683],
        [ 3.2285,  1.6256,  0.6275, -1.0032]])

In [5]:
x_q, s, zp = quantize(x_f, 8); x_q

tensor(0.0213)


tensor([[  0.0000, 122.5259, 129.5259, 159.5259],
        [121.5259, 179.5259,  72.5259,  99.5259],
        [ 64.5259, 103.5259, 163.5259,  95.5259],
        [254.5259, 179.5259, 132.5259,  56.5259]])

In [6]:
dequantize(x_q, s, zp)

tensor([[-2.2066,  0.4050,  0.5542,  1.1936],
        [ 0.3837,  1.6199, -0.6607, -0.0853],
        [-0.8312,  0.0000,  1.2788, -0.1705],
        [ 3.2184,  1.6199,  0.6181, -1.0018]])

In [7]:
class Quantizer():
    
    def __init__(self, n_bits):
        self.n_bits = n_bits
    
    def quantize_model(self, model):
        pass
    
    def quantize_layer(self, layer):
        W = layer.weights.data
        b = layer.bias.data
        
    def _quantize(self, x, mode='asymmetric'):
        
        n_levels = 2**self.n_bits
        x_min, x_max = x.min(), x.max()
        
        if mode == 'asymmetric':
            q_min, q_max = 0, n_levels-1
            self.scale = (x_max - x_min) / (q_max - q_min)
            self.zp = int(q_min - x_min / self.scale)

        elif mode == 'symmetric':
            q_min, q_max = -n_levels/2, (n_levels/2)-1
            self.scale = (x_max - x_min) / (q_max - q_min)
            self.zp = 0            
            
        else: 
            raise NotImplementedError
    
        x_int = torch.round(x/self.scale) + self.zp
        
        return torch.clamp(x_int, q_min, q_max)
    
    def _dequantize(self, x):
        return (x - self.zp)*self.scale

In [8]:
q = Quantizer(8)

In [9]:
x_f

tensor([[-2.2066,  0.3966,  0.5587,  1.1892],
        [ 0.3932,  1.6257, -0.6623, -0.0903],
        [-0.8323, -0.0078,  1.2867, -0.1683],
        [ 3.2285,  1.6256,  0.6275, -1.0032]])

In [10]:
x_q = q._quantize(x_f, 'symmetric'); x_q

tensor([[-104.,   19.,   26.,   56.],
        [  18.,   76.,  -31.,   -4.],
        [ -39.,    0.,   60.,   -8.],
        [ 127.,   76.,   29.,  -47.]])

In [11]:
q._dequantize(x_q)

tensor([[-2.2167,  0.4050,  0.5542,  1.1936],
        [ 0.3837,  1.6199, -0.6607, -0.0853],
        [-0.8312,  0.0000,  1.2788, -0.1705],
        [ 2.7069,  1.6199,  0.6181, -1.0018]])

In [12]:
class Net(nn.Module):
    def __init__(self, mnist=True):
        super().__init__()
          
        self.conv1 = nn.Conv2d(3, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)
    
    def forward(self, x):

        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [13]:
net = Net()

In [14]:
net.conv1.bias.data

tensor([ 0.0280,  0.0547,  0.0385,  0.1155, -0.0160, -0.0367,  0.0064,  0.0154,
         0.0815, -0.0021,  0.0553, -0.0576, -0.0516, -0.0505, -0.1036, -0.0582,
        -0.0221,  0.0859, -0.0632,  0.0949])

In [15]:
class Quantizer():
    
    def __init__(self, n_bits):
        self.n_bits = n_bits
    
    def quantize_model(self, model):
        for k, m in enumerate(model.modules()):
            if not isinstance(m, Net):
                W_q, b_q = self.quantize_layer(m)
                m.weight.data  = W_q
                m.bias.data = b_q
            
        return model    
    
    def quantize_layer(self, layer, mode='asymmetric'):
        W = layer.weight.data
        b = layer.bias.data

        W_q = self._quantize(W, mode=mode)
        b_q = self._quantize(b, mode=mode)
        
        return W_q, b_q
    
    def _compute_params(self, x_min, x_max, q_min, q_max):
        
        scale = (x_max - x_min) / (q_max - q_min)
        initial_zp = q_min - x_min / scale

        zp = 0
        if initial_zp < q_min:
            zp = q_min
        elif initial_zp > q_max:
            zp = q_max
        else:
            zp = initial_zp

        return scale, int(zp)
        
        
    def _quantize(self, x, mode='asymmetric'):
        
        n_levels = 2**self.n_bits
        x_min, x_max = x.min(), x.max()
        
        if mode == 'asymmetric':
            q_min, q_max = 0, n_levels-1
            self.scale, self.zp = self._compute_params(x_min, x_max, q_min, q_max)

        elif mode == 'symmetric':
            q_min, q_max = -n_levels/2, (n_levels/2)-1
            self.scale, self.zp = self._compute_params(x_min, x_max, q_min, q_max)
            self.zp = 0
        else: 
            raise NotImplementedError
    
        x_q = torch.round(x/self.scale) + self.zp
        
        return torch.clamp(x_q, q_min, q_max)
    
    def _dequantize(self, x):
        return (x - self.zp)*self.scale

In [16]:
q = Quantizer(8)

In [17]:
q._quantize(x_f, 'symmetric')

tensor([[-104.,   19.,   26.,   56.],
        [  18.,   76.,  -31.,   -4.],
        [ -39.,    0.,   60.,   -8.],
        [ 127.,   76.,   29.,  -47.]])

In [18]:
q.quantize_layer(net.conv1, mode='asymmetric')

(tensor([[[[201.,  37., 117.,  45., 194.],
           [ 88., 172., 251.,  73., 170.],
           [  0.,  37., 218.,  36., 214.],
           [206., 158.,  17., 173.,  37.],
           [100.,  71.,  60., 109.,  88.]],
 
          [[102.,  83.,  73., 202.,  69.],
           [181.,  29., 107.,  10.,  18.],
           [ 19.,  61., 168., 105., 139.],
           [ 11., 237., 116., 147., 156.],
           [ 96., 159.,  65., 155., 229.]],
 
          [[ 48.,  27., 152., 154., 130.],
           [194., 136.,  84.,  12.,  57.],
           [152., 203.,   3., 225., 122.],
           [ 80.,  82., 198., 247., 225.],
           [137.,   9., 224.,  81.,  41.]]],
 
 
         [[[ 36., 197., 181., 186., 163.],
           [ 16., 124.,  76., 242.,  85.],
           [ 47., 228., 125., 136., 171.],
           [236.,  38.,  65., 181., 115.],
           [187.,  42., 133.,  10.,   2.]],
 
          [[121., 238., 146., 108.,  98.],
           [186., 188., 237.,  79.,  83.],
           [204.,  66., 130., 219., 127

In [20]:
q_model = q.quantize_model(net)

In [31]:
q_model.fc1.weight.data

tensor([[ 12.,   6.,  89.,  ..., 155., 200.,  22.],
        [160., 135., 148.,  ...,  51., 198., 244.],
        [207.,  65., 132.,  ..., 170., 146., 241.],
        ...,
        [181., 106., 227.,  ..., 127., 217.,   0.],
        [179.,  90., 164.,  ..., 186.,  57.,  26.],
        [109., 210., 151.,  ..., 197., 184., 231.]])

In [195]:
class Net(nn.Module):
    def __init__(self, mnist=True):
        super().__init__()
          
        self.conv1 = nn.Conv2d(3, 20, 5, 1)
        self.bn1 = nn.BatchNorm2d(20)
    
    def forward(self, x):

        x = self.conv1(x)
        x = self.bn1(x)
        return x

In [196]:
net = Net()

In [197]:
def bn_folding(conv, bn):
    w = conv.weight
    b = conv.bias
    
    gamma = bn.weight
    beta = bn.bias
    mean = bn.running_mean
    var = torch.sqrt(bn.running_var + bn.eps)
    
    w_f = w * (gamma/var).view(-1,1,1,1)
    b_f = gamma * ((b - mean)/var) + beta
    return w_f, b_f

In [198]:
x = torch.randn((1, 3, 32, 32))

In [199]:
net.eval()

Net(
  (conv1): Conv2d(3, 20, kernel_size=(5, 5), stride=(1, 1))
  (bn1): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [201]:
w_f, b_f = bn_folding(net.conv1, net.bn1)

In [202]:
class Net_f(nn.Module):
    def __init__(self, mnist=True):
        super().__init__()
          
        self.conv1 = nn.Conv2d(3, 20, 5, 1)
    def forward(self, x):

        x = self.conv1(x)
        return x

In [203]:
net_f = Net_f()

In [204]:
net_f.eval()

Net_f(
  (conv1): Conv2d(3, 20, kernel_size=(5, 5), stride=(1, 1))
)

In [205]:
net_f.conv1.weight.data = w_f
net_f.conv1.bias.data = b_f

In [219]:
torch.all(torch.lt(torch.abs(torch.add(net(x), -net_f(x))), 1e-6))

tensor(1, dtype=torch.uint8)

In [224]:
def fuse(conv, bn):
    w = conv.weight
    b = conv.bias
    
    gamma = bn.weight
    beta = bn.bias
    mean = bn.running_mean
    var = torch.sqrt(bn.running_var + bn.eps)
    
    w_f = w * (gamma/var).view(-1,1,1,1)
    b_f = gamma * ((b - mean)/var) + beta
    
    fused_conv = nn.Conv2d(conv.in_channels,
                         conv.out_channels,
                         conv.kernel_size,
                         conv.stride,
                         conv.padding,
                         bias=True)
    fused_conv.weight = nn.Parameter(w)
    fused_conv.bias = nn.Parameter(b)
    return fused_conv

In [284]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()  
        self.conv1 = nn.Conv2d(3, 20, 5, 1)
        self.bn1 = nn.BatchNorm2d(20)
        self.fuse = fuse(self.conv1, self.bn1)
        self.relu1 = nn.ReLU(inplace=True)
    
    def forward(self, x, fusion=False):
        if fusion:
            x = self.fuse(x)           
        else: 
            x = self.conv1(x)
            x = self.bn1(x)
            
        x = self.relu1(x)    
        return x

In [285]:
net = Net()

In [286]:
%%timeit
net(x);

341 µs ± 3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [287]:
%%timeit
net(x, True);

238 µs ± 1.79 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
