In [4]:
#!pip install wandb -qU

In [5]:
from __future__ import annotations

from datetime import datetime, timezone
from pathlib import Path

import torch
import torch.nn.functional as F
import torchvision
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms
from tqdm import tqdm

import wandb

## Weights and Bias Login


In [6]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: fabianfuchs. Use `wandb login --relogin` to force relogin


True

## Model


Got source code for the ConvNeXtV2 model from https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py and removed drop path and custom weight initialization.


In [7]:
class LayerNorm(nn.Module):
    """LayerNorm that supports two data formats: channels_last (default) or channels_first.
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
    with shape (batch_size, channels, height, width).
    """

    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape,)

    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x


class GRN(nn.Module):
    """GRN (Global Response Normalization) layer"""

    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
        self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))

    def forward(self, x):
        Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
        return self.gamma * (x * Nx) + self.beta + x


class Block(nn.Module):
    """ConvNeXtV2 Block."""

    def __init__(self, dim, drop_path=0.0):
        """ConvNeXtV2 Block.

        Args:
            dim (int): Number of input channels.
            drop_path (float): Stochastic depth rate. Default: 0.0
        """
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)  # depthwise conv
        self.norm = LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, 4 * dim)  # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.grn = GRN(4 * dim)
        self.pwconv2 = nn.Linear(4 * dim, dim)

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.grn(x)
        x = self.pwconv2(x)
        x = x.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)

        return input + x


class ConvNeXtV2(nn.Module):
    """ConvNeXt V2."""

    def __init__(
        self,
        in_chans=3,
        num_classes=1000,
        depths=[3, 3, 9, 3],
        dims=[96, 192, 384, 768],
        drop_path_rate=0.0,
    ):
        """ConvNeXt V2.

        Args:
            in_chans (int): Number of input image channels. Default: 3
            num_classes (int): Number of classes for classification head. Default: 1000
            depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
            dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
            drop_path_rate (float): Stochastic depth rate. Default: 0.
            head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
        """
        super().__init__()
        self.depths = depths
        self.downsample_layers = nn.ModuleList()  # stem and 3 intermediate downsampling conv layers
        stem = nn.Sequential(
            nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
            LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.Sequential(
                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList()  # 4 feature resolution stages, each consisting of multiple residual blocks
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0
        for i in range(4):
            stage = nn.Sequential(*[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])])
            self.stages.append(stage)
            cur += depths[i]

        self.norm = nn.LayerNorm(dims[-1], eps=1e-6)  # final norm layer
        self.head = nn.Linear(dims[-1], num_classes)

    def forward_features(self, x):
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)
        return self.norm(x.mean([-2, -1]))  # global average pooling, (N, C, H, W) -> (N, C)

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

## Functions


In [8]:
def get_classes() -> tuple:
    """Return class labels of CIFAR-10 dataset."""
    return ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")


def get_datasets(tasks: int) -> tuple[list[Subset], list[Subset]]:
    """Split CIFAR-10 dataset into task specific subsets.

    Args:
        tasks (int): Number of tasks to split the dataset into.

    Returns:
        tuple[list[Subset], list[Subset]]: Tuple containing two list with the train and test subsets.
    """
    classes = get_classes()
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
    classes_per_task = torch.linspace(0, len(classes), tasks + 1, dtype=torch.int)
    trainsets = []
    testsets = []
    train_targets = torch.tensor(trainset.targets)
    test_targets = torch.tensor(testset.targets)
    for i in range(len(classes_per_task) - 1):
        train_indices = []
        test_indices = []
        for j in range(classes_per_task[i], classes_per_task[i + 1]):
            train_indices.extend((train_targets == j).nonzero(as_tuple=False).flatten().tolist())
            test_indices.extend((test_targets == j).nonzero(as_tuple=False).flatten().tolist())
        trainsets.append(Subset(trainset, train_indices))
        testsets.append(Subset(testset, test_indices))
    return trainsets, testsets

In [9]:
def train_standard_training_pipeline(
    run_name: str,
    trainset: Dataset,
    testset: Dataset,
    model: nn.Module,
    device: torch.device,
    criterion: nn.modules.loss._Loss,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler.LRScheduler,
    epochs: int,
    batch_size: int,
):
    # create dataloaders
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
    # move model to device
    model.to(device=device)

    # setup logging
    project_name = "continual_learning"
    run_name = f"{run_name}_{datetime.now(tz=timezone.utc).strftime('%Y_%m_%d_%H_%M_%S')}"
    wandb.init(
        # Set the project where this run will be logged
        project=project_name,
        # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
        name=run_name,
        # Track hyperparameters and run metadata
        config={
            "architecture": "CNN",
            "dataset": "CIFAR-10",
            "epochs": epochs,
            "batch_size": batch_size,
        },
    )

    # training
    for epoch in range(epochs):
        # train one epoch
        with tqdm(total=len(trainloader) * batch_size, unit="images") as progress_bar:
            model.train()
            for i, (images, labels) in enumerate(trainloader):
                progress_bar.set_description(f"Epoch {epoch+1} Batch {i}")
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                prediction = model(images.to(device=device))
                # calc loss
                loss = criterion(prediction, labels.to(device=device))
                # backward
                loss.backward()
                # optimizer
                optimizer.step()
                # scheduler
                scheduler.step()
                progress_bar.set_postfix(loss=loss.item())
                progress_bar.update(batch_size)
                wandb.log({"loss": loss})

        # validate
        correct = 0
        total = 0
        # since we're not training, we don't need to calculate the gradients for our outputs
        model.eval()
        with tqdm(total=len(testloader) * batch_size, unit="images") as progress_bar, torch.no_grad():
            for i, (images, labels) in enumerate(testloader):
                progress_bar.set_description(f"Batch {i}")
                # calculate outputs by running images through the network
                outputs = model(images.to(device=device))
                # the class with the highest energy is what we choose as prediction
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels.to(device=device)).sum().item()
                progress_bar.update(batch_size)
        accuracy = 100 * correct // total
        wandb.log({"accuracy": accuracy})

        # save model
        path = Path(wandb.run.dir).joinpath(f"model{epoch}.pth")
        torch.save(model.state_dict(), path)

    # save final model
    path = Path(wandb.run.dir).joinpath("model.pth")
    torch.save(model.state_dict(), path)

    # finish logging run
    wandb.finish()

