## Imports

In [1]:
# Standard library
import os
import pickle

# Third-party libraries
import cv2
import lpips
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import torch
from skimage import filters
from skimage.feature import graycomatrix, graycoprops
from skimage.metrics import (
    peak_signal_noise_ratio as psnr,
    structural_similarity as ssim,
 )
from tqdm import tqdm

try:
    import seaborn as sns
except ImportError:
    sns = None

In [2]:
class ImagePairLoader:
    """Separated I/O utilities for iterating and aligning LR/HR pairs."""

    @staticmethod
    def iter_pairs(lr_dir, hr_dir):
        """Yield (lr_filename, hr_filename) pairs in lexicographic order."""
        lr_images = sorted([
            f for f in os.listdir(lr_dir)
            if f.lower().endswith(("png", "jpg", "jpeg"))
        ])
        
        hr_images = sorted([
            f for f in os.listdir(hr_dir)
            if f.lower().endswith(("png", "jpg", "jpeg"))
        ])

        assert len(lr_images) == len(hr_images), \
            "Mismatch in LR/HR image counts"

        for lf, hf in zip(lr_images, hr_images):
            yield lf, hf

    @staticmethod
    def load_and_align(lr_path, hr_path, interp_map=None):
        """Load two images and resize LR to HR size if required.

        If interp_map provided, it should map LR filename to the
        interpolation method string used originally (e.g., 'INTER_CUBIC').
        That method is used for upscaling to match EDA fairness.
        """
        lr = cv2.imread(lr_path)
        hr = cv2.imread(hr_path)

        if lr is None or hr is None:
            raise ValueError(f"Failed reading {lr_path} or {hr_path}")

        if lr.shape[:2] != hr.shape[:2]:
            interp_code = cv2.INTER_LINEAR
            if interp_map is not None:
                fname = os.path.basename(lr_path)
                name = interp_map.get(fname)
                name_to_code = {
                    'INTER_LINEAR': cv2.INTER_LINEAR,
                    'INTER_CUBIC': cv2.INTER_CUBIC,
                    'INTER_AREA': cv2.INTER_AREA,
                    'INTER_LANCZOS4': cv2.INTER_LANCZOS4,
                }
                if name in name_to_code:
                    interp_code = name_to_code[name]
            lr = cv2.resize(
                lr,
                (hr.shape[1], hr.shape[0]),
                interpolation=interp_code
            )

        return lr, hr

## Methods for exploratory data analysis

