In [1]:
%load_ext autoreload
%autoreload 2
import os
import time
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

from torch import autograd
from torch.autograd import Variable
from torch.utils.data import DataLoader

from torchvision import datasets, transforms
from biotorch.initialization.functions import add_fa_weight_matrices, override_backward

In [2]:
def test(model, test_loader, batch_size):
    test_loss = 0
    correct = 0
    # Desactivate the autograd engine in test
    with torch.no_grad():
        for data, target in test_loader:
            #data = data.view(batch_size, -1)
            inputs, targets = Variable(data), Variable(target)
            predictions = model(inputs)
            predictions = torch.squeeze(predictions)
            test_loss += F.nll_loss(predictions, targets, size_average=False).item()
            pred = predictions.data.max(1, keepdim=True)[1]
            correct += pred.eq(targets.data.view_as(pred)).sum()

    test_loss /= len(test_loader.dataset)
    return test_loss, correct

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0
batch_size = 32

In [4]:
# set up datasets
print('==> Preparing data..')
train_loader = DataLoader(datasets.CIFAR10('./data', train=True, download=True,
                                                 transform=transforms.Compose([
                                                     transforms.Resize(224),
                                                     transforms.ToTensor(),
                                                     transforms.Normalize((0.1307,), (0.3081,))
                                                 ])),
                                  batch_size=batch_size, shuffle=True, drop_last=True)

test_loader = DataLoader(datasets.CIFAR10('./data', train=False, download=True,
                                                transform=transforms.Compose([
                                                    transforms.Resize(224),
                                                    transforms.ToTensor(),
                                                    transforms.Normalize((0.1307,), (0.3081,))
                                                ])),
                                 batch_size=batch_size, shuffle=False, drop_last=True)

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


In [5]:
# create the network
from torchvision.models import resnet
model_fa = resnet.resnet18()
model_fa.fc = nn.Linear(512, 10)
model_fa.apply(add_fa_weight_matrices)
model_fa.apply(override_backward)
model_bp = resnet.resnet18()
model_bp.fc = nn.Linear(512, 10)

cudnn.benchmark = True

loss_crossentropy = torch.nn.CrossEntropyLoss()
optimizer_fa = torch.optim.RMSprop(model_fa.parameters(), lr=1e-4, weight_decay=0.)
optimizer_bp = torch.optim.RMSprop(model_bp.parameters(), lr=1e-4, weight_decay=0.)

In [6]:
model_fa

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [7]:
logger_train = open('results' + 'conv_log2.txt', 'w')

In [9]:
epochs = 10
for epoch in range(epochs):
    for idx_batch, (inputs, targets) in enumerate(train_loader):
        # flatten the inputs from square image to 1d vector
        #inputs = inputs.view(batch_size, -1)
        # wrap them into varaibles
        inputs, targets = Variable(inputs), Variable(targets)
        # get outputs from the model
        #print("inputs = ", inputs.size())
        outputs_fa = model_fa(inputs)
        outputs_bp = model_bp(inputs)
        # print(outputs_fa.size())
        # print(outputs_bp.size())
        # calculate loss
        outputs_fa = torch.squeeze(outputs_fa)
        outputs_bp = torch.squeeze(outputs_bp)
        # print(outputs_fa.size())
        # print(outputs_bp.size())

        # print("-"*20)
        #print("targets.size() = ", targets.size())
        # input()
        
        loss_bp = loss_crossentropy(outputs_bp, targets)
        loss_fa = loss_crossentropy(outputs_fa, targets)
        # print(loss_bp, loss_fa)
        
        t_fa = time.time()
        model_fa.zero_grad()
        loss_fa.backward()
        optimizer_fa.step()
        t_avg_fa = time.time() - t_fa
    
        t_bp = time.time()
        model_bp.zero_grad()
        loss_bp.backward()
        optimizer_bp.step()
        t_avg_bp = time.time() - t_bp

        if (idx_batch + 1) % 10 == 0:
            train_log = 'epoch ' + str(epoch) + ' step ' + str(idx_batch + 1) + \
                        ' loss_fa ' + str(loss_fa.data.item()) + ' loss_bp ' + str(loss_bp.data.item())
                         
            times = ' time_fa '+ str(t_avg_fa) + ' time_bp ' + str(t_avg_bp)
            time_dif = t_avg_fa - t_avg_bp
            print(train_log)
            print(times)
            print(time_dif)
            logger_train.write(train_log + '\n')

    # Test models
    test_loss_fa, correct_fa = test(model_fa, test_loader, batch_size)    
    test_loss_bp, correct_bp = test(model_bp, test_loader, batch_size)

    print('\n[Epoch {}] Test results'.format(epoch))
    print('\tFA: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(test_loss_fa,
                                                                      correct_fa, len(test_loader.dataset), 100. * correct_fa / len(test_loader.dataset)))
    print('\tBP: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss_bp,
                                                                        correct_bp, len(test_loader.dataset), 100. * correct_bp / len(test_loader.dataset)))

epoch 0 step 10 loss_fa 2.4385271072387695 loss_bp 2.2538998126983643
 time_fa 3.3898894786834717 time_bp 0.2558324337005615
3.13405704498291
epoch 0 step 20 loss_fa 2.305401086807251 loss_bp 1.7154983282089233
 time_fa 3.171701669692993 time_bp 0.331129789352417
2.840571880340576
epoch 0 step 30 loss_fa 2.1611950397491455 loss_bp 1.985498309135437
 time_fa 3.168858051300049 time_bp 0.29947781562805176
2.869380235671997
epoch 0 step 40 loss_fa 1.9220905303955078 loss_bp 1.6860783100128174
 time_fa 3.1835267543792725 time_bp 0.25115537643432617
2.9323713779449463
epoch 0 step 50 loss_fa 2.0569775104522705 loss_bp 1.4631348848342896
 time_fa 3.280702590942383 time_bp 0.2445669174194336
3.036135673522949
epoch 0 step 60 loss_fa 1.7535927295684814 loss_bp 1.3928078413009644
 time_fa 3.214118480682373 time_bp 0.24637341499328613
2.967745065689087
epoch 0 step 70 loss_fa 1.9716498851776123 loss_bp 1.5075316429138184
 time_fa 3.203533411026001 time_bp 0.24327421188354492
2.960259199142456
epo

KeyboardInterrupt: 