# Latent Space Optimization Results Metrics

Compute and compare metrics for latent space optimization results.

## Setup

### Results Directory

In [None]:
from pathlib import Path

BASE_DIR = Path("../results").expanduser().resolve()

def get_result_dir(version : str, seed : int) -> Path:
    """
    Return the path to the first main.log that matches the seed.
    Allowed directory names:
      <version>_<seed>
      <version>_<seed>_<anything>
    """
    # exact match first (no job-id)
    exact = BASE_DIR / f"{version}_{seed}"
    if exact.is_dir():
        return exact

    # wildcard for any trailing underscore / job-id
    pattern = f"{version}_{seed}_*/"
    matches = sorted(BASE_DIR.glob(pattern))
    if not matches:
        raise FileNotFoundError(f"No log found for seed {seed} under {BASE_DIR}")
    return matches[-1]  # return the most recent match

### Image Dataset

In [None]:
from pathlib import Path

from torch.utils.data import Dataset
from PIL import Image

class ImgDataset(Dataset):
    """
    Loads all img images that live under the provided root directory.
    """

    def __init__(self, version, seed, subdir="img_opt", transform=None):
        """
        Args:
            version (str): Parent directory that contains data/samples.
            subdir (str): Subdirectory under data/samples to look for images.
            transform (Optional[Any]): Optional torchvision/Albumentations transform applied to the PIL image.
        """
        self.root = get_result_dir(version, seed=seed)
        self.files = sorted(
            self.root.glob(f"data/samples/iter_*/{subdir}/*.png")
        )
        self.transform = transform
        
    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = self.files[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform is not None:
            image = self.transform(image)

        return image

## Fréchet Inception Distance (FID)

In [None]:
import yaml

import numpy as np
import torch
from torchvision import transforms

from src.metrics.fid import FIDScore

def get_fid_score(version, seeds=[42, 43, 44]):
    """
    Compute the Fréchet Inception Distance (FID) score for a given version.
    Args:
        version (str): the version identifier for the model.
    Returns:
        float: The computed FID score.
        float: The standard deviation of the FID score across seeds.
    """

    scores = []
    for seed in seeds:
        # Get the result directory for the given version and seed
        result_dir = get_result_dir(version, seed)
        
        # Load hparams yaml
        hparams = yaml.safe_load(open(result_dir / "hparams.yaml", 'r'))

        # Derive min and max property range (that has not been seen during optimization)
        opt_min = int(hparams['max_property_value'])
        opt_max = 5

        # Derive image size
        img_size = 512 if version.startswith("ctrloralter") else 256

        # Load optimized images as dataset
        img_opt_dataset = ImgDataset(
            version=version,
            seed=seed,
            subdir="img_opt",
            transform=transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])
        )

        # Initialize FID instance
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        fid_instance = FIDScore(img_size=img_size, device=device, num_workers=0)
        
        # Load real statistics
        fid_instance.load_real_stats(f"../data/ffhq/inception_stats/size_{img_size}_smile_{opt_min}_{opt_max}.pt")

        # Compute FID score for the optimized images
        fid_score = fid_instance.compute_score_from_data(img_opt_dataset)

        scores.append(float(fid_score))

    return np.mean(scores), np.std(scores)

In [None]:
get_fid_score(version="ex1_sd35f_dngo_train_trustconstr", seeds=[42, 43, 44])

## Perceptual Quality (LPIPS)

In [None]:
import yaml

import torch
from torchvision import transforms

from taming.modules.losses.lpips import LPIPS

# Initialize LPIPS instance
lpips = LPIPS().eval()

def get_lpips_score(version, seeds=[42, 43, 44]):
	"""
	Compute the Learned Perceptual Image Patch Similarity (LPIPS) score for a given version.
	Args:
		version (str): the version identifier for the model.
	Returns:
		float: LPIPS score.
		float: Standard deviation of LPIPS score.
	"""
	# Move LPIPS to the appropriate device
	device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	lpips.to(device)
	
	scores = []
	for seed in seeds:
		# Get the result directory for the given version and seed
		result_dir = get_result_dir(version, seed)

		# Load hparams yaml
		hparams = yaml.safe_load(open(result_dir / "hparams.yaml", 'r'))

		# Derive image size
		img_size = 512 if version.startswith("ctrloralter") else 256

		# Load optimized images as dataset
		img_opt_dataset = ImgDataset(
			version=version,
			seed=seed,
			subdir="img_opt",
			transform=transforms.Compose([
				transforms.Resize((img_size, img_size)),
				transforms.ToTensor(),
				transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
			])
		)

		# Load original images as dataset
		img_orig_dataset = ImgDataset(
			version=version,
			seed=seed,
			subdir="img_orig",
			transform=transforms.Compose([
				transforms.Resize((img_size, img_size)),
				transforms.ToTensor(),
				transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
			])
		)

		# Convert datasets to tensors
		img_opt_dataset = torch.stack([img for img in img_opt_dataset], dim=0).to(device)
		img_orig_dataset = torch.stack([img for img in img_orig_dataset], dim=0).to(device)

		# Compute LPIPS score for the optimized images
		lpips_score = lpips(img_opt_dataset, img_orig_dataset).mean().cpu().item()
		scores.append(float(lpips_score))

	return np.mean(scores), np.std(scores)

