# Orthogonal Gradient Descent

In [1]:
%load_ext autoreload
%autoreload 2

# Global imports and settings

In [2]:
from torchvision.transforms import ToTensor, Compose, Normalize

In [3]:
from tqdm import tqdm

In [4]:
from avalanche.benchmarks.datasets import MNIST
from avalanche.training.plugins import EvaluationPlugin
from avalanche.benchmarks.generators import nc_benchmark
from avalanche.evaluation.metrics import (
    bwt_metrics,
    forgetting_metrics,
    accuracy_metrics,
)

from torch.nn import CrossEntropyLoss
from torch.optim import Adam

  from .autonotebook import tqdm as notebook_tqdm


## Custom libraries

In [5]:
import sys

sys.path.append("..")

from base_code.training.ogd import OGD
from base_code.constants import DATASET_PATH
from base_code.models.cnn import CNN2D

# Dataset and definitions

## Preprocessing definitions

In [6]:
train_transform = Compose([
    ToTensor(),
    Normalize((0.1307,), (0.3081,))
])
test_transform = Compose([
    ToTensor(),
    Normalize((0.1307,), (0.3081,))
])

## Dataset loading

In [7]:
mnist_train = MNIST(DATASET_PATH, train=True, transform=train_transform)
mnist_test = MNIST(DATASET_PATH, train=False, transform=test_transform)

## Scenario creation with train test streams

In [8]:
scenario = nc_benchmark(
    mnist_train, mnist_test, len(mnist_train.classes), shuffle=True, seed=1234, task_labels=False
)

train_stream = scenario.train_stream
test_stream = scenario.test_stream

## Evaluation metrics definition

In [14]:
eval_plugin = EvaluationPlugin(
    bwt_metrics(experience=True, stream=True),
    accuracy_metrics(experience=True, stream=True),
    forgetting_metrics(experience=True, stream=True)
)



## Model, Optimizer, Loss and Strategy definitions

In [20]:
model = CNN2D(n_classes=scenario.n_classes, n_channels=1)
optimizer = Adam(model.parameters(), lr=0.001)
criterion = CrossEntropyLoss()
strategy = OGD(
    model, optimizer, criterion, train_mb_size=32, train_epochs=1, evaluator=eval_plugin
)

# Training and evaluation

In [21]:
results = []

for (until, experience) in tqdm(enumerate(train_stream, start=1)):
    strategy.train(experience, eval_streams=[test_stream])
    metrics = strategy.eval(test_stream[:until])
    print(metrics)

    results.append(metrics)

1it [00:03,  3.99s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.9674887892376681, 'StreamBWT/eval_phase/test_stream': 0.0, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.9674887892376681, 'StreamForgetting/eval_phase/test_stream': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp000': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp000': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 1.0}


2it [00:12,  6.70s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'StreamBWT/eval_phase/test_stream': -0.9674887892376681, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.5240128068303095, 'StreamForgetting/eval_phase/test_stream': 0.9674887892376681, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp000': -0.9674887892376681, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp000': 0.9674887892376681, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 1.0}


3it [00:26, 10.01s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'StreamBWT/eval_phase/test_stream': -0.983744394618834, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.3419943820224719, 'StreamForgetting/eval_phase/test_stream': 0.983744394618834, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp000': -0.9674887892376681, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp000': 0.9674887892376681, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp001': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp001': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 1.0}


4it [00:47, 14.18s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'StreamBWT/eval_phase/test_stream': -0.9891629297458894, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.26597938144329897, 'StreamForgetting/eval_phase/test_stream': 0.9891629297458894, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp000': -0.9674887892376681, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp000': 0.9674887892376681, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp001': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp001': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp002': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp002': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003': 1.0}


5it [01:15, 19.44s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'StreamBWT/eval_phase/test_stream': -0.991872197309417, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.19801570897064902, 'StreamForgetting/eval_phase/test_stream': 0.991872197309417, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp000': -0.9674887892376681, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp000': 0.9674887892376681, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp001': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp001': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp002': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp002': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp003': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp003': 1.0, 'Top1_Acc_Exp/eval_phase/test_st

6it [01:56, 26.50s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'StreamBWT/eval_phase/test_stream': -0.9934977578475337, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.17256712844193603, 'StreamForgetting/eval_phase/test_stream': 0.9934977578475337, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp000': -0.9674887892376681, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp000': 0.9674887892376681, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp001': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp001': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp002': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp002': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp003': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp003': 1.0, 'Top1_Acc_Exp/eval_phase/test_

7it [02:55, 37.28s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'StreamBWT/eval_phase/test_stream': -0.9945814648729447, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.16256087081065598, 'StreamForgetting/eval_phase/test_stream': 0.9945814648729447, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp000': -0.9674887892376681, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp000': 0.9674887892376681, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp001': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp001': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp002': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp002': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp003': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp003': 1.0, 'Top1_Acc_Exp/eval_phase/test_

8it [04:13, 50.33s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'StreamBWT/eval_phase/test_stream': -0.995355541319667, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.12833957553058678, 'StreamForgetting/eval_phase/test_stream': 0.995355541319667, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp000': -0.9674887892376681, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp000': 0.9674887892376681, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp001': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp001': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp002': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp002': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp003': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp003': 1.0, 'Top1_Acc_Exp/eval_phase/test_st

9it [05:47, 63.96s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'StreamBWT/eval_phase/test_stream': -0.9959360986547086, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.10901001112347053, 'StreamForgetting/eval_phase/test_stream': 0.9959360986547086, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp000': -0.9674887892376681, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp000': 0.9674887892376681, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp001': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp001': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp002': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp002': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp003': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp003': 1.0, 'Top1_Acc_Exp/eval_phase/test_

10it [07:51, 47.12s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'StreamBWT/eval_phase/test_stream': -0.9963876432486298, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.101, 'StreamForgetting/eval_phase/test_stream': 0.9963876432486298, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp000': -0.9674887892376681, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp000': 0.9674887892376681, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp001': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp001': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp002': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp002': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp003': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp003': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000




In [12]:
for r in results:
    print(r)

{}
{}
{}
{}
{}
{}
{}
{}
{}
{}
