## VQGAN

### Download checkpoints
- Download pre-trained VQGAN checkpoint from https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/
- Download pre-trained LPIPS checkpoint from https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/

### Load model

In [None]:
import sys
sys.path.append('../src')

In [None]:
import json

import torch
import lightning.pytorch as pl
from torchvision import transforms
from torch.utils.data import DataLoader

from vqgan.model import VQModel
from data import VQVisualNewsDataset

In [None]:
with open('../src/hparams_vqgan.json', 'r') as f:
        hparams = json.load(f)

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [None]:
train_set = VQVisualNewsDataset('../src/data/visual_news_mini', 'train', transform)
train_loader = DataLoader(train_set,
                         batch_size=2,
                         shuffle=False,
                         num_workers=0,
                         pin_memory=True)

In [None]:
# NOTE: If using the downloaded checkpoint, process it before using it
! python ../src/process_pretrained_vqgan.py --pretrained_vqgan ../src/pretrained/vqgan.ckpt
! python -m lightning.pytorch.utilities.upgrade_checkpoint ../src/pretrained/vqgan.ckpt

In [None]:
model = VQModel.load_from_checkpoint('../src/pretrained/vqgan.ckpt', **hparams)
model.init_lpips_from_pretrained('../src/pretrained/vgg.pth')

In [None]:
# Train for 1 epoch
trainer = pl.Trainer(accelerator='cpu', max_epochs=1)
trainer.fit(model, ckpt_path='../src/pretrained/vqgan.ckpt', train_dataloaders=train_loader)

In [None]:
import io

import torchvision
import numpy as np
import matplotlib.pyplot as plt
import requests

from PIL import Image


def download_image(url):
    resp = requests.get(url)
    resp.raise_for_status()
    return Image.open(io.BytesIO(resp.content))


def imshow(img, title=None):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    img = np.transpose(npimg, (1, 2, 0))
    plt.imshow(img)
    if title is not None:
        plt.title(title)
    plt.axis('off')

    
def show_random_batch(data_loader):
    # Get a random batch
    images = next(iter(data_loader))

    imshow(torchvision.utils.make_grid(images))
    plt.show()


def visualize_model_batch(batch):
    # disable grads + batchnorm + dropout
    torch.set_grad_enabled(False)
    model.eval()
    
    # Encoded image tokens
    quant_states, loss, info = model.encode(batch)
    
    # Decode image tokens, i.e. reconstruct image from image tokens
    rec = model.decode(quant_states)
    
    # Display
    imshow(torchvision.utils.make_grid(batch), 'Original')
    plt.show()
    imshow(torchvision.utils.make_grid(rec.detach()), 'Reconstructed')
    plt.show()
    
    # enable grads + batchnorm + dropout
    torch.set_grad_enabled(True)
    model.train()


def visualize_model(data_loader):
    # Get a random batch
    images = next(iter(data_loader))
    
    visualize_model_batch(images)

In [None]:
show_random_batch(train_loader)

In [None]:
visualize_model(train_loader)

In [None]:
# Download an image, encode it, and then reconstruct

# Load the image
sample_img = download_image('https://heibox.uni-heidelberg.de/f/7bb608381aae4539ba7a/?dl=1')

# Preprocess the image using the transformation pipeline
sample_tensor = transform(sample_img)

# Add a batch dimension to the tensor
sample_batch = sample_tensor.unsqueeze(0)

visualize_model_batch(sample_batch)