## Imports

In [1]:
import os
import pickle

import cv2
import lpips
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import torch
from skimage import filters
from skimage.feature import graycomatrix, graycoprops
from skimage.metrics import (
    peak_signal_noise_ratio as psnr,
    structural_similarity as ssim,
 )
from tqdm import tqdm
import seaborn as sns

In [2]:
class ImagePairLoader:
    """Separated I/O utilities for iterating and aligning LR/HR pairs."""

    @staticmethod
    def iter_pairs(lr_base, hr_base):
        """Yield matching (lr_relpath, hr_relpath) pairs by scanning subfolders.

        Pairs are matched by their relative path from the base dirs. Only files
        present in both LR and HR trees are returned. Order is lexicographic.
        """
        
        exts = (".png", ".jpg", ".jpeg")

        def walk_relnames(base):
            rels = set()
            for root, _, files in os.walk(base):
                for f in files:
                    if f.lower().endswith(exts):
                        full = os.path.join(root, f)
                        rel = os.path.relpath(full, base)
                        rels.add(rel)
            return rels
        
        lr_set = walk_relnames(lr_base)
        hr_set = walk_relnames(hr_base)
        
        common = sorted(lr_set & hr_set)
        if not common:
            raise ValueError("No matching LR/HR image pairs were found under the provided directories.")
        
        for rel in common:
            # Use the same relative path for LR and HR; callers will join with bases
            yield rel, rel

    @staticmethod
    def load_and_align(lr_path, hr_path, interp_map=None):
        """Load two images and resize LR to HR size if required.

        If interp_map provided, it should map LR filename to the
        interpolation method string used originally (e.g., 'INTER_CUBIC').
        That method is used for upscaling to match EDA fairness.
        """
        
        lr = cv2.imread(lr_path)
        hr = cv2.imread(hr_path)

        if lr is None or hr is None:
            raise ValueError(f"Failed reading {lr_path} or {hr_path}")

        if lr.shape[:2] != hr.shape[:2]:
            interp_code = cv2.INTER_LINEAR
            if interp_map is not None:
                fname = os.path.basename(lr_path)
                name = interp_map.get(fname)
                name_to_code = {
                    'INTER_LINEAR': cv2.INTER_LINEAR,
                    'INTER_CUBIC': cv2.INTER_CUBIC,
                    'INTER_AREA': cv2.INTER_AREA,
                    'INTER_LANCZOS4': cv2.INTER_LANCZOS4,
                }
                if name in name_to_code:
                    interp_code = name_to_code[name]
            lr = cv2.resize(
                lr,
                (hr.shape[1], hr.shape[0]),
                interpolation=interp_code
            )

        return lr, hr

## Methods for exploratory data analysis

