In [14]:
import time
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cpu


In [15]:
batch_size_train = 128
batch_size_test = 1000

transform = transforms.ToTensor()

train_loader = DataLoader(
    datasets.MNIST('.', train=True, download=True, transform=transform),
    batch_size=batch_size_train,
    shuffle=True,
)

test_loader = DataLoader(
    datasets.MNIST('.', train=False, transform=transform),
    batch_size=batch_size_test,
    shuffle=False,
)

In [None]:
def lmul(a, b, M=7):
    """
    L-MUL approximation based on mantissa/exponent decomposition.
    a, b: tensors
    M: controls mantissa precision
    """
    a, b = a.float(), b.float()


    s1 = torch.sign(a)
    s2 = torch.sign(b)
    s = s1 * s2


    m1, e1 = torch.frexp(torch.abs(a))
    m2, e2 = torch.frexp(torch.abs(b))

    m1 = 2 * m1 - 1
    m2 = 2 * m2 - 1

 
    if M <= 3:
        L = M
    elif M == 4:
        L = 3
    else:
        L = 4

    bias = 1
    exponent = e1 + e2 - bias
    mantissa = 1 + m1 + m2 + 2**(-L)


    carry = (mantissa >= 2).float()
    mantissa = torch.where(carry == 1, mantissa / 2, mantissa)
    exponent = exponent + carry.long()

    out = s * torch.ldexp(mantissa, exponent)
    return out


def lmul_linear(x, W, b=None, M=7):
    """
    x: [B, D_in]
    W: [D_out, D_in]
    """
    B = x.shape[0]
    D_out, D_in = W.shape

    x_expanded = x.unsqueeze(1).expand(B, D_out, D_in)
    W_expanded = W.unsqueeze(0).expand(B, D_out, D_in)

    prod = lmul(x_expanded, W_expanded, M=M)
    out = prod.sum(dim=2)

    if b is not None:
        out = out + b

    return out



def lmul_conv2d(x, W, b=None, stride=1, padding=0, M=7):
    """
    x: [B, C_in, H, W]
    W: [C_out, C_in, kH, kW]
    """
    B, C_in, H, W_in = x.shape
    C_out, C_in_w, kH, kW = W.shape
    assert C_in == C_in_w, "Input channels mismatch"

    patches = F.unfold(x, kernel_size=(kH, kW),
                       padding=padding, stride=stride)
    K = patches.size(1)
    L = patches.size(2)

    patches = patches.transpose(1, 2)
    W_flat = W.view(C_out, -1)


    a = patches.unsqueeze(1) 
    bW = W_flat.unsqueeze(0).unsqueeze(2) 
    prod = lmul(a, bW, M=M)  

    out = prod.sum(dim=3) 

    H_out = int((H + 2*padding - kH) / stride + 1)
    W_out = int((W_in + 2*padding - kW) / stride + 1)
    out = out.view(B, C_out, H_out, W_out)

    if b is not None:
        out = out + b.view(1, -1, 1, 1)

    return out



In [None]:
class CNN(nn.Module):
    def __init__(self, use_lmul=False, M=7):
        super().__init__()
        self.use_lmul = use_lmul
        self.M = M


        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        if not self.use_lmul:

            x = F.relu(self.conv1(x))
            x = F.max_pool2d(x, 2)

            x = F.relu(self.conv2(x))
            x = F.max_pool2d(x, 2)

            x = x.view(x.size(0), -1)

            x = F.relu(self.fc1(x))
            x = self.fc2(x)

        else:

            x = lmul_conv2d(x, self.conv1.weight, self.conv1.bias,
                            stride=1, padding=1, M=self.M)
            x = F.relu(x)
            x = F.max_pool2d(x, 2)

            x = lmul_conv2d(x, self.conv2.weight, self.conv2.bias,
                            stride=1, padding=1, M=self.M)
            x = F.relu(x)
            x = F.max_pool2d(x, 2)

            x = x.view(x.size(0), -1)

            x = F.relu(lmul_linear(x, self.fc1.weight,
                                   self.fc1.bias, M=self.M))
            x = lmul_linear(x, self.fc2.weight, self.fc2.bias, M=self.M)

        return F.log_softmax(x, dim=1)


In [18]:
def train_model(model, optimizer, loader, epochs=1):
    model.train()
    for ep in range(epochs):
        total = 0
        correct = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad()
            out = model(x)
            loss = F.nll_loss(out, y)
            loss.backward()
            optimizer.step()

            pred = out.argmax(dim=1)
            total += y.size(0)
            correct += (pred == y).sum().item()

        print(f"Epoch {ep+1}/{epochs} – loss: {loss:.4f}, acc: {correct/total:.4f}")


def test_acc(model, loader):
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            pred = out.argmax(dim=1)
            total += y.size(0)
            correct += (pred == y).sum().item()
    return 100.0 * correct / total




In [None]:
baseline_model = CNN(use_lmul=False, M=7).to(device)
optimizer = torch.optim.Adam(baseline_model.parameters(), lr=1e-3)

start = time.perf_counter()
train_model(baseline_model, optimizer, train_loader, epochs=5)
baseline_time = time.perf_counter() - start

baseline_acc = test_acc(baseline_model, test_loader)

print("\nBaseline CNN accuracy (MNIST): {:.2f}%".format(baseline_acc))
print("Baseline total time (train+test): {:.2f} s".format(baseline_time))


lmul_model = CNN(use_lmul=True, M=7).to(device)
lmul_model.load_state_dict(baseline_model.state_dict())

start = time.perf_counter()
lmul_acc = test_acc(lmul_model, test_loader)
lmul_time = time.perf_counter() - start

print("\nLMUL CNN accuracy (MNIST): {:.2f}%".format(lmul_acc))
print("LMUL evaluation time: {:.4f} s".format(lmul_time))


Epoch 1/5 – loss: 0.0910, acc: 0.9283
Epoch 2/5 – loss: 0.0621, acc: 0.9811
Epoch 3/5 – loss: 0.0205, acc: 0.9868
Epoch 4/5 – loss: 0.0271, acc: 0.9893
Epoch 5/5 – loss: 0.0493, acc: 0.9913

Baseline CNN accuracy (MNIST): 98.94%
Baseline total time (train+test): 78.43 s
