<a href="https://colab.research.google.com/github/gogamid/ml-notebooks/blob/main/sccl-experiments/ewc_avalanche.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install git+https://github.com/ContinualAI/avalanche.git

In [4]:
import torch
import wandb
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from avalanche.benchmarks.classic import SplitMNIST
from avalanche.evaluation.metrics import accuracy_metrics
from avalanche.models import SimpleMLP
from avalanche.logging import InteractiveLogger, TextLogger, WandBLogger
from avalanche.training.plugins import EvaluationPlugin
from avalanche.training import EWC

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

scenario = SplitMNIST(n_experiences=5, fixed_class_order=[0,1,2,3,4,5,6,7,8,9])

model = SimpleMLP(
    num_classes=scenario.n_classes,
    hidden_size=512,
    hidden_layers=1,
    drop_rate=0
)

loggers=[InteractiveLogger(),TextLogger(open('logSplitMNIST.txt', 'a'))]
loggers.append(WandBLogger(project_name="avalanche", run_name="ewc_split_mnist"))


eval_plugin = EvaluationPlugin(
    accuracy_metrics(experience=True, stream=True),
    loggers=loggers
)
cl_strategy = EWC(
    model,
    SGD(model.parameters(), lr=0.001),
    CrossEntropyLoss(),
    ewc_lambda=1,
    mode="separate",
    decay_factor=None,
    train_mb_size=256,
    train_epochs=10,
    eval_mb_size=128,
    device=device,
    evaluator=eval_plugin
)

# TRAINING LOOP
print('Starting experiment...')
results = []
for experience in scenario.train_stream:
    print("Start of experience ", experience.current_experience)
    cl_strategy.train(experience)
    print('Training completed')

    print('Computing accuracy on the whole test set')
    results.append(cl_strategy.eval(scenario.test_stream))
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.018 MB uploaded\r'), FloatProgress(value=0.06638550452983442, max=1.…

0,1
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,█▂▁▁▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,▁█▁▁▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002,▁▁█▁▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003,▁▁▁█▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004,▁▁▁▁█
Top1_Acc_Stream/eval_phase/test_stream/Task000,▇█▁▄▃
TrainingExperience,▁▃▅▆█

0,1
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,0.0
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,0.0
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002,0.0
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003,0.01662
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004,0.95915
Top1_Acc_Stream/eval_phase/test_stream/Task000,0.1935
TrainingExperience,4.0


Starting experiment...
Start of experience  0
-- >> Start of training phase << --
100%|██████████| 50/50 [00:02<00:00, 21.73it/s]
Epoch 0 ended.
100%|██████████| 50/50 [00:01<00:00, 25.40it/s]
Epoch 1 ended.
100%|██████████| 50/50 [00:01<00:00, 25.41it/s]
Epoch 2 ended.
100%|██████████| 50/50 [00:01<00:00, 25.81it/s]
Epoch 3 ended.
100%|██████████| 50/50 [00:02<00:00, 23.64it/s]
Epoch 4 ended.
100%|██████████| 50/50 [00:02<00:00, 16.67it/s]
Epoch 5 ended.
100%|██████████| 50/50 [00:01<00:00, 25.24it/s]
Epoch 6 ended.
100%|██████████| 50/50 [00:02<00:00, 24.84it/s]
Epoch 7 ended.
100%|██████████| 50/50 [00:02<00:00, 24.56it/s]
Epoch 8 ended.
100%|██████████| 50/50 [00:02<00:00, 24.62it/s]
Epoch 9 ended.
-- >> End of training phase << --
Training completed
Computing accuracy on the whole test set
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 17/17 [00:00<00:00, 34.30it/s]
> Eval on experience 0 (Task 0) from test stream end

VBox(children=(Label(value='0.001 MB of 0.041 MB uploaded\r'), FloatProgress(value=0.029783055919207033, max=1…

0,1
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,█▁▁▁▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,▁█▁▁▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002,▁▁█▁▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003,▁▁▁█▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004,▁▁▁▁█
Top1_Acc_Stream/eval_phase/test_stream/Task000,█▅▁▄▄
TrainingExperience,▁▃▅▆█

0,1
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,0.0
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,0.0
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002,0.0
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003,0.02266
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004,0.95764
Top1_Acc_Stream/eval_phase/test_stream/Task000,0.1944
TrainingExperience,4.0
