In [None]:
!pip install wandb -qU

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m267.1/267.1 kB[0m [31m16.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from __future__ import annotations

from copy import deepcopy
from datetime import datetime, timezone
from pathlib import Path

import torch
import torch.nn.functional as F
import torchvision
from torch import nn, optim
from torch.utils.data import ConcatDataset, DataLoader, Dataset, Subset
from torchvision import transforms
from tqdm import tqdm

import wandb

## Weights and Bias Login


In [None]:
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

## Model


### ConvNeXtV2


Got source code for the ConvNeXtV2 model from https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py and removed drop path and custom weight initialization. Added variable patch size.


In [None]:
class LayerNorm(nn.Module):
    """LayerNorm that supports two data formats: channels_last (default) or channels_first.
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
    with shape (batch_size, channels, height, width).
    """

    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape,)

    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x


class GRN(nn.Module):
    """GRN (Global Response Normalization) layer"""

    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
        self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))

    def forward(self, x):
        Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
        return self.gamma * (x * Nx) + self.beta + x


class Block(nn.Module):
    """ConvNeXtV2 Block."""

    def __init__(self, dim, drop_path=0.0):
        """ConvNeXtV2 Block.

        Args:
            dim (int): Number of input channels.
            drop_path (float): Stochastic depth rate. Default: 0.0
        """
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)  # depthwise conv
        self.norm = LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, 4 * dim)  # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.grn = GRN(4 * dim)
        self.pwconv2 = nn.Linear(4 * dim, dim)

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.grn(x)
        x = self.pwconv2(x)
        x = x.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)

        return input + x


class ConvNeXtV2(nn.Module):
    """ConvNeXt V2."""

    def __init__(
        self,
        in_chans=3,
        num_classes=1000,
        depths=[3, 3, 9, 3],
        dims=[96, 192, 384, 768],
        drop_path_rate=0.0,
        patch_size=1,
    ):
        """ConvNeXt V2.

        Args:
            in_chans (int): Number of input image channels. Default: 3
            num_classes (int): Number of classes for classification head. Default: 1000
            depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
            dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
            drop_path_rate (float): Stochastic depth rate. Default: 0.
            head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
        """
        super().__init__()
        self.depths = depths
        self.downsample_layers = nn.ModuleList()  # stem and 3 intermediate downsampling conv layers
        stem = nn.Sequential(
            nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size),
            LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.Sequential(
                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList()  # 4 feature resolution stages, each consisting of multiple residual blocks
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0
        for i in range(4):
            stage = nn.Sequential(*[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])])
            self.stages.append(stage)
            cur += depths[i]

        self.norm = nn.LayerNorm(dims[-1], eps=1e-6)  # final norm layer
        self.head = nn.Linear(dims[-1], num_classes)

    def forward_features(self, x):
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)
        return self.norm(x.mean([-2, -1]))  # global average pooling, (N, C, H, W) -> (N, C)

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

### ConvMixer


source code from here https://github.com/kentaroy47/vision-transformers-cifar10/blob/main/README.md


In [None]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x


def ConvMixer(dim, depth, kernel_size=9, patch_size=7, n_classes=1000):
    return nn.Sequential(
        nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
        nn.GELU(),
        nn.BatchNorm2d(dim),
        *[
            nn.Sequential(
                Residual(
                    nn.Sequential(
                        nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"), nn.GELU(), nn.BatchNorm2d(dim)
                    )
                ),
                nn.Conv2d(dim, dim, kernel_size=1),
                nn.GELU(),
                nn.BatchNorm2d(dim),
            )
            for i in range(depth)
        ],
        nn.AdaptiveAvgPool2d((1, 1)),
        nn.Flatten(),
        nn.Linear(dim, n_classes),
    )

### SimpleCNN

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(500, 70)
        self.fc2 = nn.Linear(70, 10)
        self.flatten = torch.nn.Flatten()

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x

## Functions


### dataset


In [None]:
def get_classes() -> tuple:
    """Return class labels of CIFAR-10 dataset."""
    return ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")


def get_datasets(tasks: int) -> tuple[list[Subset], list[Subset]]:
    """Split CIFAR-10 dataset into task specific subsets.

    Args:
        tasks (int): Number of tasks to split the dataset into.

    Returns:
        tuple[list[Subset], list[Subset]]: Tuple containing two list with the train and test subsets.
    """
    classes = get_classes()
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
    classes_per_task = torch.linspace(0, len(classes), tasks + 1, dtype=torch.int)
    trainsets = []
    testsets = []
    train_targets = torch.tensor(trainset.targets)
    test_targets = torch.tensor(testset.targets)
    for i in range(len(classes_per_task) - 1):
        train_indices = []
        test_indices = []
        for j in range(classes_per_task[i], classes_per_task[i + 1]):
            train_indices.extend((train_targets == j).nonzero(as_tuple=False).flatten().tolist())
            test_indices.extend((test_targets == j).nonzero(as_tuple=False).flatten().tolist())
        trainsets.append(Subset(trainset, train_indices))
        testsets.append(Subset(testset, test_indices))
    return trainsets, testsets

### metrics


In [None]:
def accuracy(testset: Dataset, model: nn.Module, device: torch.device, batch_size=1) -> float:
    testloader = DataLoader(testset, shuffle=False, batch_size=batch_size)

    model.eval()
    correct = 0
    for images, labels in testloader:
        # calculate outputs by running images through the network
        predictions = model(images.to(device=device))
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(predictions.data, 1)
        correct += (predicted == labels.to(device=device)).sum().item()
    return 100 * correct / len(testset)

In [None]:
def average_accuracy(
    testsets: list[Dataset],
    model: nn.Module,
    device: torch.device,
    return_intermediate: bool = False,  # noqa: FBT001, FBT002
) -> float | tuple[float, list[float]]:
    average_accuracy = 0
    average_accuracies = []
    for i in range(len(testsets)):
        task_accuracy = accuracy(testset=testsets[i], model=model, device=device)
        average_accuracy += task_accuracy
        average_accuracies.append(task_accuracy)
    if return_intermediate:
        return average_accuracy / len(testsets), average_accuracies
    return average_accuracy / len(testsets)

In [None]:
def forgetting_measure(average_accuracies_per_training_per_task: list[list[float]], current_task: int) -> float:
    forgetting_measure = 0
    for j in range(current_task):  # exclude current task
        f = 0
        for i in range(j, current_task):  # exclude current task
            f_ = (
                average_accuracies_per_training_per_task[i][j]
                - average_accuracies_per_training_per_task[current_task][j]
            )
            if f_ > f:
                f = f_
        forgetting_measure += f
    return forgetting_measure / current_task

In [None]:
average_accuracies_per_training_per_task = [[100], [50, 100], [25, 50, 100], [25, 25, 50, 100]]
forgetting_measure(average_accuracies_per_training_per_task=average_accuracies_per_training_per_task, current_task=3)

66.66666666666667

### train loop


