## Generative Latent Replay with GMMs

Code to test latent replay on benchmark problems and compare with proposed generative latent replay strategies i.e. normalising bottleneck representations and sampling from fitted GMM on latent space.

In [3]:
import importlib

# ML imports
import torch
from torch.nn import CrossEntropyLoss
from torchvision import transforms

from avalanche.models import SimpleMLP
from avalanche.training import Naive
from avalanche.benchmarks.classic import PermutedMNIST #,PermutedOmniglot, RotatedOmniglot

# Local imports
import models
import utils

In [4]:
# Reload local modules after updates
importlib.reload(models)

# Config
device = utils.get_device()

# model
model = SimpleMLP(num_classes=10)

# CL Benchmark Creation
# Original AR1* usese CORE50 (n,3,128,128) and pretrained mobilenet
transform = transforms.Compose([
    #transforms.Lambda(lambda x: x.convert('RGB')),
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1))
])

perm_mnist = PermutedMNIST(n_experiences=3, train_transform=transform, eval_transform=transform)
train_stream = perm_mnist.train_stream
test_stream = perm_mnist.test_stream

# Prepare for training & testing
criterion = CrossEntropyLoss()

# Continual learning strategy
cl_strategy = models.LatentReplay(criterion, train_mb_size=32, train_epochs=2, eval_mb_size=32, device=device) #LatentReplay(criterion, train_mb_size=32, train_epochs=2, eval_mb_size=32, device=device)

# train and test loop over the stream of experiences
results = []
for train_exp in train_stream:
    cl_strategy.train(train_exp)
    results.append(cl_strategy.eval(test_stream))



-- >> Start of training phase << --
 39%|███▉      | 736/1875 [07:36<20:31,  1.08s/it]

KeyboardInterrupt: 

In [74]:
#from matplotlib import pyplot as plt

#plt.plot
results_ar1 = results