In [3]:
#!pip install wandb -qU

In [4]:
from __future__ import annotations

from copy import deepcopy
from datetime import datetime, timezone
from pathlib import Path

import torch
import torch.nn.functional as F
import torchvision
from torch import nn, optim
from torch.utils.data import ConcatDataset, DataLoader, Dataset, Subset
from torchvision import transforms
from tqdm import tqdm

import wandb

## Weights and Bias Login


In [5]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: fabianfuchs. Use `wandb login --relogin` to force relogin


True

## Model


### ConvNeXtV2


Got source code for the ConvNeXtV2 model from https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py and removed drop path and custom weight initialization. Added variable patch size.


In [6]:
class LayerNorm(nn.Module):
    """LayerNorm that supports two data formats: channels_last (default) or channels_first.
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
    with shape (batch_size, channels, height, width).
    """

    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape,)

    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x


class GRN(nn.Module):
    """GRN (Global Response Normalization) layer"""

    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
        self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))

    def forward(self, x):
        Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
        return self.gamma * (x * Nx) + self.beta + x


class Block(nn.Module):
    """ConvNeXtV2 Block."""

    def __init__(self, dim, drop_path=0.0):
        """ConvNeXtV2 Block.

        Args:
            dim (int): Number of input channels.
            drop_path (float): Stochastic depth rate. Default: 0.0
        """
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)  # depthwise conv
        self.norm = LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, 4 * dim)  # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.grn = GRN(4 * dim)
        self.pwconv2 = nn.Linear(4 * dim, dim)

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.grn(x)
        x = self.pwconv2(x)
        x = x.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)

        return input + x


class ConvNeXtV2(nn.Module):
    """ConvNeXt V2."""

    def __init__(
        self,
        in_chans=3,
        num_classes=1000,
        depths=[3, 3, 9, 3],
        dims=[96, 192, 384, 768],
        drop_path_rate=0.0,
        patch_size=1,
    ):
        """ConvNeXt V2.

        Args:
            in_chans (int): Number of input image channels. Default: 3
            num_classes (int): Number of classes for classification head. Default: 1000
            depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
            dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
            drop_path_rate (float): Stochastic depth rate. Default: 0.
            head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
        """
        super().__init__()
        self.depths = depths
        self.downsample_layers = nn.ModuleList()  # stem and 3 intermediate downsampling conv layers
        stem = nn.Sequential(
            nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size),
            LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.Sequential(
                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList()  # 4 feature resolution stages, each consisting of multiple residual blocks
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0
        for i in range(4):
            stage = nn.Sequential(*[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])])
            self.stages.append(stage)
            cur += depths[i]

        self.norm = nn.LayerNorm(dims[-1], eps=1e-6)  # final norm layer
        self.head = nn.Linear(dims[-1], num_classes)

    def forward_features(self, x):
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)
        return self.norm(x.mean([-2, -1]))  # global average pooling, (N, C, H, W) -> (N, C)

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

### ConvMixer


source code from here https://github.com/kentaroy47/vision-transformers-cifar10/blob/main/README.md


In [7]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x


def ConvMixer(dim, depth, kernel_size=9, patch_size=7, n_classes=1000):
    return nn.Sequential(
        nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
        nn.GELU(),
        nn.BatchNorm2d(dim),
        *[
            nn.Sequential(
                Residual(
                    nn.Sequential(
                        nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"), nn.GELU(), nn.BatchNorm2d(dim)
                    )
                ),
                nn.Conv2d(dim, dim, kernel_size=1),
                nn.GELU(),
                nn.BatchNorm2d(dim),
            )
            for i in range(depth)
        ],
        nn.AdaptiveAvgPool2d((1, 1)),
        nn.Flatten(),
        nn.Linear(dim, n_classes),
    )

## Functions


### dataset


In [8]:
def get_classes() -> tuple:
    """Return class labels of CIFAR-10 dataset."""
    return ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")


