In [1]:
import numpy as np
import torch
from torchvision import datasets, transforms

# 1. LOAD AND PREPROCESS MNIST
transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset  = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_images = train_dataset.data.numpy().astype(np.float32)
test_images  = test_dataset.data.numpy().astype(np.float32)
train_labels = np.array(train_dataset.targets)
test_labels  = np.array(test_dataset.targets)

# Normalize về [-1, 1]
train_images = (train_images / 255 - 0.5) / 0.5 # shape (60000,28,28)
test_images  = (test_images / 255 - 0.5) / 0.5 # shape (10000,28,28)

# Flatten
x_train = train_images.reshape(len(train_images), -1) # reshape(60000,784)
x_test  = test_images.reshape(len(test_images), -1)

# 2. DEFINE BNN
class SignActivation:
    def forward(self, x):
        self.x = x
        return np.where(self.x >= 0, 1, -1)
    def backward(self, grad_output, mode, layer):
        # Straight-Through Estimator
        grad_input = grad_output * (np.abs(self.x) <= 25)  # Ở ĐÂY, TA CHỌN NGƯỠNG (THRESHOLD LÀ 25, ĐƯỢC XEM NHƯ LÀ "VÙNG CHO PHÉP GRADIENT ĐI QUA")
        saturated = np.sum(np.abs(self.x) > 25)
        total = self.x.size
        if mode == 1:
          print(f"Saturated ratio in sign layer {layer}: {saturated/total:.2%}")
         # print(f"Input: {np.round(self.x.flatten())}")
          print("\n")
        return grad_input

class FcLayer:
    def __init__(self, in_features, out_features):
        self.weight = np.random.randn(out_features, in_features) * 0.1
        self.bias = np.zeros((1, out_features))
    def forward(self, x):
        self.x = x
        self.binary_weight = np.where(self.weight >= 0, 1, -1)
        return np.dot(x ,self.binary_weight.T) + self.bias
    def backward(self, grad_output, lr):
        grad_w = np.dot(grad_output.T, self.x)
        grad_b = np.sum(grad_output, axis=0, keepdims=True)
        grad_input = np.dot(grad_output, self.binary_weight)
        # Update real weights
        self.weight -= lr * grad_w
        self.bias -= lr * grad_b
        self.weight = np.clip(self.weight, -1, 1)
        return grad_input

class BNN:
    def __init__(self):
        self.sign0 = SignActivation()
        self.fc1 = FcLayer(784, 512)
        self.sign1 = SignActivation()
        self.fc2 = FcLayer(512, 10)
        self.sign2 = SignActivation()
    def forward(self, x):
        x = self.sign0.forward(x)
        x = self.fc1.forward(x)
        x = self.sign1.forward(x)
        x = self.fc2.forward(x)
        x = self.sign2.forward(x)
        return x
    def backward(self, grad_output, lr, epoch, mode):
        grad = self.sign2.backward(grad_output, mode, 2)
        grad = self.fc2.backward(grad, lr)
        grad = self.sign1.backward(grad, mode, 1)
        grad = self.fc1.backward(grad, lr)
        grad = self.sign0.backward(grad, 0, 0)
        return grad

# 3. LOSS FUNCTION
def mse_loss(pred, target):
    loss = np.mean((pred - target) ** 2)
    grad = 2 * (pred - target) / target.shape[0]
    return loss, grad

def label_to_binary(y, num_classes=10):
    batch_size = y.shape[0]
    binary = -np.ones((batch_size, num_classes))
    binary[np.arange(batch_size), y] = 1
    return binary

# ======================================================
# 4. TRAINING
# ======================================================
model = BNN()
lr = 0.001
epochs = 25
batch_size = 64
best_acc = 0

for epoch in range(epochs):
    total_loss = 0
    correct = 0

    # ĐẢO CÁC BỨC ẢNH
    idx = np.random.permutation(len(x_train))
    x_train = x_train[idx]
    train_labels = train_labels[idx]

    for i in range(0, len(x_train), batch_size):
        x_batch = x_train[i:i+batch_size]
        y_batch = train_labels[i:i+batch_size]
        y_bin = label_to_binary(y_batch)

        # Forward
        out = model.forward(x_batch)
        loss, grad = mse_loss(out, y_bin)
        total_loss += loss

        print_flag = (i == 0) or (i + batch_size >= len(x_train))  # đầu và cuối thôi

        # Backward
        model.backward(grad, lr, epoch, print_flag)

        # Accuracy
        preds = np.argmax(out, axis=1)
        correct += np.sum(preds == y_batch)

    acc = correct / len(x_train) * 100
    print(f"Epoch {epoch+1:02d}| Loss: {loss:.2f} Train Acc: {acc:.2f}%")

  # EVALUATE
    correct_test = 0
    for i in range(0, len(x_test), batch_size):
        x_batch = x_test[i:i+batch_size]
        y_batch = test_labels[i:i+batch_size]
        out = model.forward(x_batch)
        preds = np.argmax(out, axis=1)
        correct_test += np.sum(preds == y_batch)
    acc_test = correct_test / len(x_test) * 100
    if acc_test > best_acc:
        best_acc = acc_test

    print(f"Test Accuracy: {acc_test:.2f}%\n")

