# Train VAE

## import required packages

In [1]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary
from pytorch_lightning.loggers import WandbLogger

from data_modules import MNIST
from models.naive_vae import NaiveVAE

## Instantiate pl.Datamodule & pl.LightningModule

Data dir should be consistent in order not to re-download the dataset.

In [2]:
dm = MNIST.MNISTDataModule("./downloads")
dm.prepare_data()
dm.setup()
vae = NaiveVAE(in_channels=1, hidden_dim=[8, 16, 32, 64], latent_dim=128, lr=1e-4)

## Debug model

### fast_dev_run

Thanks to pytorch-lightning's fast_dev_run mode, the model would be trained based on a single batch.  
ModelSummary callback prints the dimension of the intermediate results, which is estimated according to the example input.
The example input, `example_input_array`, is define in the model's `__init__` function.

In [None]:
# debug train
trainer = pl.Trainer(
    fast_dev_run=True,
    callbacks=[ModelSummary(max_depth=1)],
)
trainer.fit(model=vae, datamodule=dm)

### Overfit on small batch

In order to check whether the model is expressive enough, it'd be better to train on small batches. The model should overfit on them quickly.  
Learning rate can be adjusted, by tracking gradients while overfitting.

In [None]:
trainer = pl.Trainer(
    default_root_dir="checkpoints/naive_vae",
    overfit_batches=10,
    track_grad_norm=2,
    max_epochs=500,
)

trainer.fit(model=vae, datamodule=dm)

## Train the Model

In [None]:
wandb_logger = WandbLogger(name='Adam-32-0.001',project='NaiveVAE', log_model='all')
wandb_logger.watch(vae, log='all')

checkpoint_callback = ModelCheckpoint(monitor="val_total_loss", mode="min", save_top_k=2)
trainer = pl.Trainer(
    default_root_dir="checkpoints/naive_vae",
    track_grad_norm=2,
    max_epochs=2,
    logger=wandb_logger,
    callbacks=[checkpoint_callback]
)
trainer.fit(model=vae,datamodule=dm)
wandb_logger.finalize()

TypeError: WandbLogger.finalize() missing 1 required positional argument: 'status'

## Load Checkpoint