## Generative Latent Replay with GMMs

Code to test latent replay on benchmark problems and compare with proposed generative latent replay strategies i.e. normalising bottleneck representations and sampling from fitted GMM on latent space.

In [None]:
import random
import numpy as np
from matplotlib import pyplot as plt

# ML imports
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torchvision import transforms as T

from avalanche.training import Naive, Replay, plugins
from avalanche.benchmarks.classic import (
    RotatedMNIST,
    PermutedMNIST,
    PermutedOmniglot,
    RotatedOmniglot,
)

# Local imports
import utils
from models import SimpleCNN, SimpleMLP
from strategies import LatentReplay, GenerativeLatentReplay


Hyperparameter settings

In [None]:
# Config
device = utils.get_device()

# Dataset specific attributes
dataset = RotatedMNIST  # PermutedMNIST
n_classes = 10
n_experiences = 3

# Model specification
model = "mlp"
hidden_size = 64
n_hidden_layers = 4

# Frozen backbone
freeze_depth = 2  # assert freeze_depth <= n_hidden_layers
if model == "mlp":
    latent_layer_number = freeze_depth * 3
elif model == "cnn":
    latent_layer_number = freeze_depth * 2

# Hyperparams
lr = 0.001
l2 = 0.0005
momentum = 0.9

n_epochs = 40
train_mb_size = 128
eval_mb_size = 512

replay_buffer_size = 5000
cl_plugins = None  # [plugins.EarlyStoppingPlugin(patience=1, val_stream_name="train_stream")]  # JA: need to set params

eval_every = -1

# Reproducibility
SEED = 109
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)


Building base model

In [None]:
# Model
n = 6

if model == "mlp":
    networks = [SimpleMLP(n_classes, hidden_size, n_hidden_layers) for i in range(n)]
    transform = T.Compose([T.ToTensor(), T.Lambda(torch.flatten)])

elif model == "cnn":
    networks = [SimpleCNN(n_classes, hidden_size, n_hidden_layers) for i in range(n)]
    transform = T.Compose([T.ToTensor()])


Building Continual Learning methods for comparison

In [None]:
# Continual learning strategy
gen_lat_replay_strategy = GenerativeLatentReplay(
    model=networks[4],
    rm_sz=replay_buffer_size,
    train_mb_size=train_mb_size,
    train_epochs=n_epochs,
    eval_mb_size=eval_mb_size,
    latent_layer_num=latent_layer_number,
    device=device,
    eval_every=eval_every,
    plugins=cl_plugins,
)

gen_replay_strategy = GenerativeLatentReplay(
    model=networks[5],
    rm_sz=replay_buffer_size,
    train_mb_size=train_mb_size,
    train_epochs=n_epochs,
    eval_mb_size=eval_mb_size,
    latent_layer_num=0,
    device=device,
    eval_every=eval_every,
    plugins=cl_plugins,
)

In [None]:
naive_strategy = Naive(
    model=networks[0],
    optimizer=SGD(networks[0].parameters(), lr=lr, momentum=momentum, weight_decay=l2),
    train_mb_size=train_mb_size,
    train_epochs=n_epochs,
    eval_mb_size=eval_mb_size,
    device=device,
    eval_every=eval_every,
)

replay_strategy = Replay(
    model=networks[1],
    optimizer=SGD(networks[1].parameters(), lr=lr, momentum=momentum, weight_decay=l2),
    train_mb_size=train_mb_size,
    train_epochs=n_epochs,
    eval_mb_size=eval_mb_size,
    device=device,
    criterion=CrossEntropyLoss(),
    eval_every=eval_every,
    plugins=cl_plugins,
)

In [None]:
# Sanity check - shuld perform similar to "Replay"
replay_thawed_strategy = LatentReplay(
    model=networks[2],
    rm_sz=replay_buffer_size,
    train_mb_size=train_mb_size,
    train_epochs=n_epochs,
    eval_mb_size=eval_mb_size,
    latent_layer_num=0,
    device=device,
    subsample_replays=True,
    eval_every=eval_every,
    plugins=cl_plugins,
)

lat_replay_strategy = LatentReplay(
    model=networks[3],
    rm_sz=replay_buffer_size,
    train_mb_size=train_mb_size,
    train_epochs=n_epochs,
    eval_mb_size=eval_mb_size,
    latent_layer_num=latent_layer_number,
    device=device,
    subsample_replays=True,
    eval_every=eval_every,
    plugins=cl_plugins,
)

In [None]:
strategies = {
    #"Naive": {"model": naive_strategy, "results": []},
    #"Replay": {"model": replay_strategy, "results": []},
    #"Replay (re-imp)": {"model": replay_thawed_strategy, "results": []},
    "Latent Replay": {"model": lat_replay_strategy, "results": []},
    #"Generative Replay": {"model": gen_replay_strategy, "results": []},
    "Generative Latent Replay": {"model": gen_lat_replay_strategy, "results": []},
}

Training loop

In [None]:
for strategy in strategies.keys():
    experiences = dataset(
        n_experiences=n_experiences,
        train_transform=transform,
        eval_transform=transform,
        seed=SEED,
        rotations_list=[0, 60, 300],
    )
    train_stream = experiences.train_stream
    test_stream = experiences.test_stream

    for train_exp in train_stream:
        strategies[strategy]["model"].train(train_exp)
        strategies[strategy]["results"].append(
            strategies[strategy]["model"].eval(train_stream)
        )


Plotting

In [None]:
import importlib
importlib.reload(utils)

fig, axes = plt.subplots(1, (len(strategies)), sharey="row")

utils.plot_multiple_results(
    [s["results"] for s in strategies.values()],
    strategies.keys(),
    axes,
    fig,
    n_experiences,
)
