***
## Generative Replay

Experiment reproducing **Generative Replay** strategy:
Deep Generative Replay for a Scholar consisting of a Solver and Generator.

*Code implements Deep Generative Replay as described in:*
- Deep Generative Replay: https://arxiv.org/abs/1705.08690
***

In [1]:
# --- LIBRARIES AND UTILS ---
import argparse
import unittest

# Avalanche library from ContinualAI
import avalanche

from avalanche.models import SimpleMLP
from avalanche.benchmarks.datasets import MNIST
from avalanche.benchmarks.generators import nc_benchmark
from torchvision.transforms import Compose, ToTensor, Normalize, RandomCrop

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam

# Loggers
from avalanche.logging import InteractiveLogger, TensorboardLogger

# Evaluation
from avalanche.training.plugins import GenerativeReplayPlugin, EvaluationPlugin
from avalanche.evaluation.metrics import accuracy_metrics, timing_metrics, forgetting_metrics

# Extras: Model and utils
from utils import arguments, get_average_metric, get_target_result

***
### Generator: VAE

The generator implemented in *GenerativeModel.py* is hereafter wrapped into a trainable strategy, to be passed to the generator_strategy parameter when using Deep Generative Replay learning strategy.

In [2]:
# --- custom imports

# Import the VAE generative model and its loss function
from imports.GenerativeModel import VAE_model, VAE_loss
# Import the VAE training class
from imports.TrainingInstances import VAE_TrainingStrategy

# Generative Model
generator = VAE_model((1, 28, 28), nhid=2, device=None)

# Optimizer of the generator
optimizer_generator = Adam(filter(lambda p: p.requires_grad, generator.parameters()),
    lr=0.01,
    weight_decay=0.0001)

# ..this is what to pass to the generator_strategy parameter of GR
generator_TrainableStrategy= VAE_TrainingStrategy(generator, optimizer_generator, VAE_loss,
    # additional arguments
    train_mb_size=100,
    train_epochs =4,
    eval_mb_size =1,
    device =None,
                                                   
    # the trainable strategy has to employ GenerativeReplayPlugin()
    # GenerativeReplayPlugin(None) as it is the model itself [strategy.model (VAE_model)] to be trained
    plugins=[GenerativeReplayPlugin(
        replay_size=100,               # batch size of replay added to each data batch
        increasing_replay_size=False,  # if True, double replay data added to each batch
    )])

***
### Generative Replay strategy

**Generative Replay** strategy using a solver-generator pair

In [3]:
# --- GENERATIVE REPLAY TECHNIQUE ---
# Import the Generative Replay strategy instance
from imports.TrainingInstances import GR


class GenRepl(unittest.TestCase): #TestCase class

    # Split-MNIST benchmark
    def test_smnist(self, override_args=None):

        # --- CONFIG
        device = torch.device(f"cuda:{args.cuda}"
            if torch.cuda.is_available() and args.cuda >= 0
            else "cpu")
        

        # --- BENCHMARK and SCENARIO
        train_transform = Compose([RandomCrop(28, padding=4), ToTensor(), Normalize((0.1307,), (0.3081,)) ])
        test_transform  = Compose([ToTensor(), Normalize((0.1307,), (0.3081,)) ])

        mnist_train = MNIST('./data/mnist', train=True, download=True, transform=train_transform)
        mnist_test  = MNIST('./data/mnist', train=False, download=True, transform=test_transform)
        
        
        scenario = nc_benchmark(train_dataset= mnist_train,
                                test_dataset = mnist_test,
                                n_experiences= 5,
                                task_labels  = False)
        # ---------
        
        # --- Strategy instantiation --- # 
        # 1. Model: SOLVER
        # 2. Optimizer of the solver
        # 3. Loss function of the solver
        
        # ADDITIONAL ARGUMENTS allow to customize training
        args = arguments({ 'cuda': 0,              # GPU or CPU
                           'train_epochs': 2,      # Training epochs
                           'train_mb_size': 100,   # Train minibatch size
                           'eval_mb_size' : 100,   # Eval minibatch size
                            }, override_args) 

        # Solver model: simple classifier
        model = SimpleMLP(num_classes=10, input_size=28*28, hidden_size=400, hidden_layers=2, drop_rate=0)
        
        # Optimizer of the solver
        optimizer = Adam(model.parameters(), lr=0.001)
        
        # Loss function of the solver
        criterion = CrossEntropyLoss()
        
        # --- EVALUATION --- #
        # Metrics of main interest to be tracked
        eval_plugin = EvaluationPlugin(
            accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
            #timing_metrics(epoch=True),
            #forgetting_metrics(experience=True, stream=True),
            loggers=[InteractiveLogger()], # displays a progress bar during training and evaluation
            #loggers=[],
            benchmark=scenario
        )

        # --- STRATEGY INSTANCE --- #
        cl_strategy = GR(model, optimizer, criterion,
            # additional arguments
            train_mb_size = args.train_mb_size,
            eval_mb_size  = args.eval_mb_size, 
            train_epochs  = args.train_epochs,
            device = device,                                       
            # evaluation
            evaluator = eval_plugin,                       
                                       
            # :param generator_strategy: GENERATOR wrapped into a trainable strategy
            generator_strategy = generator_TrainableStrategy
        )
        

        # --- TRAINING AND EVAL --- #
        print('Starting experiment...')
        
        for experience in scenario.train_stream:
            print('Current experience {}, contains: {} patterns'.format(experience.current_experience, len(experience.dataset)))
            print('Current classes: ',experience.classes_in_this_experience)
            
            # Train
            cl_strategy.train(experience)
            print('Training completed')
            
            # Accuracy over the whole test set (no access to task-ID at inference time)
            print('Computing accuracy over the whole test set')
            res = cl_strategy.eval(scenario.test_stream)
            
        
        avg_stream_acc = get_average_metric(res)
        print(f"Generative Replay - average stream accuracy: {avg_stream_acc:.2f}")

### Run and Evaluate the experiment
- Create an instance of the strategy object
- Execute the strategy on a benchmark

In [4]:
# Create the strategy
s = GenRepl()

# Run the experiment with custom parameters
s.test_smnist()


Starting experiment...
Current experience 0, contains: 12223 patterns
Current classes:  [2, 7]
-- >> Start of training phase << --
100%|██████████| 123/123 [00:04<00:00, 25.43it/s]
Epoch 0 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9322
	Top1_Acc_MB/train_phase/train_stream/Task000 = 1.0000
100%|██████████| 123/123 [00:04<00:00, 25.79it/s]
Epoch 1 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9804
	Top1_Acc_MB/train_phase/train_stream/Task000 = 1.0000
-- >> Start of training phase << --
100%|██████████| 123/123 [00:05<00:00, 20.54it/s]
Epoch 0 ended.
100%|██████████| 123/123 [00:05<00:00, 20.80it/s]
Epoch 1 ended.
100%|██████████| 123/123 [00:05<00:00, 20.93it/s]
Epoch 2 ended.
100%|██████████| 123/123 [00:05<00:00, 20.98it/s]
Epoch 3 ended.
-- >> End of training phase << --
-- >> End of training phase << --
Training completed
Computing accuracy over the whole test set
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test str