# 🌗 Dynamic Range Alignment

This notebook implements dynamic range alignment for each pair of digital and film images. The goal is to line up the grey-levels/luminance of each pair of images so that the model does not to learn to adjust for luminance levels.


## Setup

---

Let's install some necessary dependencies and set global variables.

In [None]:
import autorootcwd

In [None]:
import cv2 as cv
import numpy as np
from matplotlib import pyplot as plt
from skimage.exposure import cumulative_distribution

In [None]:
from typing import List, Tuple

## Luminance Alignment (Dummy Images)

---

An explanation of how histogram matching is done for images can be found on [this page by Paul Bourke](https://paulbourke.net/miscellaneous/equalisation/)

Implemented based on code on [this StackOverflow post](https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x_)
and [SciKit image tutorials](https://scikit-image.org/docs/stable/auto_examples/color_exposure/plot_histogram_matching.html)

In [None]:
# Load image example
template = cv.imread("imgs/cinestill-800t.jpg")
template = cv.cvtColor(template, cv.COLOR_BGR2RGB)  # from BGR to RGB

# Apply some basic transformations (less luminance)
source = template // 2

# Display images
fig, axs = plt.subplots(ncols=2, figsize=(15, 10))
axs[0].imshow(template)
axs[1].imshow(source)
axs[0].set_title("Template")
axs[1].set_title("Source");

Nice, we have two images - a source image and a template image that we want to match. Let's plot the histograms of both images to see the changes in luminance.

In [None]:
def cdf(channel: np.ndarray):
    """
    Computes the CDF of an image

    Args:
        channel (np.ndarray): An image channel

    Returns:
        np.ndarray: The CDF of the image channel
    """
    # Compute the CDF and the bin centres
    cdf, b = cumulative_distribution(channel)

    # Pad the CDF to have values between 0 and 1
    cdf = np.insert(cdf, 0, [0] * b[0])
    cdf = np.append(cdf, [1] * (255 - b[-1]))

    return cdf

In [None]:
# Histograms of RGB channels
fig, axs = plt.subplots(ncols=2, figsize=(15, 5))
fig.suptitle("CDF of RGB channels")
for i, channel in enumerate("RGB"):
    axs[0].plot(cdf(template[:, :, i]), label=f"Channel {channel}")
    axs[1].plot(cdf(source[:, :, i]), label=f"Channel {channel}")
axs[0].set_title("Template")
axs[1].set_title("Source")
axs[0].legend()
axs[1].legend();

In [None]:
# Histograms of LAB channels
template_lab = cv.cvtColor(template, cv.COLOR_RGB2LAB)
source_lab = cv.cvtColor(source, cv.COLOR_RGB2LAB)

fig, axs = plt.subplots(ncols=2, figsize=(15, 5))
fig.suptitle("CDF of LAB channels")
for i, channel in enumerate("LAB"):
    axs[0].plot(cdf(template_lab[:, :, i]), label=f"Channel {channel}")
    axs[1].plot(cdf(source_lab[:, :, i]), label=f"Channel {channel}")
axs[0].set_title("Template")
axs[1].set_title("Source")
axs[0].legend()
axs[1].legend();

Nice, we clearly see that the source image is darker as it's CDF is shifted to the left. Now, let's implement the histogram matching algorithm.

In [None]:
def histogram_matching(
    template_cdf: np.ndarray, source_cdf: np.ndarray, channel: np.ndarray
) -> np.ndarray:
    """
    Matches the histogram of a channel to the histogram of another channel.

    Args:
        template_cdf (np.ndarray): The CDF of the template image
        source_cdf (np.ndarray): The CDF of the source image
        channel (np.ndarray): The channel to match (of source image)

    Returns:
        np.ndarray: The channel with the matched histogram
    """
    pixels = np.arange(256)
    # find closest pixel-matches corresponding to the CDF of the input image, given the value of the CDF H of
    # the template image at the corresponding pixels, s.t. c_t = H(pixels) <=> pixels = H-1(c_t)
    new_pixels = np.interp(source_cdf, template_cdf, pixels)
    new_channel = (np.reshape(new_pixels[channel.ravel()], channel.shape)).astype(np.uint8)

    return new_channel

In [None]:
# Perform histogram matching
source_cdf = cdf(source_lab[:, :, 0])
template_cdf = cdf(template_lab[:, :, 0])

# Match the histograms
matched_source_l = histogram_matching(template_cdf, source_cdf, source[:, :, 0])

# Display the results
fig, axs = plt.subplots(ncols=2, figsize=(15, 5))
fig.suptitle("Histogram matching")
axs[0].plot(source_cdf, label="Source L-CDF")
axs[0].plot(template_cdf, label="Template L-CDF")
axs[1].plot(cdf(matched_source_l), label="Matched L-CDF")
axs[1].plot(template_cdf, label="Source L-CDF")
axs[0].legend()
axs[1].legend();

Nice, the histogram matching seems to work well. Let's now apply the matched histogram to the LAB image.

In [None]:
def luminance_align(template: np.ndarray, source: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """
    Matches the luminance of the source image to the template image.

    Args:
        template (np.ndarray): The template image (RGB)
        source (np.ndarray): The source image (RGB)

    Returns:
        Tuple[np.ndarray, np.ndarray]: The source and template images
                                    with matched luminance
    """
    # Convert images from RGB to LAB
    source_lab = cv.cvtColor(source, cv.COLOR_RGB2LAB)
    template_lab = cv.cvtColor(template, cv.COLOR_RGB2LAB)

    # Split the image channels
    source_l, source_a, source_b = cv.split(source_lab)
    template_l, _, _ = cv.split(template_lab)

    # Compute the CDF of the images
    source_cdf = cdf(source_l)
    template_cdf = cdf(template_l)

    # Match the histograms
    matched_source_l = histogram_matching(template_cdf, source_cdf, source_l)

    # Merge the new L channel with the original A and B channels
    source_lab = cv.merge((matched_source_l, source_a, source_b))

    # Convert back to RBG and then return result
    source = cv.cvtColor(source_lab, cv.COLOR_LAB2RGB)
    template = cv.cvtColor(template_lab, cv.COLOR_LAB2RGB)

    return template, source

In [None]:
# Display images
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15, 10))
axs[0, 0].imshow(template)
axs[0, 1].imshow(source)
axs[0, 0].set_title("Original Template")
axs[0, 1].set_title("Original Source")

# Match the luminance
matched_template, matched_source = luminance_align(template, source)
axs[1, 0].imshow(matched_template)
axs[1, 1].imshow(matched_source)
axs[1, 0].set_title("Matched Template")
axs[1, 1].set_title("Matched Source");

## Luminance Alignment (Actual Data)

---

### Example Image

In [None]:
from src.utils.load import load_image_pair

# Load example image
film, digital, meta = load_image_pair(13, processing_state="raw", as_array=True)

print(f"Digital: {digital.shape}, Film: {film.shape}")
_, axs = plt.subplots(ncols=2, figsize=(15, 10))
axs[0].imshow(digital)
axs[0].set_title("Digital")
axs[1].imshow(film)
axs[1].set_title("Film");

In [None]:
# Display images
from src.utils.preprocess import luminance_align

fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15, 10))
axs[0, 0].imshow(digital)
axs[0, 1].imshow(film)
axs[0, 0].set_title("Digital (Template)")
axs[0, 1].set_title("Film (Source)")

