In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

# 1) MNIST Dataset & Dataloaders
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader  = torch.utils.data.DataLoader(test_dataset,  batch_size=1000, shuffle=False)

classes = ('zero', 'one', 'two', 'three', 'four', 'five', 'sis', 'seven', 'eight', 'nine')

input_size = (28, 28)

#train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
#train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)

#test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
#test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64,shuffle=False, num_workers=2)

#classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

#input_size = (32, 32)

In [2]:
def param_to_bit(x : torch.Tensor) -> torch.Tensor:
    return torch.exp(x)

def bit_to_param(x : torch.Tensor) -> torch.Tensor:
    return torch.log(x)

def fake_float_truncate(x: torch.Tensor, e_bits: int, m_bits: int) -> torch.Tensor:
    """
    Approximate 'float' with e_bits exponent bits and m_bits mantissa bits.
    Simplified approach: unbiased exponent in integer range + truncated mantissa.
    """
    eps = 1e-45
    abs_x = x.abs().clamp(min=eps)
    sign = x.sign()
    
    # exponent
    e = torch.floor(torch.log2(abs_x))
    min_e = -(2**(e_bits-1)) + 1
    max_e =  (2**(e_bits-1)) - 1
    e_clamped = torch.clamp(e, min_e, max_e)
    
    # fraction in [1,2) if x >= eps
    frac = abs_x / (2.0 ** e_clamped)
    
    # truncate mantissa
    scale = 2.0 ** m_bits
    frac_trunc = torch.floor(frac * scale) / scale
    
    return sign * (2.0 ** e_clamped) * frac_trunc


class FakeFloatFunction(torch.autograd.Function):
    """
    Custom autograd for 'fake-float' exponent+mantissa truncation.
    """
    @staticmethod
    def forward(ctx, x, e_bits_param, m_bits_param):
        
        # save for backward
        ctx.save_for_backward(x, e_bits_param, m_bits_param)
        
        # Round e_bits, m_bits to nearest integer for the forward pass
        e_bits_int = int(torch.round(param_to_bit(e_bits_param)).clamp(min=1.0).item())
        m_bits_int = int(torch.round(param_to_bit(m_bits_param)).clamp(min=1.0).item())
        
        out = fake_float_truncate(x, e_bits_int, m_bits_int)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        x, e_bits_param, m_bits_param = ctx.saved_tensors
        
        # 1) Gradient wrt x: straight-through
        grad_x = grad_output.clone()
        
        e_bits_int = int(torch.round(param_to_bit(e_bits_param)).clamp(min=1.0).item())
        m_bits_int = int(torch.round(param_to_bit(m_bits_param)).clamp(min=1.0).item())
        
        # 2) Gradient wrt e_bits: approximate with central difference
        grad_e_bits = None
        if e_bits_param.requires_grad:
            delta = 0.1
            e_plus2_int  = int(torch.round(param_to_bit(e_bits_param + 2*delta)).clamp(min=1.0).item())
            e_plus_int   = int(torch.round(param_to_bit(e_bits_param +   delta)).clamp(min=1.0).item())
            e_minus_int  = int(torch.round(param_to_bit(e_bits_param -   delta)).clamp(min=1.0).item())
            e_minus2_int = int(torch.round(param_to_bit(e_bits_param - 2*delta)).clamp(min=1.0).item())
            
            f_plus2  = fake_float_truncate(x, e_plus2_int,  m_bits_int)
            f_plus   = fake_float_truncate(x, e_plus_int,   m_bits_int)
            f_minus  = fake_float_truncate(x, e_minus_int,  m_bits_int)
            f_minus2 = fake_float_truncate(x, e_minus2_int, m_bits_int)
            
            #diff_e = (f_plus - f_minus) * grad_output
            #grad_e_bits = diff_e.sum() / (2.0 * delta)
            
            diff_e = -f_plus2 + 8*f_plus - 8*f_minus + f_minus2
            grad_e_bits = diff_e.sum() / (12.0 * delta)
        
        # 3) Gradient wrt m_bits: approximate with central difference
        grad_m_bits = None
        if m_bits_param.requires_grad:
            delta = 0.1
            m_plus2_int  = int(torch.round(param_to_bit(m_bits_param + 2*delta)).clamp(min=1.0).item())
            m_plus_int   = int(torch.round(param_to_bit(m_bits_param +   delta)).clamp(min=1.0).item())
            m_minus_int  = int(torch.round(param_to_bit(m_bits_param -   delta)).clamp(min=1.0).item())
            m_minus2_int = int(torch.round(param_to_bit(m_bits_param - 2*delta)).clamp(min=1.0).item())
            
            f_plus2  = fake_float_truncate(x, e_bits_int, m_plus2_int)
            f_plus   = fake_float_truncate(x, e_bits_int, m_plus_int)
            f_minus  = fake_float_truncate(x, e_bits_int, m_minus_int)
            f_minus2 = fake_float_truncate(x, e_bits_int, m_minus2_int)
            
            #diff_e = (f_plus - f_minus) * grad_output
            #grad_e_bits = diff_e.sum() / (2.0 * delta)
            
            diff_e = -f_plus2 + 8*f_plus - 8*f_minus + f_minus2
            grad_m_bits = diff_e.sum() / (12.0 * delta)
        
        return grad_x, grad_e_bits, grad_m_bits


