In [1]:
%load_ext autoreload
%autoreload 2
import os
import time
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

from torch import autograd
from torch.autograd import Variable
from torch.utils.data import DataLoader

from torchvision import datasets, transforms
from biotorch.initialization.functions import convert_module, count_layers
from biotorch.layers import Conv2d, Linear
from collections import OrderedDict, defaultdict

In [2]:
def test(model, test_loader, batch_size):
    test_loss = 0
    correct = 0
    # Desactivate the autograd engine in test
    with torch.no_grad():
        for data, target in test_loader:
            #data = data.view(batch_size, -1)
            inputs, targets = data.to(device), target.to(device)
            predictions = model(inputs)
            predictions = torch.squeeze(predictions)
            test_loss += F.nll_loss(predictions, targets, size_average=False).item()
            pred = predictions.data.max(1, keepdim=True)[1]
            correct += pred.eq(targets.data.view_as(pred)).sum()

    test_loss /= len(test_loader.dataset)
    return test_loss, correct

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0
batch_size = 128

In [4]:
# set up datasets
print('==> Preparing data..')
train_loader = DataLoader(datasets.CIFAR10('./data', train=True, download=True,
                                                 transform=transforms.Compose([
                                                     transforms.Resize(224),
                                                     transforms.ToTensor(),
                                                     transforms.Normalize((0.1307,), (0.3081,))
                                                 ])),
                                  batch_size=batch_size, shuffle=True, drop_last=True)

test_loader = DataLoader(datasets.CIFAR10('./data', train=False, download=True,
                                                transform=transforms.Compose([
                                                    transforms.Resize(224),
                                                    transforms.ToTensor(),
                                                    transforms.Normalize((0.1307,), (0.3081,))
                                                ])),
                                 batch_size=batch_size, shuffle=False, drop_last=True)

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


In [5]:
# create the network
from torchvision.models import resnet
model_fa = resnet.resnet18(pretrained=True)
model_fa.fc = nn.Linear(512, 10)

In [6]:
convert_module(model_fa)

All the 20 <class 'torch.nn.modules.conv.Conv2d'> layers were converted successfully
All the 1 <class 'torch.nn.modules.linear.Linear'> layers were converted successfully


In [11]:
str(type(model_fa.conv1))

"<class 'biotorch.layers.conv.Conv2d'>"

In [29]:
for i, layer in enumerate(model_fa.modules()):
    print(type(layer))

<class 'torchvision.models.resnet.ResNet'>
<class 'biotorch.layers.conv.Conv2d'>
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.MaxPool2d'>
<class 'torch.nn.modules.container.Sequential'>
<class 'torchvision.models.resnet.BasicBlock'>
<class 'biotorch.layers.conv.Conv2d'>
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'biotorch.layers.conv.Conv2d'>
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
<class 'torchvision.models.resnet.BasicBlock'>
<class 'biotorch.layers.conv.Conv2d'>
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'biotorch.layers.conv.Conv2d'>
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
<class 'torch.nn.modules.container.Sequential'>
<class 'torchvision.models.resnet.BasicBlock'>
<class 'biotorch.layers.conv.Conv2d'>
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
<class 'torch.nn.modules