In [3]:
class ImageDatasetAnalyzer:
    """Utility collection to analyze LR/HR image pairs.

    All methods are static so they can be called without
    instantiation.
    """

    @staticmethod
    def loss_fn():
        """Return (singleton) the loaded LPIPS model.

        Returns
        -------
        lpips.LPIPS
            Initialized LPIPS model instance.
        """
        
        if not hasattr(ImageDatasetAnalyzer, '_loss_fn'):
            ImageDatasetAnalyzer._loss_fn = lpips.LPIPS(net="alex")

        return ImageDatasetAnalyzer._loss_fn

    @staticmethod
    def lpips_score(lr_img, hr_img):
        """Compute LPIPS between two aligned BGR images.

        Parameters
        ----------
        lr_img : np.ndarray
            Aligned LR image (BGR).
        hr_img : np.ndarray
            HR image (BGR).

        Returns
        -------
        float
            LPIPS value.
        """
        
        def to_tensor(img):
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
            img = 2 * img - 1
            img = np.transpose(img, (2, 0, 1))
            return torch.from_numpy(img).unsqueeze(0).float()
        
        return ImageDatasetAnalyzer.loss_fn()(
            to_tensor(lr_img),
            to_tensor(hr_img)
        ).item()

    @staticmethod
    def rms_noise(gray):
        """Estimate RMS noise using difference from a Gaussian blur.

        Parameters
        ----------
        img : np.ndarray
            Input BGR image.

        Returns
        -------
        float
            RMS noise estimate.
        """
        
        blurred = cv2.GaussianBlur(gray, (3, 3), 0)
        diff = gray.astype(np.float32) - blurred.astype(np.float32)

        return float(np.sqrt(np.mean(diff ** 2)))

    @staticmethod
    def laplacian_variance(gray):
        """Variance of Laplacian (sharpness proxy).

        Parameters
        ----------
        img : np.ndarray
            BGR image.

        Returns
        -------
        float
            Variance value.
        """
        
        return float(cv2.Laplacian(gray, cv2.CV_64F).var())

    @staticmethod
    def psnr_metric(lr_img, hr_img):
        """Compute PSNR between HR and LR images.

        Parameters
        ----------
        lr_img : np.ndarray
            Aligned LR image.
        hr_img : np.ndarray
            HR image.

        Returns
        -------
        float
            PSNR value.
        """
        
        return psnr(hr_img, lr_img, data_range=255)

    @staticmethod
    def ssim_metric(lr_img, hr_img):
        """Compute SSIM between HR and LR images.

        Parameters
        ----------
        lr_img : np.ndarray
            Aligned LR image.
        hr_img : np.ndarray
            HR image.

        Returns
        -------
        float
            SSIM value.
        """
        
        return ssim(hr_img, lr_img, channel_axis=2, data_range=255)

    @staticmethod
    def glcm_features(gray, angles=None, levels=64, multi_angle=False):
        """Extract GLCM contrast, homogeneity, correlation.

        Defaults trimmed to 64 levels and one angle (0 rad) to
        reduce memory and time. Set multi_angle=True to use
        the 4 standard angles (0, 45, 90, 135 deg) or pass a
        custom iterable via angles.

        Parameters
        ----------
        gray : np.ndarray
            Grayscale image.
        angles : iterable | None
            Angles in radians. If None uses [0] unless
            multi_angle True (then 4 angles).
        levels : int
            Quantization levels (default 64).
        multi_angle : bool
            Average metrics across four angles when True.

        Returns
        -------
        dict
            glcm_contrast, glcm_homogeneity, glcm_correlation.
        """
        
        if angles is None:
            if multi_angle:
                angles = (0, np.pi / 4, np.pi / 2, 3 * np.pi / 4)
            else:
                angles = (0,)
        # Quantize to requested levels
        if gray.max() == 0:
            norm = np.zeros_like(gray, dtype=np.uint8)
        else:
            norm = (
                (gray.astype(np.float32) / 255.0) * (levels - 1)
            ).astype(np.uint8)
        glcm = graycomatrix(
            norm,
            [1],
            list(angles),
            levels,
            symmetric=True,
            normed=True
        )
        
        return {
            "glcm_contrast": float(
                graycoprops(glcm, "contrast").mean()
            ),
            "glcm_homogeneity": float(
                graycoprops(glcm, "homogeneity").mean()
            ),
            "glcm_correlation": float(
                graycoprops(glcm, "correlation").mean()
            )
        }

    @staticmethod
    def feature_distribution(img, hsv):
        """Basic stats per BGR channel plus HSV saturation/brightness.

        Parameters
        ----------
        img : np.ndarray
            BGR image.

        Returns
        -------
        dict
            Channel stats plus saturation_mean, brightness_mean.
        """
        
        results = {}

        for idx, channel in enumerate(cv2.split(img)):
            flat = channel.ravel()
            results[f"ch{idx}_mean"] = float(np.mean(flat))
            results[f"ch{idx}_std"] = float(np.std(flat))
            results[f"ch{idx}_skew"] = float(scipy.stats.skew(flat))
            results[f"ch{idx}_kurt"] = float(scipy.stats.kurtosis(flat))
        
        results["saturation_mean"] = float(np.mean(hsv[:, :, 1]))
        results["brightness_mean"] = float(np.mean(hsv[:, :, 2]))

        return results

    @staticmethod
    def detect_artifacts(img, gray):
        """Detect artifacts: blocking, color noise, ringing.

        Parameters
        ----------
        img : np.ndarray
            BGR image.

        Returns
        -------
        dict
            blocking_score, color_noise, ringing_artifact.
        """
        
        dct = cv2.dct(np.float32(gray))
        horizontal_blocking = np.mean(np.abs(dct[7::8, :]))
        vertical_blocking = np.mean(np.abs(dct[:, 7::8]))
        blocking_score = float(
            (horizontal_blocking + vertical_blocking) / 2
        )
        blur = cv2.GaussianBlur(img, (5, 5), 0)
        color_noise = float(
            np.mean(np.abs(img.astype(float) - blur.astype(float)))
        )
        edges = cv2.Canny(gray, 100, 200)
        kernel = np.ones((5, 5), np.uint8)
        dilated = cv2.dilate(edges, kernel)
        edge_region = dilated & ~edges
        
        if np.any(edge_region):
            ringing = float(np.std(gray[edge_region]))
        else:
            ringing = 0.0

        return {
            "blocking_score": blocking_score,
            "color_noise": color_noise,
            "ringing_artifact": ringing
        }

In [4]:
class ImagePairMetrics:
    """Container of metrics computed for an LR/HR pair."""

    def __init__(
        self,
        filename,
        lpips,
        psnr,
        ssim,
        glcm_contrast,
        glcm_homogeneity,
        glcm_correlation,
        rms_noise_lr,
        rms_noise_hr,
        lap_var_lr,
        lap_var_hr,
        blocking_lr,
        blocking_hr,
        color_noise_lr,
        color_noise_hr,
        ringing_lr,
        ringing_hr,
        saturation_mean_lr,
        saturation_mean_hr,
        brightness_mean_lr,
        brightness_mean_hr,
        edge_diff,
        ch0_skew_lr=None, 
        ch0_skew_hr=None,
        ch1_skew_lr=None, 
        ch1_skew_hr=None,
        ch2_skew_lr=None, 
        ch2_skew_hr=None,
        ch0_kurt_lr=None, 
        ch0_kurt_hr=None,
        ch1_kurt_lr=None, 
        ch1_kurt_hr=None,
        ch2_kurt_lr=None, 
        ch2_kurt_hr=None,
    ):
        self.filename = filename
        self.lpips = lpips
        self.psnr = psnr
        self.ssim = ssim
        self.glcm_contrast = glcm_contrast
        self.glcm_homogeneity = glcm_homogeneity
        self.glcm_correlation = glcm_correlation
        self.rms_noise_lr = rms_noise_lr
        self.rms_noise_hr = rms_noise_hr
        self.lap_var_lr = lap_var_lr
        self.lap_var_hr = lap_var_hr
        self.blocking_lr = blocking_lr
        self.blocking_hr = blocking_hr
        self.color_noise_lr = color_noise_lr
        self.color_noise_hr = color_noise_hr
        self.ringing_lr = ringing_lr
        self.ringing_hr = ringing_hr
        self.saturation_mean_lr = saturation_mean_lr
        self.saturation_mean_hr = saturation_mean_hr
        self.brightness_mean_lr = brightness_mean_lr
        self.brightness_mean_hr = brightness_mean_hr
        self.edge_diff = edge_diff
        self.ch0_skew_lr = ch0_skew_lr
        self.ch0_skew_hr = ch0_skew_hr
        self.ch1_skew_lr = ch1_skew_lr
        self.ch1_skew_hr = ch1_skew_hr
        self.ch2_skew_lr = ch2_skew_lr
        self.ch2_skew_hr = ch2_skew_hr
        self.ch0_kurt_lr = ch0_kurt_lr
        self.ch0_kurt_hr = ch0_kurt_hr
        self.ch1_kurt_lr = ch1_kurt_lr
        self.ch1_kurt_hr = ch1_kurt_hr
        self.ch2_kurt_lr = ch2_kurt_lr
        self.ch2_kurt_hr = ch2_kurt_hr

    def as_dict(self):
        """Return metrics as a dict for DataFrame conversion."""
        return self.__dict__.copy()

