> **University of Pisa** \
> **M.Sc. Computer Science, Artificial Intelligence** \
> **Continual learning 2022/23** \
> **Authors**
* Andrea Iommi - a.iommi2@studenti.unipi.it

# Memory Replay GANs
# Learning to generate images from new categories without forgetting
#### [(original paper)](https://proceedings.neurips.cc/paper/2018/hash/a57e8915461b83adefb011530b711704-Abstract.html)
### Notebooks
*   Classical acGAN in offline settings
*   Classical acGAN in online settings
*   acGAN with join retrain
*   **acGAN with replay alignment**


In [None]:
import torch
from Trainer import Trainer
from Utils import custom_mnist
from Plot_functions import generate_classes, plot_history

In [None]:
config = dict(
    device="cuda" if torch.cuda.is_available() else "cpu",
    num_classes=10,
    img_size=32,
    channels=1,
    n_epochs=[100,150],
    batch_size=32,
    embedding=100, # latent dimension of embedding
    lr_g=7e-5, # Learning rate for generator 7e-5
    lr_d=7e-5 # Learning rate for discriminator
)

## acGAN with replay alignment (training)
We create the acGAN with the replay alignment, the architecture is very similar to the classical acGAN, in fact, the model learns only with the current experience. However, in this implementation, we added new loss added, the **replay alignment**.

The behavior is the following: in the first experience, the model learns as a classical acGAN. Starting from the second experience, the generator wants to optimize not only the adversarial and auxiliary loss, but **also minimize the L2 of difference between the current generator and the previous one** (experience) in the past classes.

In [None]:
experiences = [[0,1,2,3,4],[5,6,7,8,9]]
trainer = Trainer(config=config)
history = trainer.fit_replay_alignment(experiences=custom_mnist(experiences = experiences))
# we removed all training logs

## Loss functions and Accuracy

The charts below represent the loss function for both Generator and Discriminator. On the right there is the accuracy, also in this way the accuracy is performed only with the current classes.

In [None]:
plot_history(history)

## Evaluate the architecture

As we can see, the model is able to generate quite well all digits even if it encountered them in distinct experiences.
For this specific execution, we notice that the first digits are represented better respect to the last ones. This could be caused by the value of constant in the *replay alignment loss* which leverages the regularization.

In [None]:
generate_classes(trainer.generator, config["num_classes"], rows=10, device=config["device"])