In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from avalanche.benchmarks import SplitMNIST
from avalanche.training.supervised import EWC
from avalanche.training.plugins import EvaluationPlugin
from avalanche.logging import TensorboardLogger
from torch.utils.tensorboard import SummaryWriter
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics, timing_metrics, forgetting_metrics
from avalanche.benchmarks.classic import CORe50
from avalanche.logging import InteractiveLogger


  from .autonotebook import tqdm as notebook_tqdm


In [2]:


class SimpleCNN(nn.Module):
    def __init__(self, num_classes=50):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(128 * 16 * 16, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 16 * 16)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [8]:
benchmark = CORe50(scenario="nc")

Loading labels...
Loading LUP...
Loading labels names...
Files already downloaded and verified


In [9]:
import numpy as np
from torch.utils.data import Subset

# Definizione delle classi che vogliamo includere
target_classes = list(range(10))  # Utilizziamo solo le prime 10 classi

# Funzione per filtrare le classi desiderate
def filter_classes(dataset, target_classes):
    targets = np.array(dataset.targets)
    mask = np.isin(targets, target_classes)
    indices = np.where(mask)[0]
    return Subset(dataset, indices)

# Creazione del train e test stream con classi filtrate
train_datasets = [filter_classes(exp.dataset, target_classes) for exp in benchmark.train_stream]
test_datasets = [filter_classes(exp.dataset, target_classes) for exp in benchmark.test_stream]


In [10]:


# Creazione del logger di TensorBoard
tb_logger = TensorboardLogger()

# Plugin per la valutazione
eval_plugin = EvaluationPlugin(    
    accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
    loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
    timing_metrics(epoch=True, epoch_running=True),
    forgetting_metrics(experience=True, stream=True),\
    loggers=[tb_logger])
"""
# Definizione del plugin di valutazione con TensorBoard
eval_plugin = EvaluationPlugin(
    accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
    loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
    loggers=[tb_logger]
)
"""
# Lista dei modelli
models = [SimpleCNN(num_classes=len(target_classes))]



In [11]:

# Iterazione sui modelli e allenamento
for model in models:
    
    print(f"Training model: {model.__class__.__name__}")

    # Definizione della strategia EWC
    cl_strategy = EWC(
        model=model,
        optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),
        criterion=torch.nn.CrossEntropyLoss(),
        ewc_lambda=0.4,  # lambda per regolarizzazione EWC
        train_mb_size=32,
        train_epochs=1,
        eval_mb_size=100,
        device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
        evaluator=eval_plugin
    )
    # Esecuzione dell'allenamento e valutazione con i dataset filtrati
    for train_dataset, test_dataset in zip(train_datasets, test_datasets):
        print(f"Training on dataset with {len(train_dataset)} samples")
        cl_strategy.train(train_dataset)
        print("Training completed")
        
        print("Evaluating on test dataset")
        cl_strategy.eval(test_dataset)
        print("Evaluation completed")


Training model: SimpleCNN
Training on dataset with 23980 samples


AttributeError: 'Subset' object has no attribute 'origin_stream'