In [1]:
%load_ext autoreload
%autoreload 2

import generative_models as gm
import visualization as vis 
import dataset_utils

import torch
from torch import optim, nn
import pytorch_lightning as pl

import copy

# Preamble

In [2]:
# MNIST loader

LOADER = dataset_utils.create_loader(
    size=60000,
    batch_size=1000,
    digits=list(range(10))
    )

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [3]:
# Model configuration
config = dict(
    hidden_dims = None,
    latent_dim = 2,
    ambient_dim = 784,
    act_fun = None,
    act_args = [],
    optimizer_class = optim.Adam,
    optimizer_kwargs = {"lr": 1e-3},
)

In [4]:
# Pytorch lightning training parameters

TRAINER_KWARGS = dict(
    max_epochs=100,
    checkpoint_callback=False,
    log_every_n_steps=1, 
    flush_logs_every_n_steps=len(LOADER),
    gpus=-1, # set to 0 if cpu use is prefered,
    auto_select_gpus=True
)


# Training

In [5]:
# Hyper-parameters for a sweep
hd_set = [[512,128,32],[128,32,8]]

act_fun_set = [nn.Tanh,nn.ReLU]
act_args_set = [None,None]

iterator = tuple((i, j) for i in hd_set for j in zip(act_fun_set,act_args_set))

In [6]:
for hd, (act_fun,act_args) in iterator:
  for model_class in [gm.VAE, gm.DAE, gm.GAN]:
    
    # Name 
    name = model_class.__name__ + "_" + str(hd) + "_" + act_fun.__name__

    # Copy and modify config
    CONFIG = copy.copy(config)
    CONFIG["hidden_dims"] = hd
    CONFIG["act_fun"] = act_fun

    # Train
    print("Training", name)
    model = model_class(**CONFIG)        
    trainer = pl.Trainer(**TRAINER_KWARGS)
    trainer.fit(model, train_dataloader=LOADER)
    
    # Save the model
    trainer.save_checkpoint("models/" + name)

Training VAE_[512, 128, 32]_Tanh
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type       | Params
-----------------------------------------
0 | encoder   | Sequential | 473 K 
1 | mu_layer  | Linear     | 66    
2 | var_layer | Linear     | 66    
3 | decode    | Generator  | 473 K 
-----------------------------------------
947 K     Trainable params
0         Non-trainable params
947 K     Total params
3.788     Total estimated model params size (MB)
Epoch 99: 100%|██████████| 60/60 [00:03<00:00, 16.13it/s, loss=-2.04e+07, v_num=33]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type       | Params
--------------------------------------
0 | encode | Sequential | 473 K 
1 | decode | Generator  | 473 K 
--------------------------------------
947 K     Trainable params
0         Non-trainable params
947 K     Total params

# Results

In [7]:
i = 1
for hd, (act_fun,act_args) in iterator:
  for model_class in [gm.VAE, gm.DAE, gm.GAN]:

    print(str(i)+"/12")
    i = i + 1

    # Name 
    name = model_class.__name__ + "_" + str(hd) + "_" + act_fun.__name__

    # Copy and modify config
    CONFIG = copy.copy(config)
    CONFIG["hidden_dims"] = hd
    CONFIG["act_fun"] = act_fun

    model = model_class.load_from_checkpoint(checkpoint_path="models/" + name, **CONFIG)
    model.eval()

    # Get the limits on the latent space
    if model_class.__name__ == "DAE":
      # The latent space is not regularized
      xmin,xmax,ymin,ymax = vis.manifold_limits(model, 1000)
    else: # VAE and DAE
      xmin,xmax,ymin,ymax = -3,3,-3,3 # From a normal distribution

    if model_class.__name__ != "GAN":
      # How the latent space looks like
      latent = vis.plot_latent_space(model, 1500, list(range(10)))
      latent.savefig("figs/" + name + "_latent.png")

    # Image Grid
    p = vis.plot_image_grid(model, 15, 15, xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax)
    p.savefig("figs/" + name + "_grid.png")

    # Jacobian
    p = vis.jac_plot(model, True, 300, 45, xmin, xmax, ymin, ymax) 
    p.savefig("figs/" + name + "_jacobian.png")

    # Magnitude
    p = vis.jac_plot(model, False, 300, 45, xmin, xmax, ymin, ymax) 
    p.savefig("figs/" + name + "_magnitude.png")  

SyntaxError: invalid syntax (<ipython-input-7-6eac823dc23b>, line 6)