# Latent Space Optimization Results Metrics

Compute and compare metrics for latent space optimization results.

## Setup

In [None]:
from pathlib import Path

from torch.utils.data import Dataset
from PIL import Image

class ImgDataset(Dataset):
    """
    Loads all img images that live under the provided root directory.
    """

    def __init__(self, version, subdir="img_opt", transform=None):
        """
        Args:
            version (str): Parent directory that contains data/samples.
            subdir (str): Subdirectory under data/samples to look for images.
            transform (Optional[Any]): Optional torchvision/Albumentations transform applied to the PIL image.
        """
        self.root = Path(f"../results/{version}").expanduser().resolve()
        self.files = sorted(
            self.root.glob(f"data/samples/iter_*/{subdir}/*.png")
        )
        self.transform = transform
        
    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = self.files[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform is not None:
            image = self.transform(image)

        return image

## Fréchet Inception Distance (FID)

In [None]:
import yaml

import torch
from torchvision import transforms

from src.metrics.fid import FIDScore

def get_fid_score(version):
    """
    Compute the Fréchet Inception Distance (FID) score for a given version and iterations.
    Args:
        version (str): the version identifier for the model.
    Returns:
        int: FID score.
    """
    
    # Load hparams yaml
    hparams = yaml.safe_load(open(f"../results/{version}/hparams.yaml", 'r'))
    
    # Derive min and max property range (that has not been seen during optimization)
    opt_min = int(hparams['max_property_value'])
    opt_max = 5

    # Derive image size
    img_size = 512 if version.startswith("ctrloralter") else 256

    # Load optimized images as dataset
    img_opt_dataset = ImgDataset(
        version=version,
        subdir="img_opt",
        transform=transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
        ])
    )

    # Initialize FID instance
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    fid_instance = FIDScore(img_size=img_size, device=device, num_workers=0)
    
    # Load real statistics
    fid_instance.load_real_stats(f"../data/ffhq/inception_stats/size_{img_size}_smile_{opt_min}_{opt_max}.pt")

    # Compute FID score for the optimized images
    fid_score = fid_instance.compute_score_from_data(img_opt_dataset)

    return fid_score

In [None]:
get_fid_score(version="ctrloralter_gbo_23")

## Perceptual Quality (LPIPS)

In [None]:
from taming.modules.losses.lpips import LPIPS

# Initialize LPIPS instance
lpips = LPIPS().eval()

def get_lpips_score(version):
	"""
	Compute the Learned Perceptual Image Patch Similarity (LPIPS) score for a given version.
	Args:
		version (str): the version identifier for the model.
	Returns:
		float: LPIPS score.
	"""
	# Move LPIPS to the appropriate device
	device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	lpips.to(device)
	
	# Load hparams yaml
	hparams = yaml.safe_load(open(f"../results/{version}/hparams.yaml", 'r'))

	# Derive image size
	img_size = 512 if version.startswith("ctrloralter") else 256

	# Load optimized images as dataset
	img_opt_dataset = ImgDataset(
		version=version,
		subdir="img_opt",
		transform=transforms.Compose([
			transforms.Resize((img_size, img_size)),
			transforms.ToTensor(),
		])
	)

	# Load original images as dataset
	img_orig_dataset = ImgDataset(
		version=version,
		subdir="img_orig",
		transform=transforms.Compose([
			transforms.Resize((img_size, img_size)),
			transforms.ToTensor(),
		])
	)

	# Convert datasets to tensors
	img_opt_dataset = torch.stack([img for img in img_opt_dataset], dim=0).to(device)
	img_orig_dataset = torch.stack([img for img in img_orig_dataset], dim=0).to(device)

	# Compute LPIPS score for the optimized images
	lpips_score = lpips(img_opt_dataset, img_orig_dataset).mean().cpu().item()

	return lpips_score

In [None]:
get_lpips_score(version="ctrloralter_gbo_23")