# Match the luminance
matched_digital, matched_film = luminance_align(digital, film)
axs[1, 0].imshow(matched_digital)
axs[1, 1].imshow(matched_film)
axs[1, 0].set_title("Matched Digital (Template)")
axs[1, 1].set_title("Matched Film (Source)");

### All Data

In [None]:
from src.utils.load import load_metadata

# Load metadata
meta = load_metadata()

# Get all image indices
image_indices = list(meta.keys())

print(f"There are {len(meta)} images in the dataset")

In [None]:
# Align luminance of all images
for i, idx in enumerate(image_indices):
    # Load image pair (digital and film)
    film, digital, _ = load_image_pair(idx, processing_state="raw", as_array=True)

    # Initialise figure
    fig, axs = plt.subplots(ncols=2, figsize=(15, 5))
    fig.suptitle(f"Image Pair {idx}", fontsize=16)

    # Align images
    digital, film = luminance_align(template=digital, source=film)
    axs[0].imshow(digital)
    axs[1].imshow(film)
    axs[0].set_title("Matched Digital Image")
    axs[1].set_title("Matched Film Image")

    plt.show()

In [None]:
# Align luminance of all images
for i, idx in enumerate(image_indices):
    # Load image pair (digital and film)
    film, digital, _ = load_image_pair(idx, processing_state="raw", as_array=True)

    # Initialise figure
    fig, axs = plt.subplots(ncols=2, figsize=(15, 5))
    fig.suptitle(f"Image Pair {idx}", fontsize=16)

    # Align images
    film, digital = luminance_align(template=film, source=digital)
    axs[0].imshow(digital)
    axs[1].imshow(film)
    axs[0].set_title("Matched Digital Image")
    axs[1].set_title("Matched Film Image")

    plt.show()