# Variational Autoencoder for Images

- Small CNN-based network
- Training on CIFAR10
- Logging via tensorboard
- [Nice guide](https://towardsdatascience.com/variational-autoencoder-demystified-with-pytorch-implementation-3a06bee395ed)

## Setup Stuff

In [7]:
import torch

assert torch.cuda.is_available()

## Define Model

In [8]:
from torchsummary import summary
from vae_playground.vae import VAE

vae = VAE()
print(vae)
summary(
    model=vae,
    input_size= (3, 32, 32),
    batch_size=8192,
    device='cpu'
)

Linear spatial dims: 4
VAE(
  (encoder): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): ReLU()
  )
  (linear_mu): Linear(in_features=1024, out_features=64, bias=True)
  (linear_log_var): Linear(in_features=1024, out_features=64, bias=True)
  (decoder): Sequential(
    (0): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (1): ReLU()
    (2): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (3): ReLU()
    (4): ConvTranspose2d(16, 3, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (5): Tanh()
  )
  (linear_decoder): Linear(in_features=64, out_features=1024, bias=True)
)
----------------------------------------------------------

## Train setup

In [11]:
from pathlib import Path
from typing import Optional

import torch
from torch.utils.tensorboard import SummaryWriter

from vae_playground.util import train, test
from vae_playground.vae import VAE

def run_training(pretrained_file: Optional[str]=None,
                 checkpoint_dir: str='/home/jo/git/vae-playground/data/checkpoints/',
                 overfit: bool=False
):
    Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')
    log_writer = SummaryWriter()
    vae = VAE().to(device)    
    pretrained_path = checkpoint_dir + pretrained_file if pretrained_file else None
    # train(vae, num_epochs=2, val_freq=1)
    vae = train(vae, 
                log_writer=log_writer,
                checkpoint_dir=checkpoint_dir,
                pretrained_path=pretrained_path,
                num_epochs=40000,
                batch_size=2056,
                val_freq=400,
                # todo here
                learn_rate=1e-5,
                overfit=overfit)
    test(vae, log_writer)

## Make sure we can overfit on a small subset of the images

In [None]:
# pretrained_file = '2023-09-06 11:42:33.149736_vae_1999.pt'
pretrained_file = None

run_training(pretrained_file=pretrained_file, overfit=True)

Using device: cuda
Linear spatial dims: 4


  entry = pickle.load(f, encoding="latin1")
Training...:   0%|          | 5/40000 [00:38<85:24:16,  7.69s/it, loss=0.106750] 


KeyboardInterrupt: 

In [None]:
# from torchvision.datasets import CIFAR10
# from PIL import Image
# train_set = CIFAR10(root='./data', download=True, train=True)

# img = Image.fromarray(train_set.data[0])
# img.show()