In [None]:
%load_ext autoreload

In [2]:
%autoreload 2

## General configuration

In [3]:
import torch

In [None]:
# Print some information on the GPU
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(torch.cuda.current_device()))

True
1
GeForce GTX 1050


In [None]:
# Load parameters
import configs.config_cifar10 as config

import importlib
importlib.reload(config)

config = config.config
chosen_importance = config['ewc']['importance']

AttributeError: 'module' object has no attribute 'reload'

In [6]:
config

{'task': {'num_tasks': 10, 'is_conv': True},
 'ewc': {'importance': 1000, 'sample_size': 250},
 'opt': {'lr': 0.001, 'l1_reg': 0, 'iters': 1, 'batch_size': 64, 'epochs': 15},
 'other': {'enable_tensorboard': True,
  'device': device(type='cuda'),
  'model_name': '',
  'run_name': 'cifar10'}}

## Load data

In [10]:
# Load num_tasks datasets
import numpy as np
from data import CustomDataLoader
from torchvision import datasets, transforms

num_tasks = config['task']['num_tasks']

transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
        )


train_loader = {}
test_loader = {}

for t in range(num_tasks):
  
  train_dataset = datasets.CIFAR10(root='./data/cifar10', train=True, download=True, transform=transform)
  test_dataset = datasets.CIFAR10(root='./data/cifar10', train=False, transform=transform)
  
  train_dataset.train_labels = np.equal(train_dataset.train_labels, t).astype(np.int64)
  test_dataset.test_labels = np.equal(test_dataset.test_labels, t).astype(np.int64)

  train_loader[t] = CustomDataLoader(train_dataset, batch_size=config['opt']['batch_size'], shuffle=True)
  test_loader[t] = CustomDataLoader(test_dataset, batch_size=config['opt']['batch_size'])

Files already downloaded and verified
Files already downloaded and verified


## Define models

Take inspiration from here: https://github.com/wchliao/multi-task-image-classification/blob/master/models.py

In [18]:
from torch import nn
from torch.jit import trace
import torch.nn.functional as F


class CNN(torch.nn.Module):
    
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.batchnorm1 = nn.BatchNorm2d(32)
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2)
        self.batchnorm2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.batchnorm3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2)
        self.batchnorm4 = nn.BatchNorm2d(128)
        #self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        #self.batchnorm5 = nn.BatchNorm2d(128)
        self.output = nn.Linear(2048, num_tasks * 2)
        self.current_task = 0
        
    def forward(self, input):
        x = F.relu(self.batchnorm1(self.conv1(input)))
        x = F.relu(self.batchnorm2(self.conv2(x)))
        x = self.maxpool1(F.relu(self.batchnorm3(self.conv3(x))))
        x = F.relu(self.batchnorm4(self.conv4(x)))
        #x = self.maxpool(self.relu(self.batchnorm5(self.conv5(x))))
        x = self.output(x.reshape(input.shape[0], -1))
        return x[:, self.current_task*2:self.current_task*2 + 2]
      
    def set_task(self, task):
      self.current_task = torch.tensor(task)

In [19]:
from models import KAF, elu

