In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
import pickle
import sys
from tqdm import tqdm

from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch import cuda, device

from torchvision import transforms


from avalanche.benchmarks.classic import SplitMNIST
from avalanche.models import SimpleCNN
from avalanche.training import EWC
from avalanche.training.plugins import EvaluationPlugin
from avalanche.logging import InteractiveLogger
from avalanche.evaluation.metrics import (
    accuracy_metrics,
    bwt_metrics,
    forgetting_metrics,
)

sys.path.append("../base_code/")

from base_code.constants import DATASETS_PATH, SAVED_METRICS_PATH
from base_code.seed import set_seed


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
TORCH_DEVICE = device("cuda" if cuda.is_available() else "cpu")
N_EXPERIENCES = 5
EXPERIMENT_SEED = 1234
METRICS_SUBFOLDER = "split_mnist"
METRICS_FILENAME = f"{METRICS_SUBFOLDER}/ewc.pkl"
set_seed(EXPERIMENT_SEED)


In [4]:
# transform to 3 channels and to pil image
train_transform = transforms.Compose(
    [transforms.Lambda(lambda x: x.repeat(3, 1, 1))]
)
test_transform = transforms.Compose(
    [transforms.Lambda(lambda x: x.repeat(3, 1, 1))]
)


In [5]:
scenario = SplitMNIST(N_EXPERIENCES, seed=EXPERIMENT_SEED, dataset_root=DATASETS_PATH, eval_transform=test_transform, train_transform=train_transform)

train_stream = scenario.train_stream
test_stream = scenario.test_stream


In [6]:
eval_plugin = EvaluationPlugin(
    accuracy_metrics(experience=True, stream=True),
    forgetting_metrics(experience=True, stream=True),
    bwt_metrics(experience=True, stream=True),
    
    loggers=[InteractiveLogger()]
)

model = SimpleCNN(num_classes=scenario.n_classes).to(TORCH_DEVICE)
optimizer = SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=0.0005)
criterion = CrossEntropyLoss().to(TORCH_DEVICE)
strategy = EWC(
    model,
    optimizer,
    criterion,
    ewc_lambda=50,
    train_epochs=5,
    train_mb_size=128,
    evaluator=eval_plugin,
    eval_mb_size=128,
    device=TORCH_DEVICE,
)


In [7]:
for experience in tqdm(train_stream):
    strategy.train(experience)
    strategy.eval(test_stream)


  0%|          | 0/5 [00:00<?, ?it/s]

