In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F


class CustomNet(nn.Module):
    def __init__(self, in_features, num_units, out_features, drop_prob=0.5, eps=1e-5, momentum=0.1, batch_affine=True):
        super(CustomNet, self).__init__()

        self.in_features = in_features

        self.dropout1 = nn.Dropout(p=drop_prob)
        self.dropout2 = nn.Dropout(p=drop_prob)
        self.dropout3 = nn.Dropout(p=drop_prob)
        self.dropout4 = nn.Dropout(p=drop_prob)

        # Real-valued weights for update
        self.fc1 = nn.Linear(in_features, num_units, bias=False)
        self.bn1 = nn.BatchNorm1d(num_units, eps=eps, momentum=momentum, affine=batch_affine)

        self.fc2 = nn.Linear(num_units, num_units, bias=False)
        self.bn2 = nn.BatchNorm1d(num_units, eps=eps, momentum=momentum, affine=batch_affine)

        self.fc3 = nn.Linear(num_units, num_units, bias=False)
        self.bn3 = nn.BatchNorm1d(num_units, eps=eps, momentum=momentum, affine=batch_affine)

        self.fc4 = nn.Linear(num_units, out_features, bias=False)
        self.bn4 = nn.BatchNorm1d(out_features, eps=eps, momentum=momentum, affine=batch_affine)

        # Binary weights for forward pass
        self.binary_fc1 = nn.Linear(in_features, num_units, bias=False)
        self.binary_fc2 = nn.Linear(num_units, num_units, bias=False)
        self.binary_fc3 = nn.Linear(num_units, num_units, bias=False)
        self.binary_fc4 = nn.Linear(num_units, out_features, bias=False)

    def binarize_weights(self):
        # Binarize the weights after each update
        for layer in [self.binary_fc1, self.binary_fc2, self.binary_fc3, self.binary_fc4]:
            layer.weight.data = torch.sign(layer.weight.data)

    def forward(self, x, use_binary_weights=False):
        x = self.dropout1(x)

        if use_binary_weights:
            # Binarize weights for the forward pass
            binary_weights1 = torch.sign(self.binary_fc1.weight)
            binary_weights2 = torch.sign(self.binary_fc2.weight)
            binary_weights3 = torch.sign(self.binary_fc3.weight)
            binary_weights4 = torch.sign(self.binary_fc4.weight)

            x = torch.relu(F.linear(x, binary_weights1, None))  # Apply binary weights
            x = self.bn1(x)

            x = self.dropout2(x)
            x = torch.relu(F.linear(x, binary_weights2, None))
            x = self.bn2(x)

            x = self.dropout3(x)
            x = torch.relu(F.linear(x, binary_weights3, None))
            x = self.bn3(x)

            x = self.dropout4(x)
            x = F.linear(x, binary_weights4, None)
            x = self.bn4(x)
        else:
            # Real-valued weights for backward pass and update
            x = torch.relu(self.fc1(x))
            x = self.bn1(x)

            x = self.dropout2(x)
            x = torch.relu(self.fc2(x))
            x = self.bn2(x)

            x = self.dropout3(x)
            x = torch.relu(self.fc3(x))
            x = self.bn3(x)

            x = self.dropout4(x)
            x = self.fc4(x)
            x = self.bn4(x)

        return x


transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# Initialize the neural network
in_features = 28 * 28  # MNIST image size
num_units = 512  # Number of units in hidden layers
out_features = 10  # Number of classes in MNIST
drop_prob = 0.5  # Dropout probability

custom_net = CustomNet(in_features, num_units, out_features, drop_prob=drop_prob)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(custom_net.parameters(), lr=0.001)

# Train loop
num_epoch=50
for epoch in range(num_epoch):
    for inputs, labels in train_loader:
        inputs = inputs.view(inputs.size(0), -1)
        optimizer.zero_grad()

        # Forward pass with binary weights during training
        outputs_binary = custom_net(inputs, use_binary_weights=True)
        loss_binary = criterion(outputs_binary, labels)

        # Only use the real-valued weights for backpropagation and update
        loss_binary.backward()
        optimizer.step()

        # Binarize the weights after each update
        custom_net.binarize_weights()

    print(f'Epoch {epoch + 1}/{num_epoch}, Loss Binary: {loss_binary.item()}')

print('Training finished.')


Epoch 1/50, Loss Binary: 2.2529666423797607
Epoch 2/50, Loss Binary: 2.307882070541382
Epoch 3/50, Loss Binary: 2.246013641357422
Epoch 4/50, Loss Binary: 2.2048065662384033
Epoch 5/50, Loss Binary: 2.2872867584228516
Epoch 6/50, Loss Binary: 2.182563543319702
Epoch 7/50, Loss Binary: 1.9375221729278564
Epoch 8/50, Loss Binary: 1.9222781658172607
Epoch 9/50, Loss Binary: 2.002084493637085
Epoch 10/50, Loss Binary: 2.001797914505005
Epoch 11/50, Loss Binary: 2.078160524368286
Epoch 12/50, Loss Binary: 2.222705841064453
Epoch 13/50, Loss Binary: 1.9903569221496582
Epoch 14/50, Loss Binary: 2.1237151622772217
Epoch 15/50, Loss Binary: 1.8991446495056152
Epoch 16/50, Loss Binary: 2.1372315883636475
Epoch 17/50, Loss Binary: 1.8820286989212036
Epoch 18/50, Loss Binary: 2.05547833442688
Epoch 19/50, Loss Binary: 1.8926335573196411
Epoch 20/50, Loss Binary: 1.9684401750564575
Epoch 21/50, Loss Binary: 1.9883837699890137
Epoch 22/50, Loss Binary: 1.7623765468597412
Epoch 23/50, Loss Binary: 1.

In [None]:
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

# Load MNIST test dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Evaluate the trained model using binarized weights during the forward pass
def test_model(model, test_loader, use_binary_weights=True):
    model.eval()
    correct_top1 = 0
    correct_top5 = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.view(inputs.size(0), -1)

            # Forward pass using binarized weights during testing
            outputs = model(inputs, use_binary_weights=use_binary_weights)

            # Calculate top-1 accuracy
            _, predicted_top1 = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct_top1 += (predicted_top1 == labels).sum().item()

            # Calculate top-5 accuracy
            _, predicted_top5 = torch.topk(outputs.data, 5, dim=1)
            correct_top5 += (predicted_top5 == labels.view(-1, 1)).sum().item()

    accuracy_top1 = correct_top1 / total
    accuracy_top5 = correct_top5 / total
    print(f'Test Accuracy (Top-1) with Binarized Weights: {accuracy_top1 * 100:.2f}%')
    print(f'Test Accuracy (Top-5) with Binarized Weights: {accuracy_top5 * 100:.2f}%')

# Test the model with binarized weights during the forward pass
test_model(custom_net, test_loader, use_binary_weights=True)


Test Accuracy (Top-1) with Binarized Weights: 65.28%
Test Accuracy (Top-5) with Binarized Weights: 96.44%
