In [1]:
import os
from pathlib import Path
import sys

# Set the project root relative to this notebook
project_root = Path.cwd().parent  # assumes you're inside 'notebooks/'
os.chdir(project_root)            # ensures all relative paths work from root
sys.path.append(str(project_root / "src"))  # allows `from mils_pruning import ...`


In [2]:
import torch
import torch.nn as nn


class Binarize01(torch.autograd.Function):
    """
    Custom binarization function: 0 for x <= 0, 1 for x > 0.
    Uses Straight-Through Estimator (STE) for gradients.
    """
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output


class BinaryLinear01(nn.Module):
    """
    Linear layer with 0/1 binarized weights and BatchNorm.
    """
    def __init__(self, in_features, out_features):
        super().__init__()
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        self.bias = nn.Parameter(torch.zeros(out_features))
        self.bn = nn.BatchNorm1d(out_features)
        nn.init.normal_(self.weight, mean=0.0, std=0.1)

    def forward(self, x):
        W_b = Binarize01.apply(self.weight)
        out = nn.functional.linear(x, W_b) + self.bias
        return self.bn(out)


class BinarizedMLP01(nn.Module):
    """
    MLP with 0/1 binarized weights and activations.
    """
    def __init__(self, input_shape, nodes_h1, nodes_h2):
        super().__init__()
        in_dim = input_shape[0] * input_shape[1]
        self.fc1 = BinaryLinear01(in_dim, nodes_h1)
        self.fc2 = BinaryLinear01(nodes_h1, nodes_h2)
        self.fc3 = BinaryLinear01(nodes_h2, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)

        x = self.fc1(x)
        x = Binarize01.apply(x)

        x = self.fc2(x)
        x = Binarize01.apply(x)

        x = self.fc3(x)
        return x


In [5]:
from mils_pruning import get_mnist_data_loaders, train, EarlyStopping
import torch
import torch.nn as nn
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_loaders, val_loader, test_loader = get_mnist_data_loaders(num_runs=1)
train_loader = train_loaders[0]

model = BinarizedMLP01(input_shape=(10, 10), nodes_h1=64, nodes_h2=32).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
early_stopping = EarlyStopping(patience=5, min_delta=0)

train(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    criterion=criterion,
    early_stopping=early_stopping,
    epochs=50,
    device=device,
    experiment_id="bin01_test"
)

Epoch [1/50], Train Loss: 2.1256, Train Acc: 32.78%, Val Loss: 1.8353, Val Acc: 52.24%
Epoch [2/50], Train Loss: 1.6034, Train Acc: 57.46%, Val Loss: 1.4503, Val Acc: 59.62%
Epoch [3/50], Train Loss: 1.3581, Train Acc: 65.19%, Val Loss: 1.2558, Val Acc: 64.85%
Epoch [4/50], Train Loss: 1.2108, Train Acc: 68.37%, Val Loss: 1.1684, Val Acc: 68.22%
Epoch [5/50], Train Loss: 1.1188, Train Acc: 70.17%, Val Loss: 1.1109, Val Acc: 69.00%
Epoch [6/50], Train Loss: 1.0679, Train Acc: 70.95%, Val Loss: 1.0634, Val Acc: 70.30%
Epoch [7/50], Train Loss: 1.0247, Train Acc: 71.63%, Val Loss: 1.0018, Val Acc: 71.41%
Epoch [8/50], Train Loss: 0.9874, Train Acc: 72.26%, Val Loss: 1.0277, Val Acc: 68.82%
Epoch [9/50], Train Loss: 0.9609, Train Acc: 72.61%, Val Loss: 1.0139, Val Acc: 69.98%
Epoch [10/50], Train Loss: 0.9440, Train Acc: 72.89%, Val Loss: 0.9436, Val Acc: 72.44%
Epoch [11/50], Train Loss: 0.9304, Train Acc: 73.02%, Val Loss: 0.9588, Val Acc: 71.11%
Epoch [12/50], Train Loss: 0.9133, Train 