In [None]:
%pip install ..[image] --extra-index-url "https://download.pytorch.org/whl/cpu/torch_stable.html"

# Data Preparation

In [None]:
import random
from pathlib import Path

import cv2 as cv
import numpy as np
from torchvision.datasets import Imagenette
from torchvision.transforms import functional as F

try:
    dataset = Imagenette(root="data/", download=True, split="val", size="320px")
except RuntimeError:
    print("Dataset already downloaded")
    dataset = Imagenette(root="data/", download=False, split="val", size="320px")

random.seed(10)
dataset = random.sample([x for x, _ in dataset], k=5)


def generate_distortions_dataset(dataset):
    distorted = []
    for img in dataset:
        distorted.append(img)
        distorted.append(F.adjust_brightness(img, 2))
        distorted.append(F.adjust_gamma(img, 2))
        distorted.append(F.adjust_saturation(img, 0.1))
        distorted.append(F.gaussian_blur(img, 21))
    distorted = [np.asarray(img) for img in distorted]
    return distorted


images = generate_distortions_dataset(dataset)

In [None]:
import matplotlib.pyplot as plt


def plot_instances_score(images: list[np.ndarray], metric: str, scores: list[float], n_cols: int = 5):
    n_rows = len(images) // n_cols
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3))
    for ax, image, score in zip(axs.flat, images, scores):
        ax.imshow(image)
        ax.set_title(f"{metric}: {score:.2f}")
        ax.axis("off")
        ax.set_aspect("auto")
    plt.show()

# Input Validation

In the image modality, the pymdma package provides two types of input validation:
- **no-reference**: The image is validated without any reference image.
- **reference**: The image is validated with a reference image.

This section will demonstrate how to use the input validation functions from both types.

## No Reference Evaluation

In [None]:
from pymdma.image.measures.input_val import CLIPIQA

clip = CLIPIQA(img_size=320)  # instanciate the metric class
clip_result = clip.compute(images)  # compute the metric
_dataset_level, instance_level = clip_result.value  # fetch the instance level results

plot_instances_score(images, "CLIPIQA", instance_level, n_cols=5)

In [None]:
from pymdma.image.measures.input_val import Tenengrad

tenengrad = Tenengrad()  # sharpness metric
sharpness = tenengrad.compute(images)  # compute on RGB images

_dataset_level, instance_level = sharpness.value

plot_instances_score(images, "Sharpness", instance_level)

In [None]:
from pymdma.image.measures.input_val import Brightness

brightness = Brightness()
brightness_result = brightness.compute(images)
_dataset_level, instance_level = brightness_result.value
plot_instances_score(images, "Brightness", instance_level)

In [None]:
from pymdma.image.measures.input_val import Colorfulness

colorfulness = Colorfulness()
colorfulness_result = colorfulness.compute(images)
_dataset_level, instance_level = colorfulness_result.value
plot_instances_score(images, "Colorfulness", instance_level)

### Ploting the metric results

We provide a simple method in the `MetricResult` class to easily plot the results of the metrics. The method `plot()` will plot the results of the metrics in the format specified by the `plot_params` attribute in the `MetricResult` class. The `plot_params` attribute is a dictionary that contains the parameters to be used in the plot. If this attribute is not set, the method will default to a bar plot.

You can provide a title for the plot when calling this method, as well as an axis is which you wish to plot the results (helpfull when plotting multiple metrics in the same plot). In addition, you can provide a set of `plot_params` to be used directly by matplotlib's plotting functions.

> **Note**: You also have access to the values of the metrics via the `values` attribute in the `MetricResult` class. You can use these values to plot the results using your own plotting functions.


In [None]:
import matplotlib.pyplot as plt

brightness_result.plot("Brightness")  # plot the results from the result object
plt.show()

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 4))  # create a figure with two subplots
brightness_result.plot("Brightness", bins=5, ax=axs[0])  # plot the BRISQUE histogram on the first subplot
colorfulness_result.plot("Colorfulness", bins=5, ax=axs[1])  # plot the CLIPIQA histogram on the second subplot
plt.show()

## Full Reference Evaluation

In [None]:
import numpy as np

N_DISTORTIONS = 5

reference = []
for img in dataset:
    reference += [np.asarray(img)] * N_DISTORTIONS


def generate_full_ref_dataset(dataset):
    distorted = []
    for idx, img in enumerate(dataset):
        img = np.asarray(img)
        for var in np.linspace(1, 10, N_DISTORTIONS):
            gauss = np.random.normal(0, var, img.shape)
            distorted.append((img + gauss).astype(np.uint8))
    return [np.asarray(x) for x in distorted]


distorted = generate_full_ref_dataset(dataset)

In [None]:
from pymdma.image.measures.input_val import PSNR

psnr = PSNR()
psnr_result = psnr.compute(reference, distorted)
_, instance_level = psnr_result.value

for i in range(0, len(instance_level), N_DISTORTIONS):
    plot_instances_score(distorted[i : i + N_DISTORTIONS], "PSNR", instance_level[i : i + N_DISTORTIONS])

In [None]:
from pymdma.image.measures.input_val import MSSIM