In [3]:
class ImageDatasetAnalyzer:
    """Utility collection to analyze LR/HR image pairs.

    All methods are static so they can be called without
    instantiation.
    """

    @staticmethod
    def loss_fn():
        """Return (singleton) the loaded LPIPS model.

        Returns
        -------
        lpips.LPIPS
            Initialized LPIPS model instance.
        """
        if not hasattr(ImageDatasetAnalyzer, '_loss_fn'):
            ImageDatasetAnalyzer._loss_fn = lpips.LPIPS(net="alex")

        return ImageDatasetAnalyzer._loss_fn

    @staticmethod
    def lpips_score(lr_img, hr_img):
        """Compute LPIPS between two aligned BGR images.

        Parameters
        ----------
        lr_img : np.ndarray
            Aligned LR image (BGR).
        hr_img : np.ndarray
            HR image (BGR).

        Returns
        -------
        float
            LPIPS value.
        """
        def to_tensor(img):
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
            img = 2 * img - 1
            img = np.transpose(img, (2, 0, 1))
            return torch.from_numpy(img).unsqueeze(0).float()
        
        return ImageDatasetAnalyzer.loss_fn()(
            to_tensor(lr_img),
            to_tensor(hr_img)
        ).item()

    @staticmethod
    def rms_noise(gray):
        """Estimate RMS noise using difference from a Gaussian blur.

        Parameters
        ----------
        img : np.ndarray
            Input BGR image.

        Returns
        -------
        float
            RMS noise estimate.
        """
        blurred = cv2.GaussianBlur(gray, (3, 3), 0)
        diff = gray.astype(np.float32) - blurred.astype(np.float32)

        return float(np.sqrt(np.mean(diff ** 2)))

    @staticmethod
    def laplacian_variance(gray):
        """Variance of Laplacian (sharpness proxy).

        Parameters
        ----------
        img : np.ndarray
            BGR image.

        Returns
        -------
        float
            Variance value.
        """
        return float(cv2.Laplacian(gray, cv2.CV_64F).var())

    @staticmethod
    def additional_quality_metrics(lr_img, hr_img):
        """Compute PSNR and SSIM between HR and LR images.

        Parameters
        ----------
        lr_img : np.ndarray
            Aligned LR image.
        hr_img : np.ndarray
            HR image.

        Returns
        -------
        dict
            {'psnr': float, 'ssim': float}
        """
        return {
            "psnr": psnr(hr_img, lr_img, data_range=255),
            "ssim": ssim(
                hr_img,
                lr_img,
                channel_axis=2,
                data_range=255
            )
        }

    @staticmethod
    def glcm_features(gray, angles=None, levels=64, multi_angle=False):
        """Extract GLCM contrast, homogeneity, correlation.

        Defaults trimmed to 64 levels and one angle (0 rad) to
        reduce memory and time. Set multi_angle=True to use
        the 4 standard angles (0, 45, 90, 135 deg) or pass a
        custom iterable via angles.

        Parameters
        ----------
        gray : np.ndarray
            Grayscale image.
        angles : iterable | None
            Angles in radians. If None uses [0] unless
            multi_angle True (then 4 angles).
        levels : int
            Quantization levels (default 64).
        multi_angle : bool
            Average metrics across four angles when True.

        Returns
        -------
        dict
            glcm_contrast, glcm_homogeneity, glcm_correlation.
        """
        if angles is None:
            if multi_angle:
                angles = (0, np.pi / 4, np.pi / 2, 3 * np.pi / 4)
            else:
                angles = (0,)
        # Quantize to requested levels
        if gray.max() == 0:
            norm = np.zeros_like(gray, dtype=np.uint8)
        else:
            norm = (
                (gray.astype(np.float32) / 255.0) * (levels - 1)
            ).astype(np.uint8)
        glcm = graycomatrix(
            norm,
            [1],
            list(angles),
            levels,
            symmetric=True,
            normed=True
        )
        
        return {
            "glcm_contrast": float(
                graycoprops(glcm, "contrast").mean()
            ),
            "glcm_homogeneity": float(
                graycoprops(glcm, "homogeneity").mean()
            ),
            "glcm_correlation": float(
                graycoprops(glcm, "correlation").mean()
            )
        }

    @staticmethod
    def feature_distribution(img, hsv):
        """Basic stats per BGR channel plus HSV saturation/brightness.

        Parameters
        ----------
        img : np.ndarray
            BGR image.

        Returns
        -------
        dict
            Channel stats plus saturation_mean, brightness_mean.
        """
        results = {}

        for idx, channel in enumerate(cv2.split(img)):
            flat = channel.ravel()
            results[f"ch{idx}_mean"] = float(np.mean(flat))
            results[f"ch{idx}_std"] = float(np.std(flat))
            results[f"ch{idx}_skew"] = float(scipy.stats.skew(flat))
            results[f"ch{idx}_kurt"] = float(scipy.stats.kurtosis(flat))
        
        results["saturation_mean"] = float(np.mean(hsv[:, :, 1]))
        results["brightness_mean"] = float(np.mean(hsv[:, :, 2]))

        return results

    @staticmethod
    def detect_artifacts(img, gray):
        """Detect artifacts: blocking, color noise, ringing.

        Parameters
        ----------
        img : np.ndarray
            BGR image.

        Returns
        -------
        dict
            blocking_score, color_noise, ringing_artifact.
        """
        dct = cv2.dct(np.float32(gray))
        horizontal_blocking = np.mean(np.abs(dct[7::8, :]))
        vertical_blocking = np.mean(np.abs(dct[:, 7::8]))
        blocking_score = float(
            (horizontal_blocking + vertical_blocking) / 2
        )
        blur = cv2.GaussianBlur(img, (5, 5), 0)
        color_noise = float(
            np.mean(np.abs(img.astype(float) - blur.astype(float)))
        )
        edges = cv2.Canny(gray, 100, 200)
        kernel = np.ones((5, 5), np.uint8)
        dilated = cv2.dilate(edges, kernel)
        edge_region = dilated & ~edges
        
        if np.any(edge_region):
            ringing = float(np.std(gray[edge_region]))
        else:
            ringing = 0.0

        return {
            "blocking_score": blocking_score,
            "color_noise": color_noise,
            "ringing_artifact": ringing
        }

