## Evaluate VQGAN


In [None]:
import sys

sys.path.append("../src")

In [None]:
import torch
import lightning.pytorch as pl
from torchvision import transforms
from torch.utils.data import DataLoader

from vqgan.model import VQModel
from dataset import HMDatasetImages

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torchvision


def imshow(img, title=None):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    img = np.transpose(npimg, (1, 2, 0))
    plt.imshow(img)
    if title is not None:
        plt.title(title)
    plt.axis("off")


def visualize_model_batch(model, batch):
    # disable grads + batchnorm + dropout
    torch.set_grad_enabled(False)
    model.eval()
    # Encoded image tokens
    quant_states, loss, info = model.encode(batch)
    # Decode image tokens, i.e. reconstruct image from image tokens
    rec = model.decode(quant_states)
    # Display
    imshow(torchvision.utils.make_grid(batch), "Original")
    plt.show()
    imshow(torchvision.utils.make_grid(rec.detach()), "Reconstructed")
    plt.show()
    # enable grads + batchnorm + dropout
    torch.set_grad_enabled(True)
    model.train()
    return rec.detach()

In [None]:
transform = transforms.Compose(
    [
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ]
)

In [None]:
test_set = HMDatasetImages("./data", "test", transform)
test_loader = DataLoader(
    test_set, batch_size=4, shuffle=False, num_workers=0, pin_memory=True
)

In [None]:
test_iter = iter(test_loader)
batch1 = next(test_iter)
batch2 = next(test_iter)

In [None]:
hparams = {
    "n_embed": 16384,
    "embed_dim": 256,
    "learning_rate": 4.5e-06,
    "ddconfig": {
        "double_z": False,
        "z_channels": 256,
        "resolution": 256,
        "in_channels": 3,
        "out_ch": 3,
        "ch": 128,
        "ch_mult": [1, 1, 2, 2, 4],
        "num_res_blocks": 2,
        "attn_resolutions": [16],
        "dropout": 0.0,
    },
    "lossconfig": {
        "disc_conditional": False,
        "disc_in_channels": 3,
        "disc_start": 0,
        "disc_weight": 0.75,
        "disc_num_layers": 2,
        "codebook_weight": 1.0,
    },
}

### Pre-trained VQGAN Model

Pre-trained checkpoint trained on ImageNet for 12 epochs (~30k steps)


In [None]:
model = VQModel.load_from_checkpoint("./pretrained/vqgan.ckpt", **hparams)
model.init_lpips_from_pretrained("./pretrained/vgg.pth")
model = model.to("cpu")

In [None]:
rec_ori_b1 = visualize_model_batch(model, batch1)
rec_ori_b2 = visualize_model_batch(model, batch2)

### Fine-tuned VQGAN


In [None]:
model = VQModel.load_from_checkpoint("./checkpoints/last.ckpt", **hparams)
model.init_lpips_from_pretrained("./pretrained/vgg.pth")
model = model.to("cpu")

In [None]:
rec_tuned_b1 = visualize_model_batch(model, batch1)
rec_tuned_b2 = visualize_model_batch(model, batch2)

### Evaluate

- Peak Signal-to-Noise Ratio (PSNR)
- Structural Similarity Index (SSIM)
- Learned Perceptual Image Patch Similarity (LPIPS)


In [None]:
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
import torchvision.transforms.functional as TF

In [None]:
psnr = PeakSignalNoiseRatio()
ssim = StructuralSimilarityIndexMeasure(data_range=1.0)
lpips = LearnedPerceptualImagePatchSimilarity(net_type="vgg")

In [None]:
psnr_b1 = psnr(rec_ori_b1, batch1)
psnr_b2 = psnr(rec_ori_b2, batch2)

print("Higher PSNR is better")

print("\nPre-trained model PSNR:")
print("  Batch 1:", psnr_b1)
print("  Batch 2:", psnr_b2)

psnr_b1 = psnr(rec_tuned_b1, batch1)
psnr_b2 = psnr(rec_tuned_b2, batch2)

print("\nFine-tuned model PSNR:")
print("  Batch 1:", psnr_b1)
print("  Batch 2:", psnr_b2)

In [None]:
ssim_b1 = ssim(rec_ori_b1, batch1)
ssim_b2 = ssim(rec_ori_b2, batch2)

print("Higher SSIM is better")

print("\nPre-trained model SSIM:")
print("  Batch 1:", ssim_b1)
print("  Batch 2:", ssim_b2)

ssim_b1 = ssim(rec_tuned_b1, batch1)
ssim_b2 = ssim(rec_tuned_b2, batch2)

print("\nFine-tuned model SSIM:")
print("  Batch 1:", ssim_b1)
print("  Batch 2:", ssim_b2)

In [None]:
def normalize_batch(batch):
    batch = batch / 2 + 0.5  # unnormalize
    batch = (batch / 255.0) * 2 - 1
    return torch.clamp(batch, min=-1.0, max=1.0)


lpips_b1 = lpips(normalize_batch(rec_ori_b1), normalize_batch(batch1))
lpips_b2 = lpips(normalize_batch(rec_ori_b2), normalize_batch(batch2))

print("Lower LPIPS is better")

print("\nPre-trained model LPIPS:")
print("  Batch 1:", lpips_b1.item())
print("  Batch 2:", lpips_b2.item())

lpips_b1 = lpips(normalize_batch(rec_tuned_b1), normalize_batch(batch1))
lpips_b2 = lpips(normalize_batch(rec_tuned_b2), normalize_batch(batch2))

print("\nFine-tuned model LPIPS:")
print("  Batch 1:", lpips_b1.item())
print("  Batch 2:", lpips_b2.item())