# Binary Connect Implementation

In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable, Function
import torch.nn.functional as F

from torchvision import datasets, transforms
import numpy as np

batch_size = 128
n_epochs = 1000
validation_steps = 10
learning_rate = 5e-3
stochastic = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [2]:
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=True, download=True,
                        transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,), (0.5,))
                       ])), batch_size=batch_size, shuffle=True)

valid_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, 
                        transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,), (0.5,))
                       ])), batch_size=batch_size, shuffle=True)

In [3]:
class Binarize(Function):
    @staticmethod        
    def forward(ctx, input, weight, bias, stochastic=False):
        if stochastic:
            p_weight = torch.max(torch.zeros_like(weight), torch.min(torch.ones_like(weight), (weight + 1) / 2))
            binarized_weights = torch.bernoulli(p_weight) * 2 - 1
        else:
            binarized_weights = torch.sign(weight)
        
        output = F.linear(input, binarized_weights, bias)
        ctx.save_for_backward(input, binarized_weights, bias)
        
        return output

    @staticmethod    
    def backward(ctx, gradients):
        input, binarized_weight, bias = ctx.saved_tensors
        grad_input = None
        grad_weight = None
        grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = gradients.mm(binarized_weight)
        if ctx.needs_input_grad[1]:
            grad_weight = gradients.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = gradients.sum(0)
        
#         print(gradients)
#         return grad_input, torch.ones_like(grad_weight), grad_bias, None

        return grad_input, grad_weight, grad_bias, None

class BinarizedLinear(nn.Module):
    def __init__(self, input_size, output_size, stochastic):
        super(BinarizedLinear, self).__init__()
        self.stochastic = stochastic
        self.fc = nn.Linear(input_size, output_size)
        torch.nn.init.xavier_uniform_(self.fc.weight.data)
#         torch.nn.init.xavier_uniform(self.fc.bias.data)

    def forward(self, input):
        self.fc.weight.data.clamp_(min=-1, max=1)
        if self.stochastic and not self.training:
            output = self.fc(input)
        else:
            output = Binarize.apply(input, self.fc.weight, self.fc.bias, self.stochastic)
        return output

class BinarizedDNNModel(nn.Module):
    def __init__(self, image_size, output_size=10, hidden_size=1024, stochastic=False):
        super(BinarizedDNNModel, self).__init__()
        self.image_size = image_size
        self.stochastic = stochastic
        self.fc1 = nn.Sequential(
                   BinarizedLinear(image_size * image_size, hidden_size, stochastic),
                   nn.ReLU(),
                   nn.BatchNorm1d(hidden_size))
        self.fc2 = nn.Sequential(
                   BinarizedLinear(hidden_size, hidden_size, stochastic),
                   nn.ReLU(),
                   nn.BatchNorm1d(hidden_size)
        )
        self.fc3 = nn.Sequential(
                   BinarizedLinear(hidden_size, hidden_size, stochastic),
                   nn.ReLU(),
                   nn.BatchNorm1d(hidden_size))
        self.output_layer = nn.Sequential(
                    BinarizedLinear(hidden_size, output_size, stochastic),
                    nn.ReLU(),
                    nn.BatchNorm1d(output_size))
        
    def forward(self, x):
        x = x.view(-1, self.image_size * self.image_size)
        
        for layer in [self.fc1, self.fc2, self.fc3, self.output_layer]:
            x = layer(x)
        return x
    
class L2SVMLoss(nn.Module):
    def __init__(self):
        super(L2SVMLoss, self).__init__()
    
    def forward(self, output, target):
        y = one_hot_encoding(target)
        ot = output * y
        loss = torch.mean(torch.pow(F.relu(1 - ot), 2))
        return loss
    
def one_hot_encoding(labels):
    y = torch.eye(10) * 2 - 1
    return y[labels].to(device)