In [None]:
def train_on_task(
    trainset: Dataset,
    testset: Dataset,
    model: nn.Module,
    device: torch.device,
    optimizer: torch.optim.Optimizer,
    epochs: int,
    batch_size: int,
    lr: float,
    criterion: nn.modules.loss._Loss | None = None,
    scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
):
    # create dataloaders
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

    # move model to device
    model.to(device=device)

    if criterion is None:
        criterion = nn.CrossEntropyLoss()
    if scheduler is None:
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=lr,
            steps_per_epoch=len(trainloader),
            epochs=epochs,
        )

    # training
    for epoch in range(epochs):
        # train one epoch
        with tqdm(total=len(trainset), unit="images") as progress_bar:
            model.train()
            for i, (images, labels) in enumerate(trainloader):
                progress_bar.set_description(f"Epoch {epoch+1} Batch {i}")
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                prediction = model(images.to(device=device))
                # calc loss
                loss = criterion(prediction, labels.to(device=device))
                # backward
                loss.backward()
                # optimizer
                optimizer.step()
                # scheduler
                scheduler.step()
                progress_bar.set_postfix(loss=loss.item())
                progress_bar.update(labels.shape[0])
                wandb.log({"loss": loss})
                wandb.log({"lr": scheduler.get_last_lr()[0]})
        # save model
        path = Path(wandb.run.dir).joinpath(f"model{epoch}.pth")
        torch.save(model.state_dict(), path)

        # eval
        test_accucracy = accuracy(testset=testset, model=model, device=device, batch_size=batch_size)
        wandb.log({"test_accucracy": test_accucracy})

    # save final model
    path = Path(wandb.run.dir).joinpath("model.pth")
    torch.save(model.state_dict(), path)

### concurrent


In [None]:
def train_tasks_concurrently(  # noqa: PLR0913
    model_dict: dict,
    device: torch.device,
    epochs: int,
    batch_size: int,
    lr: float,
    weight_decay: float,
    criterion: nn.modules.loss._Loss | None = None,
    scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
) -> None:
    # build model
    constructor = model_dict.pop("constructor")
    model = constructor(**model_dict)
    # create optimizer
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    # setup logging
    project_name = "continual_learning"
    run_name = f"{datetime.now(tz=timezone.utc).strftime('%Y_%m_%d_%H_%M_%S')}"
    config = {
        "training_method": "concurrently",
        "model": model,
        "optimizer": optimizer,
        "dataset": "CIFAR-10",
        "epochs": epochs,
        "batch_size": batch_size,
        "lr": lr,
        "weight_decay": weight_decay,
        "num_parameters": sum(p.numel() for p in model.parameters()),
    }
    config.update(model_dict)
    wandb.init(
        project=project_name,
        name=run_name,
        config=config,
    )

    # get datasets
    trainsets, testsets = get_datasets(tasks=1)

    train_on_task(
        trainset=trainsets[0],
        testset=testsets[0],
        model=model,
        device=device,
        optimizer=optimizer,
        epochs=epochs,
        batch_size=batch_size,
        lr=lr,
        criterion=criterion,
        scheduler=scheduler,
    )
    # finish logging run
    wandb.finish()

### sequentially


#### naive

In [None]:
def train_tasks_sequentially(  # noqa: PLR0913
    model_dict: dict,
    device: torch.device,
    epochs: int,
    batch_size: int,
    tasks: int,
    lr: float,
    weight_decay: float,
    criterion: nn.modules.loss._Loss | None = None,
    scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
) -> None:
    # build model
    constructor = model_dict.pop("constructor")
    model = constructor(**model_dict)
    # create optimizer
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    # setup logging
    project_name = "continual_learning"
    run_name = f"{datetime.now(tz=timezone.utc).strftime('%Y_%m_%d_%H_%M_%S')}"
    config = {
        "training_method": "sequentially",
        "model": model,
        "optimizer": optimizer,
        "dataset": "CIFAR-10",
        "epochs": epochs,
        "batch_size": batch_size,
        "tasks": tasks,
        "lr": lr,
        "weight_decay": weight_decay,
        "num_parameters": sum(p.numel() for p in model.parameters()),
    }
    config.update(model_dict)
    wandb.init(
        project=project_name,
        name=run_name,
        config=config,
    )

    # get datasets
    trainsets, testsets = get_datasets(tasks=tasks)

    avg_accs_per_task = []
    for k in range(tasks):
        # train model on task
        train_on_task(
            trainset=trainsets[k],
            testset=testsets[k],
            model=model,
            device=device,
            optimizer=optimizer,
            epochs=epochs,
            batch_size=batch_size,
            lr=lr,
            criterion=criterion,
            scheduler=scheduler,
        )
        # evaluate model
        avg_acc, avg_accs = average_accuracy(
            testsets=testsets[: k + 1],  # include current task
            model=model,
            device=device,
            return_intermediate=True,
        )
        avg_accs_per_task.append(avg_accs)
        wandb.log({"accuracy_on_current_task_only": avg_accs[-1]})
        wandb.log({"average_accuracy": avg_acc})

        # calculate forgetting measure as defined here https://arxiv.org/pdf/2302.00487.pdf
        if k > 0:  # forgetting measure only makes sense, if we already trained on prior task
            wandb.log(
                {
                    "forgetting_measure": forgetting_measure(
                        average_accuracies_per_training_per_task=avg_accs_per_task,
                        current_task=k,
                    ),
                },
            )

        # save model
        path = Path(wandb.run.dir).joinpath(f"model_task{k}_of{tasks}.pth")
        torch.save(model.state_dict(), path)

    # finish logging run
    wandb.finish()

#### rehearsal


In [None]:
def train_tasks_sequentially_rehearsal(  # noqa: PLR0913
    model_dict: dict,
    device: torch.device,
    epochs: int,
    batch_size: int,
    tasks: int,
    lr: float,
    weight_decay: float,
    memory_size_per_task: int,
    criterion: nn.modules.loss._Loss | None = None,
    scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
) -> None:
    # build model
    constructor = model_dict.pop("constructor")
    model = constructor(**model_dict)
    # create optimizer
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    # setup logging
    project_name = "continual_learning"
    run_name = f"{datetime.now(tz=timezone.utc).strftime('%Y_%m_%d_%H_%M_%S')}"
    config = {
        "training_method": "sequentially with rehearsal",
        "model": model,
        "optimizer": optimizer,
        "dataset": "CIFAR-10",
        "epochs": epochs,
        "batch_size": batch_size,
        "tasks": tasks,
        "lr": lr,
        "weight_decay": weight_decay,
        "memory_size_per_task": memory_size_per_task,
        "num_parameters": sum(p.numel() for p in model.parameters()),
    }
    config.update(model_dict)
    wandb.init(
        project=project_name,
        name=run_name,
        config=config,
    )

    # get datasets
    trainsets, testsets = get_datasets(tasks=tasks)

    avg_accs_per_task = []
    memories = []
    for k in range(tasks):
        # train model on task
        train_on_task(
            trainset=ConcatDataset([trainsets[k], *memories]),
            testset=testsets[k],
            model=model,
            device=device,
            optimizer=optimizer,
            epochs=epochs,
            batch_size=batch_size,
            lr=lr,
            criterion=criterion,
            scheduler=scheduler,
        )
        # evaluate model
        avg_acc, avg_accs = average_accuracy(
            testsets=testsets[: k + 1],  # include current task
            model=model,
            device=device,
            return_intermediate=True,
        )
        avg_accs_per_task.append(avg_accs)
        wandb.log({"accuracy_on_current_task_only": avg_accs[-1]})
        wandb.log({"average_accuracy": avg_acc})

        # calculate forgetting measure as defined here https://arxiv.org/pdf/2302.00487.pdf
        if k > 0:  # forgetting measure only makes sense, if we already trained on prior task
            wandb.log(
                {
                    "forgetting_measure": forgetting_measure(
                        average_accuracies_per_training_per_task=avg_accs_per_task,
                        current_task=k,
                    ),
                },
            )

        # save model
        path = Path(wandb.run.dir).joinpath(f"model_task{k}_of{tasks}.pth")
        torch.save(model.state_dict(), path)

        # add come images and labels from current task to memory
        random_indices = torch.randint(low=0, high=len(trainsets[k - 1]), size=(memory_size_per_task,))
        memory_task = Subset(trainsets[k], random_indices)
        memories.append(memory_task)

    # finish logging run
    wandb.finish()

