In [None]:
import torch
import torch.nn as nn
from torchmetrics import Accuracy
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader

# Setup

In [None]:
class MLP(nn.Module):
    def __init__(
        self,
        n_classes=10,
        input_channels=1,
    ) -> None:
        super().__init__()
        self._conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=input_channels,
                out_channels=32,
                kernel_size=3,
            ),
            nn.ReLU(),
        )
        self._conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=32,
                out_channels=64,
                kernel_size=3,
            ),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout(),
        )
        self._lf1 = nn.Sequential(
            nn.Linear(
                in_features=9216,
                out_features=128,
            ),
           nn.ReLU(),
           nn.Dropout(),
        )
        self._lf2 = nn.Linear(
            in_features=128,
            out_features=n_classes,
        )

    def forward(self, x) -> torch.Tensor:
        x = self._conv1(x)
        x = self._conv2(x)
        x = torch.flatten(x, 1)
        x = self._lf1(x)
        x = self._lf2(x)
        return F.log_softmax(x, dim=1)


In [None]:
DATA_PATH = './tmp/'
BATCH_SIZE_TRAIN = 64
BATCH_SIZE_TEST = 1000
SEED = 1
torch.manual_seed(1)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

In [None]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)

train_loader = DataLoader(
    dataset=datasets.MNIST(
        root=DATA_PATH,
        train=True,
        download=True,
        transform=transform,
    ),
    batch_size=BATCH_SIZE_TRAIN,
    shuffle=True,
)

test_loader = DataLoader(
    dataset=datasets.MNIST(
        root=DATA_PATH,
        train=False,
        download=True,
        transform=transform,
    ),
    batch_size=BATCH_SIZE_TEST,
    shuffle=True,
)

# Classic approach

In [None]:
def train(model: nn.Module, train_loader: DataLoader, epoch: int, optimizer: torch.optim.Optimizer,) -> None:
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(DEVICE), target.to(DEVICE)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()

    if batch_idx % 250 == 0:
      print(
        f'Train Epoch: {epoch}, [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}'
      )

    # torch.save(model.state_dict(), './results/model.pth')
    # torch.save(optimizer.state_dict(), './results/optimizer.pth')

def test(model: nn.Module, test_loader: DataLoader,) -> None:
  model.eval()
  test_loss = 0

  preds = torch.tensor([]).to(DEVICE)
  expected = torch.tensor([]).to(DEVICE)

  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(DEVICE), target.to(DEVICE)
      output = model(data)
      test_loss += F.nll_loss(output, target)

      pred = output.max(1)[1]
      preds = torch.cat((preds, pred))
      expected = torch.cat((expected, target))

  test_loss /= len(test_loader.dataset)
  acc = Accuracy(task="multiclass", num_classes=10).to(DEVICE)

  print(f'Test Accuracy: {acc(preds, expected).item():.2f}')
  print(f'Test set: Avg. loss: {test_loss:.6f}')

In [None]:
n_epochs = 3
learning_rate = 0.01

model = MLP().to(DEVICE)
optimizer = torch.optim.SGD(params=model.parameters(), lr=learning_rate,)

In [None]:
for epoch in range(n_epochs):
    train(epoch=epoch, model=model, train_loader=train_loader, optimizer=optimizer,)
    test(model=model, test_loader=test_loader,)

# Permuted MNIST

In [None]:
transform2 = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
        transforms.Lambda(lambda x: torch.permute(x, (0, 2, 1)))
    ]
)

train_loader2 = DataLoader(
    dataset=datasets.MNIST(
        root=DATA_PATH,
        train=True,
        download=True,
        transform=transform2,
    ),
    batch_size=BATCH_SIZE_TRAIN,
    shuffle=True,
)

test_loader2 = DataLoader(
    dataset=datasets.MNIST(
        root=DATA_PATH,
        train=False,
        download=True,
        transform=transform2,
    ),
    batch_size=BATCH_SIZE_TEST,
    shuffle=True,
)

In [None]:
test(model=model, test_loader=test_loader2,)

In [None]:
for epoch in range(n_epochs):
    train(epoch=epoch, model=model, train_loader=train_loader2, optimizer=optimizer,)
    test(model=model, test_loader=test_loader2,)

In [None]:
print("Testing on the first task:")
test(model=model, test_loader=test_loader,)

print("Testing on the second task:")
test(model=model, test_loader=test_loader2,)

# CL

#### Naive strategy

In [None]:
from avalanche.training.supervised import Naive
from avalanche.benchmarks.classic import PermutedMNIST
from avalanche.training.plugins import EvaluationPlugin
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
from avalanche.logging import InteractiveLogger

In [None]:
model = MLP().to(DEVICE)
optimizer = torch.optim.SGD(params=model.parameters(), lr=learning_rate,)

In [None]:
permuted_mnist = PermutedMNIST(n_experiences=5, seed=SEED)
train_stream = permuted_mnist.train_stream
test_stream = permuted_mnist.test_stream