-- >> Start of training phase << --
100%|██████████| 88/88 [00:13<00:00,  6.61it/s]
Epoch 0 ended.
100%|██████████| 88/88 [00:13<00:00,  6.44it/s]
Epoch 1 ended.
100%|██████████| 88/88 [00:13<00:00,  6.57it/s]
Epoch 2 ended.
100%|██████████| 88/88 [00:13<00:00,  6.55it/s]
Epoch 3 ended.
100%|██████████| 88/88 [00:13<00:00,  6.54it/s]
Epoch 4 ended.
-- >> End of training phase << --
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 15/15 [00:00<00:00, 26.75it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.5245
-- Starting eval on experience 1 (Task 0) from test stream --
100%|██████████| 16/16 [00:00<00:00, 28.39it/s]
> Eval on experience 1 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.0000
-- Starting eval on experience 2 (Task 0) from test stream --
100%|██████████| 16/16 [00:00<00:00, 27.92it/s]
> Eval on experienc

 20%|██        | 1/5 [01:22<05:31, 82.91s/it]


> Eval on experience 4 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004 = 0.0000
-- >> End of eval phase << --
	StreamBWT/eval_phase/test_stream = 0.0000
	StreamForgetting/eval_phase/test_stream = 0.0000
	Top1_Acc_Stream/eval_phase/test_stream/Task000 = 0.0983
-- >> Start of training phase << --
100%|██████████| 93/93 [00:14<00:00,  6.43it/s]
Epoch 0 ended.
100%|██████████| 93/93 [00:14<00:00,  6.51it/s]
Epoch 1 ended.
100%|██████████| 93/93 [00:13<00:00,  6.65it/s]
Epoch 2 ended.
100%|██████████| 93/93 [00:14<00:00,  6.62it/s]
Epoch 3 ended.
100%|██████████| 93/93 [00:14<00:00,  6.62it/s]
Epoch 4 ended.
-- >> End of training phase << --
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 15/15 [00:00<00:00, 29.69it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	ExperienceBWT/eval_phase/test_stream/Task000/Exp000 = -0.5245
	ExperienceForgetting/eval_phase/test_stream/Task000/Exp

 40%|████      | 2/5 [02:49<04:15, 85.14s/it]


> Eval on experience 4 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004 = 0.0000
-- >> End of eval phase << --
	StreamBWT/eval_phase/test_stream = -0.5245
	StreamForgetting/eval_phase/test_stream = 0.5245
	Top1_Acc_Stream/eval_phase/test_stream/Task000 = 0.0974
-- >> Start of training phase << --
100%|██████████| 93/93 [00:14<00:00,  6.63it/s]
Epoch 0 ended.
100%|██████████| 93/93 [00:13<00:00,  6.66it/s]
Epoch 1 ended.
100%|██████████| 93/93 [00:13<00:00,  6.66it/s]
Epoch 2 ended.
100%|██████████| 93/93 [00:14<00:00,  6.63it/s]
Epoch 3 ended.
100%|██████████| 93/93 [00:13<00:00,  6.66it/s]
Epoch 4 ended.
-- >> End of training phase << --
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 15/15 [00:00<00:00, 30.13it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	ExperienceBWT/eval_phase/test_stream/Task000/Exp000 = -0.5245
	ExperienceForgetting/eval_phase/test_stream/Task000/Ex

 60%|██████    | 3/5 [04:15<02:50, 85.44s/it]


> Eval on experience 4 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004 = 0.0000
-- >> End of eval phase << --
	StreamBWT/eval_phase/test_stream = -0.5050
	StreamForgetting/eval_phase/test_stream = 0.5050
	Top1_Acc_Stream/eval_phase/test_stream/Task000 = 0.0958
-- >> Start of training phase << --
100%|██████████| 102/102 [00:15<00:00,  6.64it/s]
Epoch 0 ended.
100%|██████████| 102/102 [00:15<00:00,  6.51it/s]
Epoch 1 ended.
100%|██████████| 102/102 [00:15<00:00,  6.54it/s]
Epoch 2 ended.
100%|██████████| 102/102 [00:15<00:00,  6.55it/s]
Epoch 3 ended.
100%|██████████| 102/102 [00:15<00:00,  6.57it/s]
Epoch 4 ended.
-- >> End of training phase << --
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 15/15 [00:00<00:00, 29.28it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	ExperienceBWT/eval_phase/test_stream/Task000/Exp000 = -0.5245
	ExperienceForgetting/eval_phase/test_stream/

 80%|████████  | 4/5 [05:50<01:29, 89.26s/it]


> Eval on experience 4 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004 = 0.0000
-- >> End of eval phase << --
	StreamBWT/eval_phase/test_stream = -0.4990
	StreamForgetting/eval_phase/test_stream = 0.4990
	Top1_Acc_Stream/eval_phase/test_stream/Task000 = 0.1135
-- >> Start of training phase << --
100%|██████████| 95/95 [00:14<00:00,  6.60it/s]
Epoch 0 ended.
100%|██████████| 95/95 [00:14<00:00,  6.51it/s]
Epoch 1 ended.
100%|██████████| 95/95 [00:14<00:00,  6.64it/s]
Epoch 2 ended.
100%|██████████| 95/95 [00:14<00:00,  6.56it/s]
Epoch 3 ended.
100%|██████████| 95/95 [00:14<00:00,  6.57it/s]
Epoch 4 ended.
-- >> End of training phase << --
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 15/15 [00:00<00:00, 29.00it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	ExperienceBWT/eval_phase/test_stream/Task000/Exp000 = -0.5245
	ExperienceForgetting/eval_phase/test_stream/Task000/Ex

100%|██████████| 5/5 [07:19<00:00, 87.83s/it]


> Eval on experience 4 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004 = 0.4925
-- >> End of eval phase << --
	StreamBWT/eval_phase/test_stream = -0.5055
	StreamForgetting/eval_phase/test_stream = 0.5055
	Top1_Acc_Stream/eval_phase/test_stream/Task000 = 0.0980





In [9]:
# save metrics
accuracies: dict[int, list[float]] = dict()
forgettings: dict[int, list[float]] = dict()
bwt: dict[int, list[float]] = dict()

for i in range(N_EXPERIENCES):
    filled_i = str(i).zfill(3)
    accuracies[f"Task{i}"] = eval_plugin.get_all_metrics()[
        f"Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp{filled_i}"
    ][1]
    forgettings[f"Task{i}"] = eval_plugin.get_all_metrics()[
        f"ExperienceForgetting/eval_phase/test_stream/Task000/Exp{filled_i}"
    ][1]
    bwt[f"Task{i}"] = eval_plugin.get_all_metrics()[
        f"ExperienceBWT/eval_phase/test_stream/Task000/Exp{filled_i}"
    ][1]
        

accuracies["Overall"] = eval_plugin.get_all_metrics()[
    "Top1_Acc_Stream/eval_phase/test_stream/Task000"
][1]

pickle.dump({
    "accuracies": accuracies,
    "forgettings": forgettings,
    "bwt": bwt
}, open(SAVED_METRICS_PATH / METRICS_FILENAME, "wb"))
