# Variational Autoencoder for Images

- Small CNN-based network
- Training on CIFAR10
- Logging via tensorboard
- [Nice guide](https://towardsdatascience.com/variational-autoencoder-demystified-with-pytorch-implementation-3a06bee395ed)

In [None]:
# %pip install torch torchvision torcheval torchsummary tensorboard einops

## Define Model

In [None]:
from torchsummary import summary
from src.vae import VAE

vae = VAE()
print(vae)
summary(vae, (3, 32, 32), batch_size=8192)

## Train setup

In [None]:
import torch
from torch.utils.tensorboard import SummaryWriter

from src.util import train, test
from src.vae import VAE

def run_training(pretrained_file: str=None,
                 checkpoint_dir: str='/home/jo/git/vae-playground/data/checkpoints/',
                 overfit: bool=False
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    log_writer = SummaryWriter()
    vae = VAE().to(device)    
    pretrained_path = checkpoint_dir + pretrained_file if pretrained_file else None
    # train(vae, num_epochs=2, val_freq=1)
    vae = train(vae, 
                log_writer=log_writer,
                checkpoint_dir=checkpoint_dir,
                pretrained_path=pretrained_path,
                num_epochs=40000,
                val_freq=400,
                # todo here
                learn_rate=1e-5,
                overfit=overfit)
    test(vae, log_writer)

## Make sure we can overfit on a small subset of the images

In [None]:
# pretrained_file = '2023-09-06 11:42:33.149736_vae_1999.pt'
pretrained_file = None

run_training(pretrained_file=pretrained_file, overfit=True)

In [None]:
# from torchvision.datasets import CIFAR10
# from PIL import Image
# train_set = CIFAR10(root='./data', download=True, train=True)

# img = Image.fromarray(train_set.data[0])
# img.show()