cl_strategy = Naive(
    model=model,
    optimizer=optimizer,
    criterion=F.nll_loss,
    train_mb_size=BATCH_SIZE_TRAIN,
    train_epochs=n_epochs,
    eval_mb_size=BATCH_SIZE_TEST,
    device=DEVICE,
    evaluator = EvaluationPlugin(
        accuracy_metrics(experience=True, stream=True),
        loss_metrics(stream=True),
        loggers=[InteractiveLogger()],
        strict_checks=True,
    )
)

In [None]:
cl_strategy.train(train_stream)
results = cl_strategy.eval(test_stream)

In [None]:
naive_results = [results[key] for key in results.keys() if 'Exp' in key]
for key in results.keys():
    print(f'{key}: {results[key]}')

### EWC

In [None]:
from avalanche.training.supervised import EWC

In [None]:
model = MLP().to(DEVICE)
optimizer = torch.optim.SGD(params=model.parameters(), lr=learning_rate,)

cl_strategy = EWC(
    model=model,
    optimizer=optimizer,
    criterion=F.nll_loss,
    train_mb_size=BATCH_SIZE_TRAIN,
    train_epochs=n_epochs,
    eval_mb_size=BATCH_SIZE_TEST,
    device=DEVICE,
    ewc_lambda=0.4,
    evaluator = EvaluationPlugin(
        accuracy_metrics(experience=True, stream=True),
        loss_metrics(stream=True),
        loggers=[InteractiveLogger()],
        strict_checks=True,
    )
)

In [None]:
cl_strategy.train(train_stream)
results = cl_strategy.eval(test_stream)

In [None]:
ewc_results = [results[key] for key in results.keys() if 'Exp' in key]
for key in results.keys():
    print(f'{key}: {results[key]}')

### GEM

In [None]:
from avalanche.training.supervised import GEM

In [None]:
model = MLP().to(DEVICE)
optimizer = torch.optim.SGD(params=model.parameters(), lr=learning_rate,)

cl_strategy = GEM(
    model=model,
    optimizer=optimizer,
    criterion=F.nll_loss,
    train_mb_size=BATCH_SIZE_TRAIN,
    train_epochs=n_epochs,
    eval_mb_size=BATCH_SIZE_TEST,
    device=DEVICE,
    patterns_per_exp=200,
    evaluator = EvaluationPlugin(
        accuracy_metrics(experience=True, stream=True),
        loss_metrics(stream=True),
        loggers=[InteractiveLogger()],
        strict_checks=True,
    )
)

In [None]:
cl_strategy.train(train_stream)
results = cl_strategy.eval(test_stream)

In [None]:
gem_results = [results[key] for key in results.keys() if 'Exp' in key]
for key in results.keys():
    print(f'{key}: {results[key]}')

## Comparison

In [None]:
import seaborn as sns
import pandas as pd

In [None]:
sns.boxplot(data=pd.DataFrame({'Naive': naive_results, 'EWC': ewc_results, 'GEM': gem_results}))

# Class incremental

In [None]:
from avalanche.benchmarks.classic import SplitMNIST

In [None]:
permuted_mnist = SplitMNIST(n_experiences=2, seed=SEED, fixed_class_order= [0, 2, 4, 6, 8, 1, 3, 5, 7, 9])
train_stream = permuted_mnist.train_stream
test_stream = permuted_mnist.test_stream

In [None]:
model = MLP().to(DEVICE)
optimizer = torch.optim.SGD(params=model.parameters(), lr=learning_rate,)

cl_strategy = Naive(
    model=model,
    optimizer=optimizer,
    criterion=F.nll_loss,
    train_mb_size=BATCH_SIZE_TRAIN,
    train_epochs=n_epochs,
    eval_mb_size=BATCH_SIZE_TEST,
    device=DEVICE,
    evaluator = EvaluationPlugin(
        accuracy_metrics(experience=True, stream=True),
        loss_metrics(stream=True),
        loggers=[InteractiveLogger()],
        strict_checks=True,
    )
)

for stream in train_stream:
    print(f'Classes in this experience: {stream.classes_in_this_experience}')
    cl_strategy.train(stream)
results = cl_strategy.eval(test_stream)

naive_results = [results[key] for key in results.keys() if 'Exp' in key]
for key in results.keys():
    print(f'{key}: {results[key]}')

In [None]:
model = MLP().to(DEVICE)
optimizer = torch.optim.SGD(params=model.parameters(), lr=learning_rate,)

cl_strategy = EWC(
    model=model,
    optimizer=optimizer,
    criterion=F.nll_loss,
    train_mb_size=BATCH_SIZE_TRAIN,
    train_epochs=n_epochs,
    eval_mb_size=BATCH_SIZE_TEST,
    device=DEVICE,
    ewc_lambda=0.2,
    evaluator = EvaluationPlugin(
        accuracy_metrics(experience=True, stream=True),
        loss_metrics(stream=True),
        loggers=[InteractiveLogger()],
        strict_checks=True,
    )
)

for stream in train_stream:
    print(f'Classes in this experience: {stream.classes_in_this_experience}')
    cl_strategy.train(stream)
results = cl_strategy.eval(test_stream)