def get_datasets(tasks: int) -> tuple[list[Subset], list[Subset]]:
    """Split CIFAR-10 dataset into task specific subsets.

    Args:
        tasks (int): Number of tasks to split the dataset into.

    Returns:
        tuple[list[Subset], list[Subset]]: Tuple containing two list with the train and test subsets.
    """
    classes = get_classes()
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
    classes_per_task = torch.linspace(0, len(classes), tasks + 1, dtype=torch.int)
    trainsets = []
    testsets = []
    train_targets = torch.tensor(trainset.targets)
    test_targets = torch.tensor(testset.targets)
    for i in range(len(classes_per_task) - 1):
        train_indices = []
        test_indices = []
        for j in range(classes_per_task[i], classes_per_task[i + 1]):
            train_indices.extend((train_targets == j).nonzero(as_tuple=False).flatten().tolist())
            test_indices.extend((test_targets == j).nonzero(as_tuple=False).flatten().tolist())
        trainsets.append(Subset(trainset, train_indices))
        testsets.append(Subset(testset, test_indices))
    return trainsets, testsets

### metrics


In [9]:
def accuracy(testset: Dataset, model: nn.Module, device: torch.device, batch_size=1) -> float:
    testloader = DataLoader(testset, shuffle=False, batch_size=batch_size)

    model.eval()
    correct = 0
    for images, labels in testloader:
        # calculate outputs by running images through the network
        predictions = model(images.to(device=device))
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(predictions.data, 1)
        correct += (predicted == labels.to(device=device)).sum().item()
    return 100 * correct / len(testset)

In [10]:
def average_accuracy(
    testsets: list[Dataset],
    model: nn.Module,
    device: torch.device,
    return_intermediate: bool = False,  # noqa: FBT001, FBT002
) -> float | tuple[float, list[float]]:
    average_accuracy = 0
    average_accuracies = []
    for i in range(len(testsets)):
        task_accuracy = accuracy(testset=testsets[i], model=model, device=device)
        average_accuracy += task_accuracy
        average_accuracies.append(task_accuracy)
    if return_intermediate:
        return average_accuracy / len(testsets), average_accuracies
    return average_accuracy / len(testsets)

In [11]:
def forgetting_measure(average_accuracies_per_training_per_task: list[list[float]], current_task: int) -> float:
    forgetting_measure = 0
    for j in range(current_task):  # exclude current task
        f = 0
        for i in range(j, current_task):  # exclude current task
            f_ = (
                average_accuracies_per_training_per_task[i][j]
                - average_accuracies_per_training_per_task[current_task][j]
            )
            if f_ > f:
                f = f_
        forgetting_measure += f
    return forgetting_measure / current_task

In [12]:
average_accuracies_per_training_per_task = [[100], [50, 100], [25, 50, 100], [25, 25, 50, 100]]
forgetting_measure(average_accuracies_per_training_per_task=average_accuracies_per_training_per_task, current_task=3)

66.66666666666667

### train loop


In [13]:
def train_on_task(
    trainset: Dataset,
    testset: Dataset,
    model: nn.Module,
    device: torch.device,
    optimizer: torch.optim.Optimizer,
    epochs: int,
    batch_size: int,
    lr: float,
    criterion: nn.modules.loss._Loss | None = None,
    scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
):
    # create dataloaders
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

    # move model to device
    model.to(device=device)

    if criterion is None:
        criterion = nn.CrossEntropyLoss()
    if scheduler is None:
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=lr,
            steps_per_epoch=len(trainloader),
            epochs=epochs,
        )

    # training
    for epoch in range(epochs):
        # train one epoch
        with tqdm(total=len(trainset), unit="images") as progress_bar:
            model.train()
            for i, (images, labels) in enumerate(trainloader):
                progress_bar.set_description(f"Epoch {epoch+1} Batch {i}")
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                prediction = model(images.to(device=device))
                # calc loss
                loss = criterion(prediction, labels.to(device=device))
                # backward
                loss.backward()
                # optimizer
                optimizer.step()
                # scheduler
                scheduler.step()
                progress_bar.set_postfix(loss=loss.item())
                progress_bar.update(labels.shape[0])
                wandb.log({"loss": loss})
                wandb.log({"lr": scheduler.get_last_lr()[0]})
        # save model
        path = Path(wandb.run.dir).joinpath(f"model{epoch}.pth")
        torch.save(model.state_dict(), path)

        # eval
        test_accucracy = accuracy(testset=testset, model=model, device=device, batch_size=batch_size)
        wandb.log({"test_accucracy": test_accucracy})

    # save final model
    path = Path(wandb.run.dir).joinpath("model.pth")
    torch.save(model.state_dict(), path)

