***
## Learning without Forgetting - LwF

### Benchmark: Split MNIST

Experiment reproducing **Learning without Forgetting** method:  
hybrid of Distillation Networks and fine-tuning, which refers to the re-training with a low learning rate an already trained model M with new and more specific dataset, D<sub>new</sub>, with respect to the dataset, D<sub>old</sub>, with which the given model M was originally trained.

LwF, as opposed to other continual learning techniques, only uses the new data, so it assumes that past data used to pre-train the network is unavailable.  
It is a *transfer learning technique*.

`References:`
- Learning without Forgetting: https://arxiv.org/abs/1606.09282
- Three scenarios for continual learning: https://arxiv.org/abs/1904.07734

**Avalanche LwF strategy:**
https://avalanche-api.continualai.org/en/v0.1.0/generated/avalanche.training.LwF.html
***

In [1]:
# Testing framework and test runner
import unittest

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam, SGD

# Avalanche library from ContinualAI
import avalanche

# Models and benchmarks
from avalanche.models import SimpleMLP
from avalanche.benchmarks.classic  import SplitMNIST

# Loggers
from avalanche.logging import InteractiveLogger, TensorboardLogger, TextLogger

# Evaluation
from avalanche.training.plugins import EvaluationPlugin
from avalanche.evaluation.metrics import accuracy_metrics, timing_metrics, forgetting_metrics, confusion_matrix_metrics, StreamConfusionMatrix

# Extras: Model and utils
from utils import arguments, get_average_metric, get_target_result

## LwF technique
class LwF(unittest.TestCase): #TestCase class

    #### Split MNIST benchmark
    def test_smnist(self, override_args=None):
        
        scenario = SplitMNIST(n_experiences=5, return_task_id=False, fixed_class_order=[0,1,2,3,4,5,6,7,8,9])
               
        # --- Strategy instantiation --- # 
        # 1. Model
        # 2. Optimizer
        # 3. Loss function
        
        # -> ADDITIONAL ARGUMENTS allow to customize training
        args = arguments({ 'cuda': 0,              # GPU
                           'lwf_alpha': 1,         # Penalty hyperparameter for LwF
                           'lwf_temperature': 1,   # Temperature for softmax used in distillation
                           'learning_rate': 0.001, # Learning rate
                           'train_epochs' : 100,    # Training epochs
                           'train_mb_size': 128}, override_args) # Train minibatch size

        # Set up and run CUDA operations,
        # if CUDA is available, utilize GPUs for computation.
        device = torch.device(f"cuda:{args.cuda}"
                              if torch.cuda.is_available() and args.cuda >= 0 
                              else "cpu")
        
        model = SimpleMLP(num_classes=10, input_size=28*28, hidden_size=256, hidden_layers=1, drop_rate=0)

        optimizer = SGD(model.parameters(), lr=args.learning_rate)
        criterion = CrossEntropyLoss()
        
        # ------------------------ LOG ------------------------ #
        # logging results over-time to examine the experiment in real-time
        loggers = []
        
        # log to Tensorboard
        loggers.append(TensorboardLogger())
        
        # log to TextLogger
        loggers.append(TextLogger(open('out.txt','w')))
        
        # log to InteractiveLogger, displays a progress bar during training and evaluation
        loggers.append(InteractiveLogger())
        
        # -------------------- EVALUATION -------------------- #
        # Metrics of main interest to be tracked
        eval_plugin = EvaluationPlugin(
            accuracy_metrics(experience=True, stream=True, trained_experience=True),
            #confusion_matrix_metrics(num_classes=scenario.n_classes, save_image=True, stream=True),
            StreamConfusionMatrix(),
            #timing_metrics(epoch=True),
            #forgetting_metrics(experience=True, stream=True),
            loggers=loggers,
            benchmark=scenario
        )

        # -> CONTINUAL LEARNING STRATEGY: LwF
        cl_strategy = avalanche.training.LwF(model, optimizer, criterion,
                                             # additional arguments
                                             alpha=args.lwf_alpha, 
                                             temperature=args.lwf_temperature,
                                             train_mb_size=args.train_mb_size, 
                                             train_epochs=args.train_epochs,
                                             device=device,
                                             # evaluation
                                             evaluator=eval_plugin,
        )
        
        # --- Training loop --- #
        print('Starting experiment...')
        
        for experience in scenario.train_stream:
            print('Current experience {}, contains: {} patterns'.format(experience.current_experience, len(experience.dataset)))
            print('Current classes: ',experience.classes_in_this_experience)
            
            # Train
            cl_strategy.train(experience)
            print('Training completed')
            
            # Accuracy over the whole test set (no access to task-ID at inference time)
            print('Computing accuracy over the whole test set')
            res = cl_strategy.eval(scenario.test_stream)
            
        
        avg_stream_acc = get_average_metric(res)
        print(f"LwF-SplitMNIST Average Stream Accuracy: {avg_stream_acc:.2f}")

