In [1]:
%load_ext autoreload
%autoreload 2
import os
import time
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torch.backends.cudnn as cudnn

from torch import autograd
from torch.autograd import Variable
from torch.utils.data import DataLoader

from torchvision import datasets, transforms
from biotorch.module.biomodule import BioModel

## Define Model

In [2]:
class LeNet(nn.Module):
    """
    Classic LeNet Architecture
    """

    def __init__(self, activation='tanh'):
        """
        :param in_features: dimension of input features (784 for MNIST)
        :param num_layers: number of layers for feed-forward net
        :param num_hidden_list: list of integers indicating hidden nodes of each layer
        """
        super(LeNet, self).__init__()

        if activation == 'relu':
            self.activation = torch.relu
        elif activation == 'tanh':
            self.activation = torch.tanh
        elif activation == 'sigmoid':
            self.activation = torch.sigmoid

        # create layer operations
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, stride=1)
        self.conv2 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=5, stride=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, inputs):
        """
        forward pass, which is same for conventional feed-forward net
        :param inputs: inputs with shape [batch_size, in_features]
        :return: logit outputs from the network
        """
        inputs = self.activation(self.conv1(inputs))
        inputs = self.pool(inputs)

        inputs = self.activation(self.conv2(inputs))
        inputs = self.pool(inputs)

        inputs = inputs.view(inputs.size()[0], -1)

        inputs = self.activation(self.fc1(inputs))
        inputs = self.fc2(inputs)
        return inputs

## Test Function

In [3]:
def test(model, test_loader, batch_size, device):
    test_loss = 0
    correct = 0
    # Desactivate the autograd engine in test
    with torch.no_grad():
        for data, target in test_loader:
            #data = data.view(batch_size, -1)
            inputs, targets = data.to(device), target.to(device)
            predictions = model(inputs)
            predictions = torch.squeeze(predictions)
            test_loss += F.nll_loss(predictions, targets, size_average=False).item()
            pred = predictions.data.max(1, keepdim=True)[1]
            correct += pred.eq(targets.data.view_as(pred)).sum()

    test_loss /= len(test_loader.dataset)
    return test_loss, correct

## Training code

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0

batch_size = 32

In [5]:
# set up datasets
print('==> Preparing data..')

train_loader = DataLoader(datasets.MNIST('./data', train=True, download=True,
                                             transform=transforms.Compose([
                                                 transforms.ToTensor(),
                                                 transforms.Normalize((0.1307,), (0.3081,))
                                             ])),
                              batch_size=batch_size, shuffle=True, drop_last=True)

test_loader = DataLoader(datasets.MNIST('./data', train=False, download=True,
                                            transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.1307,), (0.3081,))
                                            ])),
                             batch_size=batch_size, shuffle=False, drop_last=True)

==> Preparing data..


In [6]:
cudnn.benchmark = True

## FA Experiment

In [19]:
# Create Back Propagation Model
model_bp = LeNet()

In [20]:
# Create Feedback Alignment model
model_fa = LeNet()
model_fa = BioModel(model_fa, mode='FA')

All the 2 <class 'torch.nn.modules.conv.Conv2d'> layers were converted successfully
All the 2 <class 'torch.nn.modules.linear.Linear'> layers were converted successfully


In [21]:
# Uncomment for multiple GPUs
# model_fa = nn.DataParallel(model_fa, device_ids=[0, 1, 2, 3])

In [22]:
model_fa

