# Self-DACE Results for specific pics

Name: Yunpei Gu (Team: Richard Zhao, Oliver Fritsche, Yunpei Gu)

Class: CS 7180 Advanced Perception

Date: 2025-09-22

Purpose: Visualization and evaluation for Self-DACE low-light enhancement model on pics with non-white luminance, high contrast, and multi luminant source.

Image selection:
- Pic 1 & 2 — High-contrast scenes.
- Pic 3 & 4 — Multiple illuminant sources.
- Pic 5 & 6 — Single non-white illuminant.

In [11]:
# === Config & model load ===
import os, glob, re
import numpy as np
from PIL import Image
import torch

# Use a non-interactive backend when running without GUI; comment it out if you want interactive figures locally.
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

from model import light_net
from evaluate import calculate_metrics  # Use the project's metric functions

# Paths (minimal changes) — point to the numeric mirror folders only
LOW_DIR  = "data/random/low_num"
HIGH_DIR = "data/random/high_num"   # Can be an empty folder
MODEL_PTH = "epoch_118_model.pth"   # Ensure this file exists at the repo root
RESIZE = (512, 512)
ALPHA_SCALE = 0.9

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# Load model
model = light_net().to(DEVICE)
ckpt = torch.load(MODEL_PTH, map_location=DEVICE)
model.load_state_dict(ckpt)
model.eval()
print("Model loaded from", MODEL_PTH)

# Output folders
os.makedirs("outputs/random/triptychs_num", exist_ok=True)
os.makedirs("outputs/random/iterations",   exist_ok=True)
os.makedirs("outputs/random/composites",   exist_ok=True)

Device: cpu
Model loaded from epoch_118_model.pth


In [12]:
# === Helper functions (no torchvision dependency) ===
EXTS = [".jpg", ".png", ".jpeg", ".JPG", ".PNG", ".JPEG"]

def find_file_by_index(folder, idx):
    for ext in EXTS:
        p = os.path.join(folder, f"{idx}{ext}")
        if os.path.exists(p):
            return p
    return None

def load_image_as_tensor(path, size=RESIZE, device=DEVICE):
    pil = Image.open(path).convert("RGB").resize(size)
    arr = np.array(pil).astype(np.float32) / 255.0
    ten = torch.from_numpy(arr).permute(2,0,1).unsqueeze(0).to(device)
    return pil, ten

def tensor_to_uint8_img(t):
    if t.ndim == 4:
        t = t.squeeze(0)
    t = t.detach().clamp(0,1).cpu().permute(1,2,0).numpy()
    return (t*255).astype(np.uint8)

def collect_indices(folder):
    stems = set()
    for ext in EXTS:
        for p in glob.glob(os.path.join(folder, f"*{ext}")):
            name = os.path.basename(p)
            m = re.match(r"(\d+)\.[A-Za-z]+$", name)
            if m:
                stems.add(int(m.group(1)))
    return sorted(stems)


In [13]:
# === Batch enhancement + optional metrics + save triptychs ===
indices = collect_indices(LOW_DIR)
print(f"Found {len(indices)} numeric images in {LOW_DIR}: {indices[:10]}{' ...' if len(indices)>10 else ''}")

avg_psnr = 0.0
avg_ssim = 0.0
n_metric = 0

# Keep a lightweight cache for later composite plotting
_cache = {}  # idx -> dict(low_pil, enhanced_np, gt_pil or None, metrics or None)

for idx in indices:
    low_path  = find_file_by_index(LOW_DIR,  idx)
    high_path = find_file_by_index(HIGH_DIR, idx)  # May not exist

    if not low_path:
        continue

    low_pil, low_t = load_image_as_tensor(low_path)

    with torch.no_grad():
        enhanced_t, alphas, betas = model(
            low_t,
            output_intermediate_images=False,
            alpha_scale=ALPHA_SCALE
        )

    # (Optional) metrics
    gt_pil = None
    metrics = None
    if high_path:
        gt_pil, gt_t = load_image_as_tensor(high_path)
        metrics = calculate_metrics(enhanced_t, gt_t)
        avg_psnr += metrics["psnr"]; avg_ssim += metrics["ssim"]; n_metric += 1
        print(f"[{idx}] PSNR: {metrics['psnr']:.2f} dB | SSIM: {metrics['ssim']:.4f}")
    else:
        print(f"[{idx}] (no GT) visualized only")

    # Save triptych or diptych
    enh_np = tensor_to_uint8_img(enhanced_t)
    cols = 3 if gt_pil is not None else 2
    fig, axes = plt.subplots(1, cols, figsize=(4*cols, 4))
    axes = np.atleast_1d(axes)
    axes[0].imshow(low_pil);      axes[0].set_title(f"{idx} - Original");  axes[0].axis("off")
    axes[1].imshow(enh_np);       axes[1].set_title(f"{idx} - Self-DACE"); axes[1].axis("off")
    if gt_pil is not None:
        axes[2].imshow(gt_pil);   axes[2].set_title(f"{idx} - Ground Truth");  axes[2].axis("off")
    plt.tight_layout()
    out_path = os.path.join("outputs", "random", "triptychs_num", f"{idx}_triptych.png")
    plt.savefig(out_path, dpi=220)
    plt.close()

    _cache[idx] = dict(low_pil=low_pil, enhanced_np=enh_np, gt_pil=gt_pil, metrics=metrics)

if n_metric:
    print(f"\nAverage over {n_metric} images — PSNR: {avg_psnr/n_metric:.2f} dB | SSIM: {avg_ssim/n_metric:.4f}")
else:
    print("\nNo GT found in high_num; only visualizations were saved.")


