# VQ-VAE in Pytorch with pytorch-lightning & wandb
{ [paper](https://arxiv.org/pdf/1711.00937.pdf) }
{ code : [refer1](https://github.com/deepmind/sonnet/tree/master),
[refer2](https://github.com/zalandoresearch/pytorch-vq-vae),
[refer3](https://github.com/anantzoid/Conditional-PixelCNN-decoder),
[refer4](https://github.com/j-min/PixelCNN)}
{ deploy : [flask]() }

기본적인 Model 구조는 DeepMind을 따라 CNN with residual block based E/Decoder로 구현하였습니다.

추가적으로 PixelCNN Decoder을 통한 Sampler을 시도 중입니다.

## Setup

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%cd /content/drive/Shareddrives/Colab/Lionroket/VQ-VAE

/content/drive/Shareddrives/Colab/Lionroket/VQ-VAE


In [3]:
%%bash
pip install einops wandb pytorch-lightning adamp omegaconf umap-learn -q
wandb online

W&B online, running your script from this directory will now sync to the cloud.


## Imports

In [4]:
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.datasets import CIFAR10
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import wandb
from adamp import AdamP
from omegaconf import OmegaConf
import yaml
import os, sys
from einops import rearrange, reduce, asnumpy, parse_shape
from einops.layers.torch import Rearrange, Reduce
#import umap

## Data

In [5]:
from Module.DataModule import CIFAR10_DataModule

## Model

In [6]:
from Module.VQVAE import VQVAE

## Model config

In [7]:
pwd Config/sweep_config.json

'/content/drive/Shareddrives/Colab/Lionroket/VQ-VAE'

In [8]:
config = OmegaConf.load('Config/sweep_config.json')
print(yaml.dump(OmegaConf.to_container(config)))

method: random
metric:
  goal: miniimize
  name: val_loss
parameters:
  batch_size:
    values:
    - 32
  beta:
    values:
    - 0.25
  embedding_dim:
    values:
    - 64
  lr:
    values:
    - 0.001
  num_embeddings:
    values:
    - 512
  num_hiddens:
    values:
    - 128
  num_residual_hiddens:
    values:
    - 32
  num_residual_layers:
    values:
    - 2



## Train

In [9]:
sweep_id = wandb.sweep(OmegaConf.to_container(config), entity='nemod-leo', project="VQ-VAE")

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


wandb: Paste an API key from your profile and hit enter: ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Create sweep with ID: d25ul9zq
Sweep URL: https://wandb.ai/nemod-leo/VQ-VAE/sweeps/d25ul9zq


In [10]:
def sweep_iteration():
    # set up W&B logger
    wandb.init()    # required to have access to `wandb.config`
    wandb_logger = WandbLogger(log_model=True)  # log final model

    # setup data
    cifar = CIFAR10_DataModule(batch_size=wandb.config.batch_size)
    cifar.prepare_data()
    cifar.setup()
    cifar_train_variance = cifar.get_cifar_train_variance()

    # setup model
    model = VQVAE(
        wandb.config.num_hiddens,
        wandb.config.num_residual_layers,
        wandb.config.num_residual_hiddens,
        wandb.config.num_embeddings,
        wandb.config.embedding_dim,
        wandb.config.beta,
        wandb.config.lr,
        cifar_train_variance
    )
    
    # ckpt callback
    checkpoint_callback = ModelCheckpoint(
        dirpath='result/ckpt',
        filename='Cifar10_VQ-VAE_{epoch:03d}-{val_loss:.3f}-{reconstructin_loss:.3f}-{codebook_loss:.3f}-{commitment_loss:.3f}'
    )
    
    # setup trainer
    trainer = pl.Trainer(
        logger=wandb_logger,
        gpus=-1, auto_select_gpus=True,
        precision=16,
        callbacks=[checkpoint_callback],
        max_epochs=1 #30
        )

    # train
    trainer.fit(model, datamodule=cifar)

In [11]:
wandb.agent(sweep_id, function=sweep_iteration, count=1)

[34m[1mwandb[0m: Agent Starting Run: xdgpz5lo with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	beta: 0.25
[34m[1mwandb[0m: 	embedding_dim: 64
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	num_embeddings: 512
[34m[1mwandb[0m: 	num_hiddens: 128
[34m[1mwandb[0m: 	num_residual_hiddens: 32
[34m[1mwandb[0m: 	num_residual_layers: 2
[34m[1mwandb[0m: Currently logged in as: [33mduya[0m (use `wandb login --relogin` to force relogin)


Files already downloaded and verified
Files already downloaded and verified


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
Using native 16bit precision.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  f"DataModule.{name} has already been called, so it will not be called again. "
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type            | Params
---------------------------------------------
0 | encoder  | Encoder         | 355 K 
1 | decoder  | Decoder         | 281 K 
2 | vq       | VectorQuantizer | 32.8 K
3 | pre_conv | Conv2d          | 8.3 K 
---------------------------------------------
678 K     Trainable params
0         Non-trainable params
678 K     Total params
2.714     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

VBox(children=(Label(value=' 7.92MB of 7.92MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
global_step,1406.0
_runtime,84.0
_timestamp,1631177234.0
_step,30.0
reconstructin_loss,0.15626
codebook_loss,0.05904
commitment_loss,0.05904
train_loss,0.18183
epoch,0.0
trainer/global_step,1406.0


0,1
global_step,▁█
_runtime,▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇███
_timestamp,▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇███
_step,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇███
reconstructin_loss,█▅▅▅▃▅▄▃▄▃▃▃▂▂▃▂▂▂▂▂▁▂▁▂▂▂▁▁▁
codebook_loss,█▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
commitment_loss,█▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇████