DataParallel(
  (module): BioModel(
    (model): LeNet(
      (conv1): Conv2dFA(1, 20, kernel_size=(5, 5), stride=(1, 1))
      (conv2): Conv2dFA(20, 50, kernel_size=(5, 5), stride=(1, 1))
      (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (fc1): LinearFA(in_features=800, out_features=500, bias=True)
      (fc2): LinearFA(in_features=500, out_features=10, bias=True)
    )
  )
)

In [23]:
# you can comment this to run on GPU
device = 'cpu'

In [24]:
for i, layer in enumerate(model_fa.modules()):
    print(type(layer))

<class 'torch.nn.parallel.data_parallel.DataParallel'>
<class 'biotorch.module.biomodule.BioModel'>
<class '__main__.LeNet'>
<class 'biotorch.layers.conv.Conv2dFA'>
<class 'biotorch.layers.conv.Conv2dFA'>
<class 'torch.nn.modules.pooling.MaxPool2d'>
<class 'biotorch.layers.linear.LinearFA'>
<class 'biotorch.layers.linear.LinearFA'>


In [25]:
device

'cpu'

In [28]:
model_fa.to(device)
model_bp.to(device)

LeNet(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)

In [29]:
# Create optimizers
loss_crossentropy = torch.nn.CrossEntropyLoss()
optimizer_fa = torch.optim.RMSprop(model_fa.parameters(), lr=1e-4, weight_decay=0.)
optimizer_bp = torch.optim.RMSprop(model_bp.parameters(), lr=1e-4, weight_decay=0.)

In [30]:
logger_train = open('results' + 'bp_vs_fa.txt', 'w')

In [31]:
epochs = 10
for epoch in range(epochs):
    model_fa.train()
    model_bp.train()
    for idx_batch, (inputs, targets) in enumerate(train_loader):
        # flatten the inputs from square image to 1d vector
        #inputs = inputs.view(batch_size, -1)
        # wrap them into varaibles
        inputs, targets = inputs.to(device), targets.to(device)
        # get outputs from the model
        #print("inputs = ", inputs.size())
        outputs_fa = model_fa(inputs)
        outputs_bp = model_bp(inputs)
        # print(outputs_fa.size())
        # print(outputs_bp.size())
        # calculate loss
        outputs_fa = torch.squeeze(outputs_fa)
        outputs_bp = torch.squeeze(outputs_bp)
        # print(outputs_fa.size())
        # print(outputs_bp.size())

        # print("-"*20)
        #print("targets.size() = ", targets.size())
        # input()
        
        loss_bp = loss_crossentropy(outputs_bp, targets)
        loss_fa = loss_crossentropy(outputs_fa, targets)
        # print(loss_bp, loss_fa)
        
        t_bp = time.time()
        model_bp.zero_grad()
        loss_bp.backward()
        optimizer_bp.step()
        t_avg_bp = time.time() - t_bp
        
        t_fa = time.time()
        model_fa.zero_grad()
        loss_fa.backward()
        optimizer_fa.step()
        t_avg_fa = time.time() - t_fa

        if (idx_batch + 1) % 100 == 0:
            train_log = 'epoch ' + str(epoch) + ' step ' + str(idx_batch + 1) + \
                        ' loss_fa ' + str(loss_fa.data.item()) + ' loss_bp ' + str(loss_bp.data.item())
                         
            times = ' time_fa '+ str(t_avg_fa) + ' time_bp ' + str(t_avg_bp)
            time_dif = t_avg_fa - t_avg_bp
            print(train_log)
            print(times)
            print(time_dif)
            logger_train.write(train_log + '\n')

    # Test models
    model_fa.eval()
    model_bp.eval()
    test_loss_fa, correct_fa = test(model_fa, test_loader, batch_size, device)    
    test_loss_bp, correct_bp = test(model_bp, test_loader, batch_size, device)

    print('\n[Epoch {}] Test results'.format(epoch))
    print('\tFA: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(test_loss_fa,
                                                                      correct_fa, len(test_loader.dataset), 100. * correct_fa / len(test_loader.dataset)))
    print('\tBP: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss_bp,
                                                                        correct_bp, len(test_loader.dataset), 100. * correct_bp / len(test_loader.dataset)))

epoch 0 step 100 loss_fa 2.3209924697875977 loss_bp 0.5322349667549133
 time_fa 0.003175497055053711 time_bp 0.0010976791381835938
0.002077817916870117
epoch 0 step 200 loss_fa 2.255208730697632 loss_bp 0.17781995236873627
 time_fa 0.0028536319732666016 time_bp 0.0009067058563232422
0.0019469261169433594
epoch 0 step 300 loss_fa 2.1853020191192627 loss_bp 0.22670413553714752
 time_fa 0.003612041473388672 time_bp 0.0010962486267089844
0.0025157928466796875
epoch 0 step 400 loss_fa 1.8642054796218872 loss_bp 0.2274066060781479
 time_fa 0.0030164718627929688 time_bp 0.0012524127960205078
0.001764059066772461
epoch 0 step 500 loss_fa 1.1944801807403564 loss_bp 0.19956040382385254
 time_fa 0.0030024051666259766 time_bp 0.0009572505950927734
0.002045154571533203
epoch 0 step 600 loss_fa 0.9679815769195557 loss_bp 0.33286726474761963
 time_fa 0.003511190414428711 time_bp 0.0010750293731689453
0.0024361610412597656
epoch 0 step 700 loss_fa 0.7103651165962219 loss_bp 0.20738591253757477
 time_f

KeyboardInterrupt: 

## DFA Experiments

In [29]:
model_bp = LeNet()
model_dfa = LeNet()

In [30]:
model_dfa = BioModel(model_dfa, mode='DFA', output_dim=10)

All the 2 <class 'torch.nn.modules.conv.Conv2d'> layers were converted successfully
All the 2 <class 'torch.nn.modules.linear.Linear'> layers were converted successfully


