In [None]:
%load_ext autoreload
%autoreload 2

# Libraries

In [None]:
from torchvision.transforms import Compose, ToTensor
from torchsummary import summary
import pickle
import pandas as pd
import plotly.express as px
from tqdm import tqdm

In [None]:
from avalanche.benchmarks.classic import PermutedMNIST
from avalanche.models import SimpleMLP
from avalanche.training import EWC
from avalanche.training.plugins import EvaluationPlugin
from avalanche.logging import InteractiveLogger
from avalanche.evaluation.metrics import accuracy_metrics, forgetting_metrics, bwt_metrics, timing_metrics, cpu_usage_metrics, ram_usage_metrics

from torch.nn import CrossEntropyLoss
from torch.optim import SGD

## Custom Libraries

In [None]:
import sys

sys.path.append("../base_code/")

from base_code.constants import DATASETS_PATH, SAVED_METRICS_PATH
from base_code.plugins import WeightStoragePlugin

# Dataset and definitions

## Dataset loading

We load state-of-the-art dataset Modified NIST

In [None]:
scenario = PermutedMNIST(10, seed=1234, dataset_root=DATASETS_PATH)

## Scenario creation with train test streaming

In this point, we define our scenario considering a training where in every experience of it, a new class is presented. This is, first we train with a class $a$, the following experience we train with class $b$ ($a \neq b$)

In [None]:
train_stream = scenario.train_stream
test_stream = scenario.test_stream

## Evaluation metrics definition

In [None]:
eval_plugin = EvaluationPlugin(
    accuracy_metrics(experience=True, stream=True, trained_experience=True),
    forgetting_metrics(experience=True, stream=True),
    bwt_metrics(experience=True, stream=True),
    timing_metrics(epoch=True, epoch_running=True),
    cpu_usage_metrics(experience=True, stream=True),
    ram_usage_metrics(experience=True, stream=True),
    loggers=[InteractiveLogger()]
)

## Plugin defitinitions

In [None]:
model_plugins = [WeightStoragePlugin()]

## Model, Optimizer, Loss function and Strategy definition

* `model`: Multi Layer Perceptron
* `Optimizer`: Adam
* `Loss function`: Cross Entropy
* `Strategy`: Elastic Weight Consolidation

In [None]:
# model = MLP(n_classes=scenario.n_classes, n_channels=1, width=28, height=28)
model = SimpleMLP(num_classes=scenario.n_classes, input_size=28 * 28, hidden_layers=2, hidden_size=100)
optimizer = SGD(model.parameters(), lr=1e-3)
criterion = CrossEntropyLoss()
strategy = EWC(
    model,
    optimizer,
    criterion,
    ewc_lambda=1.0,
    train_epochs=5,
    train_mb_size=128,
    plugins=model_plugins,
    evaluator=eval_plugin,
    eval_mb_size=128,
)

# Training and evaluation

Revisar porque el entrenamiento se está comportando de forma rara

In [None]:
results = list()

for experience in tqdm(train_stream):
    strategy.train(experience)

    # eval on the whole train stream
    metrics = strategy.eval(train_stream)
    results.append(metrics)

    # eval on test
    metrics = strategy.eval(test_stream)
    results.append(metrics)

# Get metrics

## Training Accuracies

In [None]:
training_accuracies: dict[int, list[float]] = dict()

for i in range(10):
    training_accuracies[f"Task{i}"] = eval_plugin.get_all_metrics()[
        f"Top1_Acc_Exp/eval_phase/train_stream/Task000/Exp00{i}"
    ][1]

training_accuracies["Overall"] = eval_plugin.get_all_metrics()[
    "Top1_Acc_Stream/eval_phase/train_stream/Task000"
][1]

## Evaluation Accuracies

In [None]:
accuracies: dict[int, list[float]] = dict()

for i in range(10):
    accuracies[f"Task{i}"] = eval_plugin.get_all_metrics()[
        f"Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp00{i}"
    ][1]

accuracies["Overall"] = eval_plugin.get_all_metrics()[
    "Top1_Acc_Stream/eval_phase/test_stream/Task000"
][1]

## Forgetting measure

In [None]:
forgetting_measures: dict[int, list[float]] = dict()

for i in range(9):
    forgetting_measures[f"Task{i}"] = eval_plugin.get_all_metrics()[
        f"ExperienceForgetting/eval_phase/test_stream/Task000/Exp00{i}"
    ][1]
forgetting_measures["Overall"] = eval_plugin.get_all_metrics()[
    "StreamForgetting/eval_phase/test_stream"
][1]

## Backward Transfer

In [None]:
bwts: dict[int, list[float]] = dict()

for i in range(9):
    bwts[f"Task{i}"] = eval_plugin.get_all_metrics()[
        f"ExperienceBWT/eval_phase/test_stream/Task000/Exp00{i}"
    ][1]
bwts["Overall"] = eval_plugin.get_all_metrics()[
    "StreamBWT/eval_phase/test_stream"
][1]

# Plotting metrics

## Traning accuracies

In [None]:
train_df = pd.DataFrame(training_accuracies)
train_df.index = range(10)

fig = px.line(train_df, x=train_df.index, y=train_df.columns, range_y=[0, 1], title="Training Accuracy vs Task")
fig.show()

## Evaluation accuracies per experience

In [None]:
acc_df = pd.DataFrame(accuracies)
acc_df.index = range(10)

fig = px.line(acc_df, x=acc_df.index, y=acc_df.columns, range_y=[0, 1], title="Test Accuracy vs Task")
fig.show()

## Forgetting measure / BWT

In [None]:
from copy import deepcopy
# transform forgetting_measures dict into df
# but first, we need to make sure that all lists have the same length
max_len = max(map(len, forgetting_measures.values()))
forgetting_measures_tmp = deepcopy(forgetting_measures)
forgetting_measures_tmp = {k: [None] * (max_len - len(v)) + v for k, v in forgetting_measures_tmp.items()}

forgetting_df = pd.DataFrame(forgetting_measures_tmp)
forgetting_df.index = range(10)

fig = px.line(forgetting_df, x=forgetting_df.index, y=forgetting_df.columns, title="Forgetting vs Task")
fig.show()

# Store metrics

In [None]:
pickle.dump(eval_plugin.get_all_metrics(), open(SAVED_METRICS_PATH / "pmnist" / "ewc.pkl", "wb"))

# Store weights

In [None]:
pickle.dump(model_plugins[0].weights, open(SAVED_METRICS_PATH / "pmnist" / "ewc_weights.pkl", "wb"))