# Train VAE

## import required packages

In [1]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelSummary

from data_modules import MNIST
from models.naive_vae import NaiveVAE

## Instantiate pl.Datamodule & pl.LightningModule

Data dir should be consistent in order not to re-download the dataset.

In [2]:
dm = MNIST.MNISTDataModule("./downloads")
dm.prepare_data()
dm.setup()
vae = NaiveVAE(in_channels=1, hidden_dim=[8, 16, 32, 64], latent_dim=128, lr=1e-4)

## Debug model

### fast_dev_run

Thanks to pytorch-lightning's fast_dev_run mode, the model would be trained based on a single batch.  
ModelSummary callback prints the dimension of the intermediate results, which is estimated according to the example input.
The example input, `example_input_array`, is define in the model's `__init__` function.

In [None]:
# debug train
trainer = pl.Trainer(
    fast_dev_run=True,
    callbacks=[ModelSummary(max_depth=1)],
)
trainer.fit(model=vae, datamodule=dm)

### Overfit on small batch

In order to check whether the model is expressive enough, it'd be better to train on small batches. The model should overfit on them quickly.  
Learning rate can be adjusted, by tracking gradients while overfitting.

In [None]:
trainer = pl.Trainer(
    default_root_dir="checkpoints/naive_vae",
    overfit_batches=10,
    track_grad_norm=2,
    max_epochs=500,
)

trainer.fit(model=vae, datamodule=dm)

## Train the Model

In [3]:
trainer = pl.Trainer(
    default_root_dir="checkpoints/naive_vae",
    track_grad_norm=2,
    max_epochs=5,
    check_val_every_n_epoch=1
)
trainer.fit(model=vae,datamodule=dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name          | Type       | Params | In sizes        | Out sizes      
---------------------------------------------------------------------------------
0 | encode_blocks | Sequential | 24.6 K | [32, 1, 32, 32] | [32, 64, 2, 2] 
1 | decode_blocks | Sequential | 24.4 K | [32, 64, 2, 2]  | [32, 8, 16, 16]
2 | fc_mu         | Linear     | 32.9 K | [32, 256]       | [32, 128]      
3 | fc_logSigma   | Linear     | 32.9 K | [32, 256]       | [32, 128]      
4 | decoder_init  | Linear     | 33.0 K | [32, 128]       | [32, 256]      
5 | final_layer   | Sequential | 673    | [32, 8, 16, 16] | [32, 1, 32, 32]
---------------------------------------------------------------------------------
148 K     Trainable params
0         Non-trainable params
148 K     Total params
0.594     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]