In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

# 1) MNIST Dataset & Dataloaders
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='.', train=True, download=True, transform=transform)
test_dataset  = datasets.MNIST(root='.', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader  = torch.utils.data.DataLoader(test_dataset,  batch_size=1000, shuffle=False)


In [2]:
def fake_float_truncate(x: torch.Tensor, e_bits: int, m_bits: int) -> torch.Tensor:
    """
    Approximate 'float' with e_bits exponent bits and m_bits mantissa bits.
    Simplified approach: unbiased exponent in integer range + truncated mantissa.
    """
    eps = 1e-45
    abs_x = x.abs().clamp(min=eps)
    sign = x.sign()
    
    # exponent
    e = torch.floor(torch.log2(abs_x))
    min_e = -(2**(e_bits)) + 1
    max_e =  (2**(e_bits)) - 1
    e_clamped = torch.clamp(e, min_e, max_e)
    
    # fraction in [1,2) if x >= eps
    frac = abs_x / (2.0 ** e_clamped)
    
    # truncate mantissa
    scale = 2.0 ** m_bits
    frac_trunc = torch.floor(frac * scale) / scale
    
    return sign * (2.0 ** e_clamped) * frac_trunc


class FakeFloatFunction(torch.autograd.Function):
    """
    Custom autograd for 'fake-float' exponent+mantissa truncation.
    """
    @staticmethod
    def forward(ctx, x, e_bits, m_bits):
        # save for backward
        ctx.save_for_backward(x, e_bits, m_bits)
        
        # Round e_bits, m_bits to nearest integer for the forward pass
        e_bits_int = int(torch.round(e_bits).clamp(min=0.0).item())
        m_bits_int = int(torch.round(m_bits).clamp(min=1.0).item())
        
        out = fake_float_truncate(x, e_bits_int, m_bits_int)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        x, e_bits, m_bits = ctx.saved_tensors
        
        # 1) Gradient wrt x: straight-through
        grad_x = grad_output.clone()
        
        e_bits_int = int(torch.round(e_bits).clamp(min=0).item())
        m_bits_int = int(torch.round(m_bits).clamp(min=1).item())
        
        # 2) Gradient wrt e_bits: approximate with central difference
        grad_e_bits = None
        if e_bits.requires_grad:
            delta = 1
            e_plus2_int  = int(torch.round(e_bits + 2*delta).clamp(min=0).item())
            e_plus_int   = int(torch.round(e_bits + delta).clamp(min=0).item())
            e_minus_int  = int(torch.round(e_bits - delta).clamp(min=0).item())
            e_minus2_int = int(torch.round(e_bits - 2*delta).clamp(min=0).item())
            
            f_plus2  = fake_float_truncate(x, e_plus2_int,  m_bits_int)
            f_plus   = fake_float_truncate(x, e_plus_int,   m_bits_int)
            f_minus  = fake_float_truncate(x, e_minus_int,  m_bits_int)
            f_minus2 = fake_float_truncate(x, e_minus2_int, m_bits_int)
            
            #diff_e = (f_plus - f_minus) * grad_output
            #grad_e_bits = diff_e.sum() / (2.0 * delta)
            
            diff_e = -f_plus2 + 8*f_plus - 8*f_minus + f_minus2
            grad_e_bits = diff_e.sum() / (12.0 * delta)
        
        # 3) Gradient wrt m_bits: approximate with central difference
        grad_m_bits = None
        if m_bits.requires_grad:
            delta = 1.0
            m_plus2_int  = int(torch.round(m_bits + 2*delta).clamp(min=1).item())
            m_plus_int   = int(torch.round(m_bits + delta).clamp(min=1).item())
            m_minus_int  = int(torch.round(m_bits - delta).clamp(min=1).item())
            m_minus2_int = int(torch.round(m_bits - 2*delta).clamp(min=1).item())
            
            f_plus2  = fake_float_truncate(x, e_bits_int, m_plus2_int)
            f_plus   = fake_float_truncate(x, e_bits_int, m_plus_int)
            f_minus  = fake_float_truncate(x, e_bits_int, m_minus_int)
            f_minus2 = fake_float_truncate(x, e_bits_int, m_minus2_int)
            
            #diff_e = (f_plus - f_minus) * grad_output
            #grad_e_bits = diff_e.sum() / (2.0 * delta)
            
            diff_e = -f_plus2 + 8*f_plus - 8*f_minus + f_minus2
            grad_m_bits = diff_e.sum() / (12.0 * delta)
        
        return grad_x, grad_e_bits, grad_m_bits


In [3]:
class SimpleQuantizedMLP(nn.Module):
    def __init__(self, e_bits=4.0, m_bits=4.0):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 10)

        # Now we make them trainable:
        self.e_bits = nn.Parameter(torch.tensor(e_bits))
        self.m_bits = nn.Parameter(torch.tensor(m_bits))

    def forward(self, x):
        x = x.view(x.size(0), -1)

        w1 = FakeFloatFunction.apply(self.fc1.weight, self.e_bits, self.m_bits)
        b1 = FakeFloatFunction.apply(self.fc1.bias,   self.e_bits, self.m_bits)
        x  = F.relu(F.linear(x, w1, b1))

        w2 = FakeFloatFunction.apply(self.fc2.weight, self.e_bits, self.m_bits)
        b2 = FakeFloatFunction.apply(self.fc2.bias,   self.e_bits, self.m_bits)
        x  = F.linear(x, w2, b2)
        return x
    
    