In [4]:
class ImagePairMetrics:
    """Container of metrics computed for an LR/HR pair.

    Dataclass avoided to remove type hints as requested.
    """

    def __init__(
        self,
        filename,
        lpips,
        psnr,
        ssim,
        glcm_contrast,
        glcm_homogeneity,
        glcm_correlation,
        rms_noise_lr,
        rms_noise_hr,
        lap_var_lr,
        lap_var_hr,
        blocking_lr,
        blocking_hr,
        color_noise_lr,
        color_noise_hr,
        ringing_lr,
        ringing_hr,
        saturation_mean_lr,
        saturation_mean_hr,
        brightness_mean_lr,
        brightness_mean_hr,
        edge_diff
    ):
        self.filename = filename
        self.lpips = lpips
        self.psnr = psnr
        self.ssim = ssim
        self.glcm_contrast = glcm_contrast
        self.glcm_homogeneity = glcm_homogeneity
        self.glcm_correlation = glcm_correlation
        self.rms_noise_lr = rms_noise_lr
        self.rms_noise_hr = rms_noise_hr
        self.lap_var_lr = lap_var_lr
        self.lap_var_hr = lap_var_hr
        self.blocking_lr = blocking_lr
        self.blocking_hr = blocking_hr
        self.color_noise_lr = color_noise_lr
        self.color_noise_hr = color_noise_hr
        self.ringing_lr = ringing_lr
        self.ringing_hr = ringing_hr
        self.saturation_mean_lr = saturation_mean_lr
        self.saturation_mean_hr = saturation_mean_hr
        self.brightness_mean_lr = brightness_mean_lr
        self.brightness_mean_hr = brightness_mean_hr
        self.edge_diff = edge_diff

    def as_dict(self):
        """Return metrics as a dict for DataFrame conversion."""
        return self.__dict__.copy()

