In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
import pickle
import sys
from tqdm import tqdm

from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch import cuda, device

from torchvision import transforms


from avalanche.benchmarks.classic import SplitMNIST
from avalanche.models import SimpleCNN
from avalanche.training.plugins import EvaluationPlugin
from avalanche.logging import InteractiveLogger
from avalanche.evaluation.metrics import (
    accuracy_metrics,
    bwt_metrics,
    forgetting_metrics,
)

sys.path.append("../base_code/")

from base_code.constants import DATASETS_PATH, SAVED_METRICS_PATH
from base_code.seed import set_seed
from base_code.training import MWUNV1


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
TORCH_DEVICE = device("cuda" if cuda.is_available() else "cpu")
N_EXPERIENCES = 5
EXPERIMENT_SEED = 1234
METRICS_SUBFOLDER = "split_mnist"
METRICS_FILENAME = f"{METRICS_SUBFOLDER}/proposal_3.pkl"
set_seed(EXPERIMENT_SEED)


In [4]:
# transform to 3 channels and to pil image
train_transform = transforms.Compose(
    [transforms.Lambda(lambda x: x.repeat(3, 1, 1))]
)
test_transform = transforms.Compose(
    [transforms.Lambda(lambda x: x.repeat(3, 1, 1))]
)


In [5]:
scenario = SplitMNIST(N_EXPERIENCES, seed=EXPERIMENT_SEED, dataset_root=DATASETS_PATH, eval_transform=test_transform, train_transform=train_transform)

train_stream = scenario.train_stream
test_stream = scenario.test_stream


In [6]:
eval_plugin = EvaluationPlugin(
    accuracy_metrics(experience=True, stream=True),
    forgetting_metrics(experience=True, stream=True),
    bwt_metrics(experience=True, stream=True),
    
    loggers=[InteractiveLogger()]
)

model = SimpleCNN(num_classes=scenario.n_classes).to(TORCH_DEVICE)
optimizer = SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=0.0005)
criterion = CrossEntropyLoss().to(TORCH_DEVICE)
strategy = MWUNV1(
    model,
    optimizer,
    criterion,
    eps=0.001,
    lambda_e=10.0,
    lambda_f=10.0,
    train_epochs=5,
    train_mb_size=128,
    evaluator=eval_plugin,
    eval_mb_size=128,
    device=TORCH_DEVICE,
)


In [7]:
for experience in tqdm(train_stream):
    strategy.train(experience)
    strategy.eval(test_stream)


  0%|          | 0/5 [00:00<?, ?it/s]