In [None]:
get_lpips_score(version="ex1_sd35f_dngo_train_trustconstr", seeds=[42, 43, 44])

## TopK

In [None]:
import numpy as np

def get_top_k(k, version, seeds=[42, 43, 44], iteration_max=500):
    """
    Compute the Top-K smile score (mean ± std over seeds).
    Args:
        k (int) : K in “Top-K”.
        version (str) : Model/version identifier.
        seeds (iterable[int]) : Random seeds to aggregate over.
        iteration_max (int) : Max evaluations to consider.
    Returns:
        tuple(float, float): (mean_topk, std_topk) across the given seeds.
    """
    topk_vals = []

    for seed in seeds:
        # Load the result file
        result_file = get_result_dir(version, seed) / "results.npz"
        results = np.load(result_file, allow_pickle=True)

        # Smile scores from iteration 1 to iteration_max
        scores = results["opt_point_properties"][:iteration_max]

        # Sort descending and pick the K-th best (account for short runs)
        k_idx = min(k, len(scores)) - 1
        topk = np.sort(scores)[::-1][k_idx]
        topk_vals.append(topk)

    mean_topk = float(np.mean(topk_vals))
    std_topk  = float(np.std(topk_vals))

    return mean_topk, std_topk

In [None]:
get_top_k(k=10, version="ex1_sd35_gbo_train_pca", seeds=[42, 43, 44], iteration_max=100)

## Mean Smile Score

In [None]:
import numpy as np

def get_smile_score(version, seeds=[42, 43, 44], iteration_max=500):
    """
    Compute the smile score mean and std for a given version and seeds.
    Args:
        version (str): the version identifier for the model.
        seeds (list): list of seeds to quantify variability.
        iteration_max (int): Maximum number of iterations to consider for the smile score.
    Returns:
        float: Mean smile score
        float: Std smile score
    """

    # Load the results for the specified version and seeds
    scores = []
    for seed in seeds:
        # Load results dictionary
        result_file = get_result_dir(version, seed) / "results.npz"
        results = np.load(result_file, allow_pickle=True)
        
		# Get smile scores
        opt_point_properties = results['opt_point_properties']
        opt_point_properties = opt_point_properties[:iteration_max]  # Limit to the first `iteration_max` iterations
        scores.append(opt_point_properties.mean(axis=0))

    # Compute mean and std
    mean_score = np.mean(scores)
    std_score = np.std(scores)

    return mean_score, std_score

In [None]:
get_smile_score(version="ex1_sd35f_dngo_train_trustconstr", seeds=[42, 43, 44], iteration_max=100)

## Runtime

In [None]:
import re
from pathlib import Path
import numpy as np

def get_log_time(version, operation, seeds=[42, 43, 44]):
	"""
	Compute per-run mean time, then the grand mean ± std across runs.
	"""
	if operation == "train":
		line_re = re.compile(r"\b\w+\s+train done in ([\d.]+)s")
	elif operation == "opt":
		line_re = re.compile(r"\b\w+\s+opt done in ([\d.]+)s")
	else:
		raise ValueError(f"Unknown operation type: {operation}")

	mean_times = []
	for seed in seeds:
		log_path = get_result_dir(version, seed) / "main.log"

		with open(log_path, 'r') as f:
			lines = f.readlines()

		times = []
		for line in lines:
			match = line_re.search(line)
			if match:
				times.append(float(match.group(1)))

		if not times:
			raise ValueError(f"No time found in log for seed {seed}")

		mean_times.append(np.mean(times))

	# Compute mean and std across all seeds
	mean_time = np.mean(mean_times)
	std_time = np.std(mean_times)

	return mean_time, std_time

In [None]:
get_log_time(version="ex1_sd35f_dngo_train_lbfgsb", operation="opt", seeds=[42, 43, 44])