In [5]:
class MetricsAggregator:
    """Orchestrates metric extraction for all image pairs."""

    @staticmethod
    def collect(
        lr_dir,
        hr_dir,
        glcm_multi_angle=False,
        glcm_levels=64,
        interp_map=None,
    ):
        """Compute metrics for each pair and accumulate global visual data.

        Always returns a tuple (rows, global_data) where global_data contains
        the accumulators required to build the global advanced visualization panel.

        Parameters
        ----------
        lr_dir : str
            LR directory.
        hr_dir : str
            HR directory.
        glcm_multi_angle : bool
            Use 4 angles for GLCM if True; else single angle.
        glcm_levels : int
            Quantization levels for GLCM (default 64) used for scalar metrics only.
        interp_map : dict | None
            Mapping from LR filename to interpolation method name for
            consistent upscaling.

        Returns
        -------
        (list[ImagePairMetrics], dict)
            rows: list of per-pair metric objects.
            global_data: dict with accumulated arrays and statistics.
        """
        rows = []

        # Structures for global visualization
        sat_bins = np.linspace(0, 256, 51)  # 50 bins 0-255
        global_data = {
            'count': 0,
            'lr_fft_sum': None,
            'hr_fft_sum': None,
            'grad_hr_sum': None,
            'glcm_sum': None,  # shape (256, 256, 1, 1)
            'sat_lr_counts': np.zeros(len(sat_bins) - 1, dtype=np.float64),
            'sat_hr_counts': np.zeros(len(sat_bins) - 1, dtype=np.float64),
            'sat_bins': sat_bins,
            'noise_means_lr': [],
        }

        for lf, hf in ImagePairLoader.iter_pairs(lr_dir, hr_dir):
            lr_img, hr_img = ImagePairLoader.load_and_align(
                os.path.join(lr_dir, lf),
                os.path.join(hr_dir, hf),
                interp_map=interp_map
            )
            
            gray_lr = cv2.cvtColor(lr_img, cv2.COLOR_BGR2GRAY)
            gray_hr = cv2.cvtColor(hr_img, cv2.COLOR_BGR2GRAY)
            hsv_lr = cv2.cvtColor(lr_img, cv2.COLOR_BGR2HSV)
            hsv_hr = cv2.cvtColor(hr_img, cv2.COLOR_BGR2HSV)

            lpips_val = ImageDatasetAnalyzer.lpips_score(lr_img, hr_img)
            q = ImageDatasetAnalyzer.additional_quality_metrics(lr_img, hr_img)
            glcm = ImageDatasetAnalyzer.glcm_features(
                gray_lr,
                levels=glcm_levels,
                multi_angle=glcm_multi_angle
            )
            fd_lr = ImageDatasetAnalyzer.feature_distribution(lr_img, hsv_lr)
            fd_hr = ImageDatasetAnalyzer.feature_distribution(hr_img, hsv_hr)
            art_lr = ImageDatasetAnalyzer.detect_artifacts(lr_img, gray_lr)
            art_hr = ImageDatasetAnalyzer.detect_artifacts(hr_img, gray_hr)
            rms_lr = ImageDatasetAnalyzer.rms_noise(gray_lr)
            rms_hr = ImageDatasetAnalyzer.rms_noise(gray_hr)
            lap_var_lr = ImageDatasetAnalyzer.laplacian_variance(gray_lr)
            lap_var_hr = ImageDatasetAnalyzer.laplacian_variance(gray_hr)

            lr_edges = filters.sobel(gray_lr)
            hr_edges = filters.sobel(gray_hr)
            edge_diff = float(np.mean(hr_edges) - np.mean(lr_edges))

            rows.append(
                ImagePairMetrics(
                    filename=lf,
                    lpips=lpips_val,
                    psnr=q['psnr'],
                    ssim=q['ssim'],
                    glcm_contrast=glcm['glcm_contrast'],
                    glcm_homogeneity=glcm['glcm_homogeneity'],
                    glcm_correlation=glcm['glcm_correlation'],
                    rms_noise_lr=rms_lr,
                    rms_noise_hr=rms_hr,
                    lap_var_lr=lap_var_lr,
                    lap_var_hr=lap_var_hr,
                    blocking_lr=art_lr['blocking_score'],
                    blocking_hr=art_hr['blocking_score'],
                    color_noise_lr=art_lr['color_noise'],
                    color_noise_hr=art_hr['color_noise'],
                    ringing_lr=art_lr['ringing_artifact'],
                    ringing_hr=art_hr['ringing_artifact'],
                    saturation_mean_lr=fd_lr['saturation_mean'],
                    saturation_mean_hr=fd_hr['saturation_mean'],
                    brightness_mean_lr=fd_lr['brightness_mean'],
                    brightness_mean_hr=fd_hr['brightness_mean'],
                    edge_diff=edge_diff
                )
            )

            # Initialize accumulators if first image
            if global_data['lr_fft_sum'] is None:
                lr_fft_mag = np.abs(np.fft.fftshift(np.fft.fft2(gray_lr)))
                hr_fft_mag = np.abs(np.fft.fftshift(np.fft.fft2(gray_hr)))
                global_data['lr_fft_sum'] = lr_fft_mag.astype(np.float64)
                global_data['hr_fft_sum'] = hr_fft_mag.astype(np.float64)
                sobelx = cv2.Sobel(gray_hr, cv2.CV_64F, 1, 0, ksize=5)
                sobely = cv2.Sobel(gray_hr, cv2.CV_64F, 0, 1, ksize=5)
                grad_mag = np.sqrt(sobelx ** 2 + sobely ** 2)
                global_data['grad_hr_sum'] = grad_mag
                lr_glcm_full = graycomatrix(
                    gray_lr,
                    [1],
                    [0],
                    256,
                    symmetric=True,
                    normed=True
                )
                global_data['glcm_sum'] = lr_glcm_full.astype(np.float64)
            else:
                global_data['lr_fft_sum'] += np.abs(
                    np.fft.fftshift(np.fft.fft2(gray_lr))
                )
                global_data['hr_fft_sum'] += np.abs(
                    np.fft.fftshift(np.fft.fft2(gray_hr))
                )
                sobelx = cv2.Sobel(gray_hr, cv2.CV_64F, 1, 0, ksize=5)
                sobely = cv2.Sobel(gray_hr, cv2.CV_64F, 0, 1, ksize=5)
                grad_mag = np.sqrt(sobelx ** 2 + sobely ** 2)
                global_data['grad_hr_sum'] += grad_mag
                lr_glcm_full = graycomatrix(
                    gray_lr,
                    [1],
                    [0],
                    256,
                    symmetric=True,
                    normed=True
                )
                global_data['glcm_sum'] += lr_glcm_full

            # Saturation histograms (LR & HR)
            sat_lr = hsv_lr[:, :, 1]
            sat_hr = hsv_hr[:, :, 1]
            lr_counts, _ = np.histogram(sat_lr, bins=global_data['sat_bins'])
            hr_counts, _ = np.histogram(sat_hr, bins=global_data['sat_bins'])
            global_data['sat_lr_counts'] += lr_counts
            global_data['sat_hr_counts'] += hr_counts
            global_data['noise_means_lr'].append(art_lr['color_noise'])
            global_data['count'] += 1

        return rows, global_data