In [31]:
loss_crossentropy = torch.nn.CrossEntropyLoss()

In [32]:
# device = 'cpu'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [33]:
model_dfa = nn.DataParallel(model_dfa, device_ids=[0, 1, 2, 3])

In [34]:
model_dfa.to(device)
model_bp.to(device)

LeNet(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)

In [35]:
# Create optimizers
optimizer_dfa = torch.optim.RMSprop(model_dfa.parameters(), lr=1e-4, weight_decay=0.)
optimizer_bp = torch.optim.RMSprop(model_bp.parameters(), lr=1e-4, weight_decay=0.)

In [36]:
logger_train = open('results' + 'bp_vs_dfa.txt', 'w')

In [37]:
epochs = 10
for epoch in range(epochs):
    model_dfa.train()
    model_bp.train()
    for idx_batch, (inputs, targets) in enumerate(train_loader):
        # flatten the inputs from square image to 1d vector
        #inputs = inputs.view(batch_size, -1)
        # wrap them into varaibles
        inputs, targets = inputs.to(device), targets.to(device)
        # get outputs from the model
        #print("inputs = ", inputs.size())
        outputs_dfa = model_dfa(inputs, targets, loss_crossentropy)
        outputs_bp = model_bp(inputs)
        # print(outputs_fa.size())
        # print(outputs_bp.size())
        # calculate loss
        outputs_dfa = torch.squeeze(outputs_dfa)
        outputs_bp = torch.squeeze(outputs_bp)
        # print(outputs_fa.size())
        # print(outputs_bp.size())

        # print("-"*20)
        #print("targets.size() = ", targets.size())
        # input()
        
        loss_bp = loss_crossentropy(outputs_bp, targets)
        loss_dfa = loss_crossentropy(outputs_dfa, targets)
        # print(loss_bp, loss_fa)
        
        t_bp = time.time()
        model_bp.zero_grad()
        loss_bp.backward()
        optimizer_bp.step()
        t_avg_bp = time.time() - t_bp
        
        t_dfa = time.time()
        model_dfa.zero_grad()
        loss_dfa.backward()
        optimizer_dfa.step()
        t_avg_dfa = time.time() - t_dfa

        if (idx_batch + 1) % 100 == 0:
            train_log = 'epoch ' + str(epoch) + ' step ' + str(idx_batch + 1) + \
                        ' loss_dfa ' + str(loss_dfa.data.item()) + ' loss_bp ' + str(loss_bp.data.item())
                         
            times = ' time_dfa '+ str(t_avg_dfa) + ' time_bp ' + str(t_avg_bp)
            time_dif = t_avg_dfa - t_avg_bp
            print(train_log)
            print(times)
            print(time_dif)
            logger_train.write(train_log + '\n')

    # Test models
    model_dfa.eval()
    model_bp.eval()
    test_loss_dfa, correct_fa = test(model_dfa, test_loader, batch_size, device)    
    test_loss_bp, correct_bp = test(model_bp, test_loader, batch_size, device)

    print('\n[Epoch {}] Test results'.format(epoch))
    print('\tDFA: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(test_loss_dfa,
                                                                      correct_fa, len(test_loader.dataset), 100. * correct_fa / len(test_loader.dataset)))
    print('\tBP: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss_bp,
                                                                        correct_bp, len(test_loader.dataset), 100. * correct_bp / len(test_loader.dataset)))

epoch 0 step 100 loss_dfa 2.3258461952209473 loss_bp 0.369176983833313
 time_dfa 0.004075765609741211 time_bp 0.0010900497436523438
0.002985715866088867
epoch 0 step 200 loss_dfa 2.341024160385132 loss_bp 0.2684066891670227
 time_dfa 0.0046710968017578125 time_bp 0.001165628433227539
0.0035054683685302734
epoch 0 step 300 loss_dfa 2.0307843685150146 loss_bp 0.31385377049446106
 time_dfa 0.004360198974609375 time_bp 0.0012257099151611328
0.003134489059448242
epoch 0 step 400 loss_dfa 1.46736741065979 loss_bp 0.11081911623477936
 time_dfa 0.003952741622924805 time_bp 0.00116729736328125
0.0027854442596435547
epoch 0 step 500 loss_dfa 1.253302812576294 loss_bp 0.19032679498195648
 time_dfa 0.00403285026550293 time_bp 0.0011668205261230469
0.002866029739379883
epoch 0 step 600 loss_dfa 0.869157075881958 loss_bp 0.14167854189872742
 time_dfa 0.00380706787109375 time_bp 0.0011725425720214844
0.0026345252990722656
epoch 0 step 700 loss_dfa 0.5580421090126038 loss_bp 0.06907501071691513
 time_

KeyboardInterrupt: 