In [2]:
%load_ext autoreload
%autoreload 2

import os

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from skimage.metrics import mean_squared_error, structural_similarity as ssim
import torch
import lpips

IMG_DIR = "generated_images_21_ad4"

# Load the images into a numpy array; let's assume they are png
img_paths = [os.path.join(IMG_DIR, f) for f in os.listdir(IMG_DIR) if f.endswith('.png')]
imgs = np.array([np.array(Image.open(p)) for p in img_paths])

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
# Function to compute distances between consecutive images, i.e. image 0->1, 1->2, etc.
# We accept a kwarg specifying the metric to use.
# Acceptable choices are "mse", "lpips", and "ssim".
def compute_distances(images, metric='mse'):
    if metric == 'mse':
        return np.array([mean_squared_error(images[i], images[i + 1]) for i in range(len(images) - 1)])
    elif metric == 'lpips':
        loss_fn = lpips.LPIPS(net='alex')
        images_tensor = torch.tensor(images).permute(0, 3, 1, 2).float() / 255.0  # Convert to [batch, channels, height, width]
        distances = []
        for i in range(len(images) - 1):
            img1 = images_tensor[i].unsqueeze(0)  # Add batch dimension
            img2 = images_tensor[i + 1].unsqueeze(0)  # Add batch dimension
            distance = loss_fn(img1, img2).item()  # Compute LPIPS distance
            distances.append(distance)
        return np.array(distances)
    elif metric == 'ssim':
        return np.array([ssim(images[i], images[i + 1], multichannel=True) for i in range(len(images) - 1)])
    else:
        raise ValueError("Unsupported metric: {}".format(metric))

In [7]:
distances_mse = compute_distances(imgs, metric='mse')
#distances_ssim = compute_distances(imgs, metric='ssim')
distances_lpips = compute_distances(imgs, metric='lpips')

print(f"Mean MSE distance: {distances_mse.mean()}")
#print(f"Mean SSIM distance: {distances_ssim.mean()}")
print(f"Mean LPIPS distance: {distances_lpips.mean()}")

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]


Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /home/mmattb/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100%|██████████| 233M/233M [00:21<00:00, 11.3MB/s] 


Loading model from: /home/mmattb/anaconda3/lib/python3.12/site-packages/lpips/weights/v0.1/alex.pth


RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [1, 197, 3, 768, 768]

In [None]:
print(distances_mse)
print(distances_lpips)

plt.plot(distances_lpips, label="lpips")
plt.plot(distances_lpips, label="mse")
plt.legend()
plt.show()