class MetricsAggregator:
    """Orchestrates metric extraction for all image pairs."""
    
    @staticmethod
    def collect(
        lr_dir,
        hr_dir,
        glcm_multi_angle=False,
        glcm_levels=64,
        interp_map=None,
    ):
        """Compute metrics for each pair and accumulate global visual data.
        Always returns a tuple (rows, global_data) where global_data contains
        the accumulators required to build the global advanced visualization panel.
        Now uses tqdm for progress display. Added complementary skew/kurt capture.
        """
        
        rows = []
        sat_bins = np.linspace(0, 256, 51)  # 50 bins 0-255
        global_data = {
            'count': 0,
            'lr_fft_sum': None,
            'hr_fft_sum': None,
            'grad_hr_sum': None,
            'glcm_sum': None,  # shape (256, 256, 1, 1)
            'sat_lr_counts': np.zeros(len(sat_bins) - 1, dtype=np.float64),
            'sat_hr_counts': np.zeros(len(sat_bins) - 1, dtype=np.float64),
            'sat_bins': sat_bins,
            'noise_means_lr': [],
        }
        
        pairs = list(ImagePairLoader.iter_pairs(lr_dir, hr_dir))
        for lf, hf in tqdm(pairs, desc="Computing metrics", unit="img"):
            lr_img, hr_img = ImagePairLoader.load_and_align(
                os.path.join(lr_dir, lf),
                os.path.join(hr_dir, hf),
                interp_map=interp_map
            )
            
            gray_lr = cv2.cvtColor(lr_img, cv2.COLOR_BGR2GRAY)
            gray_hr = cv2.cvtColor(hr_img, cv2.COLOR_BGR2GRAY)
            hsv_lr = cv2.cvtColor(lr_img, cv2.COLOR_BGR2HSV)
            hsv_hr = cv2.cvtColor(hr_img, cv2.COLOR_BGR2HSV)
            lpips_val = ImageDatasetAnalyzer.lpips_score(lr_img, hr_img)
            psnr_val = ImageDatasetAnalyzer.psnr_metric(lr_img, hr_img)
            ssim_val = ImageDatasetAnalyzer.ssim_metric(lr_img, hr_img)
            glcm = ImageDatasetAnalyzer.glcm_features(
                gray_lr,
                levels=glcm_levels,
                multi_angle=glcm_multi_angle
            )
            fd_lr = ImageDatasetAnalyzer.feature_distribution(lr_img, hsv_lr)
            fd_hr = ImageDatasetAnalyzer.feature_distribution(hr_img, hsv_hr)
            art_lr = ImageDatasetAnalyzer.detect_artifacts(lr_img, gray_lr)
            art_hr = ImageDatasetAnalyzer.detect_artifacts(hr_img, gray_hr)
            rms_lr = ImageDatasetAnalyzer.rms_noise(gray_lr)
            rms_hr = ImageDatasetAnalyzer.rms_noise(gray_hr)
            lap_var_lr = ImageDatasetAnalyzer.laplacian_variance(gray_lr)
            lap_var_hr = ImageDatasetAnalyzer.laplacian_variance(gray_hr)
            lr_edges = filters.sobel(gray_lr)
            hr_edges = filters.sobel(gray_hr)
            edge_diff = float(np.mean(hr_edges) - np.mean(lr_edges))
            ch0_skew_lr = fd_lr.get('ch0_skew')
            ch0_skew_hr = fd_hr.get('ch0_skew')
            ch1_skew_lr = fd_lr.get('ch1_skew')
            ch1_skew_hr = fd_hr.get('ch1_skew')
            ch2_skew_lr = fd_lr.get('ch2_skew')
            ch2_skew_hr = fd_hr.get('ch2_skew')
            ch0_kurt_lr = fd_lr.get('ch0_kurt')
            ch0_kurt_hr = fd_hr.get('ch0_kurt')
            ch1_kurt_lr = fd_lr.get('ch1_kurt')
            ch1_kurt_hr = fd_hr.get('ch1_kurt')
            ch2_kurt_lr = fd_lr.get('ch2_kurt')
            ch2_kurt_hr = fd_hr.get('ch2_kurt')

            rows.append(
                ImagePairMetrics(
                    filename=lf.replace('\\', '/'),
                    lpips=lpips_val,
                    psnr=psnr_val,
                    ssim=ssim_val,
                    glcm_contrast=glcm['glcm_contrast'],
                    glcm_homogeneity=glcm['glcm_homogeneity'],
                    glcm_correlation=glcm['glcm_correlation'],
                    rms_noise_lr=rms_lr,
                    rms_noise_hr=rms_hr,
                    lap_var_lr=lap_var_lr,
                    lap_var_hr=lap_var_hr,
                    blocking_lr=art_lr['blocking_score'],
                    blocking_hr=art_hr['blocking_score'],
                    color_noise_lr=art_lr['color_noise'],
                    color_noise_hr=art_hr['color_noise'],
                    ringing_lr=art_lr['ringing_artifact'],
                    ringing_hr=art_hr['ringing_artifact'],
                    saturation_mean_lr=fd_lr['saturation_mean'],
                    saturation_mean_hr=fd_hr['saturation_mean'],
                    brightness_mean_lr=fd_lr['brightness_mean'],
                    brightness_mean_hr=fd_hr['brightness_mean'],
                    edge_diff=edge_diff,
                    ch0_skew_lr=ch0_skew_lr, 
                    ch0_skew_hr=ch0_skew_hr,
                    ch1_skew_lr=ch1_skew_lr, 
                    ch1_skew_hr=ch1_skew_hr,
                    ch2_skew_lr=ch2_skew_lr, 
                    ch2_skew_hr=ch2_skew_hr,
                    ch0_kurt_lr=ch0_kurt_lr, 
                    ch0_kurt_hr=ch0_kurt_hr,
                    ch1_kurt_lr=ch1_kurt_lr, 
                    ch1_kurt_hr=ch1_kurt_hr,
                    ch2_kurt_lr=ch2_kurt_lr, 
                    ch2_kurt_hr=ch2_kurt_hr,
                )
            )
            
            if global_data['lr_fft_sum'] is None:
                lr_fft_mag = np.abs(np.fft.fftshift(np.fft.fft2(gray_lr)))
                hr_fft_mag = np.abs(np.fft.fftshift(np.fft.fft2(gray_hr)))
                global_data['lr_fft_sum'] = lr_fft_mag.astype(np.float64)
                global_data['hr_fft_sum'] = hr_fft_mag.astype(np.float64)
                
                sobelx = cv2.Sobel(gray_hr, cv2.CV_64F, 1, 0, ksize=5)
                sobely = cv2.Sobel(gray_hr, cv2.CV_64F, 0, 1, ksize=5)
                grad_mag = np.sqrt(sobelx ** 2 + sobely ** 2)
                global_data['grad_hr_sum'] = grad_mag
                
                lr_glcm_full = graycomatrix(
                    gray_lr,
                    [1],
                    [0],
                    256,
                    symmetric=True,
                    normed=True
                )
                
                global_data['glcm_sum'] = lr_glcm_full.astype(np.float64)
            else:
                global_data['lr_fft_sum'] += np.abs(
                    np.fft.fftshift(np.fft.fft2(gray_lr))
                )
                
                global_data['hr_fft_sum'] += np.abs(
                    np.fft.fftshift(np.fft.fft2(gray_hr))
                )
                
                sobelx = cv2.Sobel(gray_hr, cv2.CV_64F, 1, 0, ksize=5)
                sobely = cv2.Sobel(gray_hr, cv2.CV_64F, 0, 1, ksize=5)
                grad_mag = np.sqrt(sobelx ** 2 + sobely ** 2)
                global_data['grad_hr_sum'] += grad_mag
                
                lr_glcm_full = graycomatrix(
                    gray_lr,
                    [1],
                    [0],
                    256,
                    symmetric=True,
                    normed=True
                )
                
                global_data['glcm_sum'] += lr_glcm_full
            
            sat_lr = hsv_lr[:, :, 1]
            sat_hr = hsv_hr[:, :, 1]
            lr_counts, _ = np.histogram(sat_lr, bins=global_data['sat_bins'])
            hr_counts, _ = np.histogram(sat_hr, bins=global_data['sat_bins'])
            global_data['sat_lr_counts'] += lr_counts
            global_data['sat_hr_counts'] += hr_counts
            global_data['noise_means_lr'].append(art_lr['color_noise'])
            global_data['count'] += 1
            
        return rows, global_data

