In [49]:
%load_ext autoreload
%autoreload 2

# Libraries

In [3]:
from torchvision.transforms import Compose, ToTensor, Normalize, RandomCrop, Grayscale

In [5]:
from avalanche.benchmarks.datasets import MNIST
from avalanche.benchmarks.classic import PermutedMNIST
from avalanche.training import EWC, GEM, AGEM
from avalanche.training.plugins import EvaluationPlugin
from avalanche.benchmarks.generators import nc_benchmark
from avalanche.evaluation.metrics import (
    bwt_metrics,
    forgetting_metrics,
    accuracy_metrics,
    Accuracy,
    AccuracyPerTaskPluginMetric
)

from torch.nn import CrossEntropyLoss
from torch.optim import Adam

In [6]:
from tqdm import tqdm

## Custom Libraries

In [7]:
import sys
sys.path.append("../base_code/")

from base_code.constants import DATASET_PATH
from base_code.models.cnn import CNN2D

# Dataset and definitions

## Preprocessing definitions

In [8]:
train_transform = Compose([
    RandomCrop(28, padding=28),
    ToTensor(),
    Normalize((0.1307,), (0.3081,)),
])

test_transform = Compose([
    ToTensor(),
    Normalize((0.1307,), (0.3081,)),
])

## Dataset loading

We load state-of-the-art dataset Modified NIST

In [9]:
mnist_train = MNIST(DATASET_PATH, train=True, transform=train_transform)
mnist_test = MNIST(DATASET_PATH, train=False, transform=test_transform)

## Scenario creation with train test streaming

In this point, we define our scenario considering a training where in every experience of it, a new class is presented. This is, first we train with a class $a$, the following experience we train with class $b$ ($a \neq b$)

In [10]:
scenario = nc_benchmark(
    mnist_train, mnist_test,
    n_experiences=len(mnist_train.classes), shuffle=True, seed=1234, task_labels=False
)

train_stream = scenario.train_stream
test_stream = scenario.test_stream

## Evaluation metrics definition

In [11]:
eval_plugin = EvaluationPlugin(
    bwt_metrics(experience=True, stream=True),
    accuracy_metrics(experience=True, stream=True),
    forgetting_metrics(experience=True, stream=True)
)



## Model, Optimizer, Loss function and Strategy definition

* `model`: Multi Layer Perceptron
* `Optimizer`: Adam
* `Loss function`: Cross Entropy
* `Strategy`: Elastic Weight Consolidation

In [16]:
model = CNN2D(n_classes=scenario.n_classes, n_channels=1)
optimizer = Adam(model.parameters(), lr=0.001)
criterion = CrossEntropyLoss()
strategy = EWC(
    model, optimizer, criterion, ewc_lambda=1.0, train_epochs=5, train_mb_size=32, evaluator=eval_plugin
)

# Training and evaluation

Revisar porque el entrenamiento se está comportando de forma rara

In [17]:
results = list()

for experience in tqdm(train_stream):
    print("Current class:", experience.classes_in_this_experience[0])
    strategy.train(experience, eval_streams=[test_stream])

    print("Training completed\nComputing accuracy for whole test set")
    results.append(strategy.eval(test_stream))
    r = strategy.eval(experience)
    print(r)

  0%|          | 0/10 [00:00<?, ?it/s]