In [6]:
class StatsReporter:
    """Utilities to convert and summarize metrics to a DataFrame."""

    @staticmethod
    def dataframe(rows):
        """Convert list of ImagePairMetrics into a DataFrame.

        Parameters
        ----------
        rows : list[ImagePairMetrics]
            List of metric objects.

        Returns
        -------
        pandas.DataFrame
            One row per image pair.
        """
        return pd.DataFrame([r.as_dict() for r in rows])

    @staticmethod
    def summary(df):
        """Return basic descriptive statistics.

        Parameters
        ----------
        df : pandas.DataFrame
            Metrics data.

        Returns
        -------
        pandas.DataFrame
            mean, std and quartiles.
        """
        return df.describe().T[['mean', 'std', '25%', '50%', '75%']]

## Visualization methods

In [7]:
class ImageDataVisualization:
    """Visualization utilities for exploratory analysis."""

    @staticmethod
    def save_visual_example(lr_img, hr_img, output_path, lpips_val):
        """Save comparison figure and a difference heatmap.

        Parameters
        ----------
        lr_img : np.ndarray
            LR image.
        hr_img : np.ndarray
            HR image.
        output_path : str
            Output PNG path.
        lpips_val : float
            LPIPS value for title.
        """
        lr_resized = cv2.resize(
            lr_img,
            (hr_img.shape[1], hr_img.shape[0]),
            interpolation=cv2.INTER_CUBIC
        )

        diff_map = cv2.absdiff(lr_resized, hr_img)
        diff_map_color = cv2.applyColorMap(
            cv2.convertScaleAbs(cv2.cvtColor(diff_map, cv2.COLOR_BGR2GRAY)),
            cv2.COLORMAP_JET
        )

        _, axes = plt.subplots(1, 3, figsize=(12, 4))
        axes[0].imshow(cv2.cvtColor(lr_resized, cv2.COLOR_BGR2RGB))
        axes[0].set_title("Rescaled LR")
        axes[0].axis("off")

        axes[1].imshow(cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB))
        axes[1].set_title("HR")
        axes[1].axis("off")

        axes[2].imshow(diff_map_color)
        axes[2].set_title(f"Difference map\nLPIPS: {lpips_val:.4f}")
        axes[2].axis("off")

        plt.tight_layout()
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        plt.savefig(output_path)
        plt.close()

    @staticmethod
    def create_global_advanced_visualizations(global_data, output_path):
        """Create a single global panel with averaged spectra, gradient,
        averaged GLCM, noise mean distribution, and global saturation distributions.

        Parameters
        ----------
        global_data : dict
            Accumulated structures from MetricsAggregator.collect(..., collect_global_visuals=True).
        output_path : str
            Output path for the figure.
        """
        if global_data is None or global_data.get('count', 0) == 0:
            print("No global data available to create advanced visualization.")
            return

        n = global_data['count']
        eps = 1e-8
        # Average spectra (compute log after averaging magnitudes)
        lr_fft_avg = global_data['lr_fft_sum'] / n
        hr_fft_avg = global_data['hr_fft_sum'] / n
        # Average gradient magnitude
        grad_hr_avg = global_data['grad_hr_sum'] / n
        # Average (unnormalized-sum) GLCM then renormalize
        glcm_sum = global_data['glcm_sum']
        glcm_avg = glcm_sum / glcm_sum.sum()
        glcm_contrast = graycoprops(glcm_avg, "contrast")[0, 0]

        # Saturation histograms (normalize to density)
        sat_bins = global_data['sat_bins']
        bin_width = sat_bins[1] - sat_bins[0]
        sat_lr_density = global_data['sat_lr_counts'] / (global_data['sat_lr_counts'].sum() * bin_width + eps)
        sat_hr_density = global_data['sat_hr_counts'] / (global_data['sat_hr_counts'].sum() * bin_width + eps)
        sat_centers = 0.5 * (sat_bins[:-1] + sat_bins[1:])

        noise_means = np.array(global_data['noise_means_lr'], dtype=np.float64)

        plt.figure(figsize=(20, 10))

        # 1. Averaged LR Spectrum
        plt.subplot(231)
        plt.imshow(np.log(lr_fft_avg + eps), cmap="viridis")
        plt.title("LR Avg Frequency Spectrum")
        plt.colorbar()

        # 2. Averaged HR Spectrum
        plt.subplot(232)
        plt.imshow(np.log(hr_fft_avg + eps), cmap="viridis")
        plt.title("HR Avg Frequency Spectrum")
        plt.colorbar()

        # 3. Averaged HR Gradient Magnitude
        plt.subplot(233)
        plt.imshow(grad_hr_avg, cmap="magma")
        plt.title("HR Avg Gradient Magnitude")
        plt.colorbar()

        # 4. Averaged LR GLCM
        plt.subplot(234)
        plt.imshow(glcm_avg[:, :, 0, 0], cmap="plasma")
        plt.title(f"LR Avg GLCM (Contrast: {glcm_contrast:.2f})")
        plt.colorbar()

        # 5. Noise Mean Distribution (Histogram of per-image means)
        plt.subplot(235)
        plt.hist(noise_means, bins=30, color='tomato', edgecolor='black', alpha=0.8)
        plt.title(f"LR Color Noise Mean Distribution\nMean={noise_means.mean():.2f} Std={noise_means.std():.2f}")
        plt.xlabel("Per-image color noise mean")
        plt.ylabel("Count")

        # 6. Saturation Distribution (global)
        plt.subplot(236)
        plt.plot(sat_centers, sat_lr_density, label='LR', color='steelblue')
        plt.plot(sat_centers, sat_hr_density, label='HR', color='orange')
        plt.title("Global Saturation Density")
        plt.xlabel("Saturation value")
        plt.ylabel("Density")
        plt.legend()

        plt.tight_layout()
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        plt.savefig(output_path, dpi=300, bbox_inches="tight")
        plt.close()

    @staticmethod
    def generate_summary_visualizations(df, output_dir):
        """Histograms and correlation matrix (if seaborn available).

        Parameters
        ----------
        df : pandas.DataFrame
            DataFrame with metrics.
        output_dir : str
            Output directory.
        """
        os.makedirs(output_dir, exist_ok=True)

        metrics = [
            'lpips',
            'psnr',
            'ssim',
            'lap_var_hr',
            'rms_noise_hr',
            'blocking_hr'
        ]

        plt.figure(figsize=(15, 10))
        for i, m in enumerate(metrics, 1):
            plt.subplot(2, 3, i)
            plt.hist(df[m], bins=30, color='steelblue', edgecolor='black')
            plt.title(m)
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'distributions.png'))
        plt.close()

        if sns is not None:
            corr = df[metrics].corr()
            plt.figure(figsize=(8, 6))
            sns.heatmap(corr, cmap='coolwarm', annot=True, fmt='.2f', center=0)
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, 'correlation_matrix.png'))
            plt.close()

    @staticmethod
    def scatter_relations(df, output_dir):
        """Generate scatter plots of key metric relationships.

        Parameters
        ----------
        df : pandas.DataFrame
            Metrics DataFrame.
        output_dir : str
            Output directory.
        """
        pairs = [
            ('lpips', 'psnr'),
            ('lpips', 'ssim'),
            ('lap_var_hr', 'psnr'),
            ('rms_noise_hr', 'psnr')
        ]

        plt.figure(figsize=(16, 10))
        for i, (x, y) in enumerate(pairs, 1):
            plt.subplot(2, 2, i)
            plt.scatter(df[x], df[y], s=12, alpha=0.7)
            plt.xlabel(x)
            plt.ylabel(y)
            plt.title(f'{x} vs {y}')

        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'scatter_relations.png'))
        plt.close()

