In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim

from torch.autograd import Function

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import time
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt

# Binarization

In [None]:
class BinaryQuantize(Function):
    '''
        binary quantize function, from IR-Net
        (https://github.com/htqin/IR-Net/blob/master/CIFAR-10/ResNet20/1w1a/modules/binaryfunction.py)
    '''
    @staticmethod
    def forward(ctx, input, k, t):
        ctx.save_for_backward(input, k, t)
        out = torch.sign(input)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        input, k, t = ctx.saved_tensors
        device = input.device
        k, t = k.to(device), t.to(device)
        grad_input = k * t * (1-torch.pow(torch.tanh(input * t), 2)) * grad_output
        return grad_input, None, None

class Maxout(nn.Module):
    '''
        Nonlinear function
    '''
    def __init__(self, channel, neg_init=0.25, pos_init=1.0):
        super(Maxout, self).__init__()
        self.neg_scale = nn.Parameter(neg_init*torch.ones(channel))
        self.pos_scale = nn.Parameter(pos_init*torch.ones(channel))
        self.relu = nn.ReLU()

    def forward(self, x):
        # Maxout
        x = self.pos_scale.view(1,-1,1,1)*self.relu(x) - self.neg_scale.view(1,-1,1,1)*self.relu(-x)
        return x

class BinaryActivation(nn.Module):
    '''
        learnable distance and center for activation
    '''
    def __init__(self):
        super(BinaryActivation, self).__init__()
        self.alpha_a = nn.Parameter(torch.tensor(1.0))
        self.beta_a = nn.Parameter(torch.tensor(0.0))

    def gradient_approx(self, x):
        '''
            gradient approximation
            (https://github.com/liuzechun/Bi-Real-net/blob/master/pytorch_implementation/BiReal18_34/birealnet.py)
        '''
        out_forward = torch.sign(x)
        mask1 = x < -1
        mask2 = x < 0
        mask3 = x < 1
        out1 = (-1) * mask1.type(torch.float32) + (x*x + 2*x) * (1-mask1.type(torch.float32))
        out2 = out1 * mask2.type(torch.float32) + (-x*x + 2*x) * (1-mask2.type(torch.float32))
        out3 = out2 * mask3.type(torch.float32) + 1 * (1- mask3.type(torch.float32))
        out = out_forward.detach() - out3.detach() + out3

        return out

    def forward(self, x):
        x = (x-self.beta_a)/self.alpha_a
        x = self.gradient_approx(x)
        return self.alpha_a*(x + self.beta_a)

class LambdaLayer(nn.Module):
    '''
        for DownSample
    '''
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)