### concurrent


In [14]:
def train_tasks_concurrently(  # noqa: PLR0913
    model_dict: dict,
    device: torch.device,
    epochs: int,
    batch_size: int,
    lr: float,
    weight_decay: float,
    criterion: nn.modules.loss._Loss | None = None,
    scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
) -> None:
    # build model
    constructor = model_dict.pop("constructor")
    model_name = model_dict.pop("name")
    model = constructor(**model_dict)
    # create optimizer
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    # setup logging
    project_name = "continual_learning"
    run_name = f"{datetime.now(tz=timezone.utc).strftime('%Y_%m_%d_%H_%M_%S')}"
    config = {
        "training_method": "concurrently",
        "model_name": model_name,
        "dataset": "CIFAR-10",
        "epochs": epochs,
        "batch_size": batch_size,
        "lr": lr,
        "weight_decay": weight_decay,
        "num_parameters": sum(p.numel() for p in model.parameters()),
    }
    config.update(model_dict)
    wandb.init(
        project=project_name,
        name=run_name,
        config=config,
    )

    # get datasets
    trainsets, testsets = get_datasets(tasks=1)

    train_on_task(
        trainset=trainsets[0],
        testset=testsets[0],
        model=model,
        device=device,
        optimizer=optimizer,
        epochs=epochs,
        batch_size=batch_size,
        lr=lr,
        criterion=criterion,
        scheduler=scheduler,
    )
    # finish logging run
    wandb.finish()

### sequentially


In [15]:
def train_tasks_sequentially(  # noqa: PLR0913
    model_dict: dict,
    device: torch.device,
    epochs: int,
    batch_size: int,
    tasks: int,
    lr: float,
    weight_decay: float,
    criterion: nn.modules.loss._Loss | None = None,
    scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
) -> None:
    # build model
    constructor = model_dict.pop("constructor")
    model_name = model_dict.pop("name")
    model = constructor(**model_dict)
    # create optimizer
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    # setup logging
    project_name = "continual_learning"
    run_name = f"{datetime.now(tz=timezone.utc).strftime('%Y_%m_%d_%H_%M_%S')}"
    config = {
        "training_method": "sequentially",
        "model_name": model_name,
        "dataset": "CIFAR-10",
        "epochs": epochs,
        "batch_size": batch_size,
        "tasks": tasks,
        "lr": lr,
        "weight_decay": weight_decay,
        "num_parameters": sum(p.numel() for p in model.parameters()),
    }
    config.update(model_dict)
    wandb.init(
        project=project_name,
        name=run_name,
        config=config,
    )

    # get datasets
    trainsets, testsets = get_datasets(tasks=tasks)

    avg_accs_per_task = []
    for k in range(tasks):
        # train model on task
        train_on_task(
            trainset=trainsets[k],
            testset=testsets[k],
            model=model,
            device=device,
            optimizer=optimizer,
            epochs=epochs,
            batch_size=batch_size,
            lr=lr,
            criterion=criterion,
            scheduler=scheduler,
        )
        # evaluate model
        avg_acc, avg_accs = average_accuracy(
            testsets=testsets[: k + 1],  # include current task
            model=model,
            device=device,
            return_intermediate=True,
        )
        avg_accs_per_task.append(avg_accs)
        wandb.log({"accuracy_on_current_task_only": avg_accs[-1]})
        wandb.log({"average_accuracy": avg_acc})

        # calculate forgetting measure as defined here https://arxiv.org/pdf/2302.00487.pdf
        if k > 0:  # forgetting measure only makes sense, if we already trained on prior task
            wandb.log(
                {
                    "forgetting_measure": forgetting_measure(
                        average_accuracies_per_training_per_task=avg_accs_per_task,
                        current_task=k,
                    ),
                },
            )

        # save model
        path = Path(wandb.run.dir).joinpath(f"model_task{k}_of{tasks}.pth")
        torch.save(model.state_dict(), path)

    # finish logging run
    wandb.finish()