## Dataset processing

In [8]:
def run_eda_pipeline(
    lr_dir,
    hr_dir,
    output_dir="eda_results",
    top_k_examples=3,
    glcm_multi_angle=False,
    glcm_levels=64,
    interp_map_path="images/low_z_offset_interpolation_map.pkl",
):
    """Execute full EDA pipeline and return metrics DataFrame.

    Parameters
    ----------
    lr_dir : str
        Directory containing LR images.
    hr_dir : str
        Directory containing HR images.
    output_dir : str, default 'eda_results'
        Base directory for outputs.
    top_k_examples : int, default 3
        Number of best and worst examples (by LPIPS) to save.
    glcm_multi_angle : bool, default False
        If True use 4 angles for GLCM; else a single angle.
    glcm_levels : int, default 64
        Quantization levels for GLCM computation (for per-pair scalar metrics).
    interp_map_path : str
        Path to pickle mapping LR filename -> interpolation method.

    Returns
    -------
    pandas.DataFrame
        Metrics per image pair.
    """
    os.makedirs(output_dir, exist_ok=True)

    # Load interpolation mapping
    interp_map = None
    if interp_map_path and os.path.exists(interp_map_path):
        try:
            with open(interp_map_path, 'rb') as f:
                interp_map = pickle.load(f)
        except Exception as e:
            print(f"Warning: could not load interpolation map: {e}")
    else:
        print("Interpolation map not found; default interpolation will be used.")
    
    examples_dir = os.path.join(output_dir, "LPIPS_Scenarios")
    best_dir = os.path.join(examples_dir, "best_scenarios")
    worst_dir = os.path.join(examples_dir, "worst_scenarios")

    for d in (best_dir, worst_dir):
        os.makedirs(d, exist_ok=True)
    
    # Collect metrics + global visualization data
    rows, global_data = MetricsAggregator.collect(
        lr_dir,
        hr_dir,
        glcm_multi_angle=glcm_multi_angle,
        glcm_levels=glcm_levels,
        interp_map=interp_map,
    )
    df = StatsReporter.dataframe(rows)

    # Create single global advanced visualization panel
    ImageDataVisualization.create_global_advanced_visualizations(
        global_data,
        os.path.join(output_dir, "advanced_global_panel.png"),
    )

    # Global visualizations (hist + correlation + scatters)
    ImageDataVisualization.generate_summary_visualizations(df, output_dir)
    ImageDataVisualization.scatter_relations(df, output_dir)

    # Best / worst examples based on LPIPS
    df_sorted = df.sort_values("lpips")
    selections = [
        (df_sorted.head(top_k_examples), "best_scenarios", "best"),
        (df_sorted.tail(top_k_examples), "worst_scenarios", "worst"),
    ]

    for subset, subdir, label in selections:
        for rank, name in enumerate(subset["filename"].tolist(), 1):
            lr_img, hr_img = ImagePairLoader.load_and_align(
                os.path.join(lr_dir, name),
                os.path.join(hr_dir, name),
                interp_map=interp_map
            )
            lp_val = subset.loc[
                subset["filename"] == name,
                "lpips"
            ].values[0]
            out_path = os.path.join(
                os.path.join(examples_dir, subdir),
                f"{label}_{rank}_{name}.png"
            )
            ImageDataVisualization.save_visual_example(
                lr_img,
                hr_img,
                out_path,
                lp_val
            )
    
    print("Saved metrics (summary) and global advanced visualization to", output_dir)

    return df