### Run and Evaluate the experiment
- Create an instance of the strategy object
- Execute the strategy on a benchmark

In [2]:
# Create the strategy
s = LwF()

# Run the experiment with custom parameters
s.test_smnist()


Starting experiment...
Current experience 0, contains: 12665 patterns
Current classes:  [0, 1]
-- >> Start of training phase << --
-- Starting training on experience 0 (Task 0) from train stream --
100%|██████████| 99/99 [00:03<00:00, 30.58it/s]
Epoch 0 ended.
100%|██████████| 99/99 [00:03<00:00, 30.48it/s]
Epoch 1 ended.
100%|██████████| 99/99 [00:03<00:00, 30.14it/s]
Epoch 2 ended.
100%|██████████| 99/99 [00:03<00:00, 30.61it/s]
Epoch 3 ended.
100%|██████████| 99/99 [00:03<00:00, 30.31it/s]
Epoch 4 ended.
100%|██████████| 99/99 [00:03<00:00, 29.58it/s]
Epoch 5 ended.
100%|██████████| 99/99 [00:03<00:00, 30.63it/s]
Epoch 6 ended.
100%|██████████| 99/99 [00:03<00:00, 30.69it/s]
Epoch 7 ended.
100%|██████████| 99/99 [00:03<00:00, 30.61it/s]
Epoch 8 ended.
100%|██████████| 99/99 [00:03<00:00, 30.72it/s]
Epoch 9 ended.
100%|██████████| 99/99 [00:03<00:00, 30.52it/s]
Epoch 10 ended.
100%|██████████| 99/99 [00:03<00:00, 30.65it/s]
Epoch 11 ended.
100%|██████████| 99/99 [00:03<00:00, 30.70it

-- Starting training on experience 1 (Task 0) from train stream --
100%|██████████| 95/95 [00:03<00:00, 29.92it/s]
Epoch 0 ended.
100%|██████████| 95/95 [00:03<00:00, 30.04it/s]
Epoch 1 ended.
100%|██████████| 95/95 [00:03<00:00, 29.98it/s]
Epoch 2 ended.
100%|██████████| 95/95 [00:03<00:00, 30.00it/s]
Epoch 3 ended.
100%|██████████| 95/95 [00:03<00:00, 29.88it/s]
Epoch 4 ended.
100%|██████████| 95/95 [00:03<00:00, 29.63it/s]
Epoch 5 ended.
100%|██████████| 95/95 [00:03<00:00, 29.67it/s]
Epoch 6 ended.
100%|██████████| 95/95 [00:03<00:00, 30.00it/s]
Epoch 7 ended.
100%|██████████| 95/95 [00:03<00:00, 30.02it/s]
Epoch 8 ended.
100%|██████████| 95/95 [00:03<00:00, 29.63it/s]
Epoch 9 ended.
100%|██████████| 95/95 [00:03<00:00, 29.50it/s]
Epoch 10 ended.
100%|██████████| 95/95 [00:03<00:00, 29.73it/s]
Epoch 11 ended.
100%|██████████| 95/95 [00:03<00:00, 29.76it/s]
Epoch 12 ended.
100%|██████████| 95/95 [00:03<00:00, 29.83it/s]
Epoch 13 ended.
100%|██████████| 95/95 [00:03<00:00, 29.70it/s]

100%|██████████| 88/88 [00:02<00:00, 30.03it/s]
Epoch 0 ended.
100%|██████████| 88/88 [00:02<00:00, 30.17it/s]
Epoch 1 ended.
100%|██████████| 88/88 [00:02<00:00, 30.19it/s]
Epoch 2 ended.
100%|██████████| 88/88 [00:02<00:00, 30.26it/s]
Epoch 3 ended.
100%|██████████| 88/88 [00:02<00:00, 30.15it/s]
Epoch 4 ended.
100%|██████████| 88/88 [00:02<00:00, 30.20it/s]
Epoch 5 ended.
100%|██████████| 88/88 [00:02<00:00, 30.23it/s]
Epoch 6 ended.
100%|██████████| 88/88 [00:02<00:00, 30.18it/s]
Epoch 7 ended.
100%|██████████| 88/88 [00:02<00:00, 30.21it/s]
Epoch 8 ended.
100%|██████████| 88/88 [00:02<00:00, 29.90it/s]
Epoch 9 ended.
100%|██████████| 88/88 [00:02<00:00, 30.17it/s]
Epoch 10 ended.
100%|██████████| 88/88 [00:02<00:00, 30.17it/s]
Epoch 11 ended.
100%|██████████| 88/88 [00:02<00:00, 30.19it/s]
Epoch 12 ended.
100%|██████████| 88/88 [00:02<00:00, 29.85it/s]
Epoch 13 ended.
100%|██████████| 88/88 [00:02<00:00, 29.87it/s]
Epoch 14 ended.
100%|██████████| 88/88 [00:02<00:00, 30.04it/s]
Ep

100%|██████████| 96/96 [00:03<00:00, 30.05it/s]
Epoch 1 ended.
100%|██████████| 96/96 [00:03<00:00, 30.45it/s]
Epoch 2 ended.
100%|██████████| 96/96 [00:03<00:00, 30.43it/s]
Epoch 3 ended.
100%|██████████| 96/96 [00:03<00:00, 30.31it/s]
Epoch 4 ended.
100%|██████████| 96/96 [00:03<00:00, 30.32it/s]
Epoch 5 ended.
100%|██████████| 96/96 [00:03<00:00, 30.35it/s]
Epoch 6 ended.
100%|██████████| 96/96 [00:03<00:00, 30.45it/s]
Epoch 7 ended.
100%|██████████| 96/96 [00:03<00:00, 29.48it/s]
Epoch 8 ended.
100%|██████████| 96/96 [00:03<00:00, 30.01it/s]
Epoch 9 ended.
100%|██████████| 96/96 [00:03<00:00, 30.24it/s]
Epoch 10 ended.
100%|██████████| 96/96 [00:03<00:00, 30.28it/s]
Epoch 11 ended.
100%|██████████| 96/96 [00:03<00:00, 30.29it/s]
Epoch 12 ended.
100%|██████████| 96/96 [00:03<00:00, 30.05it/s]
Epoch 13 ended.
100%|██████████| 96/96 [00:03<00:00, 30.26it/s]
Epoch 14 ended.
100%|██████████| 96/96 [00:03<00:00, 30.27it/s]
Epoch 15 ended.
100%|██████████| 96/96 [00:03<00:00, 30.32it/s]
E

100%|██████████| 93/93 [00:03<00:00, 29.97it/s]
Epoch 2 ended.
100%|██████████| 93/93 [00:03<00:00, 30.14it/s]
Epoch 3 ended.
100%|██████████| 93/93 [00:03<00:00, 30.22it/s]
Epoch 4 ended.
100%|██████████| 93/93 [00:03<00:00, 30.33it/s]
Epoch 5 ended.
100%|██████████| 93/93 [00:03<00:00, 29.95it/s]
Epoch 6 ended.
100%|██████████| 93/93 [00:03<00:00, 30.19it/s]
Epoch 7 ended.
100%|██████████| 93/93 [00:03<00:00, 30.34it/s]
Epoch 8 ended.
100%|██████████| 93/93 [00:03<00:00, 30.29it/s]
Epoch 9 ended.
100%|██████████| 93/93 [00:03<00:00, 30.27it/s]
Epoch 10 ended.
100%|██████████| 93/93 [00:03<00:00, 30.37it/s]
Epoch 11 ended.
100%|██████████| 93/93 [00:03<00:00, 30.24it/s]
Epoch 12 ended.
100%|██████████| 93/93 [00:03<00:00, 30.22it/s]
Epoch 13 ended.
100%|██████████| 93/93 [00:03<00:00, 30.23it/s]
Epoch 14 ended.
100%|██████████| 93/93 [00:03<00:00, 30.24it/s]
Epoch 15 ended.
100%|██████████| 93/93 [00:03<00:00, 30.15it/s]
Epoch 16 ended.
100%|██████████| 93/93 [00:03<00:00, 30.17it/s]
