In [None]:
%load_ext autoreload
%autoreload 2

# Libraries

In [None]:
from torchvision.transforms import Compose, ToTensor
from torchsummary import summary
import pickle
import pandas as pd
import plotly.express as px
from tqdm import tqdm

In [None]:
from avalanche.benchmarks.classic import PermutedMNIST
from avalanche.models import SimpleMLP
from avalanche.training.plugins import EvaluationPlugin
from avalanche.logging import InteractiveLogger
from avalanche.evaluation.metrics import accuracy_metrics

from torch.nn import CrossEntropyLoss
from torch.optim import SGD

## Custom Libraries

In [None]:
import sys

sys.path.append("../base_code/")

from base_code.constants import DATASETS_PATH, SAVED_METRICS_PATH
from base_code.training import CEWCV1
from base_code.plugins import WeightStoragePlugin

# Dataset and definitions

## Preprocessing definitions

In [None]:
train_transform = Compose(
    [
        ToTensor(),
    ]
)

test_transform = Compose(
    [
        ToTensor(),
    ]
)

## Dataset loading

We load state-of-the-art dataset Modified NIST

In [None]:
# mnist_train = MNIST(DATASETS_PATH, train=True, download=True, transform=train_transform)
# mnist_test = MNIST(DATASETS_PATH, train=False, download=True, transform=test_transform)

scenario = PermutedMNIST(10, seed=1234, dataset_root=DATASETS_PATH)

## Scenario creation with train test streaming

In this point, we define our scenario considering a training where in every experience of it, a new class is presented. This is, first we train with a class $a$, the following experience we train with class $b$ ($a \neq b$)

In [None]:
# scenario = nc_benchmark(
#     mnist_train, mnist_test,
#     n_experiences=len(mnist_train.classes), shuffle=True, seed=1234, task_labels=False
# )

train_stream = scenario.train_stream
test_stream = scenario.test_stream

## Evaluation metrics definition

In [None]:
eval_plugin = EvaluationPlugin(
    accuracy_metrics(experience=True, stream=True), loggers=[InteractiveLogger()]
)

## Plugin defitinitions

In [None]:
model_plugins = [WeightStoragePlugin()]

## Model, Optimizer, Loss function and Strategy definition

* `model`: Multi Layer Perceptron
* `Optimizer`: Adam
* `Loss function`: Cross Entropy
* `Strategy`: Elastic Weight Consolidation

In [None]:
# model = MLP(n_classes=scenario.n_classes, n_channels=1, width=28, height=28)
model = SimpleMLP(num_classes=scenario.n_classes, input_size=28 * 28, hidden_layers=2, hidden_size=100)
optimizer = SGD(model.parameters(), lr=1e-3)
criterion = CrossEntropyLoss()
strategy = CEWCV1(
    model,
    optimizer,
    criterion,
    ewc_lambda_l1=1.0,
    ewc_lambda_l2=1.0,
    train_epochs=5,
    train_mb_size=128,
    evaluator=eval_plugin,
    eval_mb_size=128,
)

# Training and evaluation

Revisar porque el entrenamiento se está comportando de forma rara

In [None]:
results = list()

for experience in tqdm(train_stream[:1]):
    strategy.train(experience)
    metrics = strategy.eval(test_stream)
    results.append(metrics)

In [None]:
strategy.get_store_loss()

In [None]:
accuracies: dict[int, list[float]] = dict()

for i in range(10):
    accuracies[f"Task{i}"] = eval_plugin.get_all_metrics()[
        f"Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp00{i}"
    ][1][-10:]

accuracies["Overall"] = eval_plugin.get_all_metrics()[
    "Top1_Acc_Stream/eval_phase/test_stream/Task000"
][1][-10:]

In [None]:
acc_df = pd.DataFrame(accuracies)
acc_df.index = range(10)

In [None]:
fig = px.line(acc_df, x=acc_df.index, y=acc_df.columns, range_y=[0, 1])
fig.show()

# Store metrics

In [None]:
pickle.dump(accuracies, open(SAVED_METRICS_PATH / "cewc_v1.pkl", "wb"))