check accuracy_on_current_task_only


#### rehearsal


In [16]:
def train_tasks_sequentially_rehearsal(  # noqa: PLR0913
    model_dict: dict,
    device: torch.device,
    epochs: int,
    batch_size: int,
    tasks: int,
    lr: float,
    weight_decay: float,
    memory_size_per_task: int,
    criterion: nn.modules.loss._Loss | None = None,
    scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
) -> None:
    # build model
    constructor = model_dict.pop("constructor")
    model_name = model_dict.pop("name")
    model = constructor(**model_dict)
    # create optimizer
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    # setup logging
    project_name = "continual_learning"
    run_name = f"{datetime.now(tz=timezone.utc).strftime('%Y_%m_%d_%H_%M_%S')}"
    config = {
        "training_method": "sequentially with rehearsal",
        "model_name": model_name,
        "dataset": "CIFAR-10",
        "epochs": epochs,
        "batch_size": batch_size,
        "tasks": tasks,
        "lr": lr,
        "weight_decay": weight_decay,
        "memory_size_per_task": memory_size_per_task,
        "num_parameters": sum(p.numel() for p in model.parameters()),
    }
    config.update(model_dict)
    wandb.init(
        project=project_name,
        name=run_name,
        config=config,
    )

    # get datasets
    trainsets, testsets = get_datasets(tasks=tasks)

    avg_accs_per_task = []
    memories = []
    for k in range(tasks):
        # train model on task
        train_on_task(
            trainset=ConcatDataset([trainsets[k], *memories]),
            testset=testsets[k],
            model=model,
            device=device,
            optimizer=optimizer,
            epochs=epochs,
            batch_size=batch_size,
            lr=lr,
            criterion=criterion,
            scheduler=scheduler,
        )
        # evaluate model
        avg_acc, avg_accs = average_accuracy(
            testsets=testsets[: k + 1],  # include current task
            model=model,
            device=device,
            return_intermediate=True,
        )
        avg_accs_per_task.append(avg_accs)
        wandb.log({"accuracy_on_current_task_only": avg_accs[-1]})
        wandb.log({"average_accuracy": avg_acc})

        # calculate forgetting measure as defined here https://arxiv.org/pdf/2302.00487.pdf
        if k > 0:  # forgetting measure only makes sense, if we already trained on prior task
            wandb.log(
                {
                    "forgetting_measure": forgetting_measure(
                        average_accuracies_per_training_per_task=avg_accs_per_task,
                        current_task=k,
                    ),
                },
            )

        # save model
        path = Path(wandb.run.dir).joinpath(f"model_task{k}_of{tasks}.pth")
        torch.save(model.state_dict(), path)

        # add come images and labels from current task to memory
        random_indices = torch.randint(low=0, high=len(trainsets[k - 1]), size=(memory_size_per_task,))
        memory_task = Subset(trainsets[k], random_indices)
        memories.append(memory_task)

    # finish logging run
    wandb.finish()

#### elastic weight consolidation


In [None]:
def calculate_fisher_optimal_parameters(
    trainset: Subset,
    model: nn.Module,
    device: torch.device,
    optimizer: torch.optim.Optimizer,
    batch_size: int,
    current_task: int,
    fisher_dict: dict,
    optpar_dict: dict,
    samples_for_fisher_approximation: int,
    criterion: nn.modules.loss._Loss | None = None,
):
    model.train()
    optimizer.zero_grad()

    # get subset from trainset
    indices = torch.randint(low=0, high=len(trainset.indices), size=(samples_for_fisher_approximation))
    dataset = Subset(trainset.dataset, trainset.indices[indices])
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # accumulating gradients
    for images, labels in dataloader:
        # forward
        prediction = model(images.to(device=device))
        # calc loss
        loss = criterion(prediction, labels.to(device=device))
        # backward
        loss.backward()

    # gradients accumulated can be used to calculate fisher
    for name, param in model.named_parameters():
        optpar_dict[current_task][name] = param.data.clone()
        fisher_dict[current_task][name] = param.grad.data.clone().pow(2)