#### elastic weight consolidation


In [None]:
def calculate_fisher_optimal_parameters(
    trainset: Subset,
    model: nn.Module,
    device: torch.device,
    optimizer: torch.optim.Optimizer,
    batch_size: int,
    current_task: int,
    fisher_dict: dict,
    optpar_dict: dict,
    samples_for_fisher_approximation: int,
    criterion: nn.modules.loss._Loss,
):
    model.train()
    optimizer.zero_grad()

    # get subset from trainset
    indices = torch.randint(low=0, high=len(trainset.indices), size=(samples_for_fisher_approximation,))
    dataset = Subset(trainset, indices)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # accumulating gradients
    for images, labels in dataloader:
        # forward
        prediction = model(images.to(device=device))
        # calc loss
        loss = criterion(prediction, labels.to(device=device))
        # backward
        loss.backward()

    # gradients accumulated can be used to calculate fisher
    optpar_dict[current_task] = {}
    fisher_dict[current_task] = {}

    for name, param in model.named_parameters():
        optpar_dict[current_task][name] = param.data.clone()
        fisher_dict[current_task][name] = param.grad.data.clone().pow(2) / samples_for_fisher_approximation

In [None]:
def train_on_task_with_elastic_weight_loss(
    trainset: Subset,
    testset: Subset,
    model: nn.Module,
    device: torch.device,
    optimizer: torch.optim.Optimizer,
    epochs: int,
    batch_size: int,
    lr: float,
    fisher_dict: dict,
    optpar_dict: dict,
    ewc_lambda: float,
    current_task: int,
    criterion: nn.modules.loss._Loss,
    scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
):
    # create dataloaders
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

    # move model to device
    model.to(device=device)

    if scheduler is None:
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=lr,
            steps_per_epoch=len(trainloader),
            epochs=epochs,
        )

    # training
    for epoch in range(epochs):
        # train one epoch
        with tqdm(total=len(trainset), unit="images") as progress_bar:
            model.train()
            for i, (images, labels) in enumerate(trainloader):
                progress_bar.set_description(f"Epoch {epoch+1} Batch {i}")
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                prediction = model(images.to(device=device))
                # calc loss
                loss = criterion(prediction, labels.to(device=device))

                # add elastic weight loss
                elastic_weight_loss = 0
                for task in range(current_task):
                    for name, param in model.named_parameters():
                        fisher = fisher_dict[task][name]
                        optpar = optpar_dict[task][name]
                        elastic_weight_loss += (fisher * (optpar - param).pow(2)).sum()
                # combine loss
                loss = loss + 0.5 * elastic_weight_loss * ewc_lambda
                # backward
                loss.backward()
                # optimizer
                optimizer.step()
                # scheduler
                scheduler.step()
                progress_bar.set_postfix(loss=loss.item())
                progress_bar.update(labels.shape[0])
                wandb.log({"loss": loss})
                wandb.log({"lr": scheduler.get_last_lr()[0]})
                wandb.log({"elastic_weight_loss": elastic_weight_loss})
                wandb.log({"elastic_weight_loss_scaled": 0.5 * elastic_weight_loss * ewc_lambda})
        # save model
        path = Path(wandb.run.dir).joinpath(f"model{epoch}.pth")
        torch.save(model.state_dict(), path)

        # eval
        test_accucracy = accuracy(testset=testset, model=model, device=device, batch_size=batch_size)
        wandb.log({"test_accucracy": test_accucracy})

    # save final model
    path = Path(wandb.run.dir).joinpath("model.pth")
    torch.save(model.state_dict(), path)

In [None]:
def train_tasks_sequentially_elastic_weight_consolidation(  # noqa: PLR0913
    model_dict: dict,
    device: torch.device,
    epochs: int,
    batch_size: int,
    tasks: int,
    lr: float,
    weight_decay: float,
    ewc_lambda: float,
    samples_for_fisher_approximation: int,
    criterion: nn.modules.loss._Loss | None = None,
    scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
) -> None:
    # build model
    constructor = model_dict.pop("constructor")
    model = constructor(**model_dict)
    # create optimizer
    # optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.95)
    if criterion is None:
        criterion = nn.CrossEntropyLoss()
    # setup logging
    project_name = "continual_learning"
    run_name = f"{datetime.now(tz=timezone.utc).strftime('%Y_%m_%d_%H_%M_%S')}"
    config = {
        "training_method": "sequentially with elastic weight consolidation",
        "model": model,
        "optimizer": optimizer,
        "dataset": "CIFAR-10",
        "epochs": epochs,
        "batch_size": batch_size,
        "tasks": tasks,
        "lr": lr,
        "weight_decay": weight_decay,
        "ewc_lambda": ewc_lambda,
        "samples_for_fisher_approximation": samples_for_fisher_approximation,
        "num_parameters": sum(p.numel() for p in model.parameters()),
    }
    config.update(model_dict)
    wandb.init(
        project=project_name,
        name=run_name,
        config=config,
    )

    # get datasets
    trainsets, testsets = get_datasets(tasks=tasks)

    avg_accs_per_task = []
    fisher_dict = {}
    optpar_dict = {}
    for k in range(tasks):
        train_on_task_with_elastic_weight_loss(
            trainset=trainsets[k],
            testset=testsets[k],
            model=model,
            device=device,
            optimizer=optimizer,
            epochs=epochs,
            batch_size=batch_size,
            lr=lr,
            fisher_dict=fisher_dict,
            optpar_dict=optpar_dict,
            ewc_lambda=ewc_lambda,
            current_task=k,
            criterion=criterion,
            scheduler=scheduler,
        )
        # evaluate model
        avg_acc, avg_accs = average_accuracy(
            testsets=testsets[: k + 1],  # include current task
            model=model,
            device=device,
            return_intermediate=True,
        )
        avg_accs_per_task.append(avg_accs)
        wandb.log({"accuracy_on_current_task_only": avg_accs[-1]})
        wandb.log({"average_accuracy": avg_acc})

        # calculate forgetting measure as defined here https://arxiv.org/pdf/2302.00487.pdf
        if k > 0:  # forgetting measure only makes sense, if we already trained on prior task
            wandb.log(
                {
                    "forgetting_measure": forgetting_measure(
                        average_accuracies_per_training_per_task=avg_accs_per_task,
                        current_task=k,
                    ),
                },
            )

        # save model
        path = Path(wandb.run.dir).joinpath(f"model_task{k}_of{tasks}.pth")
        torch.save(model.state_dict(), path)

        # gradients accumulated can be used to calculate fisher
        calculate_fisher_optimal_parameters(
            trainset=trainsets[k],
            model=model,
            device=device,
            optimizer=optimizer,
            batch_size=batch_size,
            current_task=k,
            fisher_dict=fisher_dict,
            optpar_dict=optpar_dict,
            samples_for_fisher_approximation=samples_for_fisher_approximation,
            criterion=criterion,
        )

    # finish logging run
    wandb.finish()