class SimpleQuantizedCNN(nn.Module):
    def __init__(self, e_bits=4.0, m_bits=4.0):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1,  out_channels=16, kernel_size=3, padding=1)        
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        self.fc = nn.Linear(32 * 28 * 28, 10)

        self.input_e_bits = nn.Parameter(torch.tensor(e_bits), requires_grad=True)
        self.input_m_bits = nn.Parameter(torch.tensor(m_bits), requires_grad=True)

        self.output_e_bits = nn.Parameter(torch.tensor(e_bits), requires_grad=True)
        self.output_m_bits = nn.Parameter(torch.tensor(m_bits), requires_grad=True)
        
        self.w_e_bits = nn.ParameterList([
            nn.Parameter(torch.tensor(e_bits), requires_grad=True),  # layer 0
            nn.Parameter(torch.tensor(e_bits), requires_grad=True),  # layer 1
            nn.Parameter(torch.tensor(e_bits), requires_grad=True),
        ])
        
        self.w_m_bits = nn.ParameterList([
            nn.Parameter(torch.tensor(m_bits), requires_grad=True),  # layer 0
            nn.Parameter(torch.tensor(m_bits), requires_grad=True),  # layer 1
            nn.Parameter(torch.tensor(m_bits), requires_grad=True),
        ])
          
        self.b_e_bits = nn.ParameterList([
            nn.Parameter(torch.tensor(e_bits), requires_grad=True),  # layer 0
            nn.Parameter(torch.tensor(e_bits), requires_grad=True),  # layer 1
            nn.Parameter(torch.tensor(e_bits), requires_grad=True),
        ])
        
        self.b_m_bits = nn.ParameterList([
            nn.Parameter(torch.tensor(m_bits), requires_grad=True),  # layer 0
            nn.Parameter(torch.tensor(m_bits), requires_grad=True),  # layer 1
            nn.Parameter(torch.tensor(m_bits), requires_grad=True),
        ])
              
    def forward(self, x):
        x = FakeFloatFunction.apply(x, self.input_e_bits, self.input_m_bits)
        
        w1 = FakeFloatFunction.apply(self.conv1.weight, self.w_e_bits[0], self.w_m_bits[0])
        b1 = FakeFloatFunction.apply(self.conv1.bias,   self.b_e_bits[0], self.b_m_bits[0]) if self.conv1.bias is not None else None
        x  = F.relu(F.conv2d(x, w1, b1, stride=1, padding=1))

        w2 = FakeFloatFunction.apply(self.conv2.weight, self.w_e_bits[1], self.w_m_bits[1])
        b2 = FakeFloatFunction.apply(self.conv2.bias,   self.b_e_bits[1], self.b_m_bits[1]) if self.conv2.bias is not None else None
        x  = F.relu(F.conv2d(x, w2, b2, stride=1, padding=1))

        x  = x.view(x.size(0), -1)
        
        w_fc = FakeFloatFunction.apply(self.fc.weight, self.w_e_bits[2], self.w_m_bits[2])
        b_fc = FakeFloatFunction.apply(self.fc.bias,   self.b_e_bits[2], self.b_m_bits[2])
        x  = F.linear(x, w_fc, b_fc)

        x = FakeFloatFunction.apply(x, self.output_e_bits, self.output_m_bits)

        return x
    
    def printBitWidths(self):
        print(f"Input e_bits ", self.input_e_bits.item(), " m_bits ", self.input_m_bits.item())
        print(f"Output e_bits ", self.output_e_bits.item(), " m_bits ", self.output_m_bits.item())
        for i, (eb, mb) in enumerate(zip(self.w_e_bits, self.w_m_bits)):
            print(f"Layer {i} weight e_bits (float) = {eb.item()},  m_bits (float) = {mb.item()}")
        for i, (eb, mb) in enumerate(zip(self.b_e_bits, self.b_m_bits)):
            print(f"Layer {i} bias e_bits (float) = {eb.item()},  m_bits (float) = {mb.item()}")