-- >> Start of training phase << --
100%|██████████| 88/88 [00:13<00:00,  6.36it/s]
Epoch 0 ended.
100%|██████████| 88/88 [00:13<00:00,  6.63it/s]
Epoch 1 ended.
100%|██████████| 88/88 [00:13<00:00,  6.66it/s]
Epoch 2 ended.
100%|██████████| 88/88 [00:13<00:00,  6.64it/s]
Epoch 3 ended.
100%|██████████| 88/88 [00:13<00:00,  6.45it/s]
Epoch 4 ended.
-- >> End of training phase << --
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 15/15 [00:00<00:00, 30.22it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.5998
-- Starting eval on experience 1 (Task 0) from test stream --
100%|██████████| 16/16 [00:00<00:00, 29.99it/s]
> Eval on experience 1 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.0000
-- Starting eval on experience 2 (Task 0) from test stream --
100%|██████████| 16/16 [00:00<00:00, 30.64it/s]
> Eval on experienc

 20%|██        | 1/5 [01:09<04:39, 69.90s/it]


> Eval on experience 4 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004 = 0.0000
-- >> End of eval phase << --
	StreamBWT/eval_phase/test_stream = 0.0000
	StreamForgetting/eval_phase/test_stream = 0.0000
	Top1_Acc_Stream/eval_phase/test_stream/Task000 = 0.1124
-- >> Start of training phase << --
100%|██████████| 93/93 [00:13<00:00,  6.73it/s]
Epoch 0 ended.
100%|██████████| 93/93 [00:14<00:00,  6.64it/s]
Epoch 1 ended.
100%|██████████| 93/93 [00:13<00:00,  6.66it/s]
Epoch 2 ended.
100%|██████████| 93/93 [00:14<00:00,  6.62it/s]
Epoch 3 ended.
100%|██████████| 93/93 [00:14<00:00,  6.63it/s]
Epoch 4 ended.
-- >> End of training phase << --
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 15/15 [00:00<00:00, 29.95it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	ExperienceBWT/eval_phase/test_stream/Task000/Exp000 = -0.5998
	ExperienceForgetting/eval_phase/test_stream/Task000/Exp

 40%|████      | 2/5 [02:22<03:34, 71.48s/it]


> Eval on experience 4 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004 = 0.0000
-- >> End of eval phase << --
	StreamBWT/eval_phase/test_stream = -0.5998
	StreamForgetting/eval_phase/test_stream = 0.5998
	Top1_Acc_Stream/eval_phase/test_stream/Task000 = 0.0974
-- >> Start of training phase << --
100%|██████████| 93/93 [00:13<00:00,  6.69it/s]
Epoch 0 ended.
100%|██████████| 93/93 [00:13<00:00,  6.65it/s]
Epoch 1 ended.
100%|██████████| 93/93 [00:14<00:00,  6.60it/s]
Epoch 2 ended.
100%|██████████| 93/93 [00:14<00:00,  6.63it/s]
Epoch 3 ended.
100%|██████████| 93/93 [00:14<00:00,  6.57it/s]
Epoch 4 ended.
-- >> End of training phase << --
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 15/15 [00:00<00:00, 28.70it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	ExperienceBWT/eval_phase/test_stream/Task000/Exp000 = -0.5998
	ExperienceForgetting/eval_phase/test_stream/Task000/Ex

 60%|██████    | 3/5 [03:35<02:24, 72.17s/it]


> Eval on experience 4 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004 = 0.0000
-- >> End of eval phase << --
	StreamBWT/eval_phase/test_stream = -0.5427
	StreamForgetting/eval_phase/test_stream = 0.5427
	Top1_Acc_Stream/eval_phase/test_stream/Task000 = 0.0958
-- >> Start of training phase << --
100%|██████████| 102/102 [00:15<00:00,  6.58it/s]
Epoch 0 ended.
100%|██████████| 102/102 [00:15<00:00,  6.42it/s]
Epoch 1 ended.
100%|██████████| 102/102 [00:15<00:00,  6.45it/s]
Epoch 2 ended.
100%|██████████| 102/102 [00:15<00:00,  6.44it/s]
Epoch 3 ended.
100%|██████████| 102/102 [00:15<00:00,  6.48it/s]
Epoch 4 ended.
-- >> End of training phase << --
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 15/15 [00:00<00:00, 29.42it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	ExperienceBWT/eval_phase/test_stream/Task000/Exp000 = -0.5998
	ExperienceForgetting/eval_phase/test_stream/

 80%|████████  | 4/5 [04:57<01:15, 75.88s/it]


> Eval on experience 4 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004 = 0.0000
-- >> End of eval phase << --
	StreamBWT/eval_phase/test_stream = -0.5241
	StreamForgetting/eval_phase/test_stream = 0.5241
	Top1_Acc_Stream/eval_phase/test_stream/Task000 = 0.1135
-- >> Start of training phase << --
100%|██████████| 95/95 [00:14<00:00,  6.38it/s]
Epoch 0 ended.
100%|██████████| 95/95 [00:14<00:00,  6.45it/s]
Epoch 1 ended.
100%|██████████| 95/95 [00:15<00:00,  6.24it/s]
Epoch 2 ended.
100%|██████████| 95/95 [00:15<00:00,  6.15it/s]
Epoch 3 ended.
100%|██████████| 95/95 [00:15<00:00,  6.27it/s]
Epoch 4 ended.
-- >> End of training phase << --
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 15/15 [00:00<00:00, 27.13it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	ExperienceBWT/eval_phase/test_stream/Task000/Exp000 = -0.5998
	ExperienceForgetting/eval_phase/test_stream/Task000/Ex

100%|██████████| 5/5 [06:15<00:00, 75.09s/it]


> Eval on experience 4 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004 = 0.4925
-- >> End of eval phase << --
	StreamBWT/eval_phase/test_stream = -0.5243
	StreamForgetting/eval_phase/test_stream = 0.5243
	Top1_Acc_Stream/eval_phase/test_stream/Task000 = 0.0980





In [8]:
# save metrics
accuracies: dict[int, list[float]] = dict()
forgettings: dict[int, list[float]] = dict()
bwt: dict[int, list[float]] = dict()

for i in range(N_EXPERIENCES):
    filled_i = str(i).zfill(3)
    accuracies[f"Task{i}"] = eval_plugin.get_all_metrics()[
        f"Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp{filled_i}"
    ][1]
    forgettings[f"Task{i}"] = eval_plugin.get_all_metrics()[
        f"ExperienceForgetting/eval_phase/test_stream/Task000/Exp{filled_i}"
    ][1]
    bwt[f"Task{i}"] = eval_plugin.get_all_metrics()[
        f"ExperienceBWT/eval_phase/test_stream/Task000/Exp{filled_i}"
    ][1]
        

accuracies["Overall"] = eval_plugin.get_all_metrics()[
    "Top1_Acc_Stream/eval_phase/test_stream/Task000"
][1]

pickle.dump({
    "accuracies": accuracies,
    "forgettings": forgettings,
    "bwt": bwt
}, open(SAVED_METRICS_PATH / METRICS_FILENAME, "wb"))
