In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from torchvision import datasets, transforms, utils

from pt_utils import  Embeddings, Trainer, VQVAE, data_sampler

import numpy as np

In [2]:
seed = 51
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x15df0c24410>

In [7]:
device = "cuda"
model =    VQVAE(in_channel=3,
           channel=128,
           n_res_block=6,
           n_res_channel=32,
           embed_dim=2,
           n_embed=8192,
           decay=0.99).to(device)


In [None]:
# summary(model, input_size=(3, 512, 512))

In [8]:
# dataset_path = '../data/dataset_512/'
# dataset_path = '../datasets/bc_right_sub_left_minmax_4x_360'
# dataset_path = '../datasets/bc_left_4x_360'
# dataset_path = '../datasets/original/o_bc_left_9x_512_360'
dataset_path = '../datasets/original/o_bc_left_9x_512_360'

resize_shape = (512, 512)
# resize_shape = (1024, 1024)

n_gpu = 1
batch_size = 4
val_split = 0.15

transform = transforms.Compose(
    [
        # transforms.Resize(resize_shape),
        # transforms.CenterCrop(resize_shape),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)

dataset = datasets.ImageFolder(dataset_path, transform=transform)

train_dataset_len = int(len(dataset) * (1 - val_split))
test_dataset_len = len(dataset) - train_dataset_len

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_dataset_len, test_dataset_len],
                                                            generator=torch.Generator().manual_seed(seed))

train_sampler = data_sampler(train_dataset, shuffle=True, distributed=False)
test_sampler = data_sampler(test_dataset, shuffle=True, distributed=False)

train_loader = DataLoader(
    train_dataset, batch_size=batch_size // n_gpu, sampler=train_sampler, num_workers=2
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size // n_gpu, sampler=test_sampler, num_workers=2
)

In [None]:
# model_file = 'data/logs/emb_dim_1_n_embed_8192_bc_left_sub_right_minmax_4x_360/vqvae_001_train_0.04914_test_0.04206.pt'
model_file = 'data/logs/emb_dim_1_n_embed_8192_bc_right_sub_left_minmax_4x_360/vqvae_003_train_0.04287_test_0.04129.pt'

model.load_state_dict(torch.load(model_file, map_location=torch.device('cuda')))

In [9]:
epochs = 100
lr = 1e-4

latent_loss_weight = 0.25
sample_size = 25

optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-7, amsgrad=True)
# optimizer = optim.RMSprop(model.parameters(), lr=lr,weight_decay=1e-6,centered=True)

# sample_path = '../data/logs/vq-vae-2/4x/samples'
# model_path = '../data/logs/vq-vae-2/4x/weights'

# model_path = 'data/logs/emb_dim_1_n_embed_8192_bc_right_sub_left_minmax_4x_360'
model_path = 'data/logs/emb_dim_2_n_embed_8192_bc_left_9x_512_360'

Trainer.train(model=model, optimizer=optimizer, train_loader=train_loader, test_loader=test_loader,
              model_path=model_path, epochs=epochs)


epoch: 1; loss: 0.02124; mse: 0.02843; latent: 0.000; avg mse: 0.02114; lr: 0.00010: 100%|██████████| 3443/3443 [07:18<00:00,  7.86it/s]


test elbo: 0.01857


epoch: 2; loss: 0.01791; mse: 0.01617; latent: 0.000; avg mse: 0.01787; lr: 0.00010: 100%|██████████| 3443/3443 [07:16<00:00,  7.88it/s]


test elbo: 0.01756


epoch: 3; loss: 0.01756; mse: 0.01813; latent: 0.000; avg mse: 0.01754; lr: 0.00010:   2%|▏         | 79/3443 [00:15<10:46,  5.20it/s] 


KeyboardInterrupt: 

In [None]:
# sample_path = '../data/logs/vq-vae-2/samples'

with torch.no_grad():
    for j, (img, label) in enumerate(test_loader):
        img = img.to(device)
        out, latent_loss = model(img)

        sample_size = 25
    sample = img[:sample_size]

In [None]:
sample = img[:sample_size]

utils.save_image(
    torch.cat([sample, out], 0),
    f"_vq_vae_2_test.png",
    nrow=sample_size,
    normalize=True,
    # range=(-1, 1),
)