# Orthogonal Gradient Descent

In [None]:
%load_ext autoreload
%autoreload 2

# Global imports and settings

In [None]:
import pickle
from torchvision.transforms import ToTensor, Compose, Normalize
import pandas as pd
import seaborn as sns
from tqdm import tqdm

In [None]:
from avalanche.benchmarks.classic import PermutedMNIST
from avalanche.models import SimpleMLP
from avalanche.training.plugins import EvaluationPlugin
from avalanche.training import AGEM
from avalanche.logging import InteractiveLogger
from avalanche.evaluation.metrics import (
    accuracy_metrics,
)

from torch.nn import CrossEntropyLoss
from torch.optim import SGD

## Custom libraries

In [None]:
import sys

sys.path.append("..")

from base_code.constants import DATASETS_PATH, SAVED_METRICS_PATH

# Dataset and definitions

## Preprocessing definitions

In [None]:
train_transform = Compose([
    ToTensor(),
    Normalize((0.1307,), (0.3081,))
])
test_transform = Compose([
    ToTensor(),
    Normalize((0.1307,), (0.3081,))
])

## Dataset loading

In [None]:
scenario = PermutedMNIST(
    10,
    dataset_root=DATASETS_PATH,
    seed=1234,
)

## Scenario creation with train test streams

In [None]:
train_stream = scenario.train_stream
test_stream = scenario.test_stream

## Evaluation metrics definition

In [None]:
eval_plugin = EvaluationPlugin(
    accuracy_metrics(experience=True, stream=True),
    loggers=[InteractiveLogger()]
)

## Model, Optimizer, Loss and Strategy definitions

In [None]:
model = SimpleMLP(num_classes=scenario.n_classes, input_size=28 * 28, hidden_size=100, hidden_layers=2)
optimizer = SGD(model.parameters(), lr=0.001)
criterion = CrossEntropyLoss()
strategy = AGEM(
    model, optimizer, criterion, train_mb_size=256, train_epochs=5, evaluator=eval_plugin, eval_mb_size=128, patterns_per_exp=10
)

# Training and evaluation

In [None]:
results = []

for experience in tqdm(train_stream):
    strategy.train(experience)
    metrics = strategy.eval(test_stream)
    results.append(metrics)

# Plotting results

In [None]:
accuracies: dict[int, list[float]] = dict()

for i in range(10):
    accuracies[f"Task{i}"] = eval_plugin.get_all_metrics()[f"Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp00{i}"][1]

accuracies["Overall"] = eval_plugin.get_all_metrics()["Top1_Acc_Stream/eval_phase/test_stream/Task000"][1]

In [None]:
acc_df = pd.DataFrame(accuracies)
acc_df.index = range(10)

In [None]:
sns.lineplot(data=acc_df, dashes=False, markers=True)

# Store metrics

In [None]:
pickle.dump(accuracies, open(SAVED_METRICS_PATH / "agem.pkl", "wb"))