# Avalanche


In [1]:
#!pip install avalanche-lib[all]

In [2]:
from datetime import datetime, timezone

import torch
from avalanche.benchmarks.classic import SplitCIFAR10
from avalanche.evaluation.metrics import (
    accuracy_metrics,
    forgetting_metrics,
    loss_metrics,
)
from avalanche.evaluation.metrics.checkpoint import WeightCheckpoint
from avalanche.logging import InteractiveLogger, WandBLogger
from avalanche.models import SimpleMLP
from avalanche.training.plugins import EarlyStoppingPlugin, EvaluationPlugin
from avalanche.training.supervised import EWC


ModuleNotFoundError: No module named 'avalanche'

In [None]:
# scenario
benchmark = SplitCIFAR10(n_experiences=5, shuffle=False)


In [None]:
batch_size = 32
epochs = 1
lr = 0.01
momentum = 0.9


In [None]:
for ewc_lambda in torch.linspace(280, 300, 10):
    model = SimpleMLP(num_classes=benchmark.n_classes, input_size=3 * 32 * 32)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = torch.nn.CrossEntropyLoss()
    interactive_logger = InteractiveLogger()
    wandb_logger = WandBLogger(
        project_name="continual learning",
        run_name=f"avalanche_{datetime.now(tz=timezone.utc).strftime('%Y_%m_%d_%H_%M_%S')}",
        log_artifacts=True,
        config={
            "ewc_lambda": ewc_lambda,
            "batch_size": batch_size,
            "epochs": epochs,
            "lr": lr,
            "optimizer": optimizer,
            "momentum": momentum,
            "model": model,
        },
    )

    eval_plugin = EvaluationPlugin(
        accuracy_metrics(
            minibatch=False, epoch=True, epoch_running=False, experience=True, stream=True, trained_experience=True
        ),
        loss_metrics(
            minibatch=False,
            epoch=False,
            epoch_running=True,
            experience=False,
            stream=False,
        ),
        forgetting_metrics(experience=True, stream=True),
        WeightCheckpoint(),
        loggers=[interactive_logger, wandb_logger],
    )
    plugin = EarlyStoppingPlugin(
        patience=3,
        val_stream_name="eval_phase/test_stream/Task000",
        metric_name="Accuracy_On_Trained_Experiences",
        mode="max",
    )
    cl_strategy = EWC(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        ewc_lambda=ewc_lambda.item(),
        mode="separate",
        train_mb_size=batch_size,
        train_epochs=epochs,
        eval_mb_size=batch_size,
        device="cuda",
        evaluator=eval_plugin,
    )
    results = []
    for experience in benchmark.train_stream:
        print("Start of experience: ", experience.current_experience)
        print("Current Classes: ", experience.classes_in_this_experience)

        cl_strategy.train(experience)
        print("Training completed")

        print("Computing accuracy on the whole test set")
        results.append(cl_strategy.eval(benchmark.test_stream))
    del model