### wandb


In [None]:
def load_weights_from_wandb(model: nn.Module, run_name: str) -> nn.Module:
    best_model = wandb.restore(
        "model.pth",
        run_path=f"fabianfuchs/continual_learning/{run_name}",
        root=Path.cwd().joinpath("checkpoints"),
        replace=True,
    )

    # use the "name" attribute of the returned object if your framework expects a filename, e.g. as in Keras
    model.load_state_dict(torch.load(best_model.name))
    return model

## Setup


In [None]:
epochs = 1
batch_size = 64
lr = 0.01
momentum = 0.9
weight_decay = 0.01
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
tasks = 1

In [None]:
models = []

In [None]:
# convnext_minimal_v2 = {
#     "constructor": ConvNeXtV2,
#     "in_chans": 3,
#     "num_classes": 10,
#     "depths": [1, 1, 1, 1],
#     "dims": [5, 10, 20, 40],
#     "patch_size": 1,
# }
# models.append(convnext_minimal_v2)

In [None]:
# convnext_minimal = {
#     "constructor": ConvNeXtV2,
#     "name": "ConvNeXtV2",
#     "in_chans": 3,
#     "num_classes": 10,
#     "depths": [2, 2, 2, 2],
#     "dims": [128, 128, 128, 128],
#     "patch_size": 1,
# }
# models.append(convnext_minimal)

In [None]:
# convnext_atto = {
#     "constructor": ConvNeXtV2,
#     "name": "ConvNeXtV2",
#     "in_chans": 3,
#     "num_classes": 10,
#     "depths": [2, 2, 6, 2],
#     "dims": [40, 80, 160, 320],
#     "patch_size": 1,
# }
# models.append(convnext_atto)

In [None]:
# convnext_tiny = {
#     "constructor": ConvNeXtV2,
#     "name": "ConvNeXtV2",
#     "in_chans": 3,
#     "num_classes": 10,
#     "depths": [3, 3, 9, 3],
#     "dims": [96, 192, 384, 768],
#     "patch_size": 1,
# }
# models.append(convnext_tiny)

In [None]:
# convnext_base = {
#     "constructor": ConvNeXtV2,
#     "name": "ConvNeXtV2",
#     "in_chans": 3,
#     "num_classes": 10,
#     "depths": [3, 3, 27, 3],
#     "dims": [128, 256, 512, 1024],
#     "patch_size": 1,
# }
# models.append(convnext_base)

In [None]:
# conv_mixer_minimal = {
#     "constructor": ConvMixer,
#     "name": "ConvMixer",
#     "dim": 128,
#     "depth": 4,
#     "kernel_size": 7,
#     "patch_size": 1,
#     "n_classes": 10,
# }
# models.append(conv_mixer_minimal)

In [None]:
# conv_mixer_atto = {
#     "constructor": ConvMixer,
#     "name": "ConvMixer",
#     "dim": 128,
#     "depth": 8,
#     "kernel_size": 7,
#     "patch_size": 1,
#     "n_classes": 10,
# }
# models.append(conv_mixer_atto)

In [None]:
# conv_mixer_tiny = {
#     "constructor": ConvMixer,
#     "name": "ConvMixer",
#     "dim": 256,
#     "depth": 8,
#     "kernel_size": 7,
#     "patch_size": 1,
#     "n_classes": 10,
# }
# models.append(conv_mixer_tiny)

In [None]:
simple_cnn = {
    "constructor": SimpleCNN,
}
models.append(simple_cnn)

## Standard Setting

In [None]:
for model in models:
    train_tasks_concurrently(
        model_dict=deepcopy(model),
        device=device,
        epochs=epochs,
        batch_size=batch_size,
        lr=lr,
        weight_decay=weight_decay,
    )

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,███▇▇▇▇▇▅▇▅▆▅▄▅▅▄▄▄▄▄▃▃▅▄▅▄▄▁▄▃▂▄▂▄▅▃▃▂▄
lr,▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▇▇▇██
test_accucracy,▁

0,1
loss,1.72945
lr,0.0012
test_accucracy,43.34


Files already downloaded and verified
Files already downloaded and verified


Epoch 1 Batch 781: 100%|██████████| 50000/50000 [00:41<00:00, 1211.24images/s, loss=1.71]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,███▆▆▇▄▅▆▄▅▅▄▅▅▇▅▇▇▅▃▄▆▄▂▅▄▅▄▃▃▁▃▃▅▁▃▅▇▂
lr,▁▁▂▂▃▄▅▆▆▇███████▇▇▇▇▆▆▆▅▅▄▄▄▃▃▂▂▂▂▁▁▁▁▁
test_accucracy,▁

0,1
loss,1.70865
lr,0.0
test_accucracy,38.28




## Sequential without modifications


In [None]:
tasks = 5

In [None]:
for model in models:
    train_tasks_sequentially(
        model_dict=deepcopy(model),
        device=device,
        epochs=epochs,
        batch_size=batch_size,
        tasks=tasks,
        lr=lr,
        weight_decay=weight_decay,
    )

Files already downloaded and verified
Files already downloaded and verified


  self.pid = os.fork()
  self.pid = os.fork()
Epoch 1 Batch 156: 100%|██████████| 10000/10000 [00:13<00:00, 758.26images/s, loss=0.335]
Epoch 1 Batch 156: 100%|██████████| 10000/10000 [00:08<00:00, 1197.13images/s, loss=0.679]
Epoch 1 Batch 156: 100%|██████████| 10000/10000 [00:07<00:00, 1253.81images/s, loss=0.751]
Epoch 1 Batch 156: 100%|██████████| 10000/10000 [00:10<00:00, 991.39images/s, loss=0.986]
Epoch 1 Batch 156: 100%|██████████| 10000/10000 [00:08<00:00, 1212.22images/s, loss=1.45]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy_on_current_task_only,█▁▁▁▁
average_accuracy,█▂▂▁▁
forgetting_measure,█▃▂▁
loss,▂▁▁▁▁▁▁▁█▂▁▁▁▁▁▁▇▄▂▁▁▁▁▁▆▃▂▂▁▁▁▁▄▃▂▂▂▂▂▂
lr,▂▅██▆▄▃▁▁▆█▇▆▄▂▁▂▆█▇▆▄▂▁▂▇█▇▆▃▂▁▂▇█▇▅▃▂▁
test_accucracy,█▁▁▁▁

0,1
accuracy_on_current_task_only,50.0
average_accuracy,10.0
forgetting_measure,58.425
loss,1.44741
lr,0.0
test_accucracy,50.0


## Sequential with rehearsal


In [None]:
for model in models:
    for memory_size_per_task in [1000, 2000, 5000, 10000]:
        train_tasks_sequentially_rehearsal(
            model_dict=deepcopy(model),
            device=device,
            epochs=epochs,
            batch_size=batch_size,
            tasks=tasks,
            lr=lr,
            weight_decay=weight_decay,
            memory_size_per_task=memory_size_per_task,
        )

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112892766666644, max=1.0…

Files already downloaded and verified
Files already downloaded and verified


  self.pid = os.fork()
  self.pid = os.fork()
