# Binary Connect Implementation

In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable, Function
import torch.nn.functional as F

from torchvision import datasets, transforms
import numpy as np

batch_size = 128
n_epochs = 25
validation_steps = 10
learning_rate = 1e-3


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [2]:
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=True, download=True,
                        transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,), (0.5,))
                       ])), batch_size=batch_size, shuffle=True)

valid_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, 
                        transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,), (0.5,))
                       ])), batch_size=batch_size, shuffle=True)

In [3]:
class Binarize(Function):
    def __init__(self, fc):
        super(Binarize, self).__init__()
        self.fc = fc

    @staticmethod        
    def forward(ctx, input, weight, bias):
        p_weight = torch.max(torch.zeros_like(weight), torch.min(torch.ones_like(weight), (weight + 1) / 2))
        binarized_weights = torch.bernoulli(p_weight) * 2 - 1

        output = F.linear(input, weight, bias)
        ctx.save_for_backward(input, weight, bias)
    
        return output

    @staticmethod    
    def backward(ctx, gradients):
        input, weight, bias = ctx.saved_tensors
        grad_input = None
        grad_weight = None
        grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = gradients.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = gradients.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = gradients.sum(0)

        return grad_input, grad_weight, grad_bias
    
class BinarizedLinear(nn.Module):
    def __init__(self, input_size, output_size):
        super(BinarizedLinear, self).__init__()
        self.fc = nn.Linear(input_size, output_size)
        
    def forward(self, input):
        output = Binarize.apply(input, self.fc.weight, self.fc.bias)
        return output

In [4]:
class BinarizedDNNModel(nn.Module):
    def __init__(self, image_size, output_size=10, hidden_size=1024):
        super(BinarizedDNNModel, self).__init__()
        self.image_size = image_size
        
        self.fc1 = nn.Sequential(
                   BinarizedLinear(image_size * image_size, hidden_size),
                   nn.ReLU(),
                   nn.BatchNorm1d(hidden_size))
        self.fc2 = nn.Sequential(
                   BinarizedLinear(hidden_size, hidden_size),
                   nn.ReLU(),
                   nn.BatchNorm1d(hidden_size))
        self.fc3 = nn.Sequential(
                   BinarizedLinear(hidden_size, hidden_size),
                   nn.ReLU(),
                   nn.BatchNorm1d(hidden_size))
        self.output_layer = nn.Sequential(
                    BinarizedLinear(hidden_size, output_size),
                    nn.ReLU(),
                    nn.BatchNorm1d(output_size))
        
    def forward(self, x):
        x = x.view(-1, self.image_size * self.image_size)
        
        for layer in [self.fc1, self.fc2, self.fc3, self.output_layer]:
            x = layer(x)
        return x
    
class L2SVMLoss(nn.Module):
    def __init__(self):
        super(L2SVMLoss, self).__init__()
    
    def forward(self, output, target):
        y = one_hot_encoding(target)
        ot = output * y
        loss = torch.mean(torch.pow(F.relu(1 - ot), 2))
        return loss
    
def one_hot_encoding(labels):
    y = torch.eye(10) * 2 - 1
    return y[labels].to(device)

In [5]:
model = BinarizedDNNModel(image_size=28).to(device)
loss_function = L2SVMLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

BINARIZED NEW MODEL


In [6]:
print("Training...")

for epoch in range(n_epochs):
    print("========[EPOCH {}/{}]========".format(epoch, n_epochs))
    
    # Training
    train_acc = 0
    train_loss = 0
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        loss = loss_function(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.data.cpu().numpy()
        train_acc += np.mean(torch.argmax(outputs, 1).data.cpu().numpy() == labels.data.cpu().numpy())

    train_acc = train_acc / len(train_loader)
    print("[TRAIN ACCURACY]: {:.4f}".format(train_acc))
    print("[TRAIN LOSS]: {:.4f}".format(train_loss))

    if epoch % validation_steps == 0:
        # Validation
        valid_acc = 0
        for i, (images, labels) in enumerate(valid_loader):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)

            valid_acc += np.mean(torch.argmax(outputs, 1).data.cpu().numpy() == labels.data.cpu().numpy())

        valid_acc = valid_acc / len(valid_loader)
        print("[VALIDATION ACCURACY]: {:.4f}".format(valid_acc))    

Training...
[TRAIN ACCURACY]: 0.7745
[TRAIN LOSS]: 336.6467
[VALIDATION ACCURACY]: 0.8721
[TRAIN ACCURACY]: 0.8820
[TRAIN LOSS]: 217.2840
[TRAIN ACCURACY]: 0.9021
[TRAIN LOSS]: 177.0586
[TRAIN ACCURACY]: 0.9152
[TRAIN LOSS]: 147.9012
[TRAIN ACCURACY]: 0.9232
[TRAIN LOSS]: 125.2606
[TRAIN ACCURACY]: 0.9288
[TRAIN LOSS]: 106.8999
[TRAIN ACCURACY]: 0.9334
[TRAIN LOSS]: 92.0358
[TRAIN ACCURACY]: 0.9361
[TRAIN LOSS]: 79.8763
[TRAIN ACCURACY]: 0.9408
[TRAIN LOSS]: 69.3813
[TRAIN ACCURACY]: 0.9444
[TRAIN LOSS]: 60.7310
[TRAIN ACCURACY]: 0.9460
[TRAIN LOSS]: 53.8554
[VALIDATION ACCURACY]: 0.9394
[TRAIN ACCURACY]: 0.9482
[TRAIN LOSS]: 47.8662
[TRAIN ACCURACY]: 0.9502
[TRAIN LOSS]: 42.8655
[TRAIN ACCURACY]: 0.9529
[TRAIN LOSS]: 38.4202
[TRAIN ACCURACY]: 0.9545
[TRAIN LOSS]: 34.7837
[TRAIN ACCURACY]: 0.9567
[TRAIN LOSS]: 31.5108
[TRAIN ACCURACY]: 0.9582
[TRAIN LOSS]: 29.0028
[TRAIN ACCURACY]: 0.9595
[TRAIN LOSS]: 26.7822
[TRAIN ACCURACY]: 0.9609
[TRAIN LOSS]: 25.0273
[TRAIN ACCURACY]: 0.9623
[TRA