***
## Replay-based approach
### (Rehearsal method)

Experiment reproducing **Replay-based** approach for Continual Learning:  
Replay methods store small subsets of previous sub-tasks, which can be used for rehearsal, thus retraining training data of old sub-tasks. The retrained samples are trained together with samples of the current sub-task.

Basic approach:
 - Sample randomly from the current experience data
 - Fill fixed Random Memory
 - Replace examples randomly to mantain an approximate equal number of examples for experience

`References:`
- ...Replay-based approaches for CL: https://arxiv.org/abs/2108.06758 
- Replay strategies: https://course.continualai.org/lectures/strategies
***

## Dataloader
**Dataloaders** are used to provide balancing between groups (e.g. task/classes/experiences). This is especially useful with unbalanced data.

*GroupBalancedDataLoader* takes a sequence of datasets and iterates over them by providing balanced mini-batches, where the number of samples is split equally among groups.

In [1]:
import avalanche

from avalanche.benchmarks import SplitMNIST
from avalanche.benchmarks.utils.data_loader import GroupBalancedDataLoader

benchmark = SplitMNIST(n_experiences=5, return_task_id=True)

dl = GroupBalancedDataLoader([exp.dataset for exp in benchmark.train_stream], batch_size=4)
for x, y, t in dl:
    print(t.tolist())
    break

[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4]


## Memory Buffers
**Memory buffers** store data up to a maximum capacity, and they implement policies to select which data to store and which to remove when the buffer is full.

The base class is the *ExemplarsBuffer*, which implements two methods:
 - update(strategy): given the strategy's state, it updates the buffer
 - resize(strategy, new_size): updates the maximum size and updates the buffer accordingly

In [2]:
from avalanche.training.storage_policy import ReservoirSamplingBuffer
from types import SimpleNamespace

benchmark = SplitMNIST(5, return_task_id=False)
storage_p = ReservoirSamplingBuffer(max_size=30) # substitutes samples randomly

# At first, the buffer is empty..
print(f"Max buffer size: {storage_p.max_size}, current size: {len(storage_p.buffer)} \n")


for i in range(5):
    # After each update some samples are substituted with new data
    # Reservoir sampling select these samples randomly
    strategy_state = SimpleNamespace(experience=benchmark.train_stream[i])
    storage_p.update(strategy_state)
    
    print(f"Max buffer size: {storage_p.max_size}, current size: {len(storage_p.buffer)}")
    print(f"class targets: {storage_p.buffer.targets}\n")

Max buffer size: 30, current size: 0 

Max buffer size: 30, current size: 30
class targets: [0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1]

Max buffer size: 30, current size: 30
class targets: [6, 0, 0, 6, 6, 6, 1, 2, 6, 0, 0, 0, 0, 6, 1, 0, 2, 1, 1, 0, 2, 6, 6, 2, 2, 6, 1, 2, 1, 2]

Max buffer size: 30, current size: 30
class targets: [6, 0, 9, 0, 6, 9, 6, 9, 6, 1, 2, 9, 3, 9, 6, 0, 0, 0, 3, 3, 0, 9, 6, 1, 0, 2, 3, 3, 1, 9]

Max buffer size: 30, current size: 30
class targets: [6, 0, 9, 0, 6, 9, 8, 6, 9, 6, 7, 1, 2, 9, 7, 3, 9, 7, 6, 0, 8, 0, 8, 8, 0, 3, 3, 0, 9, 8]

Max buffer size: 30, current size: 30
class targets: [4, 6, 0, 9, 0, 5, 6, 9, 8, 5, 6, 9, 6, 7, 1, 2, 9, 7, 5, 3, 9, 7, 6, 5, 0, 8, 0, 8, 8, 0]



## Replay plugin
**Replay plugin** is used to update the rehearsal buffer and set the dataloader.

In [3]:
from avalanche.benchmarks.utils.data_loader import ReplayDataLoader
from avalanche.training.plugins import StrategyPlugin

class CustomReplay(StrategyPlugin):
    def __init__(self, storage_policy):
        super().__init__()
        self.storage_policy = storage_policy

    def before_training_exp(self, strategy,
                            num_workers: int = 0, shuffle: bool = True,
                            **kwargs):
        """ Here we set the dataloader. """
        if len(self.storage_policy.buffer) == 0:
            # first experience. We don't use the buffer, no need to change
            # the dataloader.
            return

        # replay dataloader samples mini-batches from the memory and current
        # data separately and combines them together.
        print("Override the dataloader.")
        strategy.dataloader = ReplayDataLoader(
            strategy.adapted_dataset,
            self.storage_policy.buffer,
            oversample_small_tasks=True,
            num_workers=num_workers,
            batch_size=strategy.train_mb_size,
            shuffle=shuffle)

    def after_training_exp(self, strategy: "BaseStrategy", **kwargs):
        """ We update the buffer after the experience.
            You can use a different callback to update the buffer in a different place
        """
        print("Buffer update.")
        self.storage_policy.update(strategy, **kwargs)

Use the plugin to train the CL model

In [4]:
# Testing framework and test runner
import unittest

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam, SGD
from torchvision.transforms import Compose, ToTensor, Normalize, RandomCrop

# Avalanche library from ContinualAI
import avalanche

from avalanche.models import SimpleMLP
from avalanche.benchmarks.datasets import MNIST
from avalanche.benchmarks.generators import nc_benchmark