Found 7 numeric images in data/random/low_num: [1, 2, 3, 4, 5, 6, 7]
[1] PSNR: 20.97 dB | SSIM: 0.9203
[2] PSNR: 22.67 dB | SSIM: 0.9511
[3] PSNR: 22.50 dB | SSIM: 0.9585
[4] PSNR: 21.85 dB | SSIM: 0.9256
[5] PSNR: 23.44 dB | SSIM: 0.9388
[6] PSNR: 21.60 dB | SSIM: 0.9238
[7] (no GT) visualized only

Average over 6 images — PSNR: 22.17 dB | SSIM: 0.9363


In [None]:
# === Iteration visualization (Original → Iter1..7 → Final) ===
# Set target indices you want to visualize; leave empty to skip.
target_indices = [1,2,3,4,5,6,7]  # e.g., [1, 2, 3]
if not target_indices:
    print("No target indices set for iteration visualization. Set target_indices = [1,2,3] to enable.")
else:
    for idx in target_indices:
        low_path = find_file_by_index(LOW_DIR, idx)
        if not low_path:
            print(f"[skip] {idx} not found in {LOW_DIR}")
            continue

        low_pil, low_t = load_image_as_tensor(low_path)

        with torch.no_grad():
            enhanced_final, alphas, betas, inter_list = model(
                low_t,
                output_intermediate_images=True,
                alpha_scale=ALPHA_SCALE
            )

        tiles = [np.array(low_pil)]
        for step_t in inter_list:
            tiles.append(tensor_to_uint8_img(step_t))
        tiles.append(tensor_to_uint8_img(enhanced_final))

        titles = ["Original"] + [f"Iter {i+1}" for i in range(len(inter_list))] + ["Final"]
        n = len(tiles)

        fig, axes = plt.subplots(1, n, figsize=(3*n, 3))
        for ax, im, tt in zip(axes, tiles, titles):
            ax.imshow(im); ax.set_title(tt); ax.axis("off")
        plt.tight_layout()
        out_path = os.path.join("outputs", "random", "iterations", f"{idx}_iters.png")
        plt.savefig(out_path, dpi=220)
        plt.close()
        print(f"[saved] {out_path}")


[saved] outputs/random/iterations/1_iters.png
[saved] outputs/random/iterations/2_iters.png
[saved] outputs/random/iterations/3_iters.png
[saved] outputs/random/iterations/4_iters.png
[saved] outputs/random/iterations/5_iters.png
[saved] outputs/random/iterations/6_iters.png
[saved] outputs/random/iterations/7_iters.png
[skip] 8 not found in data/random/low_num
[skip] 9 not found in data/random/low_num
[skip] 10 not found in data/random/low_num
[skip] 11 not found in data/random/low_num
[skip] 12 not found in data/random/low_num


In [15]:
# === Final composite figure: Original vs Self-DACE vs Ground Truth ===
# Choose which indices to include. If empty, we auto-pick up to the first 6 indices that have GT.
composite_indices = []  # e.g., [1,2,3]; leave empty to auto-pick

# Ensure we have a cache from the batch run
try:
    _ = _cache
except NameError:
    _cache = {}

# If the cache is empty (e.g., you didn't run the batch cell), rebuild a minimal cache for chosen indices
def ensure_cache(indices):
    for idx in indices:
        if idx in _cache:
            continue
        low_path  = find_file_by_index(LOW_DIR,  idx)
        if not low_path:
            continue
        high_path = find_file_by_index(HIGH_DIR, idx)

        low_pil, low_t = load_image_as_tensor(low_path)
        with torch.no_grad():
            enhanced_t, alphas, betas = model(
                low_t,
                output_intermediate_images=False,
                alpha_scale=ALPHA_SCALE
            )
        enh_np = tensor_to_uint8_img(enhanced_t)

        gt_pil = None
        if high_path:
            gt_pil, _ = load_image_as_tensor(high_path)

        _cache[idx] = dict(low_pil=low_pil, enhanced_np=enh_np, gt_pil=gt_pil, metrics=None)

if not composite_indices:
    # Auto-pick indices that have GT
    with_gt = [idx for idx, pack in _cache.items() if pack.get("gt_pil") is not None]
    if not with_gt:
        print("No GT images available for composite; please set composite_indices to any indices you want to visualize.")
    else:
        composite_indices = sorted(with_gt)[:6]
        print("Auto-picked indices with GT:", composite_indices)
else:
    ensure_cache(composite_indices)

if composite_indices:
    rows = len(composite_indices)
    fig, axes = plt.subplots(rows, 3, figsize=(15, 4*rows))
    if rows == 1:
        axes = axes.reshape(1, -1)

    for i, idx in enumerate(composite_indices):
        pack = _cache.get(idx)
        if pack is None:
            continue
        low_pil   = pack["low_pil"]
        enh_np    = pack["enhanced_np"]
        gt_pil    = pack["gt_pil"]

        axes[i, 0].imshow(low_pil);   axes[i, 0].set_title(f"{idx} - Original");    axes[i, 0].axis("off")
        axes[i, 1].imshow(enh_np);    axes[i, 1].set_title(f"{idx} - Self-DACE");   axes[i, 1].axis("off")
        if gt_pil is not None:
            axes[i, 2].imshow(gt_pil); axes[i, 2].set_title(f"{idx} - Ground Truth"); axes[i, 2].axis("off")
        else:
            axes[i, 2].axis("off"); axes[i, 2].set_title(f"{idx} - (No GT)")

    plt.tight_layout()
    out_path = os.path.join("outputs", "random", "composites", "comparison_original_selfdace_gt.png")
    plt.savefig(out_path, dpi=220)
    plt.show()
    print(f"Saved composite figure to: {out_path}")


Auto-picked indices with GT: [1, 2, 3, 4, 5, 6]
Saved composite figure to: outputs/random/composites/comparison_original_selfdace_gt.png


  plt.show()
