## CL-Gym Example: Stable A-GEM on Rotated MNIST

In this example, we use Averaged Gradient Episodic Memory (A-GEM) to train on Rotated MNIST benchmark. We use the stable version of [AGEM](https://arxiv.org/abs/1812.00420.pdf) using [Stable SGD](https://proceedings.neurips.cc/paper/2020/file/518a38cc9a0173d0b2dc088166981cf8-Paper.pdf) parameters.

## 1. Defining Parameters
First, we need to define our parameters/config for our experiment.
We define all our parameters inside a python dictionary. The parameters define different aspects of continual learning examples. For example:
-  How many tasks should we learn?
-  What our batch-size will be?
-  What Optimizer will we use?
-  Where should we store our outputs?

In [1]:
import torch
import cl_gym as cl
# first let's create params/config for our experiment

def make_params() -> dict:
    import os
    from pathlib import Path
    import uuid

    params = {
            # benchmark
            'num_tasks': 5,
            'epochs_per_task': 10,
            'per_task_memory_examples': 25,
            'batch_size_train': 64,
            'batch_size_memory': 32,
            'batch_size_validation': 256,

            # algorithm
            'optimizer': 'SGD',
            'learning_rate': 0.01,
            'momentum': 0.8,
            'learning_rate_decay': 1.0,
            'criterion': torch.nn.CrossEntropyLoss(),
            'device': torch.device('cuda:4' if torch.cuda.is_available() else 'cpu'), }

    trial_id = str(uuid.uuid4())
    params['trial_id'] = trial_id
    params['output_dir'] = os.path.join("./outputs/{}".format(trial_id))
    Path(params['output_dir']).mkdir(parents=True, exist_ok=True)

    return params

## 2. Training our continual learning algorithm

Before seeing the code, let's explain the components one more time:
### 2.1 Benchmark
* We use the `RotatedMNIST` benchmark for this example. The benchmark includes gradual rotations of MNIST digits for each task. Something like this.
<div>
<img src="https://user-images.githubusercontent.com/8312051/122752221-845a0300-d245-11eb-8892-7c4119ffe1a5.png" width="300"/>
</div>

In CL-Gym we use RotatedMNIST as follows:
```python
benchmark = cl.benchmarks.RotatedMNIST(num_tasks=5)
```
---------

### 2.2 Backbone

* We use a MLP model with two hidden layers like this:
<div>
<img src="https://user-images.githubusercontent.com/8312051/122753641-67beca80-d247-11eb-87d3-dec5cc2e63d6.png" width="300"/>
</div>

To import our backbone, we use:
```python
backbone = cl.backbones.MLP2Layers(input_dim=784, hidden_dim_1=100, hidden_dim_2=100, output_dim=10)
```

You can also create your own PyTorch models. The backbone in CL-Gym is simply a lightweight wrapper around PyTorch's ``nn.Module``.

--------

### 2.3 Collecting metrics with Callbacks

The `MetricCollector` callback evaluates the model at the end of each epoch, logs the metrics, plots the accuraies to file, and stores the validation accuracies as numpy arrays to file (see outputs folder).
```python
metric_callback = cl.callbacks.MetricCollecto(num_tasks=5,
                                              eval_interval='epoch',
                                              epochs_per_task=1)
```

-------

### 2.4  Using off-the-shelf continual learning algorithms

CL-Gym includes several continual learning algorithms. Here we use A-GEM algorithm with better parameters than the original paper:

```python
cl.algorithms.AGEM(backbone, benchmark, params)
```

You can also use other algorithms. For example, for Experience Replay method, you can use:
```python
cl.algorithms.ERRingBuffer(backbone, benchmark, params)
```


-------

### 2.5 Gluing everything together with the Trainer

The `Trainer` will orchestrate the experiment by handling the non-research part of continual learning experiments.

```
trainer = cl.trainer.ContinualTrainer(algorithm, params, callbacks=[metric_manager_callback])
```


The code below implements this note:

In [2]:
from datasets.FairMNIST import NoiseMNIST
from trainers.FairContinualTrainer import FairContinualTrainer
from algorithms.agem_sensitive import AGEM_Sensitive

def train(params):
    # # benchmark: Rotated MNIST
    # benchmark = cl.benchmarks.RotatedMNIST(num_tasks=params['num_tasks'],
    #                                        per_task_memory_examples=params['per_task_memory_examples'],
    #                                        per_task_rotation=22.5)
    benchmark = cl.benchmarks.SplitCIFAR10(num_tasks=params['num_tasks'],
                                           per_task_memory_examples=params['per_task_memory_examples'])

    # backbone: MLP with 2 hidden layers
    # backbone = cl.backbones.MLP2Layers(input_dim=784, hidden_dim_1=32, hidden_dim_2=32, output_dim=10)
    backbone = cl.backbones.ResNet18Small(multi_head=True, num_classes_per_head=5, num_classes=10)

    # Algorithm: A-GEM
    algorithm = cl.algorithms.AGEM(backbone, benchmark, params)
    # algorithm = AGEM_Sensitive(backbone, benchmark, params)

    # algorithm = cl.algorithms.ERRingBuffer(backbone, benchmark, params)

    # Callbacks
    metric_manager_callback = cl.callbacks.MetricCollector(num_tasks=params['num_tasks'],
                                                           eval_interval='epoch',
                                                           epochs_per_task=params['epochs_per_task'])

    # Make trainer
    trainer = cl.trainer.ContinualTrainer(algorithm, params, callbacks=[metric_manager_callback])

    trainer.run()
    print("final avg-acc", metric_manager_callback.meters['accuracy'].compute_final())
    print("final avg-forget", metric_manager_callback.meters['forgetting'].compute_final())


params = make_params()
train(params)

Files already downloaded and verified
Files already downloaded and verified
---------------------------- Task 1 -----------------------
[1] Eval metrics for task 1 >> {'accuracy': 92.45, 'loss': 0.0007453862130641937}
[2] Eval metrics for task 1 >> {'accuracy': 94.8, 'loss': 0.0005630276277661324}
[3] Eval metrics for task 1 >> {'accuracy': 94.8, 'loss': 0.0005299254320561886}
[4] Eval metrics for task 1 >> {'accuracy': 95.4, 'loss': 0.00047980944998562335}
[5] Eval metrics for task 1 >> {'accuracy': 95.6, 'loss': 0.0005012850500643254}
[6] Eval metrics for task 1 >> {'accuracy': 96.2, 'loss': 0.0004827506206929684}
[7] Eval metrics for task 1 >> {'accuracy': 96.2, 'loss': 0.0004735014699399471}
[8] Eval metrics for task 1 >> {'accuracy': 96.15, 'loss': 0.0005718770250678063}
[9] Eval metrics for task 1 >> {'accuracy': 96.2, 'loss': 0.0005731589794158936}
[10] Eval metrics for task 1 >> {'accuracy': 95.7, 'loss': 0.0006911326423287391}
---------------------------- Task 2 --------------

In [None]:
grad_ref
[25] Eval metrics for task 1 >> {'accuracy': 44.18, 'loss': 0.013980306243896484}
[25] Eval metrics for task 2 >> {'accuracy': 56.54, 'loss': 0.00928453779220581}
[25] Eval metrics for task 3 >> {'accuracy': 70.54, 'loss': 0.00509959619641304}
[25] Eval metrics for task 4 >> {'accuracy': 90.14, 'loss': 0.0013555611550807953}
[25] Eval metrics for task 5 >> {'accuracy': 96.27, 'loss': 0.0004954223014414311}
final avg-acc 71.53399999999999


grad_batch
[25] Eval metrics for task 1 >> {'accuracy': 17.42, 'loss': 0.023993745803833007}
[25] Eval metrics for task 2 >> {'accuracy': 28.28, 'loss': 0.01793102569580078}
[25] Eval metrics for task 3 >> {'accuracy': 53.44, 'loss': 0.009112693345546723}
[25] Eval metrics for task 4 >> {'accuracy': 86.71, 'loss': 0.001737047752737999}
[25] Eval metrics for task 5 >> {'accuracy': 96.97, 'loss': 0.00040261796098202467}
final avg-acc 56.564
final avg-forget 50.2775


SyntaxError: invalid syntax (461557611.py, line 2)