In [29]:
import os
import torch
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torchvision import transforms, datasets
import distributed as dist

from vqvae import VQVAE

In [30]:
device = 'cpu'
size = 360
checkpoint = 'vqvae_003.pt'
ckpt = torch.load(os.path.join('checkpoint', checkpoint))

In [31]:
vqvae = VQVAE()
vqvae.load_state_dict(ckpt)
vqvae = vqvae.to(device)
vqvae.eval();

In [32]:
transform = transforms.Compose(
    [
        transforms.Resize(size),
        transforms.CenterCrop(size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)

dataset = datasets.ImageFolder('images', transform=transform)
sampler = dist.data_sampler(dataset, shuffle=False, distributed=False)
loader = DataLoader(dataset, batch_size=5, sampler=sampler, num_workers=1)

In [33]:
for i, (inputs, label) in enumerate(loader):
    with torch.no_grad():
        quant_t, quant_b, _, _, _ = vqvae.encode(inputs)
        outputs = vqvae.decode(quant_t, quant_b)
        outputs = outputs.clamp(-1,1)
        save_image(torch.cat((inputs, outputs), dim=0), f'outputs/batch_{i}.jpg', normalize=True, range=(-1,1), nrow=inputs.shape[0])



In [34]:
inputs.shape

torch.Size([5, 3, 360, 360])

In [35]:
outputs.shape

torch.Size([5, 3, 360, 360])