# Imports

In [2]:
import os
import sys
sys.path.append(os.path.abspath('..'))

In [3]:
import numpy as np
from tqdm import tqdm
import torch
from src.models.early_stoppers import EarlyStopper
from src.models.meta.vanilla_vae import Vanilla_VAE
from src.models.meta.vanilla_gan import Vanilla_GAN
from src.models.meta.meta_vae import Meta_VAE
from src.models.meta.smvae import SMVAE
from src.models.marginals.vae import VAE
from src.dataset.meta_dataloaders import DataModule as metaDm
from src.dataset.marginal_dataloaders import DataModule as marginalDm

In [18]:
# How to use the dataloader
dm = marginalDm()
dm.setup()
train_dataloader = dm.train_dataloader()
print(next(iter(train_dataloader))[0].shape)

torch.Size([2, 60])


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Some useful functions

In [5]:
def get_early_stopper(patience = 100, delta = 0):
    return EarlyStopper(patience=patience,min_delta=delta)

In [6]:
def get_optimizer(model,lr,with_scheduler = True):
    optimizer=torch.optim.Adam(model.parameters(),lr=lr)
    scheduler = None
    if with_scheduler:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.4, patience=5, threshold=0.001, 
                                                       threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=False)
    return optimizer, scheduler

In [7]:
def not_trainable(model):
    for param in model.parameters():
        param.requires_grad = False

# Models

In [8]:
save_path = os.path.join(os.path.abspath('..'),"saved_models")

In [9]:
models_repos = {"Marginal VAE": os.path.join(save_path,"marginal_generators"),
                "Meta-VAE": os.path.join(save_path,"meta_vae"),
                "SMVAE": os.path.join(save_path,"smvae"),
                "Vanilla VAE": os.path.join(save_path,"vanilla_vae","fair"),
                 "Vanilla GAN": os.path.join(save_path,"vanilla_gan")
               }

# Train models

In [11]:
# The  models default values correspond to the ones used in the paper's experiments

def train_models(model_type, n=5, batch_size = 2048*2, epochs = 1000, lr=None, is_cylinder = None, verbose=False, with_scheduler=False,
                 with_early_stopper=False, with_validation=False):
    
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    
    assert model_type in models_repos, 'Model type unrecognized'
    
    if model_type == "Marginal VAE":
        assert is_cylinder!=None, 'Choose a unitary component for the marginal generator training, True for cylinders, False for densities'
        unitary_component = "cylinders" if is_cylinder else "densities"
        n=1
        
    ## Default learning rate if not explicitly specified, for the GAN, lr=0.0002 specified below
    if not lr:
        lr = 0.001
        
    for i in range(n):
        if model_type == "Marginal VAE":
            dm = marginalDm(is_cylinder = is_cylinder, batch_size=batch_size)
            n=1 # We only train once to use the model as marginal generator, comment to train n times and save the n models
            if is_cylinder:  # Specify the input size of the model depending on whether we train for densities or cylinders marginals
                model = VAE()
            else:
                model = VAE(latent_dim = 1, data_size = 30) # Change depending on the size
                
        elif model_type == "Meta-VAE":
            dm = metaDm(batch_size=batch_size)
            model = Meta_VAE()
        elif model_type == "SMVAE":
            dm = metaDm()
            model = SMVAE()
        elif model_type == "Vanilla VAE" :
            dm = metaDm(whole_system=True,batch_size=batch_size)
            # The following values are chosen to make the vanilla vae comparable in size (number of parameters) to the
            # meta vae, however with fewer parameters (default values) we get comparable results cf paper.
            model = Vanilla_VAE(n_params=100)
        else:
            dm = metaDm(whole_system=True,batch_size=batch_size)
            model = Vanilla_GAN(n_params=100)
        
        dm.setup()
        train_dataloader = dm.train_dataloader()
        
        if model_type != "Vanilla GAN":
            val_dataloader = dm.val_dataloader() if with_validation else None
            optimizer, scheduler  =  get_optimizer(model,lr, with_scheduler) 
            early_stopper = get_early_stopper() if with_early_stopper else None
        
        if model_type == "Meta-VAE":
            marginals_path = os.path.join(os.path.abspath('..'),'saved_models','marginal_generators')
            density_vae = torch.load(os.path.join(marginals_path,'densities','model_0'), map_location=device)
            cylinder_vae = torch.load(os.path.join(marginals_path,'cylinders','model_0'), map_location=device)
            not_trainable(density_vae)
            not_trainable(cylinder_vae)
            model.train(epochs,train_dataloader,density_vae,cylinder_vae,device,optimizer,scheduler,early_stopper,val_dataloader,verbose=verbose)
        elif model_type == "Vanilla GAN":
            lr = 0.0002
            optim_g, optim_d  =  get_optimizer(model.gen, lr)[0], get_optimizer(model.disc, lr)[0]
            model.train(epochs, train_dataloader, device, optim_g, optim_d)
        else:
            model.train(epochs, train_dataloader, device, optimizer,scheduler,early_stopper,val_dataloader,verbose=verbose)
        
        save_path = models_repos[model_type]
        
        if model_type == "Marginal VAE":
            save_path = os.path.join(models_repos[model_type],unitary_component)
        save_path = os.path.join(save_path,'model_{}'.format(i))
        torch.save(model,save_path)
            

