***
## AR1

### Benchmark: Split CIFAR10

Continual Learning approach obtained by combining architectural and regularization strategies.  
AR1 is built upon the CWR architectural technique with two valuable modifications: mean-shift and zero initialization, and extended by allowing weights to be tuned across batches subject to a regularization constraint (as per Synaptic Intelligence).

AR1 computational overhead is the sum of modified CWR and SI overhead.  
Considering the low computational overhead and the fact that typically Stochastic Gradient Descent is early stopped after 2 epochs, AR1 is suitable for online implementations.

`References:`
- Three scenarios for continual learning: https://arxiv.org/abs/1904.07734

**Avalanche AR1 strategy:**
https://avalanche-api.continualai.org/en/v0.1.0/generated/avalanche.training.AR1.html

AR1 **does not** accept model or optimizer, it uses its own (MobileNet with SGD).
***

In [1]:
# Testing framework and test runner
import unittest

import torch
from torchvision import transforms
from torchvision.transforms import ToTensor, Resize
from torch.nn import CrossEntropyLoss

# Avalanche library from ContinualAI
import avalanche

from avalanche.benchmarks import SplitCIFAR10
from avalanche.training import AR1

# Loggers
from avalanche.logging import InteractiveLogger, TensorboardLogger

# Evaluation
from avalanche.training.plugins import EvaluationPlugin
from avalanche.evaluation.metrics import accuracy_metrics, timing_metrics, forgetting_metrics

# Extras: Model and utils
from utils import arguments, get_average_metric, get_target_result

## AR1 technique
class AR1(unittest.TestCase): #TestCase class

    #### Split CIFAR10 benchmark
    def test_scifar10(self, override_args=None):
        
        # --- TRANSFORMATIONS
        train_transform = transforms.Compose(
            [Resize(224), ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )
        test_transform = transforms.Compose(
            [Resize(224), ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )
        # ---------

        # --- SCENARIO
        scenario = SplitCIFAR10(n_experiences = 5, 
                                train_transform = train_transform, eval_transform=test_transform)
        # ---------
        
               
        # --- ADDITIONAL ARGUMENTS allow to customize training
        args = arguments({ 'cuda': 0,              # GPU
                           'learning_rate': 0.001,   # Learning rate (SGD optimizer)
                           'L2': 0.0005,           # L2 penalty used for weight decay
                           'rm_sz': 1500,          # Size of replay buffer, shared across classes
                           'latent_layer_num': 19, # Number of layer to use as the Latent Replay Layer
                           'ewc_lambda': 0,        # Synaptic Intelligence lambda term (0 = no SI regularization)
                           'train_epochs' : 4,     # Training epochs
                           'eval_mb_size' : 128,   # Eval minibatch size
                           'train_mb_size': 128}, override_args) # Train minibatch size

        # Set up and run CUDA operations,
        # if CUDA is available, utilize GPUs for computation.
        device = torch.device(f"cuda:{args.cuda}"
                              if torch.cuda.is_available() and args.cuda >= 0 
                              else "cpu")
        
        # --- LOG
        # logging results over-time to examine the experiment in real-time
        loggers = []
        
        # log to Tensorboard
        loggers.append(TensorboardLogger())
        
        # Avalanche logging module, displays a progress bar during training and evaluation
        interactive_logger = avalanche.logging.InteractiveLogger()
        # ---------
        
        # --- EVALUATION
        # Metrics of main interest to be tracked
        eval_plugin = EvaluationPlugin(
            accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
            #timing_metrics(epoch=True),
            #forgetting_metrics(experience=True, stream=True),
            #loggers=[InteractiveLogger()],
            loggers=loggers,
            benchmark=scenario
        )
        # ---------

        # --- CONTINUAL LEARNING STRATEGY: AR1
        cl_strategy = avalanche.training.AR1(criterion=CrossEntropyLoss(),
                                              #additional arguments
                                              lr = args.learning_rate, 
                                              #l2 = args.L2,
                                              #rm_sz = args.rm_sz,
                                              #latent_layer_num = args.latent_layer_num,
                                              #ewc_lambda   = args.ewc_lambda,
                                              #eval_mb_size = args.eval_mb_size,
                                              #train_mb_size= args.train_mb_size, 
                                              #train_epochs = args.train_epochs,
                                              device = device
                                              # evaluation
                                              #evaluator = eval_plugin,
        )
        # ---------
        #cl_strategy = avalanche.training.AR1(criterion=CrossEntropyLoss, device=device)
        
        # --- Training loop --- #
        print('Starting experiment...')
        
        for experience in scenario.train_stream:
            print('Current experience {}, contains: {} patterns'.format(experience.current_experience, len(experience.dataset)))
            print('Current classes: ',experience.classes_in_this_experience)
            
            # Train
            cl_strategy.train(experience)
            print('Training completed')
            
            # Accuracy over the whole test set (no access to task-ID at inference time)
            print('Computing accuracy over the whole test set')
            res = cl_strategy.eval(scenario.test_stream)
            
        
        avg_stream_acc = get_average_metric(res)
        print(f"AR1-SplitCIFAR10 Average Stream Accuracy: {avg_stream_acc:.2f}")

### Run and Evaluate the experiment
- Create an instance of the strategy object
- Execute the strategy on a benchmark

In [3]:
# Create the strategy
s = AR1()

# Run the experiment with custom parameters
s.test_scifar10()


Files already downloaded and verified
Files already downloaded and verified
Starting experiment...
Current experience 0, contains: 10000 patterns
Current classes:  [2, 5]
-- >> Start of training phase << --
-- Starting training on experience 0 (Task 0) from train stream --


KeyboardInterrupt: 