print(f"Best Test Accuracy: {best_acc:.2f}%")


100%|██████████| 9.91M/9.91M [00:00<00:00, 61.5MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.69MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.8MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 12.5MB/s]


Saturated ratio in sign layer 2: 21.25%


Saturated ratio in sign layer 1: 38.73%


Saturated ratio in sign layer 2: 76.88%


Saturated ratio in sign layer 1: 48.78%


Epoch 01| Loss: 0.09 Train Acc: 66.45%
Test Accuracy: 76.73%

Saturated ratio in sign layer 2: 79.69%


Saturated ratio in sign layer 1: 48.15%


Saturated ratio in sign layer 2: 79.38%


Saturated ratio in sign layer 1: 48.48%


Epoch 02| Loss: 0.10 Train Acc: 80.49%
Test Accuracy: 79.10%

Saturated ratio in sign layer 2: 80.78%


Saturated ratio in sign layer 1: 48.22%


Saturated ratio in sign layer 2: 83.75%


Saturated ratio in sign layer 1: 47.43%


Epoch 03| Loss: 0.05 Train Acc: 83.41%
Test Accuracy: 83.74%

Saturated ratio in sign layer 2: 79.38%


Saturated ratio in sign layer 1: 47.51%


Saturated ratio in sign layer 2: 79.69%


Saturated ratio in sign layer 1: 46.34%


Epoch 04| Loss: 0.11 Train Acc: 84.74%
Test Accuracy: 81.84%

Saturated ratio in sign layer 2: 79.38%


Saturated ratio in sign layer 1: 48.33

In [2]:
output = model.forward(x_test[5])
label = test_labels[5]
print(f"Predicted label: {output}")
print(f"True label: {label}")

Predicted label: [[-1  1 -1 -1 -1 -1 -1 -1 -1 -1]]
True label: 1


In [3]:
output = model.forward(x_test[9])
label = test_labels[9]
print(f"Predicted label: {output}")
print(f"True label: {label}")

Predicted label: [[-1 -1 -1 -1 -1 -1 -1 -1 -1  1]]
True label: 9


In [4]:
output = model.forward(x_test[12])
label = test_labels[12]
print(f"Predicted label: {output}")
print(f"True label: {label}")

Predicted label: [[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1]]
True label: 9


In [5]:
output = model.forward(x_test[100])
label = test_labels[100]
print(f"Predicted label: {output}")
print(f"True label: {label}")

Predicted label: [[-1 -1 -1 -1 -1 -1  1 -1 -1 -1]]
True label: 6


In [6]:
!pip install numpy torch torchvision



In [7]:
# Run python in the next file
!python BNN_code.py

python3: can't open file '/content/BNN_code.py': [Errno 2] No such file or directory


In [8]:
print(f"Best Test Accuracy: {best_acc:.2f}%")

# ======================================================
# EXPORT 1-BIT
# ======================================================
# Only save the model when it achieves the best accuracy
print("\nPreparing to export weights...")

# 1. PERFORM 1-BIT CONVERSION
print("Binarizing weights (1-bit conversion)...")
fc1_w_bin = np.where(model.fc1.weight >= 0, 1, -1)
fc2_w_bin = np.where(model.fc2.weight >= 0, 1, -1)

# Print out for verification
print(f"Shape of binarized fc1_w: {fc1_w_bin.shape}")
print(f"Some sample values: {fc1_w_bin[0, :5]}")

# 2. SAVE BINARIZED WEIGHTS
np.savez(
    "bnn_weights_1bit.npz",  # Output file name
    fc1_w=fc1_w_bin,
    fc1_b=model.fc1.bias,
    fc2_w=fc2_w_bin,
    fc2_b=model.fc2.bias
)

print("\nSuccessfully exported 1-bit weights to file 'bnn_weights_1bit.npz'!")


Best Test Accuracy: 91.37%

Preparing to export weights...
Binarizing weights (1-bit conversion)...
Shape of binarized fc1_w: (512, 784)
Some sample values: [-1 -1 -1  1 -1]

Successfully exported 1-bit weights to file 'bnn_weights_1bit.npz'!