['T_destination',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_buffers',
 '_call_impl',
 '_forward_hooks',
 '_forward_impl',
 '_forward_pre_hooks',
 '_get_backward_hooks',
 '_get_name',
 '_is_full_backward_hook',
 '_load_from_state_dict',
 '_load_state_dict_pre_hooks',
 '_make_layer',
 '_maybe_warn_non_full_backward_hook',
 '_modules',
 '_named_members',
 '_non_persistent_buffers_set',
 '_norm_layer',
 '_parameters',
 '_register_load_state_dict_pre_hook',
 '_register_state_dict_hook',
 '_replicate_for_data_parallel',
 '_save_to_state_dict',
 '_slow_forward',
 '_state_d

In [7]:
w = model_fa.conv1.weight
print(w)

Parameter containing:
tensor([[[[-1.0419e-02, -6.1356e-03, -1.8098e-03,  ...,  5.6615e-02,
            1.7083e-02, -1.2694e-02],
          [ 1.1083e-02,  9.5276e-03, -1.0993e-01,  ..., -2.7124e-01,
           -1.2907e-01,  3.7424e-03],
          [-6.9434e-03,  5.9089e-02,  2.9548e-01,  ...,  5.1972e-01,
            2.5632e-01,  6.3573e-02],
          ...,
          [-2.7535e-02,  1.6045e-02,  7.2595e-02,  ..., -3.3285e-01,
           -4.2058e-01, -2.5781e-01],
          [ 3.0613e-02,  4.0960e-02,  6.2850e-02,  ...,  4.1384e-01,
            3.9359e-01,  1.6606e-01],
          [-1.3736e-02, -3.6746e-03, -2.4084e-02,  ..., -1.5070e-01,
           -8.2230e-02, -5.7828e-03]],

         [[-1.1397e-02, -2.6619e-02, -3.4641e-02,  ...,  3.2521e-02,
            6.6221e-04, -2.5743e-02],
          [ 4.5687e-02,  3.3603e-02, -1.0453e-01,  ..., -3.1253e-01,
           -1.6051e-01, -1.2826e-03],
          [-8.3730e-04,  9.8420e-02,  4.0210e-01,  ...,  7.0789e-01,
            3.6887e-01,  1.2455e-01]

In [8]:
replace_layers(model_fa)

replaced:  model_fa conv1
replaced:  model_fa fc
replaced:  0 conv1
replaced:  0 conv2
replaced:  1 conv1
replaced:  1 conv2
replaced:  0 conv1
replaced:  0 conv2
replaced:  downsample 0
replaced:  1 conv1
replaced:  1 conv2
replaced:  0 conv1
replaced:  0 conv2
replaced:  downsample 0
replaced:  1 conv1
replaced:  1 conv2
replaced:  0 conv1
replaced:  0 conv2
replaced:  downsample 0
replaced:  1 conv1
replaced:  1 conv2


In [10]:
q = model_fa.conv1.weight
print(q)

Parameter containing:
tensor([[[[-1.0419e-02, -6.1356e-03, -1.8098e-03,  ...,  5.6615e-02,
            1.7083e-02, -1.2694e-02],
          [ 1.1083e-02,  9.5276e-03, -1.0993e-01,  ..., -2.7124e-01,
           -1.2907e-01,  3.7424e-03],
          [-6.9434e-03,  5.9089e-02,  2.9548e-01,  ...,  5.1972e-01,
            2.5632e-01,  6.3573e-02],
          ...,
          [-2.7535e-02,  1.6045e-02,  7.2595e-02,  ..., -3.3285e-01,
           -4.2058e-01, -2.5781e-01],
          [ 3.0613e-02,  4.0960e-02,  6.2850e-02,  ...,  4.1384e-01,
            3.9359e-01,  1.6606e-01],
          [-1.3736e-02, -3.6746e-03, -2.4084e-02,  ..., -1.5070e-01,
           -8.2230e-02, -5.7828e-03]],

         [[-1.1397e-02, -2.6619e-02, -3.4641e-02,  ...,  3.2521e-02,
            6.6221e-04, -2.5743e-02],
          [ 4.5687e-02,  3.3603e-02, -1.0453e-01,  ..., -3.1253e-01,
           -1.6051e-01, -1.2826e-03],
          [-8.3730e-04,  9.8420e-02,  4.0210e-01,  ...,  7.0789e-01,
            3.6887e-01,  1.2455e-01]

In [7]:
model_fa = nn.DataParallel(model_fa)
model_fa.to(device)

DataParallel(
  (module): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

In [8]:
model_bp = resnet.resnet18(pretrained=True)
model_bp.fc = nn.Linear(512, 10)
model_bp.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [9]:
cudnn.benchmark = True

loss_crossentropy = torch.nn.CrossEntropyLoss()
optimizer_fa = torch.optim.RMSprop(model_fa.parameters(), lr=1e-4, weight_decay=0.)
optimizer_bp = torch.optim.RMSprop(model_bp.parameters(), lr=1e-4, weight_decay=0.)

In [10]:
model_fa.modules()[0]

TypeError: 'generator' object is not subscriptable

In [None]:
model_fa

In [11]:
logger_train = open('results' + 'conv_log2.txt', 'w')

In [12]:
epochs = 10
for epoch in range(epochs):
    for idx_batch, (inputs, targets) in enumerate(train_loader):
        # flatten the inputs from square image to 1d vector
        #inputs = inputs.view(batch_size, -1)
        # wrap them into varaibles
        inputs, targets = inputs.to(device), targets.to(device)
        # get outputs from the model
        #print("inputs = ", inputs.size())
        outputs_fa = model_fa(inputs)
        outputs_bp = model_bp(inputs)
        # print(outputs_fa.size())
        # print(outputs_bp.size())
        # calculate loss
        outputs_fa = torch.squeeze(outputs_fa)
        outputs_bp = torch.squeeze(outputs_bp)
        # print(outputs_fa.size())
        # print(outputs_bp.size())

        # print("-"*20)
        #print("targets.size() = ", targets.size())
        # input()
        
        loss_bp = loss_crossentropy(outputs_bp, targets)
        loss_fa = loss_crossentropy(outputs_fa, targets)
        # print(loss_bp, loss_fa)
        
        t_fa = time.time()
        model_fa.zero_grad()
        loss_fa.backward()
        optimizer_fa.step()
        t_avg_fa = time.time() - t_fa
    
        t_bp = time.time()
        model_bp.zero_grad()
        loss_bp.backward()
        optimizer_bp.step()
        t_avg_bp = time.time() - t_bp

        if (idx_batch + 1) % 100 == 0:
            train_log = 'epoch ' + str(epoch) + ' step ' + str(idx_batch + 1) + \
                        ' loss_fa ' + str(loss_fa.data.item()) + ' loss_bp ' + str(loss_bp.data.item())
                         
            times = ' time_fa '+ str(t_avg_fa) + ' time_bp ' + str(t_avg_bp)
            time_dif = t_avg_fa - t_avg_bp
            print(train_log)
            print(times)
            print(time_dif)
            logger_train.write(train_log + '\n')

    # Test models
    test_loss_fa, correct_fa = test(model_fa, test_loader, batch_size)    
    test_loss_bp, correct_bp = test(model_bp, test_loader, batch_size)

    print('\n[Epoch {}] Test results'.format(epoch))
    print('\tFA: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(test_loss_fa,
                                                                      correct_fa, len(test_loader.dataset), 100. * correct_fa / len(test_loader.dataset)))
    print('\tBP: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss_bp,
                                                                        correct_bp, len(test_loader.dataset), 100. * correct_bp / len(test_loader.dataset)))

epoch 0 step 100 loss_fa 1.00102698802948 loss_bp 0.3163323402404785
 time_fa 0.017718076705932617 time_bp 0.20459294319152832
-0.1868748664855957
epoch 0 step 200 loss_fa 0.8590705990791321 loss_bp 0.2614181339740753
 time_fa 0.017383813858032227 time_bp 0.2072312831878662
-0.18984746932983398
epoch 0 step 300 loss_fa 0.6312354803085327 loss_bp 0.26848942041397095
 time_fa 0.013391733169555664 time_bp 0.21135663986206055
-0.19796490669250488





[Epoch 0] Test results
	FA: Avg. loss: -2.2622, Accuracy: 7575/10000 (76%)
	BP: Avg. loss: -7.5389, Accuracy: 9340/10000 (93%)



KeyboardInterrupt: 