class AdaBin_Conv2d(nn.Conv2d):
    '''
        AdaBin Convolution
    '''
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False, a_bit=1, w_bit=1):
        super(AdaBin_Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        self.a_bit = a_bit
        self.w_bit = w_bit
        self.k = torch.tensor([10]).float().cpu()
        self.t = torch.tensor([0.1]).float().cpu()
        self.binary_a = BinaryActivation()

        self.filter_size = self.kernel_size[0]*self.kernel_size[1]*self.in_channels

    def forward(self, inputs):
        if self.a_bit==1:
            inputs = self.binary_a(inputs)

        if self.w_bit==1:
            w = self.weight
            beta_w = w.mean((1,2,3)).view(-1,1,1,1)
            alpha_w = torch.sqrt(((w-beta_w)**2).sum((1,2,3))/self.filter_size).view(-1,1,1,1)

            w = (w - beta_w)/alpha_w
            wb = BinaryQuantize().apply(w, self.k, self.t)
            weight = wb * alpha_w + beta_w
        else:
            weight = self.weight

        output = F.conv2d(inputs, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

        return output

# Network Architecture

In [None]:
class BasicBlock_1w1a(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock_1w1a, self).__init__()
        self.conv1 = AdaBin_Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.nonlinear1 = Maxout(planes)

        self.conv2 = AdaBin_Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.nonlinear2 = Maxout(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                     nn.BatchNorm2d(self.expansion * planes)
                     )

    def forward(self, x):
        out = self.bn1(self.conv1(x))
        out += self.shortcut(x)
        out = self.nonlinear1(out)
        x1 = out
        out = self.bn2(self.conv2(out))
        out += x1
        out = self.nonlinear2(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.nonlinear1 = Maxout(64)

        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)
        self.bn2 = nn.BatchNorm1d(512*block.expansion)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.nonlinear1(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.bn2(out)
        out = self.linear(out)
        # out = F.softmax(out, dim=1)
        return out

def resnet18_1w1a():
    return ResNet(BasicBlock_1w1a, [2,2,2,2])

## Helper functions

In [None]:
def plot_loss_curves(train_losses):
    plt.plot(train_losses, label='Training Loss')
    plt.xlabel('Batch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

def display_images(model, test_loader, use_cuda=True):
    model.eval()
    device = torch.device("cuda" if use_cuda else "cpu")
    with torch.no_grad():
        data, target = next(iter(test_loader))  # Get a batch of data from the test_loader
        data, target = data.to(device), target.to(device)
        output = model(data)  # Get the model's predictions
        pred = output.argmax(dim=1, keepdim=True)  # Get the predicted labels

        plt.figure(figsize=(10, 10))
        for i in range(25):
            plt.subplot(5, 5, i+1)
            plt.imshow(data[i].cpu().squeeze().numpy(), cmap='gray')
            plt.title(f'True: {target[i].item()}, Pred: {pred[i].item()}')
            plt.axis('off')
        plt.tight_layout()
        plt.show()

# MNIST

In [None]:
def main(use_cuda=True):
    # Check for CUDA availability and set the device accordingly
    device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
    print(f'Using device: {device}')

    # Data preparation
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # Normalize to range [-1, 1]
    ])
    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

    # Model initialization
    model = resnet18_1w1a().to(device)  # Move model to the correct device
    model.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)  # Adjust for 1-channel input
    model = model.to(device)

    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Training
    train_losses = []
    for epoch in range(10):  # Number of epochs
        print(f'Starting epoch {epoch + 1}')  # Debugging: print a message at the start of each epoch
        start_time = time.time()  # Debugging: record the start time of the epoch
        model.train()
        pbar = tqdm(total=len(train_loader) // 10, desc=f'Epoch {epoch + 1} Training', position=0, leave=True)
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)  # Move data and target to the correct device
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
            if batch_idx % 10 == 0:  # Debugging: print a message every 10 batches
#                 print(f'Processed {batch_idx * len(data)} training examples')
                pbar.update(1)
        pbar.close()
        print(f'Finished epoch {epoch + 1} in {time.time() - start_time:.2f} seconds')  # Debugging: print the epoch duration
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')

    # Testing
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move data and target to the correct device
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    accuracy = correct / len(test_loader.dataset)
    print(f'Test accuracy: {accuracy * 100:.2f}%')

    plot_loss_curves(train_losses)
    display_images(model, test_loader, use_cuda)

if __name__ == "__main__":
    main(use_cuda=True)  # Set to False to use CPU


# MNIST on CPU

In [None]:
def main(use_cuda=True):
    # Check for CUDA availability and set the device accordingly
    device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
    print(f'Using device: {device}')

    # Data preparation
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # Normalize to range [-1, 1]
    ])
    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

    # Model initialization
    model = resnet18_1w1a().to(device)  # Move model to the correct device
    model.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)  # Adjust for 1-channel input
    model = model.to(device)

    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Training
    train_losses = []
    total_train_time = time.time()
    for epoch in range(10):  # Number of epochs
        print(f'Starting epoch {epoch + 1}')  # Debugging: print a message at the start of each epoch
        start_time = time.time()  # Debugging: record the start time of the epoch
        model.train()
        pbar = tqdm(total=len(train_loader) // 10, desc=f'Epoch {epoch + 1} Training', position=0, leave=True)
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)  # Move data and target to the correct device
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
            if batch_idx % 10 == 0:  # Debugging: print a message every 10 batches
#                 print(f'Processed {batch_idx * len(data)} training examples')
                pbar.update(1)
        pbar.close()
        # print(f'Finished epoch {epoch + 1} in {time.time() - start_time:.2f} seconds')  # Debugging: print the epoch duration
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')
    print(f'Finished training in {time.time() - total_train_time:.2f/60} minutes')

    # Testing
    total_test_time = time.time()
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move data and target to the correct device
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    accuracy = correct / len(test_loader.dataset)
    print(f'Test accuracy: {accuracy * 100:.2f}%')
    print(f'Finished testing in {time.time() - total_test_time:2f/60} minutes')

    plot_loss_curves(train_losses)
    display_images(model, test_loader, use_cuda)

