## Evaluate VQGAN

### Set up

In [None]:
import sys
sys.path.append('../src')

In [None]:
import json

import torch
import lightning.pytorch as pl
from torchvision import transforms
from torch.utils.data import DataLoader

from vqgan.model import VQModel
from data import VQVisualNewsDataset

In [None]:
import io

import torchvision
import numpy as np
import matplotlib.pyplot as plt
import requests

from PIL import Image


def download_image(url):
    resp = requests.get(url)
    resp.raise_for_status()
    return Image.open(io.BytesIO(resp.content))


def imshow(img, title=None):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    img = np.transpose(npimg, (1, 2, 0))
    plt.imshow(img)
    if title is not None:
        plt.title(title)
    plt.axis('off')

def show_random_batch(data_loader):
    # Get a random batch
    images = next(iter(data_loader))

    imshow(torchvision.utils.make_grid(images))
    plt.show()


def visualize_model_batch(model, batch):
    # disable grads + batchnorm + dropout
    torch.set_grad_enabled(False)
    model.eval()
    
    # Encoded image tokens
    quant_states, loss, info = model.encode(batch)
    
    # Decode image tokens, i.e. reconstruct image from image tokens
    rec = model.decode(quant_states)
    
    # Display
    imshow(torchvision.utils.make_grid(batch), 'Original')
    plt.show()
    imshow(torchvision.utils.make_grid(rec.detach()), 'Reconstructed')
    plt.show()
    
    # enable grads + batchnorm + dropout
    torch.set_grad_enabled(True)
    model.train()
    return rec.detach()


def visualize_model(model, data_loader):
    # Get a random batch
    images = next(iter(data_loader))
    
    visualize_model_batch(model, images)

In [None]:
with open('../src/hparams_vqgan.json', 'r') as f:
        hparams = json.load(f)

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [None]:
test_set = VQVisualNewsDataset('../src/data/visual_news_mini', 'test', transform)
test_loader = DataLoader(test_set,
                         batch_size=4,
                         shuffle=False,
                         num_workers=0,
                         pin_memory=True)

In [None]:
test_iter = iter(test_loader)
batch1 = next(test_iter)
batch2 = next(test_iter)

### Pre-trained VQGAN Model

Pre-trained checkpoint trained on ImageNet for 12 epochs (~30k steps)

In [None]:
# NOTE: If using the downloaded checkpoint, process it before using it
! python ../src/process_pretrained_vqgan.py --pretrained_vqgan ../src/pretrained/vqgan.ckpt
! python -m lightning.pytorch.utilities.upgrade_checkpoint ../src/pretrained/vqgan.ckpt

In [None]:
model = VQModel.load_from_checkpoint('../src/pretrained/vqgan.ckpt', **hparams)
model.init_lpips_from_pretrained('../src/pretrained/vgg.pth')

In [None]:
rec_ori_b1 = visualize_model_batch(model, batch1)
rec_ori_b2 = visualize_model_batch(model, batch2)

### Load VQGAN Experiment 1

Original model fine-tuned on news dataset for 1 epoch

In [None]:
model = VQModel.load_from_checkpoint('../src/pretrained/exp1.ckpt', **hparams)
model.init_lpips_from_pretrained('../src/pretrained/vgg.pth')

In [None]:
rec_exp1_b1 = visualize_model_batch(model, batch1)
rec_exp1_b2 = visualize_model_batch(model, batch2)

### Load VQGAN Experiment 2

Original model fine-tuned on news dataset for 3 epochs

In [None]:
model = VQModel.load_from_checkpoint('../src/pretrained/exp2.ckpt', **hparams)
model.init_lpips_from_pretrained('../src/pretrained/vgg.pth')

In [None]:
rec_exp2_b1 = visualize_model_batch(model, batch1)
rec_exp2_b2 = visualize_model_batch(model, batch2)

### Evaluate

- Peak Signal-to-Noise Ratio (PSNR)
- Structural Similarity Index (SSIM)
- Learned Perceptual Image Patch Similarity (LPIPS)

In [None]:
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
import torchvision.transforms.functional as TF

In [None]:
psnr = PeakSignalNoiseRatio()
ssim = StructuralSimilarityIndexMeasure(data_range=1.0)
lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')

In [None]:
psnr_b1 = psnr(rec_ori_b1, batch1)
psnr_b2 = psnr(rec_ori_b2, batch2)

print('Higher PSNR is better')

print('\nPre-trained model PSNR:')
print('  Batch 1:', psnr_b1)
print('  Batch 2:', psnr_b2)

psnr_b1 = psnr(rec_exp1_b1, batch1)
psnr_b2 = psnr(rec_exp1_b2, batch2)

print('\nExperiment 1 model PSNR:')
print('  Batch 1:', psnr_b1)
print('  Batch 2:', psnr_b2)

psnr_b1 = psnr(rec_exp2_b1, batch1)
psnr_b2 = psnr(rec_exp2_b2, batch2)

print('\nExperiment 2 model PSNR:')
print('  Batch 1:', psnr_b1)
print('  Batch 2:', psnr_b2)

In [None]:
ssim_b1 = ssim(rec_ori_b1, batch1)
ssim_b2 = ssim(rec_ori_b2, batch2)

print('Higher SSIM is better')

print('\nPre-trained model SSIM:')
print('  Batch 1:', ssim_b1)
print('  Batch 2:', ssim_b2)

ssim_b1 = ssim(rec_exp1_b1, batch1)
ssim_b2 = ssim(rec_exp1_b2, batch2)

print('\nExperiment 1 model SSIM:')
print('  Batch 1:', ssim_b1)
print('  Batch 2:', ssim_b2)

ssim_b1 = ssim(rec_exp1_b1, batch1)
ssim_b2 = ssim(rec_exp1_b2, batch2)

print('\nExperiment 2 model SSIM:')
print('  Batch 1:', ssim_b1)
print('  Batch 2:', ssim_b2)

In [None]:
def normalize_batch(batch):
    batch = batch / 2 + 0.5  # unnormalize
    batch = (batch / 255.0) * 2 - 1
    return torch.clamp(batch, min=-1.0, max=1.0)

lpips_b1 = lpips(normalize_batch(rec_ori_b1), normalize_batch(batch1))
lpips_b2 = lpips(normalize_batch(rec_ori_b2), normalize_batch(batch2))

print('Lower LPIPS is better')

print('\nPre-trained model LPIPS:')
print('  Batch 1:', lpips_b1.item())
print('  Batch 2:', lpips_b2.item())

lpips_b1 = lpips(normalize_batch(rec_exp1_b1), normalize_batch(batch1))
lpips_b2 = lpips(normalize_batch(rec_exp1_b2), normalize_batch(batch2))

print('\nExperiment 1 model LPIPS:')
print('  Batch 1:', lpips_b1.item())
print('  Batch 2:', lpips_b2.item())

lpips_b1 = lpips(normalize_batch(rec_exp2_b1), normalize_batch(batch1))
lpips_b2 = lpips(normalize_batch(rec_exp2_b2), normalize_batch(batch2))

print('\nExperiment 2 model LPIPS:')
print('  Batch 1:', lpips_b1.item())
print('  Batch 2:', lpips_b2.item())