In [None]:
import math
from math import sqrt

import torch
from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import ExponentialLR, CosineAnnealingLR

# vision imports

from torchvision import transforms as T
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image

# dalle classes

from dalle_pytorch import DiscreteVAE

# import osgmlg
# os.environ['CUDA_VISIBLE_DEVICES'] = '5'

In [None]:
# constants

IMAGE_SIZE = 224
IMAGE_PATH = '/disk/nvme2/report_images/2006-06-28_기업_미래에셋증권_황상연,신지원_LG석유화학(012990)_145795'

EPOCHS = 1000
BATCH_SIZE = 1
LEARNING_RATE = 5e-4
LR_DECAY_RATE = 0.99999

NUM_TOKENS = 8192
NUM_LAYERS = 3
NUM_RESNET_BLOCKS = 3
SMOOTH_L1_LOSS = False
EMB_DIM = 512
HID_DIM = 256
# KL_LOSS_WEIGHT = 6.6

STARTING_TEMP = 1.
TEMP_MIN = 1e-10
ANNEAL_RATE = 1e-3

NUM_IMAGES_SAVE = 1


In [None]:
# data

ds = ImageFolder(
    IMAGE_PATH,
    T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize(IMAGE_SIZE),
        T.CenterCrop(IMAGE_SIZE),
        T.ToTensor()
    ])
)


In [None]:
ds

In [None]:
dl = DataLoader(ds, BATCH_SIZE, shuffle = True)

In [None]:
vae_params = dict(
    image_size = IMAGE_SIZE,
    num_layers = NUM_LAYERS,
    num_tokens = NUM_TOKENS,
    codebook_dim = EMB_DIM,
    hidden_dim   = HID_DIM,
    num_resnet_blocks = NUM_RESNET_BLOCKS
)

In [None]:
vae_params

In [None]:
vae = DiscreteVAE(
    **vae_params,
    smooth_l1_loss = SMOOTH_L1_LOSS,
    # kl_div_loss_weight = KL_LOSS_WEIGHT
).cuda()

In [None]:
assert len(ds) > 0, 'folder does not contain any images'
print(f'{len(ds)} images found for training')

# def save_model(path):
#     save_obj = {
#         'hparams': vae_params,
#         'weights': vae.state_dict()
#     }

#     torch.save(save_obj, path)

# optimizer

opt = Adam(vae.parameters(), lr = LEARNING_RATE)
# opt = AdamW(vae.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-4)
sched = ExponentialLR(optimizer = opt, gamma = LR_DECAY_RATE)
# sched = CosineAnnealingLR(optimizer=opt,)

# weights & biases experiment tracking

import wandb

model_config = dict(
    num_tokens = NUM_TOKENS,
    smooth_l1_loss = SMOOTH_L1_LOSS,
    num_resnet_blocks = NUM_RESNET_BLOCKS,
    # kl_loss_weight = KL_LOSS_WEIGHT
)

run = wandb.init(
    project = 'dalle_train_vae_test',
    job_type = 'train_model',
    config = model_config
)

# starting temperature

global_step = 0
temp = STARTING_TEMP

for epoch in range(EPOCHS):
    for i, (images, _) in enumerate(dl):
        images = images.cuda()

        loss, recons = vae(
            images,
            return_loss = True,
            return_recons = True,
            temp = temp
        )

        opt.zero_grad()
        loss.backward()
        opt.step()

        logs = {}

        if i % 1 == 0:
            k = NUM_IMAGES_SAVE

            with torch.no_grad():
                codes = vae.get_codebook_indices(images[:k])
                hard_recons = vae.decode(codes)

            images, recons = map(lambda t: t[:k], (images, recons))
            images, recons, hard_recons, codes = map(lambda t: t.detach().cpu(), (images, recons, hard_recons, codes))
            images, recons, hard_recons = map(lambda t: make_grid(t.float(), nrow = int(sqrt(k)), normalize = True, range = (-1, 1)), (images, recons, hard_recons))

            logs = {
                **logs,
                'sample images':        wandb.Image(images, caption = 'original images'),
                'reconstructions':      wandb.Image(recons, caption = 'reconstructions'),
                'hard reconstructions': wandb.Image(hard_recons, caption = 'hard reconstructions'),
                'codebook_indices':     wandb.Histogram(codes),
                'temperature':          temp
            }

            # save_model(f'./vae.pt')
            # wandb.save('./vae.pt')

            # temperature anneal

            temp = max(temp * math.exp(-ANNEAL_RATE * global_step), TEMP_MIN)

            # lr decay

            sched.step()

        if i % 1 == 0:
            lr = sched.get_last_lr()[0]
            print(epoch, i, f'lr - {lr:6f} loss - {loss.item()}')

            logs = {
                **logs,
                'epoch': epoch,
                'iter': i,
                'loss': loss.item(),
                'lr': lr
            }

        wandb.log(logs)
        global_step += 1

    # save trained model to wandb as an artifact every epoch's end

    # model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
    # model_artifact.add_file('vae.pt')
    # run.log_artifact(model_artifact)

# save final vae and cleanup

# save_model('./vae-final.pt')
# wandb.save('./vae-final.pt')

# model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
# model_artifact.add_file('vae-final.pt')
# run.log_artifact(model_artifact)

wandb.finish()

In [None]:
for i in dl:
    print(i[0].shape)
    break

In [None]:
print(vae)

In [None]:
a = vae(i[0], return_logits=True)

In [None]:
a.shape

In [None]:
from math import log2, sqrt
import torch
from torch import nn, einsum
import torch.nn.functional as F
import numpy as np

from axial_positional_embedding import AxialPositionalEmbedding
from einops import rearrange

from dalle_pytorch import distributed_utils
from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE
from dalle_pytorch.transformer import Transformer, DivideMax

codebook = nn.Embedding(8192, 768)

In [None]:
enc_chans = [196] * 2

In [None]:
dec_chans = list(reversed(enc_chans))

In [None]:
enc_chans = [3, *enc_chans]

In [None]:
enc_chans

In [None]:
dec_init_chan = 768

In [None]:
dec_chans[0]

In [None]:
custom_vae = torch.load('/dalle/vae-final-ds-cp/global_step121623_reports_7M/mp_rank_00_model_states.pt')

In [None]:
custom_vae

In [None]:
custom_vae.keys()

In [None]:
custom_vae['hparams']

In [None]:
dvae = DiscreteVAE(num_layers=4, codebook_dim=768, hidden_dim=256, num_resnet_blocks=2, num_tokens=8192)

In [None]:
dvae.load_state_dict(custom_vae['module'])

In [None]:
dvae.get_codebook_indices()