In [4]:
class dumdums(nn.Module):
    def __init__(self, image_size, output_size=10, hidden_size=1024, stochastic=False):
        super(dumdums, self).__init__()
        self.image_size = image_size
        self.stochastic = stochastic
        self.fc1 = nn.Sequential(
                   BinarizedLinear(image_size * image_size, output_size, stochastic),
                   nn.ReLU()
#                    nn.BatchNorm1d(output_size)
        )

    def forward(self, x):
        x = x.view(-1, self.image_size * self.image_size)
        
        x = self.fc1(x)
        return x

In [5]:
# bl = BinarizedLinear(28 * 28, 32, stochastic=stochastic).to(device)
# model = dumdums(image_size=28, stochastic=stochastic).to(device)
# loss_function = L2SVMLoss().to(device)
# optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [6]:
# outputs = model(images.view(-1, 28 * 28))
# o = torch.optim.SGD(model.parameters(), lr=1)
# loss = outputs.mean()
# loss.backward()
# o.step()
# model.fc1[0].fc.weight

In [7]:
# loss

In [8]:
# model.fc1[0].fc.weight.grad

In [9]:
model = BinarizedDNNModel(image_size=28, stochastic=stochastic).to(device)
loss_function = L2SVMLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [10]:
model.fc1[0].fc.weight

Parameter containing:
tensor([[ 0.0060, -0.0551, -0.0473,  ..., -0.0324, -0.0363,  0.0262],
        [ 0.0284, -0.0271, -0.0083,  ..., -0.0391,  0.0121,  0.0484],
        [ 0.0444, -0.0215, -0.0035,  ..., -0.0386, -0.0057, -0.0344],
        ...,
        [-0.0002, -0.0181, -0.0109,  ..., -0.0474,  0.0352, -0.0453],
        [ 0.0076, -0.0486, -0.0331,  ...,  0.0208,  0.0525,  0.0497],
        [-0.0533,  0.0404, -0.0516,  ..., -0.0511,  0.0401, -0.0077]],
       device='cuda:0', requires_grad=True)

In [None]:
print("Training...")

for epoch in range(n_epochs):
    print("========[EPOCH {}/{}]========".format(epoch, n_epochs))
    
    # Training
    train_acc = 0
    train_loss = 0
    model.train()
#     print(model.fc1[0].fc.weight, model.fc1[0].fc.weight.grad)
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        loss = loss_function(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.data.cpu().numpy()
        train_acc += np.mean(torch.argmax(outputs, 1).data.cpu().numpy() == labels.data.cpu().numpy())
    
#         if i % 100 == 0:
#             print(model.fc1[0].fc.weight.grad.mean(0))
    train_acc = train_acc / len(train_loader)
    print("[TRAIN ACCURACY]: {:.4f}".format(train_acc))
    print("[TRAIN LOSS]: {:.4f}".format(train_loss))

    model.eval()
    if epoch % validation_steps == 0:
        # Validation
        valid_acc = 0
        for i, (images, labels) in enumerate(valid_loader):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)

            valid_acc += np.mean(torch.argmax(outputs, 1).data.cpu().numpy() == labels.data.cpu().numpy())

        valid_acc = valid_acc / len(valid_loader)
        print("[VALIDATION ACCURACY]: {:.4f}".format(valid_acc))    

Training...
[TRAIN ACCURACY]: 0.1476
[TRAIN LOSS]: 282.7784
[VALIDATION ACCURACY]: 0.1143
[TRAIN ACCURACY]: 0.2232
[TRAIN LOSS]: 159.1542
[TRAIN ACCURACY]: 0.2740
[TRAIN LOSS]: 146.0057


In [None]:
model.fc1[0].fc.weight.grad.mean(0)

In [None]:
model.train()
for i in range(50000):
    outputs = model(images)
    loss = loss_function(outputs, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if i % 10 == 0:
        print(i, loss, np.mean(torch.argmax(outputs, 1).data.cpu().numpy() == labels.data.cpu().numpy()))