In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from torchvision import datasets, transforms, utils
import torchvision

from pt_utils import  Embeddings, Trainer, VQVAE, data_sampler, Vqvae2AdaptiveVae, VanillaVAE
from torchsummary import summary
import os
from torch import distributed as dist
from tqdm.notebook import trange, tqdm
from torchvision.datasets import ImageFolder
import numpy as np
from typing import Any, Callable, cast, Dict, List, Optional, Tuple

In [2]:
# seed = 51
# np.random.seed(seed)
# torch.manual_seed(seed)

device = "cuda"
# device = "cpu"

model =    VQVAE(
                in_channel=3,
                channel=128,
                n_res_block=2,
                n_res_channel=32,
                embed_dim=2,
                n_embed=512,
                decay=0.99
                ).to(device)

In [4]:
# summary(model, input_size=(3, 256, 256))

In [3]:
# dataset_path = '../data/dataset_512/'
# dataset_path = '../datasets/bc_right_sub_left_minmax_4x_360'
# dataset_path = '../datasets/bc_left_4x_360'
# dataset_path = '../datasets/original/o_bc_left_9x_512_360'
dataset_path = '../../datasets/original/o_bc_left_4x_768_360'

resize_shape = (256, 256)
# resize_shape = (512, 512)
# resize_shape = (1024, 1024)

n_gpu = 1
batch_size = 4
val_split = 0.15

transform = transforms.Compose(
    [
        # transforms.Resize(resize_shape),
        transforms.CenterCrop(resize_shape),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)

dataset = datasets.ImageFolder(dataset_path, transform=transform)

train_dataset_len = int(len(dataset) * (1 - val_split))
test_dataset_len = len(dataset) - train_dataset_len

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_dataset_len, test_dataset_len],
                                                            # generator=torch.Generator().manual_seed(seed)
                                                           )

train_sampler = data_sampler(train_dataset, shuffle=True, distributed=False)
test_sampler = data_sampler(test_dataset, shuffle=True, distributed=False)

train_loader = DataLoader(
    train_dataset, batch_size=batch_size // n_gpu, sampler=train_sampler, num_workers=2
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size // n_gpu, sampler=test_sampler, num_workers=2
)

In [None]:
epochs = 100
lr = 1e-4

latent_loss_weight = 0.25
sample_size = 25

optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-7, amsgrad=True)
# optimizer = optim.RMSprop(model.parameters(), lr=lr,weight_decay=1e-6,centered=True)

# sample_path = '../data/logs/vq-vae-2/4x/samples'
# model_path = '../data/logs/vq-vae-2/4x/weights'

# model_path = 'data/logs/emb_dim_1_n_embed_8192_bc_right_sub_left_minmax_4x_360'
model_path = 'runs/emb_dim_2_n_embed_512_bc_left_4x_768_360'

Trainer.train(model=model, optimizer=optimizer, train_loader=train_loader, test_loader=test_loader,
              model_path=model_path, epochs=epochs, device=device)


  0%|          | 0/1530 [00:00<?, ?it/s]

test elbo: 0.02121


  0%|          | 0/1530 [00:00<?, ?it/s]

test elbo: 0.02047


  0%|          | 0/1530 [00:00<?, ?it/s]