# Orthogonal Gradient Descent

In [1]:
%load_ext autoreload
%autoreload 2

# Global imports and settings

In [2]:
from torchvision.transforms import ToTensor, Compose, Normalize

In [3]:
from tqdm import tqdm

In [4]:
from avalanche.benchmarks.datasets import MNIST
from avalanche.training.plugins import EvaluationPlugin
from avalanche.benchmarks.generators import nc_benchmark
from avalanche.evaluation.metrics import (
    bwt_metrics,
    forgetting_metrics,
    accuracy_metrics,
)

from torch.nn import CrossEntropyLoss
from torch.optim import SGD

  from .autonotebook import tqdm as notebook_tqdm


## Custom libraries

In [5]:
import sys

sys.path.append("..")

from base_code.training.ogd import OGD
from base_code.constants import DATASETS_PATH
from base_code.models.mlp import MLP

# Dataset and definitions

## Preprocessing definitions

In [6]:
train_transform = Compose([
    ToTensor(),
    Normalize((0.1307,), (0.3081,))
])
test_transform = Compose([
    ToTensor(),
    Normalize((0.1307,), (0.3081,))
])

## Dataset loading

In [7]:
mnist_train = MNIST(DATASETS_PATH, train=True, transform=train_transform)
mnist_test = MNIST(DATASETS_PATH, train=False, transform=test_transform)

## Scenario creation with train test streams

In [8]:
scenario = nc_benchmark(
    mnist_train, mnist_test, len(mnist_train.classes), shuffle=True, seed=1234, task_labels=False
)

train_stream = scenario.train_stream
test_stream = scenario.test_stream

## Evaluation metrics definition

In [9]:
eval_plugin = EvaluationPlugin(
    bwt_metrics(experience=True, stream=True),
    accuracy_metrics(experience=True, stream=True),
    forgetting_metrics(experience=True, stream=True)
)



## Model, Optimizer, Loss and Strategy definitions

In [10]:
model = MLP(n_classes=scenario.n_classes, n_channels=1, width=28, height=28)
optimizer = SGD(model.parameters(), lr=0.001)
criterion = CrossEntropyLoss()
strategy = OGD(
    model, optimizer, criterion, train_mb_size=10, train_epochs=5, evaluator=eval_plugin
)

# Training and evaluation

In [11]:
results = []

for (until, experience) in tqdm(enumerate(train_stream, start=1)):
    strategy.train(experience, eval_streams=[])
    metrics = strategy.eval(test_stream[:until])
    print(metrics)

    results.append(metrics)

1it [00:22, 22.67s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'StreamBWT/eval_phase/test_stream': 0.0, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.0, 'StreamForgetting/eval_phase/test_stream': 0.0}


2it [07:14, 251.45s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'StreamBWT/eval_phase/test_stream': 0.0, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.5240128068303095, 'StreamForgetting/eval_phase/test_stream': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp000': -0.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp000': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 1.0}


3it [6:40:04, 10899.49s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'StreamBWT/eval_phase/test_stream': -0.5, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.3419943820224719, 'StreamForgetting/eval_phase/test_stream': 0.5, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp000': -0.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp000': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp001': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp001': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 1.0}


4it [7:11:40, 7345.04s/it] 

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'StreamBWT/eval_phase/test_stream': -0.6666666666666666, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.26597938144329897, 'StreamForgetting/eval_phase/test_stream': 0.6666666666666666, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp000': -0.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp000': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp001': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp001': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp002': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp002': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003': 1.0}


5it [7:49:17, 5510.04s/it]

{'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000': 0.0, 'StreamBWT/eval_phase/test_stream': -0.75, 'Top1_Acc_Stream/eval_phase/test_stream/Task000': 0.19801570897064902, 'StreamForgetting/eval_phase/test_stream': 0.75, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp000': -0.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp000': 0.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp001': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp001': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp001': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp002': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp002': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp002': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp003': 0.0, 'ExperienceBWT/eval_phase/test_stream/Task000/Exp003': -1.0, 'ExperienceForgetting/eval_phase/test_stream/Task000/Exp003': 1.0, 'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp004': 1.0}


5it [9:40:42, 6968.47s/it]


KeyboardInterrupt: 

In [None]:
eval_plugin.get_last_metrics()