In [5]:
class StatsReporter:
    """Utilities to convert and summarize metrics to a DataFrame."""

    @staticmethod
    def dataframe(rows):
        """Convert list of ImagePairMetrics into a DataFrame.

        Parameters
        ----------
        rows : list[ImagePairMetrics]
            List of metric objects.

        Returns
        -------
        pandas.DataFrame
            One row per image pair.
        """
        
        return pd.DataFrame([r.as_dict() for r in rows])

    @staticmethod
    def summary(df):
        """Return basic descriptive statistics.

        Parameters
        ----------
        df : pandas.DataFrame
            Metrics data.

        Returns
        -------
        pandas.DataFrame
            mean, std and quartiles.
        """
        
        return df.describe().T[['mean', 'std', '25%', '50%', '75%']]

## Visualization methods

In [6]:
class ImageDataVisualization:
    """Visualization utilities for exploratory analysis."""

    @staticmethod
    def save_visual_example(lr_img, hr_img, output_path, lpips_val):
        """Save comparison figure and a difference heatmap.

        Parameters
        ----------
        lr_img : np.ndarray
            LR image.
        hr_img : np.ndarray
            HR image.
        output_path : str
            Output PNG path.
        lpips_val : float
            LPIPS value for title.
        """
        
        lr_resized = cv2.resize(
            lr_img,
            (hr_img.shape[1], hr_img.shape[0]),
            interpolation=cv2.INTER_CUBIC
        )

        diff_map = cv2.absdiff(lr_resized, hr_img)
        diff_map_color = cv2.applyColorMap(
            cv2.convertScaleAbs(cv2.cvtColor(diff_map, cv2.COLOR_BGR2GRAY)),
            cv2.COLORMAP_JET
        )

        _, axes = plt.subplots(1, 3, figsize=(12, 4))
        axes[0].imshow(cv2.cvtColor(lr_resized, cv2.COLOR_BGR2RGB))
        axes[0].set_title("Rescaled LR")
        axes[0].axis("off")

        axes[1].imshow(cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB))
        axes[1].set_title("HR")
        axes[1].axis("off")

        axes[2].imshow(diff_map_color)
        axes[2].set_title(f"Difference map\nLPIPS: {lpips_val:.4f}")
        axes[2].axis("off")

        plt.tight_layout()
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        plt.savefig(output_path)
        plt.close()

    @staticmethod
    def create_advanced_visualizations(lr_img, hr_img, output_path):
        """Create per-pair advanced panel: spectra, gradients, GLCM, noise
        map, saturation distribution."""
        plt.figure(figsize=(20, 10))

        # 1. LR Spectrum
        plt.subplot(231)
        lr_fft = np.fft.fft2(cv2.cvtColor(lr_img, cv2.COLOR_BGR2GRAY))
        plt.imshow(np.log(np.abs(np.fft.fftshift(lr_fft)) + 1e-8),
                   cmap="viridis")
        plt.title("LR Frequency Spectrum")
        plt.colorbar()

        # 2. HR Spectrum
        plt.subplot(232)
        hr_fft = np.fft.fft2(cv2.cvtColor(hr_img, cv2.COLOR_BGR2GRAY))
        plt.imshow(np.log(np.abs(np.fft.fftshift(hr_fft)) + 1e-8),
                   cmap="viridis")
        plt.title("HR Frequency Spectrum")
        plt.colorbar()

        # 3. HR Gradient Magnitude
        plt.subplot(233)
        gray_hr = cv2.cvtColor(hr_img, cv2.COLOR_BGR2GRAY)
        sobelx = cv2.Sobel(gray_hr, cv2.CV_64F, 1, 0, ksize=5)
        sobely = cv2.Sobel(gray_hr, cv2.CV_64F, 0, 1, ksize=5)
        gradient_magnitude = np.sqrt(sobelx ** 2 + sobely ** 2)
        plt.imshow(gradient_magnitude, cmap="magma")
        plt.title("Gradient Magnitude")
        plt.colorbar()
        
        # 4. LR GLCM
        plt.subplot(234)
        lr_gray = cv2.cvtColor(lr_img, cv2.COLOR_BGR2GRAY)
        lr_glcm = graycomatrix(
            lr_gray, [1], [0], 256, symmetric=True, normed=True
        )
        lr_contrast = graycoprops(lr_glcm, "contrast")[0, 0]
        plt.imshow(lr_glcm[:, :, 0, 0], cmap="plasma")
        plt.title(f"LR GLCM (Contrast: {lr_contrast:.2f})")
        plt.colorbar()

        # 5. LR Color Noise Map
        plt.subplot(235)
        blur = cv2.GaussianBlur(lr_img, (5, 5), 0)
        noise_map = np.mean(
            np.abs(lr_img.astype(float) - blur.astype(float)), axis=2
        )
        color_noise_mean = float(
            np.mean(np.abs(lr_img.astype(float) - blur.astype(float)))
        )
        plt.imshow(noise_map, cmap="hot")
        plt.title(f"Noise Map (Mean: {color_noise_mean:.2f})")
        plt.colorbar()

        # 6. Saturation Distribution LR vs HR
        plt.subplot(236)
        lr_hsv = cv2.cvtColor(lr_img, cv2.COLOR_BGR2HSV)[:, :, 1]
        hr_hsv = cv2.cvtColor(hr_img, cv2.COLOR_BGR2HSV)[:, :, 1]
        plt.hist(lr_hsv.ravel(), bins=50, alpha=0.5, density=True,
                 label="LR", color="steelblue")
        plt.hist(hr_hsv.ravel(), bins=50, alpha=0.5, density=True,
                 label="HR", color="orange")
        plt.title("Saturation Distribution")
        plt.legend()

        plt.tight_layout()
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        plt.savefig(output_path, dpi=300, bbox_inches="tight")
        plt.close()

    @staticmethod
    def create_global_advanced_visualizations(global_data, output_path):
        """Create a global panel with averaged spectra, gradient, averaged
        GLCM, noise mean distribution, saturation densities."""
        
        if global_data is None or global_data.get('count', 0) == 0:
            print("No global data available to create advanced visualization.")
            return

        n = global_data['count']
        eps = 1e-8
        # Average spectra (compute log after averaging magnitudes)
        lr_fft_avg = global_data['lr_fft_sum'] / n
        hr_fft_avg = global_data['hr_fft_sum'] / n
        # Average gradient magnitude
        grad_hr_avg = global_data['grad_hr_sum'] / n
        # Average (unnormalized-sum) GLCM then renormalize
        glcm_sum = global_data['glcm_sum']
        glcm_avg = glcm_sum / glcm_sum.sum()
        glcm_contrast = graycoprops(glcm_avg, "contrast")[0, 0]

        # Saturation histograms (normalize to density)
        sat_bins = global_data['sat_bins']
        bin_width = sat_bins[1] - sat_bins[0]
        sat_lr_density = (
            global_data['sat_lr_counts'] /
            (global_data['sat_lr_counts'].sum() * bin_width + eps)
        )
        sat_hr_density = (
            global_data['sat_hr_counts'] /
            (global_data['sat_hr_counts'].sum() * bin_width + eps)
        )
        sat_centers = 0.5 * (sat_bins[:-1] + sat_bins[1:])

        noise_means = np.array(global_data['noise_means_lr'], dtype=np.float64)

        plt.figure(figsize=(20, 10))
        
        # 1 LR Spectrum
        plt.subplot(231)
        plt.imshow(np.log(lr_fft_avg + eps), cmap="viridis")
        plt.title("LR Avg Frequency Spectrum")
        plt.colorbar()
        
        # 2 HR Spectrum
        plt.subplot(232)
        plt.imshow(np.log(hr_fft_avg + eps), cmap="viridis")
        plt.title("HR Avg Frequency Spectrum")
        plt.colorbar()
        
        # 3 Gradient Magnitude
        plt.subplot(233)
        plt.imshow(grad_hr_avg, cmap="magma")
        plt.title("HR Avg Gradient Magnitude")
        plt.colorbar()
        
        # 4 GLCM
        plt.subplot(234)
        plt.imshow(glcm_avg[:, :, 0, 0], cmap="plasma")
        plt.title(f"LR Avg GLCM (Contrast: {glcm_contrast:.2f})")
        plt.colorbar()
        
        # 5 Noise mean distribution
        plt.subplot(235)
        plt.hist(
            noise_means, bins=30, color='tomato', edgecolor='black', alpha=0.8
        )
        plt.title(
            f"LR Color Noise Mean Dist\nMean={noise_means.mean():.2f} "
            f"Std={noise_means.std():.2f}"
        )
        plt.xlabel("Per-image color noise mean")
        plt.ylabel("Count")
        
        # 6 Saturation density
        plt.subplot(236)
        plt.plot(sat_centers, sat_lr_density, label='LR', color='steelblue')
        plt.plot(sat_centers, sat_hr_density, label='HR', color='orange')
        plt.title("Global Saturation Density")
        plt.xlabel("Saturation value")
        plt.ylabel("Density")
        
        plt.legend()
        plt.tight_layout()
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        plt.savefig(output_path, dpi=300, bbox_inches="tight")
        plt.close()

    @staticmethod
    def basic_distributions(df, output_dir):
        """
            Save distributions.png: 
                original basic histograms + integrated GLCM histograms.
            Metrics: 
                lpips, psnr, ssim, lap_var_hr, rms_noise_hr, blocking_hr, 
                glcm_contrast, glcm_homogeneity, glcm_correlation.
        """
        
        metrics = [
            'lpips', 'psnr', 'ssim', 'lap_var_hr', 'rms_noise_hr', 
            'blocking_hr', 'glcm_contrast', 'glcm_homogeneity', 
            'glcm_correlation'
        ]
        colors = [
            "#1f77b4","#ff7f0e","#2ca02c","#d62728","#9467bd",
            "#8c564b","#6baed6","#9edae5","#17becf"
        ]
        plt.figure(figsize=(18, 14))
        rows = 3; cols = 3
        for i, (m, c) in enumerate(zip(metrics, colors), 1):
            plt.subplot(rows, cols, i)
            plt.hist(df[m], bins=30, color=c, edgecolor='black', alpha=0.85)
            plt.title(m)
        plt.tight_layout()
        plt.savefig(
            os.path.join(output_dir, 'distributions.png'),
            dpi=300, bbox_inches='tight'
        )
        plt.close()

    @staticmethod
    def artifact_color_histograms(df, output_dir):
        """Overlay LR vs HR histograms + edge diff."""
        overlay_pairs = [
            ('blocking_lr', 'blocking_hr', 'Blocking Score'),
            ('ringing_lr', 'ringing_hr', 'Ringing Artifact'),
            ('saturation_mean_lr', 'saturation_mean_hr', 'Saturation Mean'),
            ('brightness_mean_lr', 'brightness_mean_hr', 'Brightness Mean'),
            ('color_noise_lr', 'color_noise_hr', 'Color Noise'),
        ]
        rows = 3; cols = 3
        plt.figure(figsize=(16, 12))
        for i, (lr_col, hr_col, title) in enumerate(overlay_pairs, 1):
            plt.subplot(rows, cols, i)
            plt.hist(
                df[lr_col], bins=30, alpha=0.55, label='LR', color='#1f77b4',
                edgecolor='black', linewidth=0.4
            )
            plt.hist(
                df[hr_col], bins=30, alpha=0.55, label='HR', color='#ff7f0e',
                edgecolor='black', linewidth=0.4
            )
            plt.title(title)
            plt.legend(fontsize=8)
        # edge_diff
        plt.subplot(rows, cols, 7)
        plt.hist(df['edge_diff'], bins=30, color='#2ca02c', alpha=0.8, edgecolor='black')
        plt.title('Edge Mean Diff (HR-LR)')
        plt.subplot(rows, cols, 8)
        sns.kdeplot(df['edge_diff'], fill=True, color='#2ca02c')
        plt.title('Edge Diff Density')
        plt.subplot(rows, cols, 9)
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'artifact_color_histograms.png'), dpi=300, bbox_inches='tight')
        plt.close()

    @staticmethod
    def artifact_boxplots(df, output_dir):
        """Save artifact_boxplots.png with LR vs HR boxplots."""
        
        plt.figure(figsize=(11, 6))
        groups = [
            ('blocking_lr', 'blocking_hr', 'Blocking'),
            ('ringing_lr', 'ringing_hr', 'Ringing'),
            ('saturation_mean_lr', 'saturation_mean_hr', 'Saturation'),
            ('brightness_mean_lr', 'brightness_mean_hr', 'Brightness'),
            ('color_noise_lr', 'color_noise_hr', 'ColorNoise'),
        ]
        data = []
        labels = []
        for lr_col, hr_col, name in groups:
            data.append(df[lr_col]); labels.append(f'{name} LR')
            data.append(df[hr_col]); labels.append(f'{name} HR')
        box = plt.boxplot(data, labels=labels, patch_artist=True)
        palette = ['#1f77b4', '#ff7f0e'] * len(groups)
        for patch, col in zip(box['boxes'], palette):
            patch.set_facecolor(col); patch.set_alpha(0.55)
        plt.xticks(rotation=25, ha='right')
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'artifact_boxplots.png'), dpi=300, bbox_inches='tight')
        plt.close()

    @staticmethod
    def channel_shape_bars(df, output_dir):
        """Complementary: multi-bar chart of per-channel skew & kurt (LR vs HR)."""
        
        required = [
            'ch0_skew_lr','ch0_skew_hr','ch1_skew_lr','ch1_skew_hr','ch2_skew_lr','ch2_skew_hr',
            'ch0_kurt_lr','ch0_kurt_hr','ch1_kurt_lr','ch1_kurt_hr','ch2_kurt_lr','ch2_kurt_hr'
        ]
        if not all(r in df.columns for r in required):
            print('Channel shape stats missing; skipping channel_shape_bars.')
            return
        
        skew_means_lr = [df['ch0_skew_lr'].mean(), df['ch1_skew_lr'].mean(), df['ch2_skew_lr'].mean()]
        skew_means_hr = [df['ch0_skew_hr'].mean(), df['ch1_skew_hr'].mean(), df['ch2_skew_hr'].mean()]
        kurt_means_lr = [df['ch0_kurt_lr'].mean(), df['ch1_kurt_lr'].mean(), df['ch2_kurt_lr'].mean()]
        kurt_means_hr = [df['ch0_kurt_hr'].mean(), df['ch1_kurt_hr'].mean(), df['ch2_kurt_hr'].mean()]
        channels = ['B','G','R']; x = np.arange(len(channels)); width = 0.18
        
        plt.figure(figsize=(12,6))
        plt.bar(x - 1.5*width, skew_means_lr, width, label='Skew LR', color='#3182bd')
        plt.bar(x - 0.5*width, skew_means_hr, width, label='Skew HR', color='#6baed6')
        plt.bar(x + 0.5*width, kurt_means_lr, width, label='Kurt LR', color='#fd8d3c')
        plt.bar(x + 1.5*width, kurt_means_hr, width, label='Kurt HR', color='#fdd0a2')
        plt.axhline(0, color='black', linewidth=0.8)
        plt.xticks(x, channels)
        plt.ylabel('Value (mean)')
        plt.title('Per-Channel Skew & Kurtosis (LR vs HR)')
        plt.legend(ncol=2)
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'channel_shape_bars.png'), dpi=300, bbox_inches='tight')
        plt.close()

    @staticmethod
    def correlation_matrix(df, output_dir):
        """Generate a correlation heatmap.
        Saves: correlation_matrix.png. Includes extended metrics if present."""
        
        metrics = [
            'lpips', 'psnr', 'ssim', 'lap_var_hr', 'rms_noise_hr', 
            'blocking_hr','ringing_hr', 'saturation_mean_hr', 
            'brightness_mean_hr', 'edge_diff'
        ]
        
        available = [m for m in metrics if m in df.columns]
        if len(available) < 3:
            print('Not enough columns for correlation matrix.')
            return
        
        corr = df[available].corr()
        plt.figure(figsize=(1.2 * len(available), 0.9 * len(available)))
        sns.heatmap(
            corr, cmap='flare', annot=True, fmt='.2f', center=0, square=True,
            cbar_kws={'shrink': 0.75}
        )
        plt.tight_layout()
        plt.savefig(
            os.path.join(output_dir, 'correlation_matrix.png'),
            dpi=300, bbox_inches='tight'
        )
        plt.close()

    @staticmethod
    def scatter_relations(df, output_dir):
        """Scatter plots for metrics relations.
        Saves: scatter_relations.png
        Pairs:
          (lpips, psnr), (lpips, ssim), (lap_var_hr, psnr),
          (rms_noise_hr, psnr), (blocking_hr, ringing_hr),
          (saturation_mean_hr, brightness_mean_hr),
          (edge_diff, psnr), (edge_diff, ssim)
        """
        
        pairs = [
            ('lpips', 'psnr', 'LPIPS vs PSNR', '#1f77b4'),
            ('lpips', 'ssim', 'LPIPS vs SSIM', '#ff7f0e'),
            ('lap_var_hr', 'psnr', 'LaplacianVar HR vs PSNR', '#2ca02c'),
            ('rms_noise_hr', 'psnr', 'Noise HR vs PSNR', '#d62728'),
            ('blocking_hr', 'ringing_hr', 'Blocking vs Ringing', '#9467bd'),
            ('saturation_mean_hr', 'brightness_mean_hr',
             'Saturation vs Brightness', '#8c564b'),
            ('edge_diff', 'psnr', 'Edge Diff vs PSNR', '#e377c2'),
            ('edge_diff', 'ssim', 'Edge Diff vs SSIM', '#7f7f7f')
        ]
        rows, cols = 4, 2
        plt.figure(figsize=(cols * 5, rows * 3.2))
        for i, (x, y, title, color) in enumerate(pairs, 1):
            if x not in df.columns or y not in df.columns:
                continue
            plt.subplot(rows, cols, i)
            plt.scatter(
                df[x], df[y], s=14, alpha=0.75, color=color,
                edgecolors='white', linewidths=0.4
            )
            plt.xlabel(x)
            plt.ylabel(y)
            plt.title(title, fontsize=9)
        plt.tight_layout()
        plt.savefig(
            os.path.join(output_dir, 'scatter_relations.png'),
            dpi=300, bbox_inches='tight'
        )
        plt.close()