from avalanche.training import Naive
from avalanche.training.plugins import ReplayPlugin
from avalanche.training import ParametricBuffer
from avalanche.training import RandomExemplarsSelectionStrategy

# Loggers
from avalanche.logging import InteractiveLogger, TensorboardLogger

# Evaluation
from avalanche.training.plugins import EvaluationPlugin
from avalanche.evaluation.metrics import accuracy_metrics, timing_metrics, forgetting_metrics

# utils
from utils import arguments, get_average_metric, get_target_result

## Replay-based method
class Replay(unittest.TestCase): #TestCase class

    #### Split MNIST benchmark
    def test_smnist(self, override_args=None):
        
        train_transform = Compose([
            RandomCrop(28, padding=4),
            ToTensor(),
            Normalize((0.1307,), (0.3081,)) ])

        test_transform = Compose([
            ToTensor(),
            Normalize((0.1307,), (0.3081,)) ])

        mnist_train = MNIST('./data/mnist', train=True, download=True, transform=train_transform)

        mnist_test = MNIST('./data/mnist', train=False, download=True, transform=test_transform)
        
        
        # -> BENCHMARK based on MNIST and Class-Incremental learning: Split-MNIST benchmark
        scenario = nc_benchmark(train_dataset=mnist_train,
                                test_dataset =mnist_test,
                                n_experiences=5,
                                task_labels = False)
               
        # --- Strategy instantiation --- # 
        # 1. Model
        # 2. Optimizer
        # 3. Loss function
        
        # -> ADDITIONAL ARGUMENTS allow to customize training
        args = arguments({ 'cuda': 0,              # GPU
                           'epochs': 4,            # Training epochs
                           'learning_rate': 0.001, # Learning rate
                           'eval_mb_size':  100,
                           'train_mb_size': 100}, override_args) # Minibatch size

        # Set up and run CUDA operations,
        # if CUDA is available, utilize GPUs for computation.
        device = torch.device(f"cuda:{args.cuda}"
                              if torch.cuda.is_available() and args.cuda >= 0 
                              else "cpu")
        
        model = SimpleMLP(num_classes=10, input_size=28*28, hidden_size=256, hidden_layers=1, drop_rate=0)
        optimizer = SGD(model.parameters(), lr=args.learning_rate)
        criterion = CrossEntropyLoss()
        
        storage_p = ParametricBuffer(max_size=500, groupby='class',
                                     selection_strategy=RandomExemplarsSelectionStrategy())
        
        # ------------------------ LOG ------------------------ #
        # logging results over-time to examine the experiment in real-time
        loggers = []
        
        # log to Tensorboard
        loggers.append(TensorboardLogger())
        
        # Avalanche logging module, displays a progress bar during training and evaluation
        interactive_logger = avalanche.logging.InteractiveLogger()
        
        # -------------------- EVALUATION -------------------- #
        # Metrics of main interest to be tracked
        eval_plugin = EvaluationPlugin(
            accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
            #timing_metrics(epoch=True),
            #forgetting_metrics(experience=True, stream=True),
            #loggers=[InteractiveLogger()],
            loggers=loggers,
            benchmark=scenario
        )

        # -> CONTINUAL LEARNING STRATEGY: Replay-based
        cl_strategy = avalanche.training.Naive(model, optimizer, criterion,
                                             # additional arguments
                                             train_mb_size=args.train_mb_size, 
                                             train_epochs=args.epochs,
                                             device=device,
                                             # Replay plugin
                                             plugins=[CustomReplay(storage_p)],
                                             # evaluation
                                             evaluator=eval_plugin
        )
        
        # --- Training loop --- #
        print('Starting experiment...')
        
        for experience in scenario.train_stream:
            print('Current experience {}, contains: {} patterns'.format(experience.current_experience, len(experience.dataset)))
            print('Current classes: ',experience.classes_in_this_experience)
            
            # Train
            cl_strategy.train(experience)
            print('Training completed')
            
            # Accuracy over the whole test set (no access to task-ID at inference time)
            print('Computing accuracy over the whole test set')
            res = cl_strategy.eval(scenario.test_stream)
            
        
        avg_stream_acc = get_average_metric(res)
        print(f"Replay-based method Average Stream Accuracy: {avg_stream_acc:.2f}")

### Run and Evaluate the experiment
- Create an instance of the strategy object
- Execute the strategy on a benchmark

In [5]:
# Create the strategy
s = Replay()

# Run the experiment with custom parameters
s.test_smnist()


Starting experiment...
Current experience 0, contains: 12188 patterns
Current classes:  [0, 7]
Buffer update.
Training completed
Computing accuracy over the whole test set
Current experience 1, contains: 11800 patterns
Current classes:  [2, 4]
Override the dataloader.
Buffer update.
Training completed
Computing accuracy over the whole test set
Current experience 2, contains: 11982 patterns
Current classes:  [8, 3]
Override the dataloader.
Buffer update.
Training completed
Computing accuracy over the whole test set
Current experience 3, contains: 11867 patterns
Current classes:  [9, 6]
Override the dataloader.
Buffer update.
Training completed
Computing accuracy over the whole test set
Current experience 4, contains: 12163 patterns
Current classes:  [1, 5]
Override the dataloader.
Buffer update.
Training completed
Computing accuracy over the whole test set
Replay-based method Average Stream Accuracy: 0.46