In [4]:
def bitwidth_penalty(model, lambda_bw=1e-3):
    """
    Computes a penalty term for the bitwidth parameters in 'model'.
    'lambda_bw' is the weight/scale for this regularization.
    """
    penalty = 0.0
    
    penalty += 4.0*model.input_e_bits*model.input_e_bits + model.input_m_bits*model.input_m_bits
    penalty += 4.0*model.output_e_bits*model.output_e_bits + model.output_m_bits*model.output_m_bits
    
    # If the model has multiple layers with e_bits and m_bits in a ParameterList:
    for eb, mb in zip(model.w_e_bits, model.w_m_bits):
        # Option A: Penalize the raw float value (the "continuous" version)
        penalty += 4.0*eb*eb + mb*mb
        
        # Option B (alternative): Penalize the rounded integer version
        # penalty += torch.round(eb) + torch.round(mb)
    
    for eb, mb in zip(model.b_e_bits, model.b_m_bits):
        penalty += 4.0*eb*eb + mb*mb
            
    return lambda_bw * penalty

In [5]:
def train(model, device, train_loader, optimizer, epoch, lambda_bw=1e-3):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss_ce = F.cross_entropy(output, target)
        penalty_bw = bitwidth_penalty(model, lambda_bw) 
        loss = loss_ce + penalty_bw
        loss.backward()
        optimizer.step()

        if batch_idx % 200 == 0:
            print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
                  f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)}"
          f" ({accuracy:.2f}%)\n")

In [6]:

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

# Create model
# model = SimpleQuantizedMLP(e_bits=4.0, m_bits=4.0).to(device)
model = SimpleQuantizedCNN(e_bits=8.0, m_bits=24.0).to(device)

# Create optimizer (SGD or Adam)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# or: optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Train for some epochs
num_epochs = 100
for epoch in range(1, num_epochs + 1):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
    model.printBitWidths()


Test set: Average loss: 0.0680, Accuracy: 9788/10000 (97.88%)

Input e_bits  3.7857253551483154  m_bits  19.797115325927734
Output e_bits  3.783262014389038  m_bits  19.923742294311523
Layer 0 weight e_bits (float) = 3.78328800201416,  m_bits (float) = 19.92359161376953
Layer 1 weight e_bits (float) = 3.7832722663879395,  m_bits (float) = 19.923786163330078
Layer 2 weight e_bits (float) = 3.782982587814331,  m_bits (float) = 19.923580169677734
Layer 0 bias e_bits (float) = 3.783287763595581,  m_bits (float) = 19.9235897064209
Layer 1 bias e_bits (float) = 3.783287525177002,  m_bits (float) = 19.923603057861328
Layer 2 bias e_bits (float) = 3.783287763595581,  m_bits (float) = 19.92359733581543

Test set: Average loss: 0.0466, Accuracy: 9853/10000 (98.53%)

Input e_bits  2.4999098777770996  m_bits  17.499897003173828
Output e_bits  1.7866649627685547  m_bits  16.518877029418945
Layer 0 weight e_bits (float) = 1.7762789726257324,  m_bits (float) = 16.509624481201172
Layer 1 weight e_bit

KeyboardInterrupt: 