> **University of Pisa** \
> **M.Sc. Computer Science, Artificial Intelligence** \
> **Continual learning 2022/23** \
> **Authors**
* Andrea Iommi - a.iommi2@studenti.unipi.it

# Memory Replay GANs
# Learning to generate images from new categories without forgetting
#### [(original paper)](https://proceedings.neurips.cc/paper/2018/hash/a57e8915461b83adefb011530b711704-Abstract.html)
### Notebooks
*   **Classical acGAN in offline settings**
*   Classical acGAN in online settings
*   acGAN with join retrain
*   acGAN with replay alignment


In [None]:
import torch
from Trainer import Trainer
from Utils import custom_mnist
from Plot_functions import generate_classes, plot_history

In [None]:
config = dict(
    device="cuda" if torch.cuda.is_available() else "cpu",
    num_classes=10,
    img_size=32,
    channels=1,
    n_epochs=[1],
    batch_size=64,
    embedding=100, # latent dimension of embedding
    lr_g=0.0002, # Learning rate for generator
    lr_d=0.0002 # Learning rate for discriminator
)

## Classical acGAN in offline settings (training)
As a first step, we create a classical acGAN in offline setting, where all digits are learned at the same time.
In this setting, we have only one experience that contains all digits.

In [None]:
experiences = [[0,1,2,3,4,5,6,7,8,9]] # list of experiences
exp_generator = custom_mnist(experiences = experiences)
trainer = Trainer(config=config)
history = trainer.fit_classic(experiences=exp_generator,create_gif=True)
# we removed all training logs

## Loss functions and Accuracy

The charts below represent the *loss function* for both Generator and Discriminator. Since the GAN architecture is based on Min-Max optimization, the loss functions are not smooth but irregular. Generally, finding an optimal parametrization is really hard for this kind of architecture. We relied on the original paper for the **learning rate** and **batch_size** (linked in above).

In [None]:
plot_history(history)

## Evaluate the architecture

As we can see, the model is able to generate quite well all digits.
In the following figure, we have an example of results. We identify *t* as a conditional input and *gen* as the number of examples to generate.

In [None]:
generate_classes(trainer.generator, config["num_classes"], rows=10, device=config["device"])