Epoch 1 Batch 156: 100%|██████████| 10000/10000 [00:09<00:00, 1107.50images/s, loss=0.472]
Epoch 1 Batch 171: 100%|██████████| 11000/11000 [00:09<00:00, 1113.68images/s, loss=1.11]
Epoch 1 Batch 187: 100%|██████████| 12000/12000 [00:10<00:00, 1126.76images/s, loss=1.24]
Epoch 1 Batch 203: 100%|██████████| 13000/13000 [00:11<00:00, 1102.56images/s, loss=1.28]
Epoch 1 Batch 218: 100%|██████████| 14000/14000 [00:12<00:00, 1125.79images/s, loss=1.9]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy_on_current_task_only,█▁▁▁▁
average_accuracy,█▂▂▁▁
forgetting_measure,█▃▂▁
loss,▂▁▁▁▁▁▁▆▂▂▂▂▂▂█▄▂▂▂▂▂▂▆▄▃▂▂▂▂▂▂▅▄▃▃▃▃▃▃▂
lr,▂▅█▇▅▃▁▂▇█▇▅▃▁▂▅█▇▆▄▂▁▂▅██▆▅▃▁▁▃▇█▇▆▄▃▂▁
test_accucracy,█▁▁▁▁

0,1
accuracy_on_current_task_only,50.0
average_accuracy,10.0
forgetting_measure,58.9875
loss,1.89557
lr,0.0
test_accucracy,50.0


Files already downloaded and verified
Files already downloaded and verified


Epoch 1 Batch 156: 100%|██████████| 10000/10000 [00:08<00:00, 1196.02images/s, loss=0.31]
Epoch 1 Batch 187: 100%|██████████| 12000/12000 [00:10<00:00, 1188.82images/s, loss=1.19]
Epoch 1 Batch 218: 100%|██████████| 14000/14000 [00:12<00:00, 1135.73images/s, loss=1.6]
Epoch 1 Batch 249: 100%|██████████| 16000/16000 [00:14<00:00, 1107.11images/s, loss=2.09]
Epoch 1 Batch 281: 100%|██████████| 18000/18000 [00:16<00:00, 1071.45images/s, loss=1.92]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy_on_current_task_only,█▁▁▁▁
average_accuracy,█▂▂▁▁
forgetting_measure,█▃▂▁
loss,▂▁▁▁▁▁▄▂▂▂▂▂▂█▃▃▃▃▃▃▂▆▄▃▃▃▃▃▃▃▆▅▄▄▄▃▃▃▃▄
lr,▃▇█▅▃▁▃██▆▄▂▁▃██▇▅▄▂▁▃▆█▇▇▅▃▁▁▂▆▇█▇▆▅▃▂▁
test_accucracy,█▁▁▁▁

0,1
accuracy_on_current_task_only,50.0
average_accuracy,10.0
forgetting_measure,58.875
loss,1.91674
lr,0.0
test_accucracy,50.0


Files already downloaded and verified
Files already downloaded and verified


Epoch 1 Batch 156: 100%|██████████| 10000/10000 [00:08<00:00, 1184.57images/s, loss=0.91]
Epoch 1 Batch 234: 100%|██████████| 15000/15000 [00:13<00:00, 1132.13images/s, loss=1.3]
Epoch 1 Batch 312: 100%|██████████| 20000/20000 [00:17<00:00, 1124.63images/s, loss=1.84]
Epoch 1 Batch 390: 100%|██████████| 25000/25000 [00:22<00:00, 1120.47images/s, loss=2.08]
Epoch 1 Batch 468: 100%|██████████| 30000/30000 [00:26<00:00, 1130.36images/s, loss=2.28]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy_on_current_task_only,█▁▁▁▁
average_accuracy,█▂▂▁▁
forgetting_measure,█▃▂▁
loss,▂▁▁▁█▂▂▂▂▂▅▂▂▂▂▂▂▂▃▃▂▂▂▂▂▂▂▂▃▃▃▂▂▂▂▂▂▂▂▂
lr,▃█▆▂▁▇█▆▄▁▁▆█▇▆▅▂▁▂▄▇█▇▆▄▃▂▁▁▄▆██▇▆▅▄▂▂▁
test_accucracy,█▁▁▁▁

0,1
accuracy_on_current_task_only,50.0
average_accuracy,10.0
forgetting_measure,59.1
loss,2.28002
lr,0.0
test_accucracy,50.0


Files already downloaded and verified
Files already downloaded and verified


Epoch 1 Batch 156: 100%|██████████| 10000/10000 [00:09<00:00, 1064.04images/s, loss=0.262]
Epoch 1 Batch 312: 100%|██████████| 20000/20000 [00:17<00:00, 1155.23images/s, loss=1.35]
Epoch 1 Batch 468: 100%|██████████| 30000/30000 [00:35<00:00, 842.26images/s, loss=1.55] 
Epoch 1 Batch 624: 100%|██████████| 40000/40000 [00:41<00:00, 966.01images/s, loss=2.07] 
Epoch 1 Batch 781: 100%|██████████| 50000/50000 [00:52<00:00, 947.78images/s, loss=1.75] 


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy_on_current_task_only,█▂▁▄▁
average_accuracy,█▃▂▁▁
forgetting_measure,█▁▃▂
loss,▂▁▁▄▃▃▃▃▇▄▃▃▃▃▃▃█▄▅▄▄▄▄▄▄▄█▆▅▅▄▄▄▄▄▄▄▄▄▄
lr,▇█▂▃█▆▃▂▁▅█▇▆▄▃▁▁▄▇██▇▅▄▂▁▁▂▄▇██▇▇▆▄▃▂▁▁
test_accucracy,█▂▁▄▁

0,1
accuracy_on_current_task_only,41.6
average_accuracy,30.67
forgetting_measure,30.8875
loss,1.75228
lr,0.0
test_accucracy,41.6


## Sequential with elastic weight consolidation


In [None]:
for model in models:
    for ewc_lambda in [75000]:
        for samples_for_fisher_approximation in [1000]:
            train_tasks_sequentially_elastic_weight_consolidation(
                model_dict=deepcopy(model),
                device=device,
                epochs=epochs,
                batch_size=batch_size,
                tasks=tasks,
                lr=lr,
                weight_decay=weight_decay,
                ewc_lambda=ewc_lambda,
                samples_for_fisher_approximation=samples_for_fisher_approximation,
            )

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy_on_current_task_only,▁
average_accuracy,▁
elastic_weight_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█
elastic_weight_loss_scaled,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█
loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█
lr,▁▂▃▅▇████▇▆▆▅▄▃▃▂▂▁▁▁▂▄▆▇███▇▇▆▅▅▄▃▂▂▁▁▁
test_accucracy,▁

0,1
accuracy_on_current_task_only,77.35
average_accuracy,77.35
elastic_weight_loss,
elastic_weight_loss_scaled,
loss,
lr,0.0
test_accucracy,77.35


Files already downloaded and verified
Files already downloaded and verified


Epoch 1 Batch 156: 100%|██████████| 10000/10000 [00:09<00:00, 1004.07images/s, loss=0.469]
Epoch 1 Batch 156: 100%|██████████| 10000/10000 [00:09<00:00, 1023.85images/s, loss=0.987]
Epoch 1 Batch 156: 100%|██████████| 10000/10000 [00:10<00:00, 968.53images/s, loss=nan]
Epoch 1 Batch 156: 100%|██████████| 10000/10000 [00:11<00:00, 858.78images/s, loss=nan]
Epoch 1 Batch 156: 100%|██████████| 10000/10000 [00:10<00:00, 980.76images/s, loss=nan]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy_on_current_task_only,█▆▁▁▁
average_accuracy,█▃▂▁▁
elastic_weight_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█
elastic_weight_loss_scaled,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█
forgetting_measure,█▄▂▁
loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█
lr,▂▅██▆▄▃▁▁▆█▇▆▄▂▁▂▆█▇▆▄▂▁▂▇█▇▆▃▂▁▂▇█▇▅▃▂▁
test_accucracy,█▆▁▁▁