if __name__ == "__main__":
    main(use_cuda=False)  # Set to False to use CPU


# CIFAR-10

In [None]:
def display_images(model, dataloader, use_cuda):
    model.eval()
    with torch.no_grad():
        for data, target in dataloader:
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            break  # We only need one batch of images

    # Convert the images and labels to numpy arrays
    images = data.cpu().numpy()
    labels = pred.cpu().numpy()

    # Plot the first 25 images from the batch
    plt.figure(figsize=(10,10))
    for i in range(25):
        plt.subplot(5,5,i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        # Transpose the image dimensions from CxHxW to HxWxC
        plt.imshow(np.transpose(images[i], (1, 2, 0)))
        plt.xlabel(labels[i])
    plt.show()

In [None]:
def main(use_cuda=True):
    # Check for CUDA availability and set the device accordingly
    device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
    print(f'Using device: {device}')

    # Data preparation
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to range [-1, 1]
    ])
    train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

    # Model initialization
    model = resnet18_1w1a().to(device)  # Move model to the correct device
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)  # Adjust for 3-channel input
    model = model.to(device)

    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Training
    train_losses = []
    total_train_time = time.time()
    for epoch in range(10):  # Number of epochs
        print(f'Starting epoch {epoch + 1}')  # Debugging: print a message at the start of each epoch
        # start_time = time.time()  # Debugging: record the start time of the epoch
        model.train()
        pbar = tqdm(total=len(train_loader) // 10, desc=f'Epoch {epoch + 1} Training', position=0, leave=True)
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)  # Move data and target to the correct device
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
            if batch_idx % 10 == 0:  # Debugging: print a message every 10 batches
#                 print(f'Processed {batch_idx * len(data)} training examples')
                pbar.update(1)
        pbar.close()
        # print(f'Finished epoch {epoch + 1} in {time.time() - start_time:.2f} seconds')  # Debugging: print the epoch duration
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')
    print(f'Finished training in {time.time() - total_train_time:2f} seconds')

    # Testing
    total_test_time = time.time()
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move data and target to the correct device
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    accuracy = correct / len(test_loader.dataset)
    print(f'Test accuracy: {accuracy * 100:.2f}%')
    print(f'Finished testing in {time.time() - total_test_time:2f} seconds')

    plot_loss_curves(train_losses)
    display_images(model, test_loader, use_cuda)

if __name__ == "__main__":
    main(use_cuda=True)  # Set to False to use CPU


# CIFAR-10 on CPU

In [None]:
def main(use_cuda=True):
    # Check for CUDA availability and set the device accordingly
    device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
    print(f'Using device: {device}')

    # Data preparation
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to range [-1, 1]
    ])
    train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

    # Model initialization
    model = resnet18_1w1a().to(device)  # Move model to the correct device
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)  # Adjust for 3-channel input
    model = model.to(device)

    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Training
    train_losses = []
    total_train_time = time.time()
    for epoch in range(10):  # Number of epochs
        print(f'Starting epoch {epoch + 1}')  # Debugging: print a message at the start of each epoch
        # start_time = time.time()  # Debugging: record the start time of the epoch
        model.train()
        pbar = tqdm(total=len(train_loader) // 10, desc=f'Epoch {epoch + 1} Training', position=0, leave=True)
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)  # Move data and target to the correct device
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
            if batch_idx % 10 == 0:  # Debugging: print a message every 10 batches
#                 print(f'Processed {batch_idx * len(data)} training examples')
                pbar.update(1)
        pbar.close()
        # print(f'Finished epoch {epoch + 1} in {time.time() - start_time:.2f} seconds')  # Debugging: print the epoch duration
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')
    print(f'Finished training in {time.time() - total_train_time:2f} seconds')

    # Testing
    total_test_time = time.time()
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move data and target to the correct device
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    accuracy = correct / len(test_loader.dataset)
    print(f'Test accuracy: {accuracy * 100:.2f}%')
    print(f'Finished testing in {time.time() - total_test_time:2f} seconds')

    plot_loss_curves(train_losses)
    display_images(model, test_loader, use_cuda)

if __name__ == "__main__":
    main(use_cuda=False)  # Set to False to use CPU