In [1]:
import os
import time
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

from torch import autograd
from torch.autograd import Variable
from torch.utils.data import DataLoader

from torchvision import datasets, transforms
from biotorch.initialization.functions import add_fa_weight_matrices, override_backward

## Define Model

In [2]:
class LeNet(nn.Module):
    """
    Classic LeNet Architecture
    """

    def __init__(self, activation='tanh'):
        """
        :param in_features: dimension of input features (784 for MNIST)
        :param num_layers: number of layers for feed-forward net
        :param num_hidden_list: list of integers indicating hidden nodes of each layer
        """
        super(LeNet, self).__init__()

        if activation == 'relu':
            self.activation = torch.relu
        elif activation == 'tanh':
            self.activation = torch.tanh
        elif activation == 'sigmoid':
            self.activation = torch.sigmoid

        # create layer operations
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, stride=1)
        self.conv2 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=5, stride=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, inputs):
        """
        forward pass, which is same for conventional feed-forward net
        :param inputs: inputs with shape [batch_size, in_features]
        :return: logit outputs from the network
        """
        inputs = self.activation(self.conv1(inputs))
        inputs = self.pool(inputs)

        inputs = self.activation(self.conv2(inputs))
        inputs = self.pool(inputs)

        inputs = inputs.view(inputs.size()[0], -1)

        inputs = self.activation(self.fc1(inputs))
        inputs = self.fc2(inputs)
        return inputs

## Test Function

In [3]:
def test(model, test_loader, batch_size):
    test_loss = 0
    correct = 0
    # Desactivate the autograd engine in test
    with torch.no_grad():
        for data, target in test_loader:
            #data = data.view(batch_size, -1)
            inputs, targets = Variable(data), Variable(target)
            predictions = model(inputs)
            predictions = torch.squeeze(predictions)
            test_loss += F.nll_loss(predictions, targets, size_average=False).item()
            pred = predictions.data.max(1, keepdim=True)[1]
            correct += pred.eq(targets.data.view_as(pred)).sum()

    test_loss /= len(test_loader.dataset)
    return test_loss, correct

## Training code

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0

batch_size = 32

In [5]:
# set up datasets
print('==> Preparing data..')

train_loader = DataLoader(datasets.MNIST('./data', train=True, download=True,
                                             transform=transforms.Compose([
                                                 transforms.ToTensor(),
                                                 transforms.Normalize((0.1307,), (0.3081,))
                                             ])),
                              batch_size=batch_size, shuffle=True, drop_last=True)

test_loader = DataLoader(datasets.MNIST('./data', train=False, download=True,
                                            transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.1307,), (0.3081,))
                                            ])),
                             batch_size=batch_size, shuffle=False, drop_last=True)

==> Preparing data..
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


99.1%IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100.0%


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


112.7%

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Processing...
Done!



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [6]:
cudnn.benchmark = True

In [7]:
# Create Back Propagation Model
model_bp = LeNet()

In [8]:
# Create Feedback Alignment model
model_fa = LeNet()
model_fa.apply(add_fa_weight_matrices)
model_fa.apply(override_backward)

LeNet(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)

In [9]:
# Create optimizers
loss_crossentropy = torch.nn.CrossEntropyLoss()
optimizer_fa = torch.optim.RMSprop(model_fa.parameters(), lr=1e-4, weight_decay=0.)
optimizer_bp = torch.optim.RMSprop(model_bp.parameters(), lr=1e-4, weight_decay=0.)

In [10]:
logger_train = open('results' + 'bp_vs_fa.txt', 'w')

In [12]:
epochs = 1
for epoch in range(epochs):
    for idx_batch, (inputs, targets) in enumerate(train_loader):
        # flatten the inputs from square image to 1d vector
        #inputs = inputs.view(batch_size, -1)
        # wrap them into varaibles
        inputs, targets = Variable(inputs), Variable(targets)
        # get outputs from the model
        #print("inputs = ", inputs.size())
        outputs_fa = model_fa(inputs)
        outputs_bp = model_bp(inputs)
        # print(outputs_fa.size())
        # print(outputs_bp.size())
        # calculate loss
        outputs_fa = torch.squeeze(outputs_fa)
        outputs_bp = torch.squeeze(outputs_bp)
        # print(outputs_fa.size())
        # print(outputs_bp.size())

        # print("-"*20)
        #print("targets.size() = ", targets.size())
        # input()
        
        loss_bp = loss_crossentropy(outputs_bp, targets)
        loss_fa = loss_crossentropy(outputs_fa, targets)
        # print(loss_bp, loss_fa)
        
        t_fa = time.time()
        model_fa.zero_grad()
        loss_fa.backward()
        optimizer_fa.step()
        t_avg_fa = time.time() - t_fa
    
        t_bp = time.time()
        model_bp.zero_grad()
        loss_bp.backward()
        optimizer_bp.step()
        t_avg_bp = time.time() - t_bp

        if (idx_batch + 1) % 100 == 0:
            train_log = 'epoch ' + str(epoch) + ' step ' + str(idx_batch + 1) + \
                        ' loss_fa ' + str(loss_fa.data.item()) + ' loss_bp ' + str(loss_bp.data.item())
                         
            times = ' time_fa '+ str(t_avg_fa) + ' time_bp ' + str(t_avg_bp)
            time_dif = t_avg_fa - t_avg_bp
            print(train_log)
            print(times)
            print(time_dif)
            logger_train.write(train_log + '\n')

    # Test models
    test_loss_fa, correct_fa = test(model_fa, test_loader, batch_size)    
    test_loss_bp, correct_bp = test(model_bp, test_loader, batch_size)

    print('\n[Epoch {}] Test results'.format(epoch))
    print('\tFA: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(test_loss_fa,
                                                                      correct_fa, len(test_loader.dataset), 100. * correct_fa / len(test_loader.dataset)))
    print('\tBP: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss_bp,
                                                                        correct_bp, len(test_loader.dataset), 100. * correct_bp / len(test_loader.dataset)))

epoch 0 step 100 loss_fa 0.12437755614519119 loss_bp 0.008606676012277603
 time_fa 0.008363008499145508 time_bp 0.006242036819458008
0.0021209716796875
epoch 0 step 200 loss_fa 0.21762007474899292 loss_bp 0.05049574002623558
 time_fa 0.009228944778442383 time_bp 0.006242036819458008
0.002986907958984375
epoch 0 step 300 loss_fa 0.24455222487449646 loss_bp 0.077779121696949
 time_fa 0.008696556091308594 time_bp 0.006389141082763672
0.002307415008544922
epoch 0 step 400 loss_fa 0.32020407915115356 loss_bp 0.18049514293670654
 time_fa 0.009351491928100586 time_bp 0.005378246307373047
0.003973245620727539
epoch 0 step 500 loss_fa 0.15975359082221985 loss_bp 0.06126094236969948
 time_fa 0.008693695068359375 time_bp 0.00639033317565918
0.0023033618927001953
epoch 0 step 600 loss_fa 0.22966253757476807 loss_bp 0.007768119685351849
 time_fa 0.009076356887817383 time_bp 0.00702977180480957
0.0020465850830078125
epoch 0 step 700 loss_fa 0.2555553913116455 loss_bp 0.02789762057363987
 time_fa 0.0