In [17]:
def train_on_task_with_elastic_weight_loss(
    trainset: Subset,
    testset: Subset,
    model: nn.Module,
    device: torch.device,
    optimizer: torch.optim.Optimizer,
    epochs: int,
    batch_size: int,
    lr: float,
    fisher_dict: dict,
    optpar_dict: dict,
    ewc_lambda: float,
    current_task: int,
    criterion: nn.modules.loss._Loss | None = None,
    scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
):
    # create dataloaders
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

    # move model to device
    model.to(device=device)

    if criterion is None:
        criterion = nn.CrossEntropyLoss()
    if scheduler is None:
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=lr,
            steps_per_epoch=len(trainloader),
            epochs=epochs,
        )

    # training
    for epoch in range(epochs):
        # train one epoch
        with tqdm(total=len(trainset), unit="images") as progress_bar:
            model.train()
            for i, (images, labels) in enumerate(trainloader):
                progress_bar.set_description(f"Epoch {epoch+1} Batch {i}")
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                prediction = model(images.to(device=device))
                # calc loss
                loss = criterion(prediction, labels.to(device=device))

                # add elastic weight loss
                for task in range(current_task):
                    for name, param in model.named_parameters():
                        fisher = fisher_dict[task][name]
                        optpar = optpar_dict[task][name]
                        loss += (fisher * (optpar - param).pow(2)).sum() * ewc_lambda

                # backward
                loss.backward()
                # optimizer
                optimizer.step()
                # scheduler
                scheduler.step()
                progress_bar.set_postfix(loss=loss.item())
                progress_bar.update(labels.shape[0])
                wandb.log({"loss": loss})
                wandb.log({"lr": scheduler.get_last_lr()[0]})
        # save model
        path = Path(wandb.run.dir).joinpath(f"model{epoch}.pth")
        torch.save(model.state_dict(), path)

        # eval
        test_accucracy = accuracy(testset=testset, model=model, device=device, batch_size=batch_size)
        wandb.log({"test_accucracy": test_accucracy})

    # save final model
    path = Path(wandb.run.dir).joinpath("model.pth")
    torch.save(model.state_dict(), path)

In [35]:
def train_tasks_sequentially_elastic_weight_consolidation(  # noqa: PLR0913
    model_dict: dict,
    device: torch.device,
    epochs: int,
    batch_size: int,
    tasks: int,
    lr: float,
    weight_decay: float,
    ewc_lambda: float,
    samples_for_fisher_approximation: int,
    criterion: nn.modules.loss._Loss | None = None,
    scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
) -> None:
    # build model
    constructor = model_dict.pop("constructor")
    model_name = model_dict.pop("name")
    model = constructor(**model_dict)
    # create optimizer
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    # setup logging
    project_name = "continual_learning"
    run_name = f"{datetime.now(tz=timezone.utc).strftime('%Y_%m_%d_%H_%M_%S')}"
    config = {
        "training_method": "sequentially with elastic weight consolidation",
        "model_name": model_name,
        "dataset": "CIFAR-10",
        "epochs": epochs,
        "batch_size": batch_size,
        "tasks": tasks,
        "lr": lr,
        "weight_decay": weight_decay,
        "ewc_lambda": ewc_lambda,
        "num_parameters": sum(p.numel() for p in model.parameters()),
    }
    config.update(model_dict)
    wandb.init(
        project=project_name,
        name=run_name,
        config=config,
    )

    # get datasets
    trainsets, testsets = get_datasets(tasks=tasks)

    avg_accs_per_task = []
    fisher_dict = {}
    optpar_dict = {}
    for k in range(tasks):
        train_on_task_with_elastic_weight_loss(
            trainset=trainsets[k],
            testset=testsets[k],
            model=model,
            device=device,
            optimizer=optimizer,
            epochs=epochs,
            batch_size=batch_size,
            lr=lr,
            fisher_dict=fisher_dict,
            optpar_dict=optpar_dict,
            ewc_lambda=ewc_lambda,
            current_task=k,
            criterion=criterion,
            scheduler=scheduler,
        )
        # evaluate model
        avg_acc, avg_accs = average_accuracy(
            testsets=testsets[: k + 1],  # include current task
            model=model,
            device=device,
            return_intermediate=True,
        )
        avg_accs_per_task.append(avg_accs)
        wandb.log({"accuracy_on_current_task_only": avg_accs[-1]})
        wandb.log({"average_accuracy": avg_acc})

        # calculate forgetting measure as defined here https://arxiv.org/pdf/2302.00487.pdf
        if k > 0:  # forgetting measure only makes sense, if we already trained on prior task
            wandb.log(
                {
                    "forgetting_measure": forgetting_measure(
                        average_accuracies_per_training_per_task=avg_accs_per_task,
                        current_task=k,
                    ),
                },
            )

        # save model
        path = Path(wandb.run.dir).joinpath(f"model_task{k}_of{tasks}.pth")
        torch.save(model.state_dict(), path)

        # gradients accumulated can be used to calculate fisher
        calculate_fisher_optimal_parameters(
            trainset=trainsets[k],
            model=model,
            device=device,
            optimizer=optimizer,
            batch_size=batch_size,
            current_task=k,
            fisher_dict=fisher_dict,
            optpar_dict=optpar_dict,
            samples_for_fisher_approximation=samples_for_fisher_approximation,
        )

    # finish logging run
    wandb.finish()

