# Incremental learning on image classification

## Libraries and packages


In [1]:
!pip3 install 'torch==1.4.0'
!pip3 install 'torchvision==0.5.0'
!pip3 install 'Pillow-SIMD'
!pip3 install 'tqdm'

Collecting torch==1.4.0
[?25l  Downloading https://files.pythonhosted.org/packages/24/19/4804aea17cd136f1705a5e98a00618cb8f6ccc375ad8bfa437408e09d058/torch-1.4.0-cp36-cp36m-manylinux1_x86_64.whl (753.4MB)
[K     |████████████████████████████████| 753.4MB 21kB/s 
[31mERROR: torchvision 0.6.0+cu101 has requirement torch==1.5.0, but you'll have torch 1.4.0 which is incompatible.[0m
[?25hInstalling collected packages: torch
  Found existing installation: torch 1.5.0+cu101
    Uninstalling torch-1.5.0+cu101:
      Successfully uninstalled torch-1.5.0+cu101
Successfully installed torch-1.4.0
Collecting torchvision==0.5.0
[?25l  Downloading https://files.pythonhosted.org/packages/7e/90/6141bf41f5655c78e24f40f710fdd4f8a8aff6c8b7c6f0328240f649bdbe/torchvision-0.5.0-cp36-cp36m-manylinux1_x86_64.whl (4.0MB)
[K     |████████████████████████████████| 4.0MB 3.5MB/s 
Installing collected packages: torchvision
  Found existing installation: torchvision 0.6.0+cu101
    Uninstalling torchvision-0



In [1]:
import os
import urllib
import logging

import numpy as np

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
from torch.utils.data import Dataset, Subset, DataLoader, ConcatDataset
from torch.backends import cudnn

import torchvision
from torchvision import transforms
from torchvision.models import resnet34

from PIL import Image
from tqdm import tqdm

from copy import deepcopy

from sklearn.metrics import confusion_matrix

In [3]:
# GitHub credentials for cloning private repository
username = 'LilMowgli'
password = '_Kora3030_'

# Download packages from repository
password = urllib.parse.quote(password)
!git clone https://$username:$password@github.com/manuelemacchia/incremental-learning-image-classification.git
password = ''

!mv -v incremental-learning-image-classification/* .
!rm -rf incremental-learning-image-classification README.md

Cloning into 'incremental-learning-image-classification'...
remote: Enumerating objects: 674, done.[K
remote: Total 674 (delta 0), reused 0 (delta 0), pack-reused 674[K
Receiving objects: 100% (674/674), 2.80 MiB | 11.83 MiB/s, done.
Resolving deltas: 100% (359/359), done.
renamed 'incremental-learning-image-classification/data' -> './data'
renamed 'incremental-learning-image-classification/icarlSVM.ipynb' -> './icarlSVM.ipynb'
renamed 'incremental-learning-image-classification/joint_training.ipynb' -> './joint_training.ipynb'
renamed 'incremental-learning-image-classification/losses' -> './losses'
renamed 'incremental-learning-image-classification/model' -> './model'
renamed 'incremental-learning-image-classification/notebook.ipynb' -> './notebook.ipynb'
renamed 'incremental-learning-image-classification/README.md' -> './README.md'
renamed 'incremental-learning-image-classification/report' -> './report'
renamed 'incremental-learning-image-classification/studies-classifier.ipynb' -> 

In [2]:
from data.cifar100 import Cifar100
from model.resnet_cifar import resnet32
from model.manager import Manager
from model.icarl import Exemplars
from model.icarl import iCaRL
from utils import plot

  import pandas.util.testing as tm


## Arguments

In [3]:
# Directories
DATA_DIR = 'data'       # Directory where the dataset will be downloaded

# Settings
DEVICE = 'cuda'

# Dataset

RANDOM_STATE = None

RANDOM_STATES = [658, 423, 422]      # For reproducibility of results                        
                                     # Note: different random states give very different
                                     # splits and therefore very different results.

NUM_CLASSES = 100       # Total number of classes
NUM_BATCHES = 10
CLASS_BATCH_SIZE = 10   # Size of batch of classes for incremental learning

VAL_SIZE = 0.1          # Proportion of validation set with respect to training set (between 0 and 1)

# Training
BATCH_SIZE = 64         # Batch size (iCaRL sets this to 128)
LR = 2                  # Initial learning rate
                       
MOMENTUM = 0.9          # Momentum for stochastic gradient descent (SGD)
WEIGHT_DECAY = 1e-5     # Weight decay from iCaRL

NUM_RUNS = 3            # Number of runs of every method
                        # Note: this should be at least 3 to have a fair benchmark

NUM_EPOCHS = 70         # Total number of training epochs
MILESTONES = [49, 63]   # Step down policy from iCaRL (MultiStepLR)
                        # Decrease the learning rate by gamma at each milestone
GAMMA = 0.2             # Gamma factor from iCaRL

## Fine tuning

### Data preparation

In [None]:
# Define transformations for training
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(), # Turn PIL Image to torch.Tensor
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Define transformations for evaluation
test_transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                                    
])

In [None]:
train_dataloaders = [[] for i in range(NUM_RUNS)]
val_dataloaders = [[] for i in range(NUM_RUNS)]
test_dataloaders = [[] for i in range(NUM_RUNS)]



for run_i in range(NUM_RUNS):

    test_subsets = []
    random_state = RANDOM_STATES[run_i]

    for split_i in range(CLASS_BATCH_SIZE):

        # Download dataset only at first instantiation
        if(run_i+split_i == 0):
            download = True
        else:
            download = False

        download = False

        # Create CIFAR100 dataset
        train_dataset = Cifar100(DATA_DIR, train=True, download=download,
                                 random_state=random_state, transform=train_transform)
        test_dataset = Cifar100(DATA_DIR, train=False, download=False,
                                random_state=random_state, transform=test_transform)

        # Subspace of CIFAR100 of 10 classes
        train_dataset.set_classes_batch(train_dataset.batch_splits[split_i])
        test_dataset.set_classes_batch(
            [test_dataset.batch_splits[i] for i in range(0, split_i+1)])

        # Define train and validation indices
        train_indices, val_indices = train_dataset.train_val_split(
            VAL_SIZE, random_state)

        train_dataloaders[run_i].append(DataLoader(Subset(train_dataset, train_indices),
                                                   batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True))

        val_dataloaders[run_i].append(DataLoader(Subset(train_dataset, val_indices),
                                                 batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True))

        # Dataset with all seen class
        test_dataloaders[run_i].append(DataLoader(test_dataset,
                                                  batch_size=BATCH_SIZE, shuffle=True, num_workers=4))

In [None]:
# Sanity check: visualize a batch of images
dataiter = iter(test_dataloaders[0][5])
images, labels = dataiter.next()

plot.image_grid(images, one_channel=False)
unique_labels = np.unique(labels, return_counts=True)
unique_labels

### Execution

In [None]:
train_loss_history = []
train_accuracy_history = []
val_loss_history = []
val_accuracy_history = []
test_accuracy_history = []

# Iterate over runs
for train_dataloader, val_dataloader, test_dataloader in zip(train_dataloaders,
                                                             val_dataloaders, test_dataloaders):
  
    
    train_loss_history.append({})
    train_accuracy_history.append({})
    val_loss_history.append({})
    val_accuracy_history.append({})
    test_accuracy_history.append({})

    net = resnet32()  # Define the net
    
    criterion = nn.BCEWithLogitsLoss()  # Define the loss      
    
    i = 0
    for train_split, val_split, test_split in zip(train_dataloader,
                                                  val_dataloader, test_dataloader):
      
        
        current_split = "Split %i"%(i)
        print(current_split)

        parameters_to_optimize = net.parameters()
        optimizer = optim.SGD(parameters_to_optimize, lr=LR,
                          momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, 
                                               milestones=MILESTONES, gamma=GAMMA)

        # Define Manager Object
        manager = Manager(DEVICE, net, criterion, optimizer, scheduler,
                          train_split, val_split, test_split)

        scores = manager.train(NUM_EPOCHS)  # train the model

        # score[i] = dictionary with key:epoch, value: score
        train_loss_history[-1][current_split] = scores[0]
        train_accuracy_history[-1][current_split] = scores[1]
        val_loss_history[-1][current_split] = scores[2]
        val_accuracy_history[-1][current_split] = scores[3]

        # Test the model on classes seen until now
        test_accuracy, all_preds = manager.test()

        test_accuracy_history[-1][current_split] = test_accuracy

        # Uncomment if default resnet has 10 node at last FC layer
        manager.increment_classes(n=10)  # add 10 nodes to last FC layer

        i+=1
        

In [None]:
# Confusion matrix over last run test predictions
targets = test_dataset.targets
preds = all_preds.to('cpu').numpy()

plot.heatmap_cm(targets, preds)

In [None]:
def mean_std_scores(train_loss_history, train_accuracy_history,
                   val_loss_history, val_accuracy_history, test_accuracy_history):
  '''
      Average the scores of runs different splits
  '''
  # keys = 'Split i-esim'
  keys = train_loss_history[0].keys()

  # Containers for average scores
  avg_train_loss = {k:[] for k in keys}
  avg_train_accuracy = {k:[] for k in keys}
  avg_val_loss = {k:[] for k in keys}
  avg_val_accuracy = {k:[] for k in keys}
  avg_test_accuracy = {k:[] for k in keys}
  
  train_loss = []
  train_accuracy = []
  val_loss = []
  val_accuracy = []
  test_accuracy = []

  for key in keys:
    for run in range(NUM_RUNS):

      # Append all i-th scores (split i-esim) for the different runs
      avg_train_loss[key].append(train_loss_history[run][key])
      avg_train_accuracy[key].append(train_accuracy_history[run][key])
      avg_val_loss[key].append(val_loss_history[run][key])
      avg_val_accuracy[key].append(val_accuracy_history[run][key])
      avg_test_accuracy[key].append(test_accuracy_history[run][key])

    # Define (mean, std) of the i-th score for each split
    train_loss.append([np.array(avg_train_loss[key]).mean(), np.array(avg_train_loss[key]).std()])
    train_accuracy.append([np.array(avg_train_accuracy[key]).mean(), np.array(avg_train_accuracy[key]).std()])
    val_loss.append([np.array(avg_val_loss[key]).mean(), np.array(avg_val_loss[key]).std()])
    val_accuracy.append([np.array(avg_val_accuracy[key]).mean(), np.array(avg_val_accuracy[key]).std()])
    test_accuracy.append([np.array(avg_test_accuracy[key]).mean(), np.array(avg_test_accuracy[key]).std()])

  train_loss = np.array(train_loss)
  train_accuracy = np.array(train_accuracy)
  val_loss = np.array(val_loss)
  val_accuracy = np.array(val_accuracy)
  test_accuracy = np.array(test_accuracy)

  # Return averaged scores
  return(train_loss, train_accuracy, val_loss, val_accuracy, test_accuracy)

In [None]:
# Get the average scores
train_loss, train_accuracy, val_loss, val_accuracy,\
test_accuracy = mean_std_scores(train_loss_history, train_accuracy_history,
                                   val_loss_history, val_accuracy_history, test_accuracy_history)

In [None]:
plot.train_val_scores(train_loss, train_accuracy, val_loss, val_accuracy, None)

In [None]:
plot.test_scores(test_accuracy, None)

In [None]:
# @todo: create utils package for functions

import ast

def load_json_scores(root):

  with open(os.path.join(root, 'train_accuracy_history.json')) as f:
      train_accuracy_history = ast.literal_eval(f.read())

  with open(os.path.join(root, 'train_loss_history.json')) as f:
      train_loss_history = ast.literal_eval(f.read())

  with open(os.path.join(root, 'val_accuracy_history.json')) as f:
      val_accuracy_history = ast.literal_eval(f.read())

  with open(os.path.join(root, 'val_loss_history.json')) as f:
      val_loss_history = ast.literal_eval(f.read())

  with open(os.path.join(root, 'test_accuracy_history.json')) as f:
      test_accuracy_history = ast.literal_eval(f.read())

  return(train_loss_history, train_accuracy_history, val_loss_history,
         val_accuracy_history, test_accuracy_history)

In [None]:
# @todo: create utils package for functions
import json

def save_json_scores(root, train_loss_history, train_accuracy_history,
                   val_loss_history, val_accuracy_history, test_accuracy_history):

    with open(os.path.join(root, 'train_loss_history.json'), 'w') as fout:
        json.dump(train_loss_history, fout)

    with open(os.path.join(root, 'train_accuracy_history.json'), 'w') as fout:
        json.dump(train_accuracy_history, fout)

    with open(os.path.join(root, 'val_loss_history.json'), 'w') as fout:
        json.dump(val_loss_history, fout)

    with open(os.path.join(root, 'val_accuracy_history.json'), 'w') as fout:
        json.dump(val_accuracy_history, fout)

    with open(os.path.join(root, 'test_accuracy_history.json'), 'w') as fout:
        json.dump(test_accuracy_history, fout)

In [None]:
save_json_scores('scores', train_loss_history, train_accuracy_history,
                   val_loss_history, val_accuracy_history, test_accuracy_history)

In [None]:
from google.colab import files
!zip -r scores.zip scores
files.download("scores.zip")

## Learning Without Forgetting

### Arguments

In [4]:
# Training settings for Learning Without Forgetting
RANDOM_STATES = [658, 423, 422] 
BATCH_SIZE = 128
LR = 2

### Data preparation

In [5]:
# Transformations for Learning Without Forgetting
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(), # Turn PIL Image to torch.Tensor
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                                    
])

In [None]:
train_dataloaders = [[] for i in range(NUM_RUNS)]
val_dataloaders = [[] for i in range(NUM_RUNS)]
test_dataloaders = [[] for i in range(NUM_RUNS)]

for run_i in range(NUM_RUNS):

  test_subsets = []
  random_state = RANDOM_STATES[run_i]

  for split_i in range(CLASS_BATCH_SIZE):

    # Download dataset only at first instantiation
    if(run_i+split_i == 0):
      download = True
    else:
      download = False

    # Create CIFAR100 dataset
    train_dataset = Cifar100(DATA_DIR, train = True, download = download, random_state = random_state, transform=train_transform)
    test_dataset = Cifar100(DATA_DIR, train = False, download = False, random_state = random_state, transform=test_transform)
   
    # Subspace of CIFAR100 of 10 classes
    train_dataset.set_classes_batch(train_dataset.batch_splits[split_i]) 
    test_dataset.set_classes_batch([test_dataset.batch_splits[i] for i in range(0, split_i+1)])

    # Define train and validation indices
    train_indices, val_indices = train_dataset.train_val_split(VAL_SIZE, random_state)
    
    train_dataloaders[run_i].append(DataLoader(Subset(train_dataset, train_indices), 
                               batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True))
    
    val_dataloaders[run_i].append(DataLoader(Subset(train_dataset, val_indices), 
                                batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True))
    
    # Dataset with all seen class
    test_dataloaders[run_i].append(DataLoader(test_dataset, 
                               batch_size=BATCH_SIZE, shuffle=True, num_workers=4))           

In [None]:
# Sanity check: visualize a batch of images
dataiter = iter(test_dataloaders[0][5])
images, labels = dataiter.next()

plot.image_grid(images, one_channel=False)
unique_labels = np.unique(labels, return_counts=True)
unique_labels

In [None]:
from torch.nn import BCEWithLogitsLoss
from copy import deepcopy

'''BCE formulation:
 let x = logits, z = labels. The logistic loss is

  z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
'''

   
CLASS_BATCH_SIZE = 10


class LWF():
  def __init__(self, device, net, old_net, criterion, optimizer, scheduler,
               train_dataloader, test_dataloader, num_classes=10):
    
    self.device = device

    self.net = net
    self.best_net = self.net
    self.old_net = old_net # None for first ten classes

    self.criterion = BCEWithLogitsLoss() # Classifier criterion 
    self.optimizer = optimizer
    self.scheduler = scheduler

    self.train_dataloader = train_dataloader
    self.test_dataloader = test_dataloader

    self.num_classes = num_classes # can be incremented ouitside methods in the main, or inside methods
    self.order = np.arange(100)

    self.sigmoid = nn.Sigmoid()


  def warm_up():
    pass

  def increment_classes(self, n=10):
    """Add n classes in the final fully connected layer."""

    in_features = self.net.fc.in_features  # size of each input sample
    out_features = self.net.fc.out_features  # size of each output sample
    weight = self.net.fc.weight.data

    self.net.fc = nn.Linear(in_features, out_features+n)
    self.net.fc.weight.data[:out_features] = weight

  def to_onehot(self, targets): 
    '''
    Args:
    targets : dataloader.dataset.targets of the new task images
    '''
    one_hot_targets = torch.eye(self.num_classes)[targets]

    return one_hot_targets.to(self.device)

  def do_first_batch(self, batch, labels):

    batch = batch.to(self.device)
    labels = labels.to(self.device) # new classes labels

    # Zero-ing the gradients
    self.optimizer.zero_grad()

    # One hot encoding of new task labels 
    one_hot_labels = self.to_onehot(labels) # Size = [128, 10]

    # New net forward pass
    outputs = self.net(batch)  
    
    loss = self.criterion(outputs, one_hot_labels) # BCE Loss with sigmoids over outputs

    # Get predictions
    _, preds = torch.max(outputs.data, 1)

    # Accuracy over NEW IMAGES, not over all images
    running_corrects = \
        torch.sum(preds == labels.data).data.item() # Può essere che debba usare targets e non labels

    # Backward pass: computes gradients
    loss.backward()

    self.optimizer.step()

    return loss, running_corrects


  def do_batch(self, batch, labels):

    batch = batch.to(self.device)
    labels = labels.to(self.device) # new classes labels

    # Zero-ing the gradients
    self.optimizer.zero_grad()

    # One hot encoding of new task labels 
    one_hot_labels = self.to_onehot(labels) # Size = [128, n_classes] will be sliced as [:, :self.num_classes-10]
    new_classes = (self.order[range(self.num_classes-10, self.num_classes)]).astype(np.int32)
    one_hot_labels = torch.stack([one_hot_labels[:, i] for i in new_classes], axis=1)

    # Old net forward pass
    old_outputs = self.sigmoid(self.old_net(batch)) # Size = [128, 100]
    old_classes = (self.order[range(self.num_classes-10)]).astype(np.int32)
    old_outputs = torch.stack([old_outputs[:, i] for i in old_classes], axis =1)
    
    # Combine new and old class targets
    targets = torch.cat((old_outputs, one_hot_labels), 1)

    # New net forward pass
    outputs = self.net(batch) # Size = [128, 100] comparable with the define targets
    out_classes = (self.order[range(self.num_classes)]).astype(np.int32)
    outputs = torch.stack([outputs[:, i] for i in out_classes], axis=1)
  
    
    loss = self.criterion(outputs, targets) # BCE Loss with sigmoids over outputs (over targets must be done manually)

    # Get predictions
    _, preds = torch.max(outputs.data, 1)

    # Accuracy over NEW IMAGES, not over all images
    running_corrects = \
        torch.sum(preds == labels.data).data.item() 

    # Backward pass: computes gradients
    loss.backward()

    self.optimizer.step()

    return loss, running_corrects


  def do_epoch(self, current_epoch):

    self.net.train()

    running_train_loss = 0
    running_corrects = 0
    total = 0
    batch_idx = 0

    print(f"Epoch: {current_epoch}, LR: {self.scheduler.get_last_lr()}")

    for images, labels in self.train_dataloader:

      if self.num_classes == CLASS_BATCH_SIZE:
        loss, corrects = self.do_first_batch(images, labels)
      else:
        loss, corrects = self.do_batch(images, labels)

      running_train_loss += loss.item()
      running_corrects += corrects
      total += labels.size(0)
      batch_idx += 1

    self.scheduler.step()

    # Calculate average scores
    train_loss = running_train_loss / batch_idx # Average over all batches
    train_accuracy = running_corrects / float(total) # Average over all samples

    print(f"Train loss: {train_loss}, Train accuracy: {train_accuracy}")

    return (train_loss, train_accuracy)


  def train(self, num_epochs):
    """Train the network for a specified number of epochs
    
    Args:
        num_epochs (int): number of epochs for training the network.
    Returns:
        train_loss: loss computed on the last epoch
        train_accuracy: accuracy computed on the last epoch
        
    """

    # @todo: is the return behaviour intended? (scores of the last epoch)

    self.net = self.net.to(self.device)
    if self.old_net != None:
      self.old_net = self.old_net.to(self.device)
      self.old_net.train(False)

    cudnn.benchmark  # Calling this optimizes runtime

    self.best_loss = float("inf")
    self.best_epoch = 0

    for epoch in range(num_epochs):
        # Run an epoch (start counting form 1)
        train_loss, train_accuracy = self.do_epoch(epoch+1)
    
      
        print("")

    return (train_loss, train_accuracy)



  def test(self):
    """Test the model.
    Returns:
        accuracy (float): accuracy of the model on the test set
    """

    self.best_net.train(False)  # Set Network to evaluation mode

    running_corrects = 0
    total = 0

    all_preds = torch.tensor([]) # to store all predictions
    all_preds = all_preds.type(torch.LongTensor)
    
    for images, labels in self.test_dataloader:
      images = images.to(self.device)
      labels = labels.to(self.device)
      total += labels.size(0)

      # Forward Pass
      outputs = self.best_net(images)

      # Get predictions
      _, preds = torch.max(outputs.data, 1)

      # Update Corrects
      running_corrects += torch.sum(preds == labels.data).data.item()

      # Append batch predictions
      all_preds = torch.cat(
          (all_preds.to(self.device), preds.to(self.device)), dim=0
      )

    # Calculate accuracy
    accuracy = running_corrects / float(total)  

    print(f"Test accuracy: {accuracy}")

    return (accuracy, all_preds)

In [None]:
train_loss_history = []
train_accuracy_history = []
test_accuracy_history = []



# Iterate over runs
for train_dataloader, test_dataloader in zip(train_dataloaders,
                                                             test_dataloaders):
  
    
    train_loss_history.append({})
    train_accuracy_history.append({})

    test_accuracy_history.append({})

    net = resnet32()  # Define the net
    
    criterion = nn.BCEWithLogitsLoss()  # Define the loss
        
    
    i = 0
    for train_split, test_split in zip(train_dataloader,
                                       test_dataloader):
      
      # Redefine optimizer at each split (pass by reference issue)
      parameters_to_optimize = net.parameters()
      optimizer = optim.SGD(parameters_to_optimize, lr=LR,
                            momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
      scheduler = optim.lr_scheduler.MultiStepLR(optimizer, 
                                                milestones=MILESTONES, gamma=GAMMA)
        
      current_split = "Split %i"%(i)
      print(current_split)

      num_classes = CLASS_BATCH_SIZE*(i+1)

      if num_classes == CLASS_BATCH_SIZE:
        # Old Network = None
        lwf = LWF(DEVICE, net, None, criterion, optimizer, scheduler,
                          train_split, test_split, num_classes)
      else:
        lwf = LWF(DEVICE, net, old_net, criterion, optimizer, scheduler,
                        train_split, test_split, num_classes)
        

      scores = lwf.train(NUM_EPOCHS)  # train the model

      # score[i] = dictionary with key:epoch, value: score
      train_loss_history[-1][current_split] = scores[0]
      train_accuracy_history[-1][current_split] = scores[1]

      # Test the model on classes seen until now
      test_accuracy, all_preds = lwf.test()

      test_accuracy_history[-1][current_split] = test_accuracy

      # Uncomment if default resnet has 10 node at last FC layer
      old_net = deepcopy(lwf.net)
      lwf.increment_classes()

      i =i+1

## iCaRL

### Data preparation

In [6]:
# Transformations for Learning Without Forgetting
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(), # Turn PIL Image to torch.Tensor
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                                    
])

In [7]:
train_subsets = [[] for i in range(NUM_RUNS)]
val_subsets = [[] for i in range(NUM_RUNS)]
test_subsets = [[] for i in range(NUM_RUNS)]

for run_i in range(NUM_RUNS):
    for split_i in range(CLASS_BATCH_SIZE):
        if run_i+split_i == 0: # Download dataset only at first instantiation
            download = True
        else:
            download = False

        # Create CIFAR100 dataset
        train_dataset = Cifar100(DATA_DIR, train=True, download=download, random_state=RANDOM_STATES[run_i], transform=train_transform)
        test_dataset = Cifar100(DATA_DIR, train=False, download=False, random_state=RANDOM_STATES[run_i], transform=test_transform)
    
        # Subspace of CIFAR100 of 10 classes
        train_dataset.set_classes_batch(train_dataset.batch_splits[split_i]) 
        test_dataset.set_classes_batch([test_dataset.batch_splits[i] for i in range(0, split_i+1)])

        # Define train and validation indices
        train_indices, val_indices = train_dataset.train_val_split(VAL_SIZE, RANDOM_STATES[run_i])

        # Define subsets
        train_subsets[run_i].append(Subset(train_dataset, train_indices))
        val_subsets[run_i].append(Subset(train_dataset, val_indices))
        test_subsets[run_i].append(test_dataset)

### iCaRL Class

In [22]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.backends import cudnn

import numpy as np
import numpy.ma as ma
from math import floor
from copy import deepcopy
import random

sigmoid = nn.Sigmoid() # Sigmoid function

NUM_EXAM_EXEMPLARS = 10

class Exemplars(torch.utils.data.Dataset):
    def __init__(self, exemplars, transform=None):
        # exemplars = [
        #     [ex0_class0, ex1_class0, ex2_class0, ...],
        #     [ex0_class1, ex1_class1, ex2_class1, ...],
        #     ...
        #     [ex0_classN, ex1_classN, ex2_classN, ...]
        # ]

        self.dataset = []
        self.targets = []

        for y, exemplar_y in enumerate(exemplars):
            self.dataset.extend(exemplar_y)
            self.targets.extend([y] * len(exemplar_y))

        self.transform = transform
    
    def __getitem__(self, index):
        image = self.dataset[index]
        target = self.targets[index]

        if self.transform is not None:
            image = self.transform(image)

        return image, target

    def __len__(self):
        return len(self.targets)

class iCaRL:
    """Implement iCaRL, a strategy for simultaneously learning classifiers and a
    feature representation in the class-incremental setting.
    """

    def __init__(self, device, net, lr, momentum, weight_decay, milestones, gamma, num_epochs, batch_size, train_transform, test_transform):
        self.device = device
        self.net = net

        # Set hyper-parameters
        self.LR = lr
        self.MOMENTUM = momentum
        self.WEIGHT_DECAY = weight_decay
        self.MILESTONES = milestones
        self.GAMMA = gamma
        self.NUM_EPOCHS = num_epochs
        self.BATCH_SIZE = batch_size
        
        # Set transformations
        self.train_transform = train_transform
        self.test_transform = test_transform

        # List of exemplar sets. Each set contains memory_size/num_classes exemplars
        # with num_classes the number of classes seen until now by the network.
        self.exemplars = []

        # @DEBUG
        self.exemplars_rand = []

        # Initialize the copy of the old network, used to compute outputs of the
        # previous network for the distillation loss, to None. This is useful to
        # correctly apply the first function when training the network for the
        # first time.
        self.old_net = None
        self.old_nets = []

        # store saved features in a dictionary where key = split,
        # value = [last_net_features, old_nets_features]
        self.examined_exemplars_loader = None
        self.examined_exemplars_features = dict()

        # Maximum number of exemplars
        self.memory_size = 2000
    
        # Loss function
        self.criterion = nn.BCEWithLogitsLoss()

        # If True, test on the best model found (e.g., minimize loss). If False,
        # test on the last model build (of the last epoch).
        self.VALIDATE = False

    def classify(self, batch, train_dataset=None):
        """Mean-of-exemplars classifier used to classify images into the set of
        classes observed so far.

        Args:
            batch (torch.tensor): batch to classify
        Returns:
            label (int): class label assigned to the image
        """

        batch_features = self.extract_features(batch) # (batch size, 64)
        for i in range(batch_features.size(0)):
            batch_features[i] = batch_features[i]/batch_features[i].norm() # Normalize sample feature representation
        batch_features = batch_features.to(self.device)

        if self.cached_means is None:
            print("Computing mean of exemplars... ", end="")

            self.cached_means = []

            # Number of known classes
            num_classes = len(self.exemplars)

            # Compute the means of classes with all the data available,
            # including training data which contains samples belonging to
            # the latest 10 classes. This will remove noise from the mean
            # estimate, improving the results.
            if train_dataset is not None:
                train_features_list = [[] for _ in range(10)]

                for train_sample, label in train_dataset:
                    features = self.extract_features(train_sample, batch=False, transform=self.test_transform)
                    features = features/features.norm()
                    train_features_list[label % 10].append(features)

            # Compute means of exemplars for all known classes
            for y in range(num_classes):
                if (train_dataset is not None) and (y in range(num_classes-10, num_classes)):
                    features_list = train_features_list[y % 10]
                else:
                    features_list = []

                for exemplar in self.exemplars[y]:
                    features = self.extract_features(exemplar, batch=False, transform=self.test_transform)
                    features = features/features.norm() # Normalize the feature representation of the exemplar
                    features_list.append(features)
                
                features_list = torch.stack(features_list)
                class_means = features_list.mean(dim=0)
                class_means = class_means/class_means.norm() # Normalize the class means

                self.cached_means.append(class_means)
            
            self.cached_means = torch.stack(self.cached_means).to(self.device)
            print("done")

        preds = []
        for i in range(batch_features.size(0)):
            f_arg = torch.norm(batch_features[i] - self.cached_means, dim=1)
            preds.append(torch.argmin(f_arg))
        
        return torch.stack(preds)
    
    # @DEBUG
    def classify_rand(self, batch, train_dataset=None):
        """Mean-of-exemplars classifier used to classify images into the set of
        classes observed so far.

        Args:
            batch (torch.tensor): batch to classify
        Returns:
            label (int): class label assigned to the image
        """

        batch_features = self.extract_features(batch) # (batch size, 64)
        for i in range(batch_features.size(0)):
            batch_features[i] = batch_features[i]/batch_features[i].norm() # Normalize sample feature representation
        batch_features = batch_features.to(self.device)

        if self.cached_means is None:
            print("Computing mean of exemplars... ", end="")

            self.cached_means = []

            # Number of known classes
            num_classes = len(self.exemplars_rand)

            # Compute the means of classes with all the data available,
            # including training data which contains samples belonging to
            # the latest 10 classes. This will remove noise from the mean
            # estimate, improving the results.
            if train_dataset is not None:
                train_features_list = [[] for _ in range(10)]

                for train_sample, label in train_dataset:
                    features = self.extract_features(train_sample, batch=False, transform=self.test_transform)
                    features = features/features.norm()
                    train_features_list[label % 10].append(features)

            # Compute means of exemplars for all known classes
            for y in range(num_classes):
                if (train_dataset is not None) and (y in range(num_classes-10, num_classes)):
                    features_list = train_features_list[y % 10]
                else:
                    features_list = []

                for exemplar in self.exemplars_rand[y]:
                    features = self.extract_features(exemplar, batch=False, transform=self.test_transform)
                    features = features/features.norm() # Normalize the feature representation of the exemplar
                    features_list.append(features)
                
                features_list = torch.stack(features_list)
                class_means = features_list.mean(dim=0)
                class_means = class_means/class_means.norm() # Normalize the class means

                self.cached_means.append(class_means)
            
            self.cached_means = torch.stack(self.cached_means).to(self.device)
            print("done")

        preds = []
        for i in range(batch_features.size(0)):
            f_arg = torch.norm(batch_features[i] - self.cached_means, dim=1)
            preds.append(torch.argmin(f_arg))
        
        return torch.stack(preds)
    
    def extract_features(self, sample, batch=True, transform=None):
        """Extract features from single sample or from batch.
        
        Args:
            sample (PIL image or torch.tensor): sample(s) from which to
                extract features
            batch (bool): if True, sample is a torch.tensor containing a batch
                of images with dimensions (batch_size, 3, 32, 32)
            transform: transformations to apply to the PIL image before
                processing
        Returns:
            features: torch.tensor, 1-D of dimension 64 for single samples or
                2-D of dimension (batch_size, 64) for batch
        """

        assert not (batch is False and transform is None), "if a PIL image is passed to extract_features, a transform must be defined"

        self.net.train(False)
        if self.best_net is not None: self.best_net.train(False)
        if self.old_net is not None: 
          self.old_net.train(False)
          for i in range(len(self.old_nets)):
            self.old_nets[i] = self.old_nets[i].train(False)


        if batch is False: # Treat sample as single PIL image
            sample = transform(sample)
            sample = sample.unsqueeze(0) # https://stackoverflow.com/a/59566009/6486336

        sample = sample.to(self.device)

        if self.VALIDATE:
            features = self.best_net.features(sample)
        else:
            features = self.net.features(sample)

        if batch is False:
            features = features[0]

        return features

    def incremental_train(self, split, train_dataset, val_dataset):
        """Adjust internal knowledge based on the additional information
        available in the new observations.

        Args:
            split (int): current split number, counting from zero
            train_dataset: dataset for training the model
            val_dataset: dataset for validating the model
        Returns:
            train_logs: tuple of four metrics (train_loss, train_accuracy,
            val_loss, val_accuracy) obtained during network training
        """

        self.split = split

        if split is not 0:
            # Increment the number of output nodes for the new network by 10
            self.increment_classes(10)
            last_net_preds, old_nets_preds= self.save_exemplars_features()
            # update features dictionary for analisys
            self.examined_exemplars_features[self.split] = [last_net_preds, old_nets_preds]

        # Improve network parameters upon receiving new classes. Effectively
        # train a new network starting from the current network parameters.
        train_logs = self.update_representation(train_dataset, val_dataset)

        # Compute the number of exemplars per class
        num_classes = self.output_neurons_count()
        m = floor(self.memory_size / num_classes)

        print(f"Target number of exemplars per class: {m}")
        print(f"Target total number of exemplars: {m*num_classes}")

        # Reduce pre-existing exemplar sets in order to fit new exemplars
        for y in range(len(self.exemplars)):
            self.exemplars[y] = self.reduce_exemplar_set(self.exemplars[y], m)

        # @DEBUG
        for y in range(len(self.exemplars_rand)):
            self.exemplars_rand[y] = self.reduce_exemplar_set(self.exemplars[y], m)

        # Construct exemplar set for new classes
        new_exemplars = self.construct_exemplar_set(train_dataset, m)
        self.exemplars.extend(new_exemplars)
        

        # @DEBUG
        new_exemplars = self.construct_exemplar_set_rand(train_dataset, m)
        self.exemplars_rand.extend(new_exemplars)

        return train_logs


    def save_exemplars_features(self):
      # Take predictions on old_classes (num_classes-10) with
      # 1. self.old_net
      # 2. self.old_nets, keeping prediction of each net over the classes it's been trained on
      # Predictions are taken on exemplars of a single class, self.examined_exemplas, which are the same
      # across all the training procedure from split 0 to split 9
      # Predictions must be passed into a sigmoid since we want to compare the targets seen by the
      # currently trained net

      dataiter = iter(self.examined_exemplars_loader)
      images, _ = dataiter.next()
      images = images.to(self.device)

      last_net_preds = sigmoid(self.old_net(images))

      old_nets_preds = sigmoid(self.old_nets[0](images))[:, :10]
      old_nets_preds_brave = sigmoid(self.old_nets[0](images))

      for i in range(1, len(self.old_nets[1:])):
        old_nets_preds = torch.cat( (old_nets_preds, sigmoid(self.old_nets[i](images))[:, i*10:(i+1)*10] ), dim=1)
        old_nets_preds_brave = torch.cat( (old_nets_preds_brave, sigmoid(self.old_nets[i](images))[:, i*10:] ), dim=1)
      
      if torch.all(torch.eq(old_nets_preds, old_nets_preds_brave)): 
        print("You have been brave and that is good")
      else:
        print("Bravery does not pay")

      if(old_nets_preds.size() != last_net_preds):
        print("Correct size of prediction matrices")
      
      assert old_nets_preds.size() != last_net_preds, "Error: The two predictions matrices must be of the same size"

      # if old_nets_preds.size() != last_net_preds
      #   print("Error: The 2 predictions matrices must be of the same size")

      return last_net_preds, old_nets_preds
      

    def update_representation(self, train_dataset, val_dataset):
        """Update the parameters of the network.

        Args:
            train_dataset: dataset for training the model
            val_dataset: dataset for validating the model
        Returns:
            train_logs: tuple of four metrics (train_loss, train_accuracy,
            val_loss, val_accuracy) obtained during network training
        """

        # Combine the new training data with existing exemplars.
        print(f"Length of exemplars set: {sum([len(self.exemplars[y]) for y in range(len(self.exemplars))])}")
        exemplars_dataset = Exemplars(self.exemplars, self.train_transform)
        train_dataset_with_exemplars = ConcatDataset([exemplars_dataset, train_dataset])

        # Train the network on combined dataset
        train_logs = self.train(train_dataset_with_exemplars, val_dataset) # @todo: include exemplars in validation set?

        # Keep a copy of the current network in order to compute its outputs for
        # the distillation loss while the new network is being trained.
        self.old_net = deepcopy(self.net)
        self.old_nets.append(self.old_net)


        return train_logs

    def construct_exemplar_set(self, dataset, m):
        """...

        Args:
            dataset: dataset containing a split (samples from 10 classes) from
                which to take exemplars
            m (int): target number of exemplars per class
        Returns:
            exemplars: list of samples extracted from the dataset
        """

        dataset.dataset.disable_transform()

        samples = [[] for _ in range(10)]
        for image, label in dataset:
            label = label % 10 # Map labels to 0-9 range
            samples[label].append(image)

        dataset.dataset.enable_transform()

        # Initialize exemplar sets
        exemplars = [[] for _ in range(10)]

        # Iterate over classes
        for y in range(10):
            print(f"Extracting exemplars from class {y} of current split... ", end="")


            # Transform samples to tensors and apply normalization
            transformed_samples = torch.zeros((len(samples[y]), 3, 32, 32)).to(self.device)
            for i in range(len(transformed_samples)):
                transformed_samples[i] = self.test_transform(samples[y][i])

            # Extract features from samples
            samples_features = self.extract_features(transformed_samples).to(self.device)

            # Compute the feature mean of the current class
            features_mean = samples_features.mean(dim=0)

            # Initializes indices vector, containing the index of each exemplar chosen
            idx = []

            # See iCaRL algorithm 4
            for k in range(1, m+1): # k = 1, ..., m -- Choose m exemplars
                if k == 1: # No exemplars chosen yet, sum to 0 vector
                    f_sum = torch.zeros(64).to(self.device)
                else: # Sum of features of all exemplars chosen until now (j = 1, ..., k-1)
                    f_sum = samples_features[idx].sum(dim=0)

                # Compute argument of argmin function
                f_arg = torch.norm(features_mean - 1/k * samples_features + f_sum, dim=1)

                # Mask exemplars that were already taken, as we do not want to store the
                # same exemplar more than once
                mask = np.zeros(len(f_arg), int)
                mask[idx] = 1
                f_arg_masked = ma.masked_array(f_arg.cpu().detach().numpy(), mask=mask)

                # Compute the nearest available exemplar
                exemplar_idx = np.argmin(f_arg_masked)

                idx.append(exemplar_idx)
            
            # Save exemplars to exemplar set
            for i in idx:
                exemplars[y].append(samples[y][i])

            print(f"Extracted {len(exemplars[y])} exemplars.")
                
            # take first ten exemplars of first class, only once during the
            # whole incremental training procedure
            if self.split == 0 and y == 0:
              examined_exemplars = [exemplars[y][:NUM_EXAM_EXEMPLARS]] # add a dimension

              examined_exemplars_dataset = Exemplars(examined_exemplars, self.train_transform)
              self.examined_exemplars_loader = DataLoader(examined_exemplars_dataset,
                                                          batch_size = NUM_EXAM_EXEMPLARS)
              print("Saved {} exemplars for features analysis".format(len(examined_exemplars[0])))
                
            
            
        return exemplars

    def construct_exemplar_set_rand(self, dataset, m):
        """Randomly sample m elements from a dataset without replacement.

        Args:
            dataset: dataset containing a split (samples from 10 classes) from
                which to take exemplars
            m (int): target number of exemplars per class
        Returns:
            exemplars: list of samples extracted from the dataset
        """

        dataset.dataset.disable_transform()

        samples = [[] for _ in range(10)]
        for image, label in dataset:
            label = label % 10 # Map labels to 0-9 range
            samples[label].append(image)

        dataset.dataset.enable_transform()

        exemplars = [[] for _ in range(10)]

        for y in range(10):
            print(f"Randomly extracting exemplars from class {y} of current split... ", end="")

            # Randomly choose m samples from samples[y] without replacement
            exemplars[y] = random.sample(samples[y], m)

            print(f"Extracted {len(exemplars[y])} exemplars.")

        return exemplars

    def reduce_exemplar_set(self, exemplar_set, m):
        """Procedure for removing exemplars from a given set.

        Args:
            exemplar_set (set): set of exemplars belonging to a certain class
            m (int): target number of exemplars
        Returns:
            exemplar_set: reduced exemplar set
        """

        return exemplar_set[:m]

    def train(self, train_dataset, val_dataset):
        """Train the network for a specified number of epochs, and save
        the best performing model on the validation set.
        
        Args:
            train_dataset: dataset for training the model
            val_dataset: dataset for validating the model
        Returns: train_logs: tuple of four metrics (train_loss, train_accuracy,
            val_loss, val_accuracy) obtained during network training. If
            validation is enabled, return scores of the best epoch, otherwise
            return scores of the last epoch.
        """

        # Define the optimization algorithm
        parameters_to_optimize = self.net.parameters()
        self.optimizer = optim.SGD(parameters_to_optimize, 
                                   lr=self.LR,
                                   momentum=self.MOMENTUM,
                                   weight_decay=self.WEIGHT_DECAY)
        
        # Define the learning rate decaying policy
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer,
                                                        milestones=self.MILESTONES,
                                                        gamma=self.GAMMA)

        # Create DataLoaders for training and validation
        self.train_dataloader = DataLoader(train_dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)
        self.val_dataloader = DataLoader(val_dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

        # Send networks to chosen device
        self.net = self.net.to(self.device)
        if self.old_net is not None: 
          self.old_net = self.old_net.to(self.device)
          for i in range(len(self.old_nets)):
            self.old_nets[i] = self.old_nets[i].to(self.device)

        cudnn.benchmark  # Calling this optimizes runtime

        self.best_val_loss = float('inf')
        self.best_val_accuracy = 0
        self.best_train_loss = float('inf')
        self.best_train_accuracy = 0
        
        self.best_net = None
        self.best_epoch = -1

        for epoch in range(self.NUM_EPOCHS):
            # Run an epoch (start counting form 1)
            train_loss, train_accuracy = self.do_epoch(epoch+1)
        
            # Validate after each epoch 
            val_loss, val_accuracy = self.validate()    

            # Validation criterion: best net is the one that minimizes the loss
            # on the validation set.
            if self.VALIDATE and val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.best_val_accuracy = val_accuracy
                self.best_train_loss = train_loss
                self.best_train_accuracy = train_accuracy

                self.best_net = deepcopy(self.net)
                self.best_epoch = epoch
                print("Best model updated")

        if self.VALIDATE:
            val_loss = self.best_val_loss
            val_accuracy = self.best_val_accuracy
            train_loss = self.best_train_loss
            train_accuracy = self.best_train_accuracy

            print(f"Best model found at epoch {self.best_epoch+1}")

        return train_loss, train_accuracy, val_loss, val_accuracy
    
    def do_epoch(self, current_epoch):
        """Trains model for one epoch.
        
        Args:
            current_epoch (int): current epoch number (begins from 1)
        Returns:
            train_loss: average training loss over all batches of the
                current epoch.
            train_accuracy: training accuracy of the current epoch over
                all samples.
        """

        # Set the current network in training mode
        self.net.train()
        if self.old_net is not None: 
          self.old_net.train(False)
          for i in range(len(self.old_nets)):
            self.old_nets[i] = self.old_nets[i].train(False)
        if self.best_net is not None: self.best_net.train(False)

        running_train_loss = 0
        running_corrects = 0
        total = 0
        batch_idx = 0

        print(f"Epoch: {current_epoch}, LR: {self.scheduler.get_last_lr()}")

        for images, labels in self.train_dataloader:
            loss, corrects = self.do_batch(images, labels)

            running_train_loss += loss.item()
            running_corrects += corrects
            total += labels.size(0)
            batch_idx += 1

        self.scheduler.step()

        # Calculate average scores
        train_loss = running_train_loss / batch_idx # Average over all batches
        train_accuracy = running_corrects / float(total) # Average over all samples

        print(f"Train loss: {train_loss}, Train accuracy: {train_accuracy}")

        return train_loss, train_accuracy

    def do_batch(self, batch, labels):
        """Train network for a batch. Loss is applied here.

        Args:
            batch: batch of data used for training the network
            labels: targets of the batch
        Returns:
            loss: output of the criterion applied
            running_corrects: number of correctly classified elements
        """
        batch = batch.to(self.device)
        labels = labels.to(self.device)

        # Zero-ing the gradients
        self.optimizer.zero_grad()
        
        # One-hot encoding of labels of the new training data (new classes)
        # Size: batch size (rows) by number of classes seen until now (columns)
        #
        # e.g., suppose we have four images in a batch, and each incremental
        #   step adds three new classes. At the second step, the one-hot
        #   encoding may return the following tensor:
        #
        #       tensor([[0., 0., 0., 1., 0., 0.],   # image 0 (label 3)
        #               [0., 0., 0., 0., 1., 0.],   # image 1 (label 4)
        #               [0., 0., 0., 0., 0., 1.],   # image 2 (label 5)
        #               [0., 0., 0., 0., 1., 0.]])  # image 3 (label 4)
        #
        #   The first three elements of each vector will always be 0, as the
        #   new training batch does not contain images belonging to classes
        #   already seen in previous steps.
        #
        #   The last three elements of each vector will contain the actual
        #   information about the class of each image (one-hot encoding of the
        #   label). Therefore, we slice the tensor and remove the columns 
        #   related to old classes (all zeros).
        num_classes = self.output_neurons_count() # Number of classes seen until now, including new classes
        one_hot_labels = self.to_onehot(labels)[:, num_classes-10:num_classes]

        if self.old_net is None:
            # Network is training for the first time, so we only apply the
            # classification loss.
            targets = one_hot_labels

        else:
            # Old net forward pass. We compute the outputs of the old network
            # and apply a sigmoid function. These are used in the distillation
            # loss. We discard the output of the new neurons, as they are not
            # considered in the distillation loss.
            old_net_outputs = sigmoid(self.old_net(batch))[:, :num_classes-10]


            # Concatenate the outputs of the old network and the one-hot encoded
            # labels along dimension 1 (columns).
            # 
            # Each row refers to an image in the training set, and contains:
            # - the output of the old network for that image, used by the
            #   distillation loss
            # - the one-hot label of the image, used by the classification loss
            targets = torch.cat((old_net_outputs, one_hot_labels), dim=1)

        # Forward pass
        outputs = self.net(batch)
        loss = self.criterion(outputs, targets)

        # Get predictions
        _, preds = torch.max(outputs.data, 1)

        # Accuracy over NEW IMAGES, not over all images
        running_corrects = torch.sum(preds == labels.data).data.item() 

        # Backward pass: computes gradients
        loss.backward()

        self.optimizer.step()

        return loss, running_corrects

    def validate(self):
        """Validate the model.
        
        Returns:
            val_loss: average loss function computed on the network outputs
                of the validation set (val_dataloader).
            val_accuracy: accuracy computed on the validation set.
        """

        self.net.train(False)
        if self.old_net is not None: self.old_net.train(False)
        if self.best_net is not None: self.best_net.train(False)

        running_val_loss = 0
        running_corrects = 0
        total = 0
        batch_idx = 0

        for images, labels in self.val_dataloader:
            images = images.to(self.device)
            labels = labels.to(self.device)
            total += labels.size(0)

            # One hot encoding of new task labels 
            one_hot_labels = self.to_onehot(labels)

            # New net forward pass
            outputs = self.net(images)  
            loss = self.criterion(outputs, one_hot_labels) # BCE Loss with sigmoids over outputs

            running_val_loss += loss.item()

            # Get predictions
            _, preds = torch.max(outputs.data, 1)

            # Update the number of correctly classified validation samples
            running_corrects += torch.sum(preds == labels.data).data.item()

            batch_idx += 1

        # Calculate scores
        val_loss = running_val_loss / batch_idx
        val_accuracy = running_corrects / float(total)

        print(f"Validation loss: {val_loss}, Validation accuracy: {val_accuracy}")

        return val_loss, val_accuracy

    def test(self, test_dataset, train_dataset=None):
        """Test the model.

        Args:
            test_dataset: dataset on which to test the network
            train_dataset: training set used to train the last split, if
                available
        Returns:
            accuracy (float): accuracy of the model on the test set
        """

        self.net.train(False)
        if self.best_net is not None: self.best_net.train(False)  # Set Network to evaluation mode
        if self.old_net is not None: self.old_net.train(False)

        self.test_dataloader = DataLoader(test_dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=4)

        running_corrects = 0
        total = 0

        # To store all predictions
        all_preds = torch.tensor([])
        all_preds = all_preds.type(torch.LongTensor)

        # Clear mean of exemplars cache
        self.cached_means = None
        
        # Disable transformations for train_dataset, if available, as we will
        # need original PIL images from which to extract features.
        if train_dataset is not None: train_dataset.dataset.disable_transform()

        for images, labels in self.test_dataloader:
            images = images.to(self.device)
            labels = labels.to(self.device)

            total += labels.size(0)
            
            with torch.no_grad():
                preds = self.classify(images, train_dataset)

            running_corrects += torch.sum(preds == labels.data).data.item()

            all_preds = torch.cat(
                (all_preds.to(self.device), preds.to(self.device)), dim=0
            )

        if train_dataset is not None: train_dataset.dataset.enable_transform()

        # Calculate accuracy
        accuracy = running_corrects / float(total)  

        print(f"Test accuracy (iCaRL): {accuracy} ", end="")

        if train_dataset is None:
            print("(only exemplars)")
        else:
            print("(exemplars and training data)")

        return accuracy, all_preds

    # @DEBUG
    def test_rand(self, test_dataset, train_dataset=None):
        """Test the model.

        Args:
            test_dataset: dataset on which to test the network
            train_dataset: training set used to train the last split, if
                available
        Returns:
            accuracy (float): accuracy of the model on the test set
        """

        self.net.train(False)
        if self.best_net is not None: self.best_net.train(False)  # Set Network to evaluation mode
        if self.old_net is not None: self.old_net.train(False)

        self.test_dataloader = DataLoader(test_dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=4)

        running_corrects = 0
        total = 0

        # To store all predictions
        all_preds = torch.tensor([])
        all_preds = all_preds.type(torch.LongTensor)

        # Clear mean of exemplars cache
        self.cached_means = None
        
        # Disable transformations for train_dataset, if available, as we will
        # need original PIL images from which to extract features.
        if train_dataset is not None: train_dataset.dataset.disable_transform()

        for images, labels in self.test_dataloader:
            images = images.to(self.device)
            labels = labels.to(self.device)

            total += labels.size(0)
            
            with torch.no_grad():
                preds = self.classify_rand(images, train_dataset)

            running_corrects += torch.sum(preds == labels.data).data.item()

            all_preds = torch.cat(
                (all_preds.to(self.device), preds.to(self.device)), dim=0
            )

        if train_dataset is not None: train_dataset.dataset.enable_transform()

        # Calculate accuracy
        accuracy = running_corrects / float(total)  

        print(f"Test accuracy (iCaRL): {accuracy} ", end="")

        if train_dataset is None:
            print("(only exemplars)")
        else:
            print("(exemplars and training data)")

        return accuracy, all_preds

    def test_without_classifier(self, test_dataset):
        """Test the model without classifier, using the outputs of the
        network instead.

        Args:
            test_dataset: dataset on which to test the network
        Returns:
            accuracy (float): accuracy of the model on the test set
        """

        self.net.train(False)
        if self.best_net is not None: self.best_net.train(False) # Set Network to evaluation mode
        if self.old_net is not None: self.old_net.train(False)

        self.test_dataloader = DataLoader(test_dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=4)

        running_corrects = 0
        total = 0

        all_preds = torch.tensor([]) # to store all predictions
        all_preds = all_preds.type(torch.LongTensor)
        
        for images, labels in self.test_dataloader:
            images = images.to(self.device)
            labels = labels.to(self.device)
            total += labels.size(0)

            # Forward Pass
            with torch.no_grad():
                if self.VALIDATE:
                    outputs = self.best_net(images)
                else:
                    outputs = self.net(images)

            # Get predictions
            _, preds = torch.max(outputs.data, 1)

            # Update Corrects
            running_corrects += torch.sum(preds == labels.data).data.item()

            # Append batch predictions
            all_preds = torch.cat(
                (all_preds.to(self.device), preds.to(self.device)), dim=0
            )

        # Calculate accuracy
        accuracy = running_corrects / float(total)  

        print(f"Test accuracy (hybrid1): {accuracy}")

        return accuracy, all_preds
    
    def increment_classes(self, n=10):
        """Add n classes in the final fully connected layer."""

        in_features = self.net.fc.in_features  # size of each input sample
        out_features = self.net.fc.out_features  # size of each output sample
        weight = self.net.fc.weight.data

        self.net.fc = nn.Linear(in_features, out_features+n)
        self.net.fc.weight.data[:out_features] = weight
    
    def output_neurons_count(self):
        """Return the number of output neurons of the current network."""

        return self.net.fc.out_features
    
    def feature_neurons_count(self):
        """Return the number of neurons of the last layer of the feature extractor."""

        return self.net.fc.in_features
    
    def to_onehot(self, targets):
        """Convert targets to one-hot encoding (for BCE loss).

        Args:
            targets: dataloader.dataset.targets of the new task images
        """
        num_classes = self.net.fc.out_features
        one_hot_targets = torch.eye(num_classes)[targets]

        return one_hot_targets.to(self.device)

### Execution

In [12]:
# iCaRL hyperparameters
LR = 2
MOMENTUM = 0.9
WEIGHT_DECAY = 0.00001
MILESTONES = [49, 63]
GAMMA = 0.2
NUM_EPOCHS = 1
BATCH_SIZE = 64

In [19]:
# Define what tests to run
TEST_ICARL = True # Run test with iCaRL (exemplars + train dataset)
TEST_HYBRID1 = False # Run test with hybrid1

# Initialize logs
logs_icarl = [[] for _ in range(NUM_RUNS)]
logs_icarl_rand = [[] for _ in range(NUM_RUNS)] # @DEBUG
logs_hybrid1 = [[] for _ in range(NUM_RUNS)]

for run_i in range(NUM_RUNS):
    net = resnet32()
    icarl = iCaRL(DEVICE, net, LR, MOMENTUM, WEIGHT_DECAY, MILESTONES, GAMMA, NUM_EPOCHS, BATCH_SIZE, train_transform, test_transform)

    for split_i in range(10):
        print(f"## Split {split_i} of run {run_i} ##")
        
        train_logs = icarl.incremental_train(split_i, train_subsets[run_i][split_i], val_subsets[run_i][split_i])

        all_targets = torch.stack([label[0] for _, label in DataLoader(test_subsets[run_i][split_i])])

        if TEST_ICARL:
            logs_icarl[run_i].append({})
            logs_icarl_rand[run_i].append({}) # @DEBUG

            print("Herding:")
            acc, all_preds = icarl.test(test_subsets[run_i][split_i], train_subsets[run_i][split_i])

            # print("Random:") # @DEBUG
            # acc_rand, _ = icarl.test_rand(test_subsets[run_i][split_i], train_subsets[run_i][split_i]) # @DEBUG

            # logs_icarl_rand[run_i][split_i]['accuracy'] = acc_rand # @DEBUG

            logs_icarl[run_i][split_i]['accuracy'] = acc
            logs_icarl[run_i][split_i]['conf_mat'] = confusion_matrix(all_targets.to('cpu'), all_preds.to('cpu'))

            logs_icarl[run_i][split_i]['train_loss'] = train_logs[0]
            logs_icarl[run_i][split_i]['train_accuracy'] = train_logs[1]
            logs_icarl[run_i][split_i]['val_loss'] = train_logs[2]
            logs_icarl[run_i][split_i]['val_accuracy'] = train_logs[3]

        if TEST_HYBRID1:
            logs_hybrid1[run_i].append({})

            acc, all_preds = icarl.test_without_classifier(test_subsets[run_i][split_i])

            logs_hybrid1[run_i][split_i]['accuracy'] = acc

## Split 0 of run 0 ##
Length of exemplars set: 0
Epoch: 1, LR: [2]
Train loss: 0.3802311824900763, Train accuracy: 0.10647321428571428
Validation loss: 0.32623486859457834, Validation accuracy: 0.09598214285714286
Target number of exemplars per class: 200
Target total number of exemplars: 2000
Extracting exemplars from class 0 of current split... Extracted 200 exemplars.
Saved 10 exemplars for features analysis
Extracting exemplars from class 1 of current split... Extracted 200 exemplars.
Extracting exemplars from class 2 of current split... Extracted 200 exemplars.
Extracting exemplars from class 3 of current split... Extracted 200 exemplars.
Extracting exemplars from class 4 of current split... Extracted 200 exemplars.
Extracting exemplars from class 5 of current split... Extracted 200 exemplars.
Extracting exemplars from class 6 of current split... Extracted 200 exemplars.
Extracting exemplars from class 7 of current split... Extracted 200 exemplars.
Extracting exemplars from class

KeyboardInterrupt: ignored

In [20]:
features_dictionary = icarl.examined_exemplars_features

In [21]:
features_dictionary

{1: [tensor([[0.0823, 0.1041, 0.1039, 0.1009, 0.0850, 0.0780, 0.0955, 0.0756, 0.0833,
           0.1196],
          [0.0747, 0.0923, 0.0935, 0.0950, 0.0799, 0.0725, 0.0876, 0.0653, 0.0779,
           0.1063],
          [0.0758, 0.1063, 0.1019, 0.1081, 0.0843, 0.0789, 0.0954, 0.0740, 0.0863,
           0.1166],
          [0.0800, 0.1028, 0.1018, 0.1009, 0.0840, 0.0764, 0.0946, 0.0738, 0.0828,
           0.1166],
          [0.0754, 0.1077, 0.1028, 0.1104, 0.0850, 0.0803, 0.0960, 0.0748, 0.0876,
           0.1179],
          [0.0756, 0.1078, 0.1026, 0.1100, 0.0848, 0.0797, 0.0961, 0.0747, 0.0874,
           0.1174],
          [0.0750, 0.1070, 0.1020, 0.1098, 0.0844, 0.0795, 0.0956, 0.0741, 0.0871,
           0.1167],
          [0.1210, 0.0953, 0.1095, 0.0828, 0.0931, 0.0841, 0.0944, 0.0769, 0.0734,
           0.1283],
          [0.0755, 0.1079, 0.1029, 0.1103, 0.0850, 0.0803, 0.0960, 0.0749, 0.0876,
           0.1180],
          [0.0759, 0.0979, 0.0970, 0.0991, 0.0816, 0.0745, 0.0909, 0.0

### Plots

In [None]:
train_loss = [[logs_icarl[run_i][i]['train_loss'] for i in range(10)] for run_i in range(NUM_RUNS)]
train_accuracy = [[logs_icarl[run_i][i]['train_accuracy'] for i in range(10)] for run_i in range(NUM_RUNS)]
val_loss = [[logs_icarl[run_i][i]['val_loss'] for i in range(10)] for run_i in range(NUM_RUNS)]
val_accuracy = [[logs_icarl[run_i][i]['val_accuracy'] for i in range(10)] for run_i in range(NUM_RUNS)]
test_accuracy = [[logs_icarl[run_i][i]['accuracy'] for i in range(10)] for run_i in range(NUM_RUNS)]

train_loss = np.array(train_loss)
train_accuracy = np.array(train_accuracy)
val_loss = np.array(val_loss)
val_accuracy = np.array(val_accuracy)
test_accuracy = np.array(test_accuracy)

train_loss_stats = np.array([train_loss.mean(0), train_loss.std(0)]).transpose()
train_accuracy_stats = np.array([train_accuracy.mean(0), train_accuracy.std(0)]).transpose()
val_loss_stats = np.array([val_loss.mean(0), val_loss.std(0)]).transpose()
val_accuracy_stats = np.array([val_accuracy.mean(0), val_accuracy.std(0)]).transpose()
test_accuracy_stats = np.array([test_accuracy.mean(0), test_accuracy.std(0)]).transpose()

In [None]:
plot.train_val_scores(train_loss_stats, train_accuracy_stats, val_loss_stats, val_accuracy_stats)

In [None]:
plot.test_scores(test_accuracy_stats)