In [1]:
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import SGD

from avalanche.benchmarks.classic import PermutedMNIST
from avalanche.training.plugins import EvaluationPlugin
from avalanche.evaluation.metrics import accuracy_metrics
from avalanche.models import SimpleMLP
from avalanche.training.supervised import Naive


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

# Config
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# model
model = SimpleMLP(num_classes=10)

# CL Benchmark Creation
perm_mnist = PermutedMNIST(n_experiences=3)
train_stream = perm_mnist.train_stream
test_stream = perm_mnist.test_stream

# Prepare for training & testing
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = CrossEntropyLoss()
eval_plugin = EvaluationPlugin(
    accuracy_metrics(minibatch=True, epoch=True, epoch_running=True, 
                     experience=True, stream=True))

# Continual learning strategy
cl_strategy = Naive(
    model, optimizer, criterion, train_mb_size=32, train_epochs=2, 
    eval_mb_size=32, evaluator=eval_plugin, device=device)

# train and test loop

results = []
for train_task in train_stream:
    cl_strategy.train(train_task, num_workers=4)
    results.append(cl_strategy.eval(test_stream))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /home/thealch3mist/.avalanche/data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 8056688.52it/s] 


Extracting /home/thealch3mist/.avalanche/data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /home/thealch3mist/.avalanche/data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /home/thealch3mist/.avalanche/data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 331041.49it/s]


Extracting /home/thealch3mist/.avalanche/data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /home/thealch3mist/.avalanche/data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /home/thealch3mist/.avalanche/data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 2503850.83it/s]


Extracting /home/thealch3mist/.avalanche/data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /home/thealch3mist/.avalanche/data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /home/thealch3mist/.avalanche/data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3698413.66it/s]


Extracting /home/thealch3mist/.avalanche/data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /home/thealch3mist/.avalanche/data/mnist/MNIST/raw





In [8]:
import pprint
pprint.pprint(results)

[{'Top1_Acc_Epoch/train_phase/train_stream/Task000': 0.9238666666666666,
  'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.9469,
  'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.1097,
  'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.143,
  'Top1_Acc_MB/train_phase/train_stream/Task000': 0.96875,
  'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.39986666666666665,
  'Top1_RunningAcc_Epoch/train_phase/train_stream/Task000': 0.9238666666666666},
 {'Top1_Acc_Epoch/train_phase/train_stream/Task000': 0.9346166666666667,
  'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.9375,
  'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.9529,
  'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.1224,
  'Top1_Acc_MB/train_phase/train_stream/Task000': 0.90625,
  'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.6709333333333334,
  'Top1_RunningAcc_Epoch/train_phase/train_stream/Task000': 0.9346166666666667},
 {'Top1_Acc_Epoch/train_phase/train_stream/Tas