### wandb


In [17]:
def load_weights_from_wandb(model: nn.Module, run_name: str) -> nn.Module:
    best_model = wandb.restore(
        "model.pth",
        run_path=f"fabianfuchs/continual_learning/{run_name}",
        root=Path.cwd().joinpath("checkpoints"),
        replace=True,
    )

    # use the "name" attribute of the returned object if your framework expects a filename, e.g. as in Keras
    model.load_state_dict(torch.load(best_model.name))
    return model

## Standard Setting


In [18]:
epochs = 20
batch_size = 64
lr = 0.01
momentum = 0.9
weight_decay = 0.01
device = torch.device("cuda:0")
tasks = 1

In [31]:
models = []

In [20]:
# convnext_minimal = {
#     "constructor": ConvNeXtV2,
#     "name": "ConvNeXtV2",
#     "in_chans": 3,
#     "num_classes": 10,
#     "depths": [2, 2, 2, 2],
#     "dims": [128, 128, 128, 128],
#     "patch_size": 1,
# }
# models.append(convnext_minimal)

In [21]:
# convnext_atto = {
#     "constructor": ConvNeXtV2,
#     "name": "ConvNeXtV2",
#     "in_chans": 3,
#     "num_classes": 10,
#     "depths": [2, 2, 6, 2],
#     "dims": [40, 80, 160, 320],
#     "patch_size": 1,
# }
# models.append(convnext_atto)

In [22]:
# convnext_tiny = {
#     "constructor": ConvNeXtV2,
#     "name": "ConvNeXtV2",
#     "in_chans": 3,
#     "num_classes": 10,
#     "depths": [3, 3, 9, 3],
#     "dims": [96, 192, 384, 768],
#     "patch_size": 1,
# }
# models.append(convnext_tiny)

In [23]:
# convnext_base = {
#     "constructor": ConvNeXtV2,
#     "name": "ConvNeXtV2",
#     "in_chans": 3,
#     "num_classes": 10,
#     "depths": [3, 3, 27, 3],
#     "dims": [128, 256, 512, 1024],
#     "patch_size": 1,
# }
# models.append(convnext_base)

In [32]:
conv_mixer_minimal = {
    "constructor": ConvMixer,
    "name": "ConvMixer",
    "dim": 128,
    "depth": 4,
    "kernel_size": 7,
    "patch_size": 1,
    "n_classes": 10,
}
models.append(conv_mixer_minimal)

In [25]:
# conv_mixer_atto = {
#     "constructor": ConvMixer,
#     "name": "ConvMixer",
#     "dim": 128,
#     "depth": 8,
#     "kernel_size": 7,
#     "patch_size": 1,
#     "n_classes": 10,
# }
# models.append(conv_mixer_atto)