ewc_results = [results[key] for key in results.keys() if 'Exp' in key]
for key in results.keys():
    print(f'{key}: {results[key]}')

In [None]:
model = MLP().to(DEVICE)
optimizer = torch.optim.SGD(params=model.parameters(), lr=learning_rate,)

cl_strategy = GEM(
    model=model,
    optimizer=optimizer,
    criterion=F.nll_loss,
    train_mb_size=BATCH_SIZE_TRAIN,
    train_epochs=n_epochs,
    eval_mb_size=BATCH_SIZE_TEST,
    device=DEVICE,
    patterns_per_exp=200,
    evaluator = EvaluationPlugin(
        accuracy_metrics(experience=True, stream=True),
        loss_metrics(stream=True),
        loggers=[InteractiveLogger()],
        strict_checks=True,
    )
)

for stream in train_stream:
    print(f'Classes in this experience: {stream.classes_in_this_experience}')
    cl_strategy.train(stream)
results = cl_strategy.eval(test_stream)

gem_results = [results[key] for key in results.keys() if 'Exp' in key]
for key in results.keys():
    print(f'{key}: {results[key]}')

In [None]:
sns.boxplot(data=pd.DataFrame({'Naive': naive_results, 'EWC': ewc_results, 'GEM': gem_results}))

# Different dataset

In [None]:
from avalanche.benchmarks.generators import nc_benchmark

In [None]:
train_transform = transforms.Compose(
    [
        transforms.RandomCrop(28, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)
test_transform = transforms.Compose(
    [
        transforms.Resize(28),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)

In [None]:
cifar_train = datasets.CIFAR100(root='./tmp', transform=train_transform, train=True, download=True,)
cifar_test = datasets.CIFAR100(root='./tmp', transform=test_transform, train=False, download=True,)

In [None]:
scenario = nc_benchmark(
    train_dataset=cifar_train,
    test_dataset=cifar_test,
    n_experiences=5,
    shuffle=True,
    seed=SEED,
    task_labels=False,
)

cifar_train_stream = scenario.train_stream
cifar_test_stream = scenario.test_stream

In [None]:
model = MLP(n_classes=100, input_channels=3,).to(DEVICE)
optimizer = torch.optim.SGD(params=model.parameters(), lr=learning_rate,)

cl_strategy = Naive(
    model=model,
    optimizer=optimizer,
    criterion=F.nll_loss,
    train_mb_size=BATCH_SIZE_TRAIN,
    train_epochs=n_epochs,
    eval_mb_size=BATCH_SIZE_TEST,
    device=DEVICE,
    evaluator = EvaluationPlugin(
        accuracy_metrics(experience=True, stream=True),
        loss_metrics(stream=True),
        loggers=[InteractiveLogger()],
        strict_checks=True,
    )
)

cl_strategy.train(cifar_train_stream)
results = cl_strategy.eval(cifar_test_stream)

naive_results = [results[key] for key in results.keys() if 'Exp' in key]
for key in results.keys():
    print(f'{key}: {results[key]}')

In [None]:
model = MLP(n_classes=100, input_channels=3,).to(DEVICE)
optimizer = torch.optim.SGD(params=model.parameters(), lr=learning_rate,)

cl_strategy = EWC(
    model=model,
    optimizer=optimizer,
    criterion=F.nll_loss,
    train_mb_size=BATCH_SIZE_TRAIN,
    train_epochs=n_epochs,
    eval_mb_size=BATCH_SIZE_TEST,
    device=DEVICE,
    ewc_lambda=0.2,
    evaluator = EvaluationPlugin(
        accuracy_metrics(experience=True, stream=True),
        loss_metrics(stream=True),
        loggers=[InteractiveLogger()],
        strict_checks=True,
    )
)

cl_strategy.train(cifar_train_stream)
results = cl_strategy.eval(cifar_test_stream)

ewc_results = [results[key] for key in results.keys() if 'Exp' in key]
for key in results.keys():
    print(f'{key}: {results[key]}')

In [None]:
model = MLP(n_classes=100, input_channels=3,).to(DEVICE)
optimizer = torch.optim.SGD(params=model.parameters(), lr=learning_rate,)

cl_strategy = GEM(
    model=model,
    optimizer=optimizer,
    criterion=F.nll_loss,
    train_mb_size=BATCH_SIZE_TRAIN,
    train_epochs=n_epochs,
    eval_mb_size=BATCH_SIZE_TEST,
    device=DEVICE,
    patterns_per_exp=200,
    evaluator = EvaluationPlugin(
        accuracy_metrics(experience=True, stream=True),
        loss_metrics(stream=True),
        loggers=[InteractiveLogger()],
        strict_checks=True,
    )
)

cl_strategy.train(cifar_train_stream)
results = cl_strategy.eval(cifar_test_stream)

gem_results = [results[key] for key in results.keys() if 'Exp' in key]
for key in results.keys():
    print(f'{key}: {results[key]}')

In [None]:
sns.boxplot(data=pd.DataFrame({'Naive': naive_results, 'EWC': ewc_results, 'GEM': gem_results}))