In [137]:
import torch
import torch.nn as nn
from torch.autograd import Variable, Function
import torch.nn.functional as F

from torchvision import datasets, transforms
import numpy as np

batch_size = 128
n_epochs = 1000
validation_steps = 10
learning_rate = 5e-3
stochastic = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [138]:
x = np.random.randn(16)
layer = nn.Linear(16, 32).to(device)
x = torch.FloatTensor(x).to(device)

In [159]:
def gradscale(x, scale):
    yOut = x
    yGrad = x * scale
    y = torch.detach(yOut - yGrad) + yGrad
    return y

def roundpass(x):
    yOut = torch.round(x)
    yGrad = x
    y = torch.detach(yOut - yGrad) + yGrad
    return y

class QuantizedLinear(nn.Module):
    def __init__(self, original_linear, bitwidth=16):
        super(QuantizedLinear, self).__init__()
        self.weight = original_linear.weight
        self.bias = original_linear.bias
        self.Qn = - (2 ** (bitwidth - 1))
        self.Qp = 2 ** (bitwidth - 1) - 1
        
        # Initialize weights
        detached_weights = ql.weight.data.cpu().numpy()
        step = np.array((2 * np.mean(np.abs(detached_weights))) / np.sqrt(self.Qp))
        step = nn.Parameter(torch.from_numpy(step))
        self.register_parameter('step_size', step)
        self.step_size.requires_grad = True
        
        self.grad_scale_factor = np.sqrt(np.prod(self.weight.size()) * self.Qp)
        
    def forward(self, inputs):
        s = gradscale(self.step_size, self.grad_scale_factor)
        quantized_weights = self.weight / s
        quantized_weights = torch.clamp(quantized_weights, self.Qn, self.Qp)
        quantized_weights = roundpass(quantized_weights)
        output = F.linear(inputs, quantized_weights, self.bias)
        return output

In [160]:
ql = QuantizedLinear(layer).to(device)
ql.grad_scale_factor

4095.9374995231556

In [161]:
output = ql.forward(x)
loss = output.sum()
loss.backward()

In [162]:
ql.step_size.requires_grad

True

In [163]:
ql.step_size.grad

tensor(5.1513e+09, device='cuda:0', dtype=torch.float64)

In [164]:
ql.weight.grad

tensor([[-4764.1489,  -519.7213, -1442.0760, -2907.0669,  -155.3403,  1988.7479,
         -1607.5035,  4598.8130, -1256.9056, -1727.8313,  1623.0482,  1882.9502,
           848.3790,  -821.1964,  3214.6448,  -766.5775],
        [-4764.1489,  -519.7213, -1442.0760, -2907.0669,  -155.3403,  1988.7479,
         -1607.5035,  4598.8130, -1256.9056, -1727.8313,  1623.0482,  1882.9502,
           848.3790,  -821.1964,  3214.6448,  -766.5775],
        [-4764.1489,  -519.7213, -1442.0760, -2907.0669,  -155.3403,  1988.7479,
         -1607.5035,  4598.8130, -1256.9056, -1727.8313,  1623.0482,  1882.9502,
           848.3790,  -821.1964,  3214.6448,  -766.5775],
        [-4764.1489,  -519.7213, -1442.0760, -2907.0669,  -155.3403,  1988.7479,
         -1607.5035,  4598.8130, -1256.9056, -1727.8313,  1623.0482,  1882.9502,
           848.3790,  -821.1964,  3214.6448,  -766.5775],
        [-4764.1489,  -519.7213, -1442.0760, -2907.0669,  -155.3403,  1988.7479,
         -1607.5035,  4598.8130, -1256.