### Train the Meta-VAE and Baselines

The default values correspond to the paper's experiments section. Aside from the Vanilla VAE and GAN. The saved Vanilla VAE/GAN models in the "fair" folders corrspond to models that have comparable number of parameters as the Meta-VAE ans the SMVAE. The paper parameters correspond to the saved models in the "paper_params" folders.

In [18]:
train_models("Vanilla GAN",epochs=1000,batch_size=512)

In [11]:
train_models("Vanilla VAE",epochs=1000)

In [12]:
train_models("Meta-VAE",epochs=1000)

  1%|          | 7/1000 [01:36<3:47:46, 13.76s/it]


KeyboardInterrupt: 

In [11]:
train_models("SMVAE",epochs=1500)

100%|██████████| 1500/1500 [22:30<00:00,  1.11it/s]
100%|██████████| 1500/1500 [22:19<00:00,  1.12it/s]
100%|██████████| 1500/1500 [22:07<00:00,  1.13it/s]
100%|██████████| 1500/1500 [21:50<00:00,  1.14it/s]
100%|██████████| 1500/1500 [21:54<00:00,  1.14it/s]


### Train the marginal VAEs. 
Bigger batch sizes (BS) result in poorer results. The saved models were trained using BS=1024

Density VAE

In [12]:
train_models("Marginal VAE",batch_size = 1024, epochs=1000,is_cylinder=False,with_scheduler=True,verbose=True)

  0%|          | 1/1000 [00:00<04:14,  3.92it/s]

loss  25.773477840423585


  0%|          | 2/1000 [00:00<04:12,  3.96it/s]

loss  25.378363418579102


  0%|          | 3/1000 [00:00<04:23,  3.78it/s]

loss  24.892875003814698


  0%|          | 4/1000 [00:01<04:24,  3.77it/s]

loss  24.032931327819824


  0%|          | 5/1000 [00:01<04:37,  3.59it/s]

loss  22.229809761047363


  1%|          | 6/1000 [00:01<04:38,  3.57it/s]

loss  16.662148666381835


  1%|          | 7/1000 [00:01<04:37,  3.58it/s]

loss  7.804074501991272


  1%|          | 7/1000 [00:02<05:01,  3.29it/s]


KeyboardInterrupt: 

Cylinder VAE

In [None]:
train_models("Marginal VAE",batch_size = 1024, epochs=1000,is_cylinder=True,with_scheduler=True,verbose=True)