# Importing Libraries

In the following block we are just importing the main libraries used for creating a NN and processing its output

In [None]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.autograd import Variable
from models.binarized_modules import  BinarizeLinear,BinarizeConv2d
from models.binarized_modules import  Binarize,HingeLoss
import matplotlib.pyplot as plt

# Load MNIST

In the incoming block the MNIST dataset is created and loaded to the standard DataLoader of pytorch. This allow to simply call train_data and test_data when training the network without having to manually create the batches. 

In [None]:
# Preprocessing data: convert to tensors and normalize by subtracting dataset
# mean and dividing by std.
# We need to recall that the data is normalized when doing the ASIC implementation 
# The dummy input we feed in the ASIC must be normalized as well 

#transform = transforms.Compose([transforms.ToTensor(),
                               # transforms.Normalize((0.1307,), (0.3081,))])

# Get data from torchvision.datasets
train_data = datasets.MNIST('../data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.MNIST('../data', train=False, download=True, transform=transforms.ToTensor())

# Define data loaders used to iterate through dataset
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1000)

# Show some example images and the associated label to verify that the data is loaded correctly 

labels_map = {
    0: "0",
    1: "1",
    2: "2",
    3: "3",
    4: "4",
    5: "5",
    6: "6",
    7: "7",
    8: "8",
    9: "9 Boot",
}

figure = plt.figure(figsize=(8, 8))
cols, rows = 10, 1
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(train_data), size=(1,)).item()
    img, label = train_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

# MNIST classification with Binary Neural Network
## Sign function 
The function implemented below is the sign() function mentioned in the paper. However, this function is not used for training as it would not allow for gradient descent calculation. The idea is to use this function after the train has been performed (TODO). 

In [None]:
def my_sign(a):
    
    a_buff = torch.empty(a.shape)
    for idx, element in enumerate(a):
        for idy, sub_element in enumerate(element):
            if(sub_element >= 0):
                a_buff[idx][idy] = 1
            else:
                a_buff[idx][idy] = -1
            
    return a_buff

## Create a class for the pytorch BNN
Here I am basically creating my own definition of the network, the __init__ is the constructor and creates the class instances of the layer I want to use. The foreward function instead, perform the foreward pass of the network based on the order on which I put the layers previously created. 

In [None]:
class MY_BNN(nn.Module):
    """
    PyTorch neural network. Network layers are defined in __init__ and forward
    pass implemented in forward.
    
    Args:
        in_features: number of features in input layer
        hidden_dim: number of features in hidden dimension
        out_features: number of features in output layer
    """
    
    def __init__(self, in_features, hidden_dim, out_features):
        super(MY_BNN, self).__init__()
       
        self.fc1 = BinarizeLinear(in_features, hidden_dim)
        self.htanh1 = nn.Hardtanh()
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.fc2 = BinarizeLinear(hidden_dim, hidden_dim)
        self.htanh2 = nn.Hardtanh()
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.fc3 = BinarizeLinear(hidden_dim, hidden_dim)
        self.htanh3 = nn.Hardtanh()
        self.bn3 = nn.BatchNorm1d(hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, out_features)
        self.logsoftmax=nn.LogSoftmax()
        self.drop=nn.Dropout(0.5)

    def forward(self, x):
        x = x.view(-1, 28*28)
       # print(x.shape)
        x = self.fc1(x)
        x = self.bn1(x)
        #x = my_sign(x)
        x = self.htanh1(x)
        x = self.fc2(x)
        x = self.bn2(x)
        #x = my_sign(x)
        x = self.htanh2(x)
        x = self.fc3(x)
        x = self.drop(x)
        x = self.bn3(x)
        #x = my_sign(x)
        x = self.htanh3(x)
        x = self.fc4(x)
        return self.logsoftmax(x)
        return x
    # return self.logsoftmax(x)


## Initialize parameters and criterion of the network 

In [None]:
# Initialize Pytorch network

in_features = 28*28 # this is because the input image is flatten into and array of 28*28, 28 being the number of pixels
hidden_dim = 100 # number of neurons in an hidden layer
out_features = 10 # we need to classify 10 classes of number 0 to 10 
learning_rate = 0.001 # this is the step that we take to move in the direction of the gradient 

criterion = nn.CrossEntropyLoss()  # Meaning that we use cross entropy as a loss function 
epochs = 10 # number of times we are going across the full dataset 

model = MY_BNN(in_features, hidden_dim, out_features)
optimizer = optim.Adam(model.parameters(), learning_rate) # Adam algorithm to optimize change of learning_rate


## Definition of the training function

In [None]:
train_losses = []  # hold the loss for each batch -> used to display training afterwards 

def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        train_losses.append(loss)

        if epoch%40==0:
            optimizer.param_groups[0]['lr']=optimizer.param_groups[0]['lr']*0.1

        optimizer.zero_grad()
        loss.backward()
        for p in list(model.parameters()):
            if hasattr(p,'org'):
                p.data.copy_(p.org)
        optimizer.step()
        for p in list(model.parameters()):
            if hasattr(p,'org'):
                p.org.copy_(p.data.clamp_(-1,1))
                
        #print(output)
        #print(target)
       # correct = torch.argmax(output, axis=1) == torch.argmax(target, axis=1)
        #train_accs.append(torch.sum(correct)/len(y_pred))
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

## The training loop

In [None]:
for epoch in range(1, epochs + 1):
    train(epoch)

## Display loss vs iterations

In [None]:
new = []
for element in train_losses:
    new.append(element.detach().numpy())
    
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.plot(new)
plt.grid()
plt.savefig('foo.pdf')

## Define a function for testing the trained network

In [None]:
def test():
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
       
            data, target = Variable(data), Variable(target)
            output = model(data)
            test_loss += criterion(output, target).item() # sum up batch loss
            pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

## Test it
After this section the accuracy of the trained network will be printed

In [None]:
test()

## Test a single prediction after training

In [None]:
img, label = test_data[11]
print(label)
plt.imshow(img.squeeze(), cmap="gray")
output = model(img)
pred = output.data.max(1, keepdim=True)[1]
print(pred)

## Printing the weights
This section must be refined to make sure that I am able to print the weight for each layer in a proper manner. 

In [None]:
for param in model.parameters():
    print(param)