In [3]:
class SimpleQuantizedMLP(nn.Module):
    def __init__(self, e_bits=4.0, m_bits=4.0):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, len(classes))

        # Now we make them trainable:
        self.e_bits = nn.Parameter(torch.tensor(e_bits))
        self.m_bits = nn.Parameter(torch.tensor(m_bits))

    def forward(self, x):
        x = x.view(x.size(0), -1)

        w1 = FakeFloatFunction.apply(self.fc1.weight, self.e_bits, self.m_bits)
        b1 = FakeFloatFunction.apply(self.fc1.bias,   self.e_bits, self.m_bits)
        x  = F.relu(F.linear(x, w1, b1))

        w2 = FakeFloatFunction.apply(self.fc2.weight, self.e_bits, self.m_bits)
        b2 = FakeFloatFunction.apply(self.fc2.bias,   self.e_bits, self.m_bits)
        x  = F.linear(x, w2, b2)
        return x
    
    
class SimpleQuantizedCNN(nn.Module):
    def __init__(self, e_bits=8.0, m_bits=23.0):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1,  out_channels=32, kernel_size=3, padding=1)        
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, len(classes))
        
        self.i_e_bits_param = nn.ParameterList([
            nn.Parameter(bit_to_param(torch.tensor(e_bits)), requires_grad=True),  # layer 0
            nn.Parameter(bit_to_param(torch.tensor(e_bits)), requires_grad=True),  # layer 1
            nn.Parameter(bit_to_param(torch.tensor(e_bits)), requires_grad=True),
            nn.Parameter(bit_to_param(torch.tensor(e_bits)), requires_grad=True),
        ])
        
        self.i_m_bits_param = nn.ParameterList([
            nn.Parameter(bit_to_param(torch.tensor(m_bits)), requires_grad=True),  # layer 0
            nn.Parameter(bit_to_param(torch.tensor(m_bits)), requires_grad=True),  # layer 1
            nn.Parameter(bit_to_param(torch.tensor(m_bits)), requires_grad=True),
            nn.Parameter(bit_to_param(torch.tensor(m_bits)), requires_grad=True),
        ])
        
        self.w_e_bits_param = nn.ParameterList([
            nn.Parameter(bit_to_param(torch.tensor(e_bits)), requires_grad=True),  # layer 0
            nn.Parameter(bit_to_param(torch.tensor(e_bits)), requires_grad=True),  # layer 1
            nn.Parameter(bit_to_param(torch.tensor(e_bits)), requires_grad=True),
            nn.Parameter(bit_to_param(torch.tensor(e_bits)), requires_grad=True),
        ])
        
        self.w_m_bits_param = nn.ParameterList([
            nn.Parameter(bit_to_param(torch.tensor(m_bits)), requires_grad=True),  # layer 0
            nn.Parameter(bit_to_param(torch.tensor(m_bits)), requires_grad=True),  # layer 1
            nn.Parameter(bit_to_param(torch.tensor(m_bits)), requires_grad=True),
            nn.Parameter(bit_to_param(torch.tensor(m_bits)), requires_grad=True),
        ])
          
        self.b_e_bits_param = nn.ParameterList([
            nn.Parameter(bit_to_param(torch.tensor(e_bits)), requires_grad=True),  # layer 0
            nn.Parameter(bit_to_param(torch.tensor(e_bits)), requires_grad=True),  # layer 1
            nn.Parameter(bit_to_param(torch.tensor(e_bits)), requires_grad=True),
            nn.Parameter(bit_to_param(torch.tensor(e_bits)), requires_grad=True),
        ])
        
        self.b_m_bits_param = nn.ParameterList([
            nn.Parameter(bit_to_param(torch.tensor(m_bits)), requires_grad=True),  # layer 0
            nn.Parameter(bit_to_param(torch.tensor(m_bits)), requires_grad=True),  # layer 1
            nn.Parameter(bit_to_param(torch.tensor(m_bits)), requires_grad=True),
            nn.Parameter(bit_to_param(torch.tensor(m_bits)), requires_grad=True),
        ])
              
    def forward(self, x):
        
        x = FakeFloatFunction.apply(x, self.i_e_bits_param[0], self.i_m_bits_param[0])
        w1 = FakeFloatFunction.apply(self.conv1.weight, self.w_e_bits_param[0], self.w_m_bits_param[0])
        b1 = FakeFloatFunction.apply(self.conv1.bias,   self.b_e_bits_param[0], self.b_m_bits_param[0]) if self.conv1.bias is not None else None
        x  = F.relu(F.conv2d(x, w1, b1, stride=1, padding=1))
        x = F.max_pool2d(x, 2)  # 32x32 -> 16x16
        
        x = FakeFloatFunction.apply(x, self.i_e_bits_param[1], self.i_m_bits_param[1])
        w2 = FakeFloatFunction.apply(self.conv2.weight, self.w_e_bits_param[1], self.w_m_bits_param[1])
        b2 = FakeFloatFunction.apply(self.conv2.bias,   self.b_e_bits_param[1], self.b_m_bits_param[1]) if self.conv2.bias is not None else None
        x  = F.relu(F.conv2d(x, w2, b2, stride=1, padding=1))
        x = F.max_pool2d(x, 2)  # 16x16 -> 8x8
        
        x  = x.view(x.size(0), -1)
        
        x = FakeFloatFunction.apply(x, self.i_e_bits_param[2], self.i_m_bits_param[2])
        w_fc1 = FakeFloatFunction.apply(self.fc1.weight, self.w_e_bits_param[2], self.w_m_bits_param[2])
        b_fc1 = FakeFloatFunction.apply(self.fc1.bias,   self.b_e_bits_param[2], self.b_m_bits_param[2])
        x  = F.linear(x, w_fc1, b_fc1)

        x = FakeFloatFunction.apply(x, self.i_e_bits_param[3], self.i_m_bits_param[3])
        w_fc2 = FakeFloatFunction.apply(self.fc2.weight, self.w_e_bits_param[3], self.w_m_bits_param[3])
        b_fc2 = FakeFloatFunction.apply(self.fc2.bias,   self.b_e_bits_param[3], self.b_m_bits_param[3])
        x  = F.linear(x, w_fc2, b_fc2)

        return x
    
    def printBitWidths(self):
        for i, (eb, mb) in enumerate(zip(self.i_e_bits_param, self.i_m_bits_param)):
            print(f"Layer {i} input e_bits (float) = {param_to_bit(eb).item()},  m_bits (float) = {param_to_bit(mb).item()}")
        for i, (eb, mb) in enumerate(zip(self.w_e_bits_param, self.w_m_bits_param)):
            print(f"Layer {i} weight e_bits (float) = {param_to_bit(eb).item()},  m_bits (float) = {param_to_bit(mb).item()}")
        for i, (eb, mb) in enumerate(zip(self.b_e_bits_param, self.b_m_bits_param)):
            print(f"Layer {i} bias e_bits (float) = {param_to_bit(eb).item()},  m_bits (float) = {param_to_bit(mb).item()}")


