## Generative Latent Replay

Experimental code to test generative latent replay on benchmark continual learning problems.

i.e. normalising bottleneck representations and sampling from fitted GMM on latent space.

In [None]:
# ML imports
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torchvision import transforms as T
from avalanche.training import Naive, Replay, plugins
from avalanche.benchmarks.classic import RotatedMNIST, PermutedMNIST

# Local imports
from src import utils, plotting, models
from src.strategies import LatentReplay, GenerativeLatentReplay


Setup

In [None]:
# Reproducibility
SEED = 731
utils.set_seed(SEED)

Problem definition

In [23]:
# Dataset specific attributes
n_experiences = 5
transform = T.Compose([T.ToTensor(), T.Lambda(torch.flatten)])

experiences = RotatedMNIST(
        n_experiences=n_experiences,
        train_transform=transform,
        eval_transform=transform,
        seed=SEED,
        rotations_list=[0, 60, 120, 180, 300],
    )
#experiences = utils.shrink_dataset(experiences, 1000)

In [25]:
import importlib
importlib.reload(utils)

experiences = utils.shrink_dataset(experiences, 1000)

1000
1000
1000
1000
1000
1000
1000
1000
1000
1000


<avalanche.benchmarks.utils.avalanche_dataset.AvalancheSubset at 0x1987737ee20>

Hyperparameters

In [None]:
# Replays
replay_buffer_size = 600

# Frozen backbone
freeze_depth = 1
latent_layer_number = freeze_depth * 3

# Generic hyperparams
lr = 0.1  # 0.001
l2 = 0.0005
momentum = 0.9

eval_rate = 1

strategy_kwargs = {
    "train_epochs": 40,
    "train_mb_size": 128,
    "eval_mb_size": 512,
    "device": utils.get_device(),
    "plugins": [
        plugins.EarlyStoppingPlugin(
            patience=eval_rate, val_stream_name="train_stream/Task000", margin=0.005, #metric
        )
    ],
    "eval_every": eval_rate,
}


Building base model

In [None]:
# Model specification
model = "mlp"

model_kwargs = {
    "drop_rate": 0,
    "num_classes": 10,
    "hidden_size": 100,
    "hidden_layers": 2,
}

In [None]:
# Model
n_models = 6

if model == "mlp":
    networks = [models.SimpleMLP(**model_kwargs) for i in range(n_models)]
    transform = T.Compose([T.ToTensor(), T.Lambda(torch.flatten)])

elif model == "cnn":
    networks = [models.SimpleCNN(**model_kwargs) for i in range(n_models)]
    transform = T.Compose([T.ToTensor()])

Building Continual Learning methods for comparison

In [None]:
# Baseline
naive_strategy = Naive(
    model=networks[0],
    optimizer=SGD(networks[0].parameters(), lr=lr, momentum=momentum, weight_decay=l2),
    evaluator=utils.get_eval_plugin('naive'),
    **strategy_kwargs,
)

# Benchmark - raw and latent replay
replay_strategy = Replay(
    model=networks[1],
    criterion=CrossEntropyLoss(),
    optimizer=SGD(networks[1].parameters(), lr=lr, momentum=momentum, weight_decay=l2),
    evaluator=utils.get_eval_plugin('replay'),
    **strategy_kwargs,
)

# Sanity check - should perform similar to "Replay"
replay_thawed_strategy = LatentReplay(
   model=networks[2],
   latent_layer_num=0,
   subsample_replays=True,
   rm_sz=replay_buffer_size,
   evaluator=utils.get_eval_plugin('replay_reimplemented'),
   **strategy_kwargs,
)

lat_replay_strategy = LatentReplay(
    model=networks[3],
    rm_sz=replay_buffer_size,
    latent_layer_num=latent_layer_number,
    subsample_replays=True,
    evaluator=utils.get_eval_plugin('latent_replay'),
    **strategy_kwargs,
)

gen_replay_strategy = GenerativeLatentReplay(
    model=networks[4],
    rm_sz=replay_buffer_size,
    latent_layer_num=0,
    evaluator=utils.get_eval_plugin('generative_replay'),
    **strategy_kwargs,
)

# Continual learning strategy
gen_lat_replay_strategy = GenerativeLatentReplay(
    model=networks[5],
    rm_sz=replay_buffer_size,
    latent_layer_num=latent_layer_number,
    evaluator=utils.get_eval_plugin('generative_latent_replay'),
    **strategy_kwargs,
)

# print(gen_lat_replay_strategy.model)

In [None]:
strategies = {
    "Generative Latent Replay": {"model": gen_lat_replay_strategy, "results": []},
    #"Latent Replay": {"model": lat_replay_strategy, "results": []},
    #"Replay": {"model": replay_strategy, "results": []},
    #"Naive": {"model": naive_strategy, "results": []},
    #"Replay (re-imp)": {"model": replay_thawed_strategy, "results": []},
    #"Generative Replay": {"model": gen_replay_strategy, "results": []},
}

Training loop

In [None]:
for strat in strategies.values():
    train_stream = experiences.train_stream
    test_stream = experiences.test_stream

    for train_exp in train_stream:
        strat["model"].train(train_exp)

Plotting

In [None]:
import importlib
importlib.reload(plotting)

plotting.plot_multiple_results(strategies.keys(), n_experiences, loss=True)

In [None]:
strategies_cache = strategies

In [None]:
utils.results_to_df(strategies.keys(), results)