In [9]:
# Example execution
lr_dir = "images/LR/low_z_offset"
hr_dir = "images/HR/low_z_offset"
output_dir = "eda_results"

run_eda_pipeline(lr_dir, hr_dir, output_dir)

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\lpips\weights\v0.1\alex.pth
Saved metrics (summary) and global advanced visualization to eda_results


Unnamed: 0,filename,lpips,psnr,ssim,glcm_contrast,glcm_homogeneity,glcm_correlation,rms_noise_lr,rms_noise_hr,lap_var_lr,...,blocking_hr,color_noise_lr,color_noise_hr,ringing_lr,ringing_hr,saturation_mean_lr,saturation_mean_hr,brightness_mean_lr,brightness_mean_hr,edge_diff
0,low_z_offset0.png,0.366334,28.221429,0.701873,0.454962,0.841343,0.989867,0.879475,4.060324,12.613618,...,5.797513,0.961992,3.606722,18.854521,21.444372,11.509357,12.950005,183.433956,184.272290,0.016656
1,low_z_offset1.png,0.265116,27.169442,0.700121,2.347552,0.777473,0.958003,4.726769,6.003793,396.442132,...,7.627305,4.452345,5.668956,18.558220,19.603585,14.939427,15.491859,180.374013,180.991347,0.008946
2,low_z_offset10.png,0.428815,28.779378,0.736538,0.883876,0.729247,0.977193,0.785974,3.479813,10.139952,...,6.531646,0.887972,3.386565,16.505174,20.071437,14.380084,15.641852,177.573384,178.037272,0.017233
3,low_z_offset100.png,0.322418,28.415277,0.770093,1.387595,0.712648,0.966102,2.105365,3.922570,74.137854,...,7.013493,2.411143,3.759353,16.947183,15.812100,11.184954,14.339162,186.064271,187.674375,0.004738
4,low_z_offset101.png,0.387359,27.716992,0.800652,0.790549,0.767847,0.982223,1.189740,3.788779,22.975271,...,7.404454,1.356468,3.552279,7.631252,17.134706,10.067243,13.028755,181.412432,182.483307,0.015734
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
338,low_z_offset95.png,0.234495,28.838332,0.820682,3.994325,0.510739,0.979713,1.863103,2.028015,58.010730,...,5.417804,2.157735,2.268738,29.495470,30.104342,18.205507,21.387568,167.743203,169.499549,-0.003494
339,low_z_offset96.png,0.408952,29.509783,0.882310,6.513451,0.687068,0.950421,5.248876,3.611707,482.949887,...,8.148830,4.593182,3.894900,24.528069,28.405069,20.432363,20.046231,175.686017,175.989067,0.009544
340,low_z_offset97.png,0.472600,27.617079,0.759207,4.246037,0.704739,0.932201,4.711739,4.272199,408.382066,...,8.232294,4.894483,4.128254,15.141235,20.206672,18.902400,13.879125,187.507362,186.101583,0.010250
341,low_z_offset98.png,0.260119,30.448617,0.862947,1.314772,0.687244,0.976494,1.126493,3.904114,21.247846,...,7.380521,1.358431,3.798310,16.375244,22.256849,17.483325,17.969403,175.122044,175.186092,0.017100