In [26]:
# conv_mixer_tiny = {
#     "constructor": ConvMixer,
#     "name": "ConvMixer",
#     "dim": 256,
#     "depth": 8,
#     "kernel_size": 7,
#     "patch_size": 1,
#     "n_classes": 10,
# }
# models.append(conv_mixer_tiny)

In [27]:
# for model in models:
#     train_tasks_concurrently(
#         model_dict=model,
#         device=device,
#         epochs=epochs,
#         batch_size=batch_size,
#         lr=lr,
#         weight_decay=weight_decay,
#     )

## Sequential without modifications


In [28]:
tasks = 5

In [29]:
for model in models:
    train_tasks_sequentially(
        model_dict=deepcopy(model),
        device=device,
        epochs=epochs,
        batch_size=batch_size,
        tasks=tasks,
        lr=lr,
        weight_decay=weight_decay,
    )

Files already downloaded and verified
Files already downloaded and verified


Epoch 1 Batch 156: 100%|██████████| 10000/10000 [00:25<00:00, 399.98images/s, loss=1.63] 
Epoch 2 Batch 156: 100%|██████████| 10000/10000 [00:20<00:00, 490.73images/s, loss=0.277]
Epoch 3 Batch 156: 100%|██████████| 10000/10000 [00:26<00:00, 381.29images/s, loss=0.619] 
Epoch 4 Batch 156: 100%|██████████| 10000/10000 [00:24<00:00, 409.37images/s, loss=0.372] 
Epoch 5 Batch 156: 100%|██████████| 10000/10000 [00:24<00:00, 411.30images/s, loss=0.268] 
Epoch 6 Batch 156: 100%|██████████| 10000/10000 [00:24<00:00, 405.78images/s, loss=0.0778]
Epoch 7 Batch 156: 100%|██████████| 10000/10000 [00:24<00:00, 410.48images/s, loss=0.0924]
Epoch 8 Batch 156: 100%|██████████| 10000/10000 [00:23<00:00, 427.65images/s, loss=0.0751] 
Epoch 9 Batch 156: 100%|██████████| 10000/10000 [00:24<00:00, 400.50images/s, loss=0.00683]
Epoch 10 Batch 156: 100%|██████████| 10000/10000 [00:24<00:00, 414.18images/s, loss=0.0864] 
Epoch 11 Batch 156: 100%|██████████| 10000/10000 [00:24<00:00, 401.13images/s, loss=0.01

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy_on_current_task_only,▇▁▅█▇
average_accuracy,█▃▂▁▁
forgetting_measure,█▁▁▄
loss,▂▁▁▁▁▁▁▁█▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▂▅██▆▄▃▁▁▄█▇▆▄▂▁▂▅██▆▄▂▁▂▆█▇▆▄▂▁▂▆█▇▆▄▂▁
test_accucracy,▅▆▇▇████▁▄▅▅▄▆▆▆▄▅▆▅▇▇▇▇▂▆▂▇████▇▇▆▇▇███

0,1
accuracy_on_current_task_only,97.45
average_accuracy,19.49
forgetting_measure,95.4375
loss,5e-05
lr,0.0
test_accucracy,97.45


## Sequential with rehearsal


In [30]:
for model in models:
    for memory_size_per_task in [1000, 2000, 5000, 10000]:
        train_tasks_sequentially_rehearsal(
            model_dict=model,
            device=device,
            epochs=epochs,
            batch_size=batch_size,
            tasks=tasks,
            lr=lr,
            weight_decay=weight_decay,
            memory_size_per_task=memory_size_per_task,
        )

KeyError: 'constructor'

## Sequential with elastic weight consolidation


In [None]:
for model in models:
    for ewc_lambda in [0.2, 0.4, 0.6]:
        train_tasks_sequentially_elastic_weight_consolidation(
            model_dict=model,
            device=device,
            epochs=epochs,
            batch_size=batch_size,
            tasks=tasks,
            lr=lr,
            weight_decay=weight_decay,
            ewc_lambda=ewc_lambda,
        )