## Dataset processing

In [7]:
def run_eda_pipeline(
    lr_dir,
    hr_dir,
    output_dir="eda_results",
    top_k_examples=1,
    glcm_multi_angle=False,
    glcm_levels=64,
    interp_map_path="",
):
    """Execute full EDA pipeline and return metrics DataFrame.

    Generates:
      - advanced_global_panel.png
      - distributions.png
      - artifact_color_histograms.png
      - artifact_boxplots.png
      - channel_shape_bars.png
      - correlation_matrix.png
      - scatter_relations.png
      - LPIPS_Scenarios
    """
    
    os.makedirs(output_dir, exist_ok=True)

    # Load interpolation mapping
    interp_map = None
    if interp_map_path and os.path.exists(interp_map_path):
        try:
            with open(interp_map_path, 'rb') as f:
                interp_map = pickle.load(f)
        except Exception as e:
            print(f"Warning: could not load interpolation map: {e}")
    else:
        print("Interpolation map not found; default interpolation will be used.")
    
    examples_dir = os.path.join(output_dir, "LPIPS_Scenarios")
    best_dir = os.path.join(examples_dir, "best_scenarios")
    worst_dir = os.path.join(examples_dir, "worst_scenarios")

    for d in (best_dir, worst_dir):
        os.makedirs(d, exist_ok=True)
    
    # Collect metrics + global visualization data
    rows, global_data = MetricsAggregator.collect(
        lr_dir,
        hr_dir,
        glcm_multi_angle=glcm_multi_angle,
        glcm_levels=glcm_levels,
        interp_map=interp_map,
    )
    df = StatsReporter.dataframe(rows)

    # Global data visualizations
    ImageDataVisualization.create_global_advanced_visualizations(
        global_data,
        os.path.join(output_dir, "advanced_global_panel.png"),
    )
    ImageDataVisualization.basic_distributions(df, output_dir)
    ImageDataVisualization.artifact_color_histograms(df, output_dir)
    ImageDataVisualization.artifact_boxplots(df, output_dir)
    ImageDataVisualization.channel_shape_bars(df, output_dir)
    ImageDataVisualization.correlation_matrix(df, output_dir)
    ImageDataVisualization.scatter_relations(df, output_dir)
    df_sorted = df.sort_values("lpips")
    selections = [
        (df_sorted.head(top_k_examples), best_dir, "best"),
        (df_sorted.tail(top_k_examples), worst_dir, "worst"),
    ]
    
    for subset, scenario_dir, label in selections:
        for rank, rel_path in enumerate(subset["filename"].tolist(), 1):
            lr_full = os.path.join(lr_dir, rel_path)
            hr_full = os.path.join(hr_dir, rel_path)
            
            lr_img, hr_img = ImagePairLoader.load_and_align(
                lr_full,
                hr_full,
                interp_map=interp_map
            )
            
            lp_val = subset.loc[subset["filename"] == rel_path, "lpips"].values[0]

            # Normalize relative path and split
            rel_norm = rel_path.replace("\\", "/")
            rel_no_ext = os.path.splitext(rel_norm)[0]
            parent = os.path.dirname(rel_no_ext)
            stem = os.path.basename(rel_no_ext)

            # Create output directory preserving the parent structure
            save_dir = os.path.join(scenario_dir, parent) if parent else scenario_dir
            os.makedirs(save_dir, exist_ok=True)

            # Build output filenames without duplicating extensions
            out_basic = os.path.join(save_dir, f"{label}_{rank}_{stem}.png")
            ImageDataVisualization.save_visual_example(
                lr_img, hr_img, out_basic, lp_val
            )

            out_adv = os.path.join(save_dir, f"{label}_{rank}_advanced_{stem}.png")
            ImageDataVisualization.create_advanced_visualizations(
                lr_img, hr_img, out_adv
            )
    
    return df