In [4]:
def bitwidth_penalty(model, lambda_bw=1e-3):
    """
    Computes a penalty term for the bitwidth parameters in 'model'.
    'lambda_bw' is the weight/scale for this regularization.
    """
    penalty = 0.0
    
    # If the model has multiple layers with e_bits and m_bits in a ParameterList:
    for eb, mb in zip(model.w_e_bits_param, model.w_m_bits_param):
        # Option A: Penalize the raw float value (the "continuous" version)
        
        eb_ = eb - bit_to_param(torch.tensor(1.0))
        mb_ = mb - bit_to_param(torch.tensor(1.0))
        
        penalty += eb_*eb_ + mb_*mb_
        
        # Option B (alternative): Penalize the rounded integer version
        # penalty += torch.round(eb) + torch.round(mb)
    
    for eb, mb in zip(model.b_e_bits_param, model.b_m_bits_param):
        eb_ = eb - bit_to_param(torch.tensor(1.0))
        mb_ = mb - bit_to_param(torch.tensor(1.0))
        penalty += eb_*eb_ + mb_*mb_

    for eb, mb in zip(model.i_e_bits_param, model.i_m_bits_param):
        eb_ = eb - bit_to_param(torch.tensor(1.0))
        mb_ = mb - bit_to_param(torch.tensor(1.0))
        penalty += eb_*eb_ + mb_*mb_
                    
    return lambda_bw * penalty