def generate_full_ref_dataset(dataset):
    distorted = []
    for idx, img in enumerate(dataset):
        img = np.asarray(img)
        for size in range(N_DISTORTIONS + 1, 1, -1):
            dst = img.copy()
            dst = cv.rectangle(dst, (20, 20), (img.shape[1] // size, img.shape[0] // size), (0, 0, 0), -1)
            distorted.append((dst).astype(np.uint8))
    return [np.asarray(x) for x in distorted]


distorted = generate_full_ref_dataset(dataset)

mssim = MSSIM(kernel_size=15)
mssim_result = mssim.compute(reference, distorted)
_, instance_level = mssim_result.value

for i in range(0, len(instance_level), N_DISTORTIONS):
    plot_instances_score(distorted[i : i + N_DISTORTIONS], "MSSIM", instance_level[i : i + N_DISTORTIONS])

# Synthetic Validation

The automatic evaluation of synthetically generated images is a common practice in the field of generative AI, and is crucial for the assessment of the quality of large synthetic datasets. This is usually done by comparing the synthetic images to a set of reference images by considering the similarity between the distributions of the two sets. In this section, we will demonstrate how to use the `pymdma` package to evaluate the quality of synthetic images.

> **Warning**: Please download the CIFAKE dataset from the following link: [CIFAKE](https://www.kaggle.com/datasets/birdy654/cifake-real-and-ai-generated-synthetic-images) and extract the files to the `data` folder in the root directory of this notebook.

In [None]:
# Download CIFAKE dataset from kaggle
%pip install kagglehub

In [None]:
import kagglehub

kagglehub.config.DEFAULT_CACHE_FOLDER = Path("data/.kagglehub")
cifake_path = kagglehub.dataset_download("birdy654/cifake-real-and-ai-generated-synthetic-images")
cifake_path = Path(cifake_path)
print("Downloaded CIFake dataset to ", str(cifake_path))

In [None]:
import random

from pymdma.image.models.features import ExtractorFactory

random.seed(10)

cifake_path = cifake_path / "test"
test_images_ref = Path(cifake_path / "REAL")  # real images
test_images_synth = Path(cifake_path / "FAKE")  # synthetic images

images_ref = random.sample([img for img in test_images_ref.iterdir() if img.is_file()], 5000)
images_synth = random.sample([img for img in test_images_synth.iterdir() if img.is_file()], 5000)

extractor = ExtractorFactory.model_from_name(name="dino_vits8")
ref_features = extractor.extract_features_from_files(images_ref)
synth_features = extractor.extract_features_from_files(images_synth)

print("Reference features shape:", ref_features.shape)
print("Synthetic features shape:", synth_features.shape)

In [None]:
from umap import UMAP

umap = UMAP(n_components=2, random_state=10, n_jobs=1)
real_feats_2d = umap.fit_transform(ref_features)
fake_feats_2d = umap.transform(synth_features)

plt.figure(figsize=(10, 10))
plt.scatter(real_feats_2d[:, 0], real_feats_2d[:, 1], s=3, label="Real Samples")
plt.scatter(fake_feats_2d[:, 0], fake_feats_2d[:, 1], s=3, label="Fake Samples")
plt.title("UMAP Features Visualization | Real vs Synthetic")
plt.legend()
plt.show()

In [None]:
from pymdma.image.measures.synthesis_val import ImprovedPrecision, ImprovedRecall

ip = ImprovedPrecision(k=5)
ir = ImprovedRecall(k=5)

ip_result = ip.compute(ref_features, synth_features)
ir_result = ir.compute(ref_features, synth_features)

precision_dataset, precision_instance = ip_result.value
recall_dataset, recall_instance = ir_result.value

print(f"Precision: {precision_dataset:.2f} | Recall: {recall_dataset:.2f}")
print(f"Precision: {precision_instance[:20]} | Recall: {recall_instance[:20]}")

In [None]:
def plot_instances_grid(images: list[np.ndarray], n_cols: int = 25):
    n_rows = len(images) // n_cols
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 1, n_rows * 1))
    fig.subplots_adjust(hspace=0, wspace=0)
    for ax, image in zip(axs.flat, images):
        ax.imshow(image)
        ax.get_xaxis().set_ticks([])
        ax.get_yaxis().set_ticks([])
        ax.axis("off")
        ax.set_aspect("auto")
    return fig

In [None]:
import random

from PIL import Image

random.seed(12)
precision_instance = np.array(precision_instance)
imprecise_idx = np.argwhere(precision_instance < 1).flatten()
precise_idx = np.argwhere(precision_instance >= 1).flatten()

precise_samples = random.sample(list(precise_idx), 200)
imprecise_samples = random.sample(list(imprecise_idx), 200)
precise_samples = [np.asarray(Image.open(images_synth[i])) for i in precise_samples]
imprecise_samples = [np.asarray(Image.open(images_synth[i])) for i in imprecise_samples]

precise_fig = plot_instances_grid(precise_samples, n_cols=25)
precise_fig.suptitle("CIFAKE Precise samples", fontsize=16)
plt.show()

imprecise_fig = plot_instances_grid(imprecise_samples, n_cols=25)
imprecise_fig.suptitle("CIFAKE Imprecise samples", fontsize=16)
plt.show()

In [None]:
from pymdma.image.measures.synthesis_val import GIQA

giqa = GIQA()
giqa_result = giqa.compute(ref_features, synth_features)

giqa_dataset, giqa_instance = giqa_result.value
print(f"Dataset level: {giqa_dataset:.2f}")
print(f"Instance level: {giqa_instance[:40]}")

In [None]:
import matplotlib.pyplot as plt

giqa_result.plot("GIQA", bins=50)
plt.xlabel("Score")
plt.ylabel("Frequency")
plt.show()

In [None]:
best_idx = np.argsort(giqa_instance)[::-1][:200]
best_samples = [np.asarray(Image.open(images_synth[i])) for i in best_idx]

best_fig = plot_instances_grid(best_samples, n_cols=25)
best_fig.suptitle("CIFAKE Best samples", fontsize=16)
plt.show()


worst_idx = np.argsort(giqa_instance)[:200]
worst_samples = [np.asarray(Image.open(images_synth[i])) for i in worst_idx]

worst_fig = plot_instances_grid(worst_samples, n_cols=25)
worst_fig.suptitle("CIFAKE Worst samples", fontsize=16)
plt.show()