class KAFCNN(torch.nn.Module):

    def __init__(self):
        super(KAFCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.batchnorm1 = nn.BatchNorm2d(32)
        self.maxpool = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
        self.conv2 = nn.Conv2d(32, 52, kernel_size=3, padding=1, stride=2)
        self.batchnorm2 = nn.BatchNorm2d(52)
        self.conv3 = nn.Conv2d(52, 52, kernel_size=3, padding=1)
        self.batchnorm3 = nn.BatchNorm2d(52)
        self.conv4 = nn.Conv2d(52, 96, kernel_size=3, padding=1, stride=2)
        self.batchnorm4 = nn.BatchNorm2d(96)
        self.kaf1 = KAF(32, init_fcn=elu, is_conv=True, D=15)
        self.kaf2 = KAF(52, init_fcn=elu, is_conv=True, D=15)
        self.kaf3 = KAF(52, init_fcn=elu, is_conv=True, D=15)
        self.kaf4 = KAF(96, init_fcn=elu, is_conv=True, D=15)
        #self.kaf5 = KAF(96, init_fcn=elu, is_conv=True)
        self.output = nn.Linear(1536, num_tasks * 2)
        self.current_task = 0

    def forward(self, input):
        x = self.kaf1(self.batchnorm1(self.conv1(input)))
        x = self.kaf2(self.batchnorm2(self.conv2(x)))
        x = self.maxpool(self.kaf3(self.batchnorm3(self.conv3(x))))
        x = self.kaf4(self.batchnorm4(self.conv4(x)))
        #x = self.maxpool(self.kaf5(self.batchnorm5(self.conv5(x))))
        x = self.output(x.reshape(input.shape[0], -1))
        return x[:, self.current_task*2:self.current_task*2 + 2]
    
    def set_task(self, task):
      self.current_task = task

In [20]:
net = CNN().to(config['other']['device'])

In [21]:
kafnet = KAFCNN().to(config['other']['device'])

In [22]:
print('Numero di parametri rete classica: ', sum([torch.numel(p) for p in net.parameters()]))
print('Numero di parametri KAFNET: ', sum([torch.numel(p) for p in kafnet.parameters()]))

Numero di parametri rete classica:  138948
Numero di parametri KAFNET:  95428


## Main experiment

In [25]:
torch.save(net.state_dict(), 'models/cnn.pt')
torch.save(kafnet.state_dict(), 'models/kafcnn.pt')

In [26]:
print('Training CNN with no penalty')
from train import repeat_train_n_times
config['ewc']['importance'] = 0.0 # Temporarily disable importance
config['other']['model_name'] = 'CNN' # Select a name for the model saving
loss, loss_full, acc = repeat_train_n_times(net, train_loader, test_loader, \
                                            net_init_path='models/cnn.pt', 
                                            config=config)

Training CNN with no penalty


*** ITER  1  of  1  ***


Training for task 1 of 2...


Training for task 2 of 2...





In [27]:
print('Training CNN with EWC penalty')
config['ewc']['importance'] = chosen_importance # Re-enable importance
config['other']['model_name'] = 'EWC-CNN' # Select a name for the model saving
loss_ewc, loss_full_ewc, acc_ewc = repeat_train_n_times(net, train_loader, test_loader,\
                                            net_init_path='models/cnn.pt', 
                                            config=config)

Training CNN with EWC penalty


*** ITER  1  of  1  ***


Training for task 1 of 2...


Training for task 2 of 2...





In [28]:
print('Training KAF-CNN network with no penalty')
config['ewc']['importance'] = 0.0 # Temporarily disable importance
config['other']['model_name'] = 'KAF-CNN' # Select a name for the model saving
loss_kaf, loss_full_kaf, acc_kaf = repeat_train_n_times(kafnet, train_loader, test_loader,\
                                            net_init_path='models/kafcnn.pt', 
                                            config=config)

Training KAF-CNN network with no penalty


*** ITER  1  of  1  ***


Training for task 1 of 2...


Training for task 2 of 2...





In [29]:
print('Training KAF-CNN network with EWC penalty')
config['ewc']['importance'] = chosen_importance # Re-enable importance
config['other']['model_name'] = 'EWC-KAF-CNN' # Select a name for the model saving
loss_kaf_ewc, loss_full_kaf_ewc, acc_kaf_ewc = repeat_train_n_times(kafnet, train_loader, test_loader,\
                                            net_init_path='models/kafcnn.pt',
                                            config=config)

Training KAF-CNN network with EWC penalty


*** ITER  1  of  1  ***


Training for task 1 of 2...


KeyboardInterrupt: 

## Some plots

In [None]:
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt

epochs = config['opt']['epochs']

# Loss plots
f, ax = plt.subplots(num_tasks, 1, sharex=True, figsize=(12, 12))

# Plot loss for standard network / standard training
for t, v in loss.items():
  ax[t].plot(list(range(t * epochs, (t + 1) * epochs)), v, '--', label='Standard')

# Plot loss for standard network / EWC training
for t, v in loss_ewc.items():
  ax[t].plot(list(range(t * epochs, (t + 1) * epochs)), v, label='EWC')
  
# Plot loss for KAF network / standard training
for t, v in loss_kaf.items():
  ax[t].plot(list(range(t * epochs, (t + 1) * epochs)), v, label='KAF')
  
# Plot loss for KAF network / EWC training
for t, v in loss_kaf_ewc.items():
  ax[t].plot(list(range(t * epochs, (t + 1) * epochs)), v, label='KAF-EWC')
  

# Add legend
handles, labels = ax[-1].get_legend_handles_labels()
f.legend(handles, labels, loc='upper right')

f.subplots_adjust(hspace=0)
plt.show()

In [None]:
# Accuracy plots
f, ax = plt.subplots(num_tasks, 1, sharex=True, figsize=(12, 12))

# Plot accuracy for standard network / standard training
for t, v in acc.items():
  x_axis = list(range(t * epochs, num_tasks * epochs))
  ax[t].plot(x_axis, v, '--', label='Standard')
  
# Plot accuracy for standard network / EWC training
for t, v in acc_ewc.items():
  x_axis = list(range(t * epochs, num_tasks * epochs))
  ax[t].plot(x_axis, v, label='EWC')
  
# Plot accuracy for KAF network / standard training
for t, v in acc_kaf.items():
  x_axis = list(range(t * epochs, num_tasks * epochs))
  ax[t].plot(x_axis, v, label='KAF')
  
# Plot accuracy for KAF network / EWC training
for t, v in acc_ewc_kaf.items():
  x_axis = list(range(t * epochs, num_tasks * epochs))
  ax[t].plot(x_axis, v, label='KAF-EWC')
  
 
# Add legend
handles, labels = ax[-1].get_legend_handles_labels()
f.legend(handles, labels, loc='upper right')

plt.subplots_adjust(hspace=0)