In [10]:
def train_tasks_sequentially(  # noqa: PLR0913
    model: nn.Module,
    device: torch.device,
    criterion: nn.modules.loss._Loss,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler.LRScheduler,
    epochs: int,
    batch_size: int,
    tasks: int,
) -> None:
    # get dataset
    trainsets, testsets = get_datasets(tasks=tasks)

    for i in range(tasks):
        train_standard_training_pipeline(
            run_name=f"task{i+1}_of_{tasks}",
            trainset=trainsets[i],
            testset=testsets[i],
            model=model,
            device=device,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            epochs=epochs,
            batch_size=batch_size,
        )

broken -> seems to return same model even with different run names


In [21]:
def load_trained_model(run_name: str):
    convnext = ConvNeXtV2(in_chans=3, num_classes=10, depths=[2, 2, 6, 2], dims=[40, 80, 160, 320])

    best_model = wandb.restore(
        "model.pth",
        run_path=f"fabianfuchs/continual_learning/{run_name}",
        root=Path.cwd().joinpath("checkpoints"),
        replace=True,
    )

    # use the "name" attribute of the returned object if your framework expects a filename, e.g. as in Keras
    convnext.load_state_dict(torch.load(best_model.name))
    return convnext

In [14]:
def accuracy(testset: Dataset, model: nn.Module, device: torch.device) -> float:
    testloader = DataLoader(testset, shuffle=False)

    model.eval()
    correct = 0
    for images, labels in testloader:
        # calculate outputs by running images through the network
        predictions = model(images.to(device=device))
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(predictions.data, 1)
        correct += (predicted == labels.to(device=device)).sum().item()
    return 100 * correct / len(testloader)

In [15]:
def average_accuracy(testsets: list[Dataset], model: nn.Module, device: torch.device) -> float:
    average_accuracy = 0
    for i in range(len(testsets)):
        task_accuracy = accuracy(testset=testsets[i], model=model, device=device)
        average_accuracy += task_accuracy
        print(f"{task_accuracy=}")
    return average_accuracy / len(testsets)

In [27]:
trainsets, testsets = get_datasets(tasks=5)
model_after_task5 = load_trained_model("nod8s4pu")
average_accuracy(testsets=testsets, model=model_after_task5, device="cpu")


Files already downloaded and verified
Files already downloaded and verified
task_accuracy=0.0
task_accuracy=0.0
task_accuracy=0.0
task_accuracy=0.0
task_accuracy=81.3


16.259999999999998

In [28]:
trainsets, testsets = get_datasets(tasks=5)
model_after_task4 = load_trained_model("ips9plod")
average_accuracy(testsets=testsets, model=model_after_task4, device="cpu")


Files already downloaded and verified
Files already downloaded and verified
task_accuracy=0.1
task_accuracy=0.0
task_accuracy=0.0
task_accuracy=86.95
task_accuracy=0.0


17.41

In [29]:
trainsets, testsets = get_datasets(tasks=5)
model_after_task3 = load_trained_model("90uuu6hf")
average_accuracy(testsets=testsets, model=model_after_task3, device="cpu")


Files already downloaded and verified
Files already downloaded and verified
task_accuracy=0.0
task_accuracy=0.0
task_accuracy=76.65
task_accuracy=0.0
task_accuracy=0.0


15.330000000000002

In [30]:
trainsets, testsets = get_datasets(tasks=5)
model_after_task2 = load_trained_model("la3fepkv")
average_accuracy(testsets=testsets, model=model_after_task2, device="cpu")


Files already downloaded and verified
Files already downloaded and verified
task_accuracy=0.0
task_accuracy=75.4
task_accuracy=0.0
task_accuracy=0.0
task_accuracy=0.0


15.080000000000002

In [31]:
trainsets, testsets = get_datasets(tasks=5)
model_after_task1 = load_trained_model("l7eev0gl")
average_accuracy(testsets=testsets, model=model_after_task1, device="cpu")


Files already downloaded and verified
Files already downloaded and verified
task_accuracy=85.9
task_accuracy=0.0
task_accuracy=0.0
task_accuracy=0.0
task_accuracy=0.0


17.18

## Standard Setting


In [19]:
epochs = 100
batch_size = 64
lr = 0.001
momentum = 0.9
device = torch.device("cpu")
tasks = 1

In [20]:
convnext = ConvNeXtV2(in_chans=3, num_classes=10, depths=[2, 2, 6, 2], dims=[40, 80, 160, 320])

In [None]:
optimizer = optim.SGD(convnext.parameters(), lr=lr, momentum=momentum)

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=lr,
    steps_per_epoch=5000 // tasks,
    epochs=epochs,
)

In [None]:
train_tasks_sequentially(
    model=convnext,
    device=device,
    optimizer=optimizer,
    scheduler=scheduler,
    criterion=criterion,
    tasks=1,
)

In [None]:
# train_tasks_sequentially(
#     model=convnext,
#     device=torch.device("cpu"),
#     optimizer=optimizer,
#     scheduler=scheduler,
#     criterion=criterion,
#     tasks=5,
# )