In [1]:
import sys
import os
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
from networks import NoKafnet, Kafnet
import utils.datasetsUtils.MINST as MINST
from utils.datasetsUtils.taskManager import SingleTargetClassificationTask, NoTask
import configs.configClasses as configClasses
from torchvision.transforms import transforms
import torch
from Trainer import Trainer
import matplotlib.pyplot as plt
from collections import defaultdict
import copy
import numpy as np

In [3]:
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(torch.cuda.current_device()))

True
1
GeForce GTX 1050


In [4]:
config = configClasses.OnlineLearningConfig()
config.EPOCHS = 20
config.L1_REG = 0
config.LR = 1e-3
config.EWC_IMPORTANCE = 500
config.SAVE_PATH = './models/permuted_minst/multikaf'
config.MODEL_NAME = 'multikaf'
config.IS_CONVOLUTIONAL = False
print(config)
    
config_ewc = copy.deepcopy(config)
config_ewc.MODEL_NAME='ewc'
print(config_ewc)

confing_ewt_kaf = copy.copy(config)
confing_ewt_kaf.MODEL_NAME = 'ewc_kaf'
print(confing_ewt_kaf)

CONFIG PARAMETERS
BATCH_SIZE: 64
DEVICE: cuda
EPOCHS: 20
EWC_IMPORTANCE: 500
EWC_SAMPLE_SIZE: 250
EWC_TYPE: <class 'networks.continual_learning.OnlineEWC'>
GAMMA: 1.0
IS_CONVOLUTIONAL: False
ITERS: 1
L1_REG: 0
LOSS: cross_entropy
LR: 0.001
MODEL_NAME: multikaf
OPTIMIZER: SGD
RUN_NAME: default
SAVE_PATH: ./models/permuted_minst/multikaf
USE_EWC: True
USE_TENSORBOARD: True

CONFIG PARAMETERS
BATCH_SIZE: 64
DEVICE: cuda
EPOCHS: 20
EWC_IMPORTANCE: 500
EWC_SAMPLE_SIZE: 250
EWC_TYPE: <class 'networks.continual_learning.OnlineEWC'>
GAMMA: 1.0
IS_CONVOLUTIONAL: False
ITERS: 1
L1_REG: 0
LOSS: cross_entropy
LR: 0.001
MODEL_NAME: ewc
OPTIMIZER: SGD
RUN_NAME: default
SAVE_PATH: ./models/permuted_minst/multikaf
USE_EWC: True
USE_TENSORBOARD: True

CONFIG PARAMETERS
BATCH_SIZE: 64
DEVICE: cuda
EPOCHS: 20
EWC_IMPORTANCE: 500
EWC_SAMPLE_SIZE: 250
EWC_TYPE: <class 'networks.continual_learning.OnlineEWC'>
GAMMA: 1.0
IS_CONVOLUTIONAL: False
ITERS: 1
L1_REG: 0
LOSS: cross_entropy
LR: 0.001
MODEL_NAME: ewc

In [5]:
dataset = MINST.PermutedMINST('../data/minst', download=True, n_permutation=4,
                        force_download=False, train_split=0.8, transform=None, target_transform=None)
dataset.load_dataset()

dataset_no_ewt = copy.deepcopy(dataset)

../data/minst/download
task #0 with train 56000 and test 14000 images (label: 0)
task #1 with train 56000 and test 14000 images (label: 1)
task #2 with train 56000 and test 14000 images (label: 2)
task #3 with train 56000 and test 14000 images (label: 3)


In [6]:
net_multikaf = Kafnet.MultiKAFMLP(len(dataset.class_to_idx), hidden_size=int(400*0.8),  kaf_init_fcn=None,
                             trainable_dict=False, kernel_combination='softmax')

kaf_ewt = Kafnet.KAFMLP(len(dataset.class_to_idx), hidden_size=int(400*0.8),  kaf_init_fcn=None)

net_ewt = NoKafnet.MLP(len(dataset.class_to_idx), hidden_size=int(400))

print('Numero di parametri rete multikaf: ', sum([torch.numel(p) for p in net_multikaf.parameters()]))
print('Numero di parametri rete classica: ', sum([torch.numel(p) for p in net_ewt.parameters()]))
print('Numero di parametri rete kaf: ', sum([torch.numel(p) for p in kaf_ewt.parameters()]))

Numero di parametri rete multikaf:  517450
Numero di parametri rete classica:  638810
Numero di parametri rete kaf:  479050


In [7]:
trainer_multikaf = Trainer(net_multikaf, copy.deepcopy(dataset), config)

metrics_multikaf = trainer_multikaf.load()
if not metrics_multikaf:
    metrics_multikaf = trainer_multikaf.all_tasks()

KeyboardInterrupt: 

In [None]:
trainer_kaf_ewt = Trainer(kaf_ewt, copy.deepcopy(dataset), confing_ewt_kaf)

metrics_kaf_ewt = trainer_kaf_ewt.load()
if not metrics_kaf_ewt:
    metrics_kaf_ewt = trainer_kaf_ewt.all_tasks()

In [None]:
trainer_ewt = Trainer(net_ewt, copy.deepcopy(dataset), config_ewc)
    
metrics_ewt = trainer_ewt.load()
if not metrics_ewt:
    metrics_ewt = trainer_ewt.all_tasks()

In [None]:
n_task = len(metrics_ewt['tasks'])
tot_epochs = 0

print('Multikaf', metrics_multikaf['metrics'])
print('Kaf', metrics_kaf_ewt['metrics'])
print('Ewc', metrics_ewt['metrics'])


for k, v in metrics_multikaf['tasks'].items():
    tot_epochs = max(tot_epochs, len(v['accuracy']))
             
for k, v in metrics_ewt['tasks'].items():
    tot_epochs = max(tot_epochs, len(v['accuracy']))
      
fig = plt.figure(figsize=(12, 24))

ax = None
for i, task in enumerate(metrics_ewt['tasks'].keys()):
        
    ewt = metrics_ewt['tasks'][task]
    no_ewt = metrics_multikaf['tasks'][task]
    kaf = metrics_kaf_ewt['tasks'][task]

    x = range(tot_epochs-len(ewt['accuracy']), tot_epochs)

    ax = fig.add_subplot(n_task, 1, i+1, sharex=ax) 
    
    ax.plot(x, ewt['accuracy'], label='ewc')
    ax.plot(x, no_ewt['accuracy'], label='multikaf ewc')
    ax.plot(x, kaf['accuracy'], label='kaf ewc')


    ax.set_xticks(range(0, tot_epochs, 5),minor=False)
    
    ax.set_title("Task {}".format(task))
    ax.legend(loc="lower left")
    ax.grid(True, axis='x')
    
fig.subplots_adjust(hspace=0.01)