0,1
accuracy_on_current_task_only,0.0
average_accuracy,10.0
elastic_weight_loss,
elastic_weight_loss_scaled,
forgetting_measure,22.3
loss,
lr,0.0
test_accucracy,0.0


### Avalanche

In [None]:
!pip install avalanche-lib[all]

Collecting avalanche-lib[all]
  Downloading avalanche_lib-0.5.0-py3-none-any.whl (971 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m971.9/971.9 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
Collecting gputil (from avalanche-lib[all])
  Downloading GPUtil-1.4.0.tar.gz (5.5 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pytorchcv (from avalanche-lib[all])
  Downloading pytorchcv-0.0.67-py2.py3-none-any.whl (532 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m532.4/532.4 kB[0m [31m37.4 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics (from avalanche-lib[all])
  Downloading torchmetrics-1.3.2-py3-none-any.whl (841 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m841.5/841.5 kB[0m [31m48.3 MB/s[0m eta [36m0:00:00[0m
Collecting qpsolvers[open_source_solvers] (from avalanche-lib[all])
  Downloading qpsolvers-4.3.2-py3-none-any.whl (79 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [3

In [None]:
from avalanche.benchmarks.classic import SplitCIFAR10

In [None]:
from avalanche.training.supervised import EWC

In [None]:
from avalanche.evaluation.metrics.checkpoint import WeightCheckpoint
from avalanche.logging import InteractiveLogger, WandBLogger
from avalanche.training.plugins import EvaluationPlugin, EarlyStoppingPlugin
from avalanche.evaluation.metrics import (
    forgetting_metrics,
    accuracy_metrics,
    loss_metrics,
)

In [None]:
from avalanche.models import SimpleMLP

In [None]:
# scenario
benchmark = SplitCIFAR10(n_experiences=5, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /root/.avalanche/data/cifar10/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 78229449.17it/s]


Extracting /root/.avalanche/data/cifar10/cifar-10-python.tar.gz to /root/.avalanche/data/cifar10
Files already downloaded and verified


In [None]:
# model = ConvMixer(
#     dim= 128,
#     depth= 4,
#     kernel_size= 7,
#     patch_size= 1,
#     n_classes= 10,
# )

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(500, 70)
        self.fc2 = nn.Linear(70, 10)
        self.flatten = torch.nn.Flatten()

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x

In [None]:
batch_size = 32
epochs = 1
lr = 0.01
momentum = 0.9

In [None]:
# torch.cat((torch.linspace(0.6, 1, 3), torch.linspace(2, 10, 9), torch.linspace(15, 100, 15)))


In [None]:
for ewc_lambda in torch.linspace(280, 300, 10):
    #model = SimpleMLP(num_classes=benchmark.n_classes, input_size=3*32*32)
    model = Net()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = nn.CrossEntropyLoss()
    interactive_logger = InteractiveLogger()
    wandb_logger = WandBLogger(
        project_name="continual learning",
        run_name=f"avalanche_{datetime.now(tz=timezone.utc).strftime('%Y_%m_%d_%H_%M_%S')}",
        log_artifacts=True,
        config={
            "ewc_lambda":ewc_lambda,
            "batch_size": batch_size,
            "epochs": epochs,
            "lr": lr,
            "optimizer":optimizer,
            "momentum":momentum,
            "model": model,
        },
    )

    eval_plugin = EvaluationPlugin(
        accuracy_metrics(
            minibatch=False,
            epoch=True,
            epoch_running=False,
            experience=True,
            stream=True,
            trained_experience=True
        ),
        loss_metrics(
            minibatch=False,
            epoch=False,
            epoch_running=True,
            experience=False,
            stream=False,
        ),
        forgetting_metrics(experience=True, stream=True),
        WeightCheckpoint(),
        loggers=[interactive_logger, wandb_logger],
    )
    plugin = EarlyStoppingPlugin(patience= 3, val_stream_name="eval_phase/test_stream/Task000", metric_name="Accuracy_On_Trained_Experiences", mode="max")
    cl_strategy = EWC(model=model, optimizer=optimizer, criterion=criterion, ewc_lambda=ewc_lambda.item(), mode="separate",train_mb_size=batch_size, train_epochs=epochs, eval_mb_size=batch_size, device="cuda", evaluator=eval_plugin)
    results = []
    for experience in benchmark.train_stream:
        print("Start of experience: ", experience.current_experience)
        print("Current Classes: ", experience.classes_in_this_experience)

        cl_strategy.train(experience)
        print('Training completed')

        print('Computing accuracy on the whole test set')
        results.append(cl_strategy.eval(benchmark.test_stream))
    del model

VBox(children=(Label(value='0.799 MB of 0.799 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task000,█▃▂▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp000,██▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp001,▁▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp002,▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp003,▁
RunningLoss_Epoch/train_phase/train_stream/Task000,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█
StreamForgetting/eval_phase/test_stream,▁█▇▅▄
Top1_Acc_Epoch/train_phase/train_stream/Task000,█▆▆▁▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,█▁▁▅▅
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,▁█▁▁▁

0,1
Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task000,0.1
ExperienceForgetting/eval_phase/test_stream/Task000/Exp000,0.3455
ExperienceForgetting/eval_phase/test_stream/Task000/Exp001,0.527
ExperienceForgetting/eval_phase/test_stream/Task000/Exp002,0.6535
ExperienceForgetting/eval_phase/test_stream/Task000/Exp003,0.0
RunningLoss_Epoch/train_phase/train_stream/Task000,
StreamForgetting/eval_phase/test_stream,0.3815
Top1_Acc_Epoch/train_phase/train_stream/Task000,0.0
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,0.5
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,0.0


Start of experience:  0
Current Classes:  [0, 1]
-- >> Start of training phase << --
100%|██████████| 313/313 [00:08<00:00, 36.35it/s]
Epoch 0 ended.
	RunningLoss_Epoch/train_phase/train_stream/Task000 = 0.7437
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.6364
-- >> End of training phase << --
Training completed
Computing accuracy on the whole test set
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 63/63 [00:00<00:00, 64.33it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.8335
-- Starting eval on experience 1 (Task 0) from test stream --
100%|██████████| 63/63 [00:00<00:00, 64.54it/s]
> Eval on experience 1 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.0000
-- Starting eval on experience 2 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 57.75it/s]
> Eval on experience 2 (Task 0) from test

VBox(children=(Label(value='0.654 MB of 0.654 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task000,█▃▂▂▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp000,███▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp001,▁▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp002,▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp003,▁
RunningLoss_Epoch/train_phase/train_stream/Task000,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█
StreamForgetting/eval_phase/test_stream,▁█▇▇▆
Top1_Acc_Epoch/train_phase/train_stream/Task000,█▆▆▇▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,█▁▁▁▅
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,▁█▁▁▁

0,1
Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task000,0.1
ExperienceForgetting/eval_phase/test_stream/Task000/Exp000,0.3335
ExperienceForgetting/eval_phase/test_stream/Task000/Exp001,0.6525
ExperienceForgetting/eval_phase/test_stream/Task000/Exp002,0.6785
ExperienceForgetting/eval_phase/test_stream/Task000/Exp003,0.7615
RunningLoss_Epoch/train_phase/train_stream/Task000,
StreamForgetting/eval_phase/test_stream,0.6065
Top1_Acc_Epoch/train_phase/train_stream/Task000,0.0112
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,0.5
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,0.0


Start of experience:  0
Current Classes:  [0, 1]
-- >> Start of training phase << --
100%|██████████| 313/313 [00:07<00:00, 40.86it/s]
Epoch 0 ended.
	RunningLoss_Epoch/train_phase/train_stream/Task000 = 0.7072
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.6579
-- >> End of training phase << --
Training completed
Computing accuracy on the whole test set
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 58.66it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.7365
-- Starting eval on experience 1 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 59.94it/s]
> Eval on experience 1 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.0000
-- Starting eval on experience 2 (Task 0) from test stream --
100%|██████████| 63/63 [00:00<00:00, 63.68it/s]
> Eval on experience 2 (Task 0) from test

VBox(children=(Label(value='0.401 MB of 0.401 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task000,█▃▂▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp000,▁▁▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp001,▁▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp002,▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp003,▁
RunningLoss_Epoch/train_phase/train_stream/Task000,▁▁▁▁▁▁▁▁█
StreamForgetting/eval_phase/test_stream,▁█▅▃▃
Top1_Acc_Epoch/train_phase/train_stream/Task000,█▁▁▁▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,█▁▁▁▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,▁▁▁▁▁

0,1
Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task000,0.1
ExperienceForgetting/eval_phase/test_stream/Task000/Exp000,0.2365
ExperienceForgetting/eval_phase/test_stream/Task000/Exp001,0.0
ExperienceForgetting/eval_phase/test_stream/Task000/Exp002,0.0
ExperienceForgetting/eval_phase/test_stream/Task000/Exp003,0.0
RunningLoss_Epoch/train_phase/train_stream/Task000,
StreamForgetting/eval_phase/test_stream,0.05913
Top1_Acc_Epoch/train_phase/train_stream/Task000,0.0
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,0.5
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,0.0


Start of experience:  0
Current Classes:  [0, 1]
-- >> Start of training phase << --
100%|██████████| 313/313 [00:08<00:00, 37.42it/s]
Epoch 0 ended.
	RunningLoss_Epoch/train_phase/train_stream/Task000 = 0.6719
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.6719
-- >> End of training phase << --
Training completed
Computing accuracy on the whole test set
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 44.93it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.8235
-- Starting eval on experience 1 (Task 0) from test stream --
100%|██████████| 63/63 [00:00<00:00, 63.58it/s]
> Eval on experience 1 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.0000
-- Starting eval on experience 2 (Task 0) from test stream --
100%|██████████| 63/63 [00:00<00:00, 63.36it/s]
> Eval on experience 2 (Task 0) from test

VBox(children=(Label(value='0.672 MB of 0.672 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task000,█▃▂▂▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp000,███▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp001,▁▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp002,▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp003,▁
RunningLoss_Epoch/train_phase/train_stream/Task000,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█
StreamForgetting/eval_phase/test_stream,▁██▇▇
Top1_Acc_Epoch/train_phase/train_stream/Task000,█▇▇▇▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,█▁▁▁▅
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,▁█▁▁▁

0,1
Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task000,0.1
ExperienceForgetting/eval_phase/test_stream/Task000/Exp000,0.3235
ExperienceForgetting/eval_phase/test_stream/Task000/Exp001,0.7145
ExperienceForgetting/eval_phase/test_stream/Task000/Exp002,0.7195
ExperienceForgetting/eval_phase/test_stream/Task000/Exp003,0.8475
RunningLoss_Epoch/train_phase/train_stream/Task000,
StreamForgetting/eval_phase/test_stream,0.65125
Top1_Acc_Epoch/train_phase/train_stream/Task000,0.0017
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,0.5
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,0.0


Start of experience:  0
Current Classes:  [0, 1]
-- >> Start of training phase << --
100%|██████████| 313/313 [00:08<00:00, 37.37it/s]
Epoch 0 ended.
	RunningLoss_Epoch/train_phase/train_stream/Task000 = 0.7278
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.6542
-- >> End of training phase << --
Training completed
Computing accuracy on the whole test set
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 60.18it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.8160
-- Starting eval on experience 1 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 59.55it/s]
> Eval on experience 1 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.0000
-- Starting eval on experience 2 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 60.12it/s]
> Eval on experience 2 (Task 0) from test

VBox(children=(Label(value='0.612 MB of 0.612 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task000,█▃▂▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp000,█▁▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp001,▁▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp002,▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp003,▁
RunningLoss_Epoch/train_phase/train_stream/Task000,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█
StreamForgetting/eval_phase/test_stream,▁█▅▄▃
Top1_Acc_Epoch/train_phase/train_stream/Task000,█▆▁▁▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,█▁▅▅▅
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,▁█▁▁▁

0,1
Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task000,0.1
ExperienceForgetting/eval_phase/test_stream/Task000/Exp000,0.316
ExperienceForgetting/eval_phase/test_stream/Task000/Exp001,0.6735
ExperienceForgetting/eval_phase/test_stream/Task000/Exp002,0.0
ExperienceForgetting/eval_phase/test_stream/Task000/Exp003,0.0
RunningLoss_Epoch/train_phase/train_stream/Task000,
StreamForgetting/eval_phase/test_stream,0.24737
Top1_Acc_Epoch/train_phase/train_stream/Task000,0.0
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,0.5
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,0.0


Start of experience:  0
Current Classes:  [0, 1]
-- >> Start of training phase << --
100%|██████████| 313/313 [00:08<00:00, 39.06it/s]
Epoch 0 ended.
	RunningLoss_Epoch/train_phase/train_stream/Task000 = 0.6888
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.6662
-- >> End of training phase << --
Training completed
Computing accuracy on the whole test set
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 59.31it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.7975
-- Starting eval on experience 1 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 59.93it/s]
> Eval on experience 1 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.0000
-- Starting eval on experience 2 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 60.65it/s]
> Eval on experience 2 (Task 0) from test

VBox(children=(Label(value='0.491 MB of 0.491 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task000,█▃▂▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp000,██▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp001,▁▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp002,▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp003,▁
RunningLoss_Epoch/train_phase/train_stream/Task000,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█
StreamForgetting/eval_phase/test_stream,▁█▇▆▄
Top1_Acc_Epoch/train_phase/train_stream/Task000,█▆▆▁▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,█▁▁▅▅
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,▁█▁▁▁

0,1
Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task000,0.1
ExperienceForgetting/eval_phase/test_stream/Task000/Exp000,0.2975
ExperienceForgetting/eval_phase/test_stream/Task000/Exp001,0.591
ExperienceForgetting/eval_phase/test_stream/Task000/Exp002,0.656
ExperienceForgetting/eval_phase/test_stream/Task000/Exp003,0.0
RunningLoss_Epoch/train_phase/train_stream/Task000,
StreamForgetting/eval_phase/test_stream,0.38612
Top1_Acc_Epoch/train_phase/train_stream/Task000,0.0
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,0.5
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,0.0


Start of experience:  0
Current Classes:  [0, 1]
-- >> Start of training phase << --
100%|██████████| 313/313 [00:07<00:00, 39.39it/s]
Epoch 0 ended.
	RunningLoss_Epoch/train_phase/train_stream/Task000 = 0.7194
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.6436
-- >> End of training phase << --
Training completed
Computing accuracy on the whole test set
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 59.23it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.8440
-- Starting eval on experience 1 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 59.75it/s]
> Eval on experience 1 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.0000
-- Starting eval on experience 2 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 57.05it/s]
> Eval on experience 2 (Task 0) from test

VBox(children=(Label(value='0.823 MB of 0.823 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task000,█▃▂▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp000,██▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp001,▁▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp002,▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp003,▁
RunningLoss_Epoch/train_phase/train_stream/Task000,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█
StreamForgetting/eval_phase/test_stream,▁█▇▆▄
Top1_Acc_Epoch/train_phase/train_stream/Task000,█▇▇▁▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,█▁▁▅▅
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,▁█▁▁▁

0,1
Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task000,0.1
ExperienceForgetting/eval_phase/test_stream/Task000/Exp000,0.344
ExperienceForgetting/eval_phase/test_stream/Task000/Exp001,0.6805
ExperienceForgetting/eval_phase/test_stream/Task000/Exp002,0.6415
ExperienceForgetting/eval_phase/test_stream/Task000/Exp003,0.0
RunningLoss_Epoch/train_phase/train_stream/Task000,
StreamForgetting/eval_phase/test_stream,0.4165
Top1_Acc_Epoch/train_phase/train_stream/Task000,0.0
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,0.5
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,0.0


Start of experience:  0
Current Classes:  [0, 1]
-- >> Start of training phase << --
100%|██████████| 313/313 [00:08<00:00, 38.69it/s]
Epoch 0 ended.
	RunningLoss_Epoch/train_phase/train_stream/Task000 = 0.7241
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.6502
-- >> End of training phase << --
Training completed
Computing accuracy on the whole test set
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 58.13it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.7645
-- Starting eval on experience 1 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 58.01it/s]
> Eval on experience 1 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.0000
-- Starting eval on experience 2 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 60.97it/s]
> Eval on experience 2 (Task 0) from test

VBox(children=(Label(value='0.491 MB of 0.491 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task000,█▄▂▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp000,██▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp001,▁▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp002,▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp003,▁
RunningLoss_Epoch/train_phase/train_stream/Task000,▂▂▁▁▁▁▁▁▅▃▃▂▂▂▂▂█▄▃▃▂▂▂▂
StreamForgetting/eval_phase/test_stream,▁██▆▅
Top1_Acc_Epoch/train_phase/train_stream/Task000,█▆▇▁▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,█▁▁▆▆
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,▁█▁▁▁

0,1
Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task000,0.1
ExperienceForgetting/eval_phase/test_stream/Task000/Exp000,0.2645
ExperienceForgetting/eval_phase/test_stream/Task000/Exp001,0.679
ExperienceForgetting/eval_phase/test_stream/Task000/Exp002,0.645
ExperienceForgetting/eval_phase/test_stream/Task000/Exp003,0.0
RunningLoss_Epoch/train_phase/train_stream/Task000,
StreamForgetting/eval_phase/test_stream,0.39713
Top1_Acc_Epoch/train_phase/train_stream/Task000,0.0
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,0.5
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,0.0


Start of experience:  0
Current Classes:  [0, 1]
-- >> Start of training phase << --
100%|██████████| 313/313 [00:07<00:00, 39.73it/s]
Epoch 0 ended.
	RunningLoss_Epoch/train_phase/train_stream/Task000 = 0.6891
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.6604
-- >> End of training phase << --
Training completed
Computing accuracy on the whole test set
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 58.48it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.8380
-- Starting eval on experience 1 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 57.62it/s]
> Eval on experience 1 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.0000
-- Starting eval on experience 2 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 58.84it/s]
> Eval on experience 2 (Task 0) from test

VBox(children=(Label(value='0.870 MB of 0.870 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task000,█▃▂▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp000,██▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp001,▁▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp002,▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp003,▁
RunningLoss_Epoch/train_phase/train_stream/Task000,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█
StreamForgetting/eval_phase/test_stream,▁█▇▆▅
Top1_Acc_Epoch/train_phase/train_stream/Task000,█▇▇▁▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,█▁▁▅▅
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,▁█▁▁▁

0,1
Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task000,0.1
ExperienceForgetting/eval_phase/test_stream/Task000/Exp000,0.338
ExperienceForgetting/eval_phase/test_stream/Task000/Exp001,0.678
ExperienceForgetting/eval_phase/test_stream/Task000/Exp002,0.715
ExperienceForgetting/eval_phase/test_stream/Task000/Exp003,0.0
RunningLoss_Epoch/train_phase/train_stream/Task000,
StreamForgetting/eval_phase/test_stream,0.43275
Top1_Acc_Epoch/train_phase/train_stream/Task000,0.0
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,0.5
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,0.0


Start of experience:  0
Current Classes:  [0, 1]
-- >> Start of training phase << --
100%|██████████| 313/313 [00:08<00:00, 38.19it/s]
Epoch 0 ended.
	RunningLoss_Epoch/train_phase/train_stream/Task000 = 0.6966
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.6541
-- >> End of training phase << --
Training completed
Computing accuracy on the whole test set
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 55.97it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.8310
-- Starting eval on experience 1 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 58.17it/s]
> Eval on experience 1 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.0000
-- Starting eval on experience 2 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 58.20it/s]
> Eval on experience 2 (Task 0) from test

VBox(children=(Label(value='0.164 MB of 0.164 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task000,█▂▂▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp000,▁▁▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp001,▁▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp002,▁▁
ExperienceForgetting/eval_phase/test_stream/Task000/Exp003,▁
RunningLoss_Epoch/train_phase/train_stream/Task000,▁▁▁▁▁▁▁▁█
StreamForgetting/eval_phase/test_stream,▁█▅▃▃
Top1_Acc_Epoch/train_phase/train_stream/Task000,█▂▁▁▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,█▁▁▁▁
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,▁▁▁▁▁

0,1
Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task000,0.1
ExperienceForgetting/eval_phase/test_stream/Task000/Exp000,0.331
ExperienceForgetting/eval_phase/test_stream/Task000/Exp001,0.0
ExperienceForgetting/eval_phase/test_stream/Task000/Exp002,0.0
ExperienceForgetting/eval_phase/test_stream/Task000/Exp003,0.0
RunningLoss_Epoch/train_phase/train_stream/Task000,
StreamForgetting/eval_phase/test_stream,0.08275
Top1_Acc_Epoch/train_phase/train_stream/Task000,0.0
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000,0.5
Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001,0.0


Start of experience:  0
Current Classes:  [0, 1]
-- >> Start of training phase << --
100%|██████████| 313/313 [00:09<00:00, 32.40it/s]
Epoch 0 ended.
	RunningLoss_Epoch/train_phase/train_stream/Task000 = 0.7250
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.6432
-- >> End of training phase << --
Training completed
Computing accuracy on the whole test set
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 40.23it/s]
> Eval on experience 0 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.7785
-- Starting eval on experience 1 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 51.25it/s]
> Eval on experience 1 (Task 0) from test stream ended.
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.0000
-- Starting eval on experience 2 (Task 0) from test stream --
100%|██████████| 63/63 [00:01<00:00, 57.93it/s]
> Eval on experience 2 (Task 0) from test