In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import torchvision.utils
import matplotlib.pyplot as plt
import time
import os
import copy
from PIL import Image
from sklearn.decomposition import PCA

In [2]:
# Define a simple CNN model for full-precision and binary networks, sign-based binarization approach
class SimpleCNN(nn.Module):
    def __init__(self, binary=False):
        super(SimpleCNN, self).__init__()
        self.binary = binary
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.pool = nn.MaxPool2d(2, 2)
        
    def binarize(self, x):
        return torch.sign(x) * torch.mean(torch.abs(x))

    def forward(self, x):
        if self.binary:
            x = self.binarize(self.conv1(x))
        else:
            x = self.conv1(x)
        x = self.pool(F.relu(x))
        
        if self.binary:
            x = self.binarize(self.conv2(x))
        else:
            x = self.conv2(x)
        x = self.pool(F.relu(x))
        
        x = x.view(-1, 64 * 7 * 7)
        
        if self.binary:
            x = self.binarize(self.fc1(x))
        else:
            x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

In [3]:
def get_data_loaders(batch_size=64):
    # Transformations for CIFAR-10 dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalization for CIFAR-10
    ])
    
    # Load CIFAR-10 dataset using torchvision's built-in loader
    train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
    
    # DataLoaders for training and testing
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    
    return train_loader, test_loader

In [4]:
# Training Function
def train_model(model, train_loader, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for inputs, targets in train_loader:
            inputs, targets = inputs.cuda(), targets.cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")


In [5]:
# Evaluation Function
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    accuracy = 100. * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [6]:
# Main Execution
if __name__ == "__main__":
    train_loader, test_loader = get_data_loaders()

    # Full-Precision Model
    model_fp = SimpleCNN(binary=False).cuda()
    optimizer_fp = optim.Adam(model_fp.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    print("Training Full-Precision Model")
    train_model(model_fp, train_loader, criterion, optimizer_fp)
    acc_fp = evaluate_model(model_fp, test_loader)

    # Binary Model
    model_bnn = SimpleCNN(binary=True).cuda()
    optimizer_bnn = optim.Adam(model_bnn.parameters(), lr=0.001)

    print("Training Binary Neural Network")
    train_model(model_bnn, train_loader, criterion, optimizer_bnn)
    acc_bnn = evaluate_model(model_bnn, test_loader)

    # Compare Results
    print(f"Full-Precision Accuracy: {acc_fp:.2f}%")
    print(f"Binary Neural Network Accuracy: {acc_bnn:.2f}%")


URLError: <urlopen error [Errno -3] Temporary failure in name resolution>