Current class: 5
Training completed
Computing accuracy for whole test set


 10%|█         | 1/10 [00:15<02:18, 15.39s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp005': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp006': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp007': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp008': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp009': 0.0, 'StreamBWT/eval_phase/test_stream': -1.0, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.0892, 'StreamForgetting/eval_phase/test_stream': 1.0, 'Top1_Acc_Exp/eval_phase/train_stream/Task000/Exp000': 1.0, 'StreamBWT/eval_phase/train_stream': 0.0, 'Top1_Acc_Stream/eval_phase/train_stream/Task000': 1.0, 'StreamForgetting/eval_phase/train_stream': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp000':

 20%|██        | 2/10 [00:33<02:15, 16.90s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp005': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp006': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp007': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp008': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp009': 0.0, 'StreamBWT/eval_phase/test_stream': -1.0, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.0982, 'StreamForgetting/eval_phase/test_stream': 1.0, 'Top1_Acc_Exp/eval_phase/train_stream/Task000/Exp000': 1.0, 'StreamBWT/eval_phase/train_stream': 0.0, 'Top1_Acc_Stream/eval_phase/train_stream/Task000': 1.0, 'StreamForgetting/eval_phase/train_stream': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp000':

 30%|███       | 3/10 [00:52<02:05, 17.98s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp005': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp006': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp007': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp008': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp009': 0.0, 'StreamBWT/eval_phase/test_stream': -1.0, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.0974, 'StreamForgetting/eval_phase/test_stream': 1.0, 'Top1_Acc_Exp/eval_phase/train_stream/Task000/Exp000': 1.0, 'StreamBWT/eval_phase/train_stream': 0.0, 'Top1_Acc_Stream/eval_phase/train_stream/Task000': 1.0, 'StreamForgetting/eval_phase/train_stream': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp000':

 40%|████      | 4/10 [01:13<01:54, 19.07s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp005': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp006': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp007': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp008': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp009': 0.0, 'StreamBWT/eval_phase/test_stream': -1.0, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.1032, 'StreamForgetting/eval_phase/test_stream': 1.0, 'Top1_Acc_Exp/eval_phase/train_stream/Task000/Exp000': 1.0, 'StreamBWT/eval_phase/train_stream': 0.0, 'Top1_Acc_Stream/eval_phase/train_stream/Task000': 1.0, 'StreamForgetting/eval_phase/train_stream': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp000':

 50%|█████     | 5/10 [01:35<01:40, 20.08s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp005': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp006': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp007': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp008': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp009': 0.0, 'StreamBWT/eval_phase/test_stream': -1.0, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.0958, 'StreamForgetting/eval_phase/test_stream': 1.0, 'Top1_Acc_Exp/eval_phase/train_stream/Task000/Exp000': 1.0, 'StreamBWT/eval_phase/train_stream': 0.0, 'Top1_Acc_Stream/eval_phase/train_stream/Task000': 1.0, 'StreamForgetting/eval_phase/train_stream': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp000':

 60%|██████    | 6/10 [01:59<01:25, 21.41s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp005': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp006': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp007': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp008': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp009': 0.0, 'StreamBWT/eval_phase/test_stream': -1.0, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.1009, 'StreamForgetting/eval_phase/test_stream': 1.0, 'Top1_Acc_Exp/eval_phase/train_stream/Task000/Exp000': 1.0, 'StreamBWT/eval_phase/train_stream': 0.0, 'Top1_Acc_Stream/eval_phase/train_stream/Task000': 1.0, 'StreamForgetting/eval_phase/train_stream': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp000':

 70%|███████   | 7/10 [02:27<01:10, 23.50s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp005': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp006': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp007': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp008': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp009': 0.0, 'StreamBWT/eval_phase/test_stream': -1.0, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.1135, 'StreamForgetting/eval_phase/test_stream': 1.0, 'Top1_Acc_Exp/eval_phase/train_stream/Task000/Exp000': 1.0, 'StreamBWT/eval_phase/train_stream': 0.0, 'Top1_Acc_Stream/eval_phase/train_stream/Task000': 1.0, 'StreamForgetting/eval_phase/train_stream': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp000':

 80%|████████  | 8/10 [02:55<00:49, 24.96s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp005': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp006': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp007': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp008': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp009': 0.0, 'StreamBWT/eval_phase/test_stream': -0.8571428571428571, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.1135, 'StreamForgetting/eval_phase/test_stream': 0.8571428571428571, 'Top1_Acc_Exp/eval_phase/train_stream/Task000/Exp000': 1.0, 'StreamBWT/eval_phase/train_stream': 0.0, 'Top1_Acc_Stream/eval_phase/train_stream/Task000': 0.0, 'StreamForgetting/eval_phase/train_stream': 0.0, 'ExperienceBWT/eval_phas

 90%|█████████ | 9/10 [03:23<00:25, 25.89s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp005': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp006': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp007': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp008': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp009': 0.0, 'StreamBWT/eval_phase/test_stream': -0.75, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.1135, 'StreamForgetting/eval_phase/test_stream': 0.75, 'Top1_Acc_Exp/eval_phase/train_stream/Task000/Exp000': 1.0, 'StreamBWT/eval_phase/train_stream': 0.0, 'Top1_Acc_Stream/eval_phase/train_stream/Task000': 0.0, 'StreamForgetting/eval_phase/train_stream': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp000

100%|██████████| 10/10 [03:52<00:00, 23.27s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp005': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp006': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp007': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp008': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp009': 0.0, 'StreamBWT/eval_phase/test_stream': -0.6666666666666666, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.1135, 'StreamForgetting/eval_phase/test_stream': 0.6666666666666666, 'Top1_Acc_Exp/eval_phase/train_stream/Task000/Exp000': 1.0, 'StreamBWT/eval_phase/train_stream': 0.0, 'Top1_Acc_Stream/eval_phase/train_stream/Task000': 0.0, 'StreamForgetting/eval_phase/train_stream': 0.0, 'ExperienceBWT/eval_phas




In [132]:
for result in results:
    print(result)

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp005': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp006': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp007': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp008': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp009': 0.0, 'StreamBWT/eval_phase/test_stream': 0.0, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.0892, 'StreamForgetting/eval_phase/test_stream': 0.0}
{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003': 0.0, 'Top1_Acc_Exp/eval_phase/

In [68]:
for i, result in enumerate(results, start=1):
    keys = result.keys()
    keys = filter(lambda x: "Acc" in x,  keys)
    print("Evaluating after training task", i)
    
    for j, key in enumerate(keys, start=1):
        print(f"{key}: {result[key]}")


Evaluating after training task 1
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000: 1.0
Top1_Acc_Exp/eval_phase/test_stream/Task001/Exp001: 0.0
Top1_Acc_Exp/eval_phase/test_stream/Task002/Exp002: 0.0
Top1_Acc_Exp/eval_phase/test_stream/Task003/Exp003: 0.0
Top1_Acc_Exp/eval_phase/test_stream/Task004/Exp004: 0.0
Top1_Acc_Exp/eval_phase/test_stream/Task005/Exp005: 0.0
Top1_Acc_Exp/eval_phase/test_stream/Task006/Exp006: 0.0
Top1_Acc_Exp/eval_phase/test_stream/Task007/Exp007: 0.0
Top1_Acc_Exp/eval_phase/test_stream/Task008/Exp008: 0.0
Top1_Acc_Exp/eval_phase/test_stream/Task009/Exp009: 0.0
Top1_Acc_Stream/eval_phase/test_stream/Task009: 0.0892
Evaluating after training task 2
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000: 0.0
Top1_Acc_Exp/eval_phase/test_stream/Task001/Exp001: 1.0
Top1_Acc_Exp/eval_phase/test_stream/Task002/Exp002: 0.0
Top1_Acc_Exp/eval_phase/test_stream/Task003/Exp003: 0.0
Top1_Acc_Exp/eval_phase/test_stream/Task004/Exp004: 0.0
Top1_Acc_Exp/eval_phase/test_stream/Tas

In [None]:
model.eval