In [8]:
lr_dir = "images/LR"
hr_dir = "images/HR"
output_dir = "eda_results"
interp_map_path = "images/interpolation_map.pkl"

run_eda_pipeline(lr_dir, hr_dir, output_dir, glcm_multi_angle=True, glcm_levels=256, interp_map_path=interp_map_path)



Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\lpips\weights\v0.1\alex.pth


Computing metrics: 100%|██████████| 313/313 [06:12<00:00,  1.19s/img]
  box = plt.boxplot(data, labels=labels, patch_artist=True)


Unnamed: 0,filename,lpips,psnr,ssim,glcm_contrast,glcm_homogeneity,glcm_correlation,rms_noise_lr,rms_noise_hr,lap_var_lr,...,ch1_skew_lr,ch1_skew_hr,ch2_skew_lr,ch2_skew_hr,ch0_kurt_lr,ch0_kurt_hr,ch1_kurt_lr,ch1_kurt_hr,ch2_kurt_lr,ch2_kurt_hr
0,high_z_offset/high_z_offset0.png,0.364281,25.366058,0.803856,137.410715,0.149594,0.934708,2.499807,7.405629,108.620061,...,0.036239,-0.030008,-0.056324,-0.109049,-0.064233,-0.078027,-0.101885,-0.142501,-0.115562,-0.141359
1,high_z_offset/high_z_offset1.png,0.251316,27.754356,0.847467,135.974401,0.161373,0.941820,2.532183,5.740295,107.357917,...,-0.242923,-0.255645,-0.419483,-0.414943,0.167955,0.057892,0.143544,-0.011911,0.158280,-0.028525
2,high_z_offset/high_z_offset10.png,0.349757,24.708093,0.725036,225.869160,0.430218,0.887708,6.394023,6.363694,751.291854,...,-0.324179,-0.352662,-0.423683,-0.460285,0.299866,0.121729,0.419595,0.147142,0.493415,0.164363
3,high_z_offset/high_z_offset100.png,0.317964,26.041058,0.748882,82.977939,0.161440,0.944976,2.002901,6.041940,69.841931,...,-0.079562,-0.209162,-0.347382,-0.432156,0.391550,0.152999,0.459532,0.143770,0.626193,0.207427
4,high_z_offset/high_z_offset101.png,0.415194,26.146266,0.759185,89.467270,0.160775,0.956186,2.164097,6.394686,82.893296,...,-0.129830,-0.109321,-0.380877,-0.322283,0.179148,0.221096,0.304085,0.258230,0.500662,0.378809
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
308,low_z_offset/low_z_offset95.png,0.398480,27.146761,0.767963,130.535761,0.470106,0.883706,4.739694,4.937069,378.520439,...,-0.708274,-0.753500,-0.787584,-0.841872,0.191959,0.171960,0.684935,0.758831,0.903825,1.084455
309,low_z_offset/low_z_offset96.png,0.230179,31.569480,0.907952,49.799471,0.249285,0.957178,1.312079,4.030325,29.217389,...,-0.760076,-0.784344,-0.937650,-0.946355,-0.350854,-0.279414,0.358212,0.445815,0.855934,0.951866
310,low_z_offset/low_z_offset97.png,0.191707,35.451120,0.942800,50.536258,0.507053,0.962478,2.999149,1.779312,157.944040,...,-0.401852,-0.401222,-0.527594,-0.529595,-0.490090,-0.487183,-0.579616,-0.561587,-0.440199,-0.405477
311,low_z_offset/low_z_offset98.png,0.305755,29.097566,0.824472,85.593845,0.479782,0.905295,3.869637,3.417618,258.198508,...,-0.663705,-0.816551,-0.740988,-0.953585,0.030360,0.056749,0.276913,0.613602,0.294812,0.875763