In [5]:
def train(model, device, train_loader, optimizer, epoch, lambda_bw=1e-3):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss_ce = F.cross_entropy(output, target)
        penalty_bw = bitwidth_penalty(model, lambda_bw) 
        loss = loss_ce + penalty_bw
        loss.backward()
        optimizer.step()

        if batch_idx % 200 == 0:
            print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
                  f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)}"
          f" ({accuracy:.2f}%)\n")

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

# Create model
# model = SimpleQuantizedMLP(e_bits=4.0, m_bits=4.0).to(device)
model = SimpleQuantizedCNN(e_bits=8.0, m_bits=23.0).to(device)

# Create optimizer (SGD or Adam)
#optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Train for some epochs
num_epochs = 100
for epoch in range(1, num_epochs + 1):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
    model.printBitWidths()


Test set: Average loss: 0.0516, Accuracy: 9837/10000 (98.37%)

Layer 0 input e_bits (float) = 3.4134433269500732,  m_bits (float) = 18.93144416809082
Layer 1 input e_bits (float) = 5.547675132751465,  m_bits (float) = 4.0709357261657715
Layer 2 input e_bits (float) = 5.518855571746826,  m_bits (float) = 4.070084095001221
Layer 3 input e_bits (float) = 3.8000473976135254,  m_bits (float) = 5.490313529968262
Layer 0 weight e_bits (float) = 3.4714467525482178,  m_bits (float) = 7.524622440338135
Layer 1 weight e_bits (float) = 3.8684537410736084,  m_bits (float) = 17.12986946105957
Layer 2 weight e_bits (float) = 4.358256816864014,  m_bits (float) = 4.055737018585205
Layer 3 weight e_bits (float) = 3.1102652549743652,  m_bits (float) = 8.874690055847168
Layer 0 bias e_bits (float) = 3.2837061882019043,  m_bits (float) = 9.214698791503906
Layer 1 bias e_bits (float) = 3.5657315254211426,  m_bits (float) = 10.116086959838867
Layer 2 bias e_bits (float) = 3.109066963195801,  m_bits (float) 