## Chromaticity Spaces and Illuminate Spectral Ratio  
This notebook implements various aspects of Maxwell's RoadVision Paper

___ 
### Packages

In [None]:
import cv2  
import pandas as pd  
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from pathlib import Path

___
### Set Up Paths

In [None]:
HOME = Path.cwd()
PATH_TO_DATA_FOLDER = HOME / "data" / "folder_8" / "done"
PATH_TO_ANNOTATIONS_CSV =  HOME / "data" / "folder_8"
img_file =  "wenqing_fan_078.tif" # "wenqing_fan_078.tif" #"wenqing_fan_0101.tif" # "walsh_john_035.tif" 
annotations_csv_file = "annotation_folder_8.csv"
epsilon = 1e-10 # to be used to avoid dividing by zero
equalize_hist = False # This will apply histogram equalization to log chroma imgs
normalize_img = False

### Helper Functions

In [None]:
def display_image(chromaticity_image:np.array, title:str) -> None:
    """ 
    Displays an image using mayplotlib
    """
    plt.imshow(chromaticity_image)
    plt.title(title)
    plt.axis('off')
    plt.show();

def histogram_equalize(img: np.array) -> np.array:
    """
    Apply histogram equalization to an image.
    
    Parameters:
        img (np.array): Input image.
        
    Returns:
        np.array: Hist-equalized img.
    """
    # Extract the r and g channels
    r = img[:, :, 0]
    g = img[:, :, 1]
    b = img[:, :, 2]

    # Rescale log_r and log_g to the range [0, 255] for equalization
    r_rescaled = ((r - r.min()) / (r.max() - r.min()) * 255).astype(np.uint8)
    g_rescaled = ((g - g.min()) / (g.max() - g.min()) * 255).astype(np.uint8)
    b_rescaled = ((b - b.min()) / (b.max() - b.min()) * 255).astype(np.uint8)

    # Apply histogram equalization
    r_equalized = cv2.equalizeHist(r_rescaled)
    g_equalized = cv2.equalizeHist(g_rescaled)
    b_equalized = cv2.equalizeHist(b_rescaled)

    # Normalize equalized values back to [0, 1]
    r = r_equalized / 255.0
    g = g_equalized / 255.0
    b = b_equalized / 255.0
    
    img_equalized = np.stack((r, g, b), axis = 2)
    return img_equalized

def normalize_image(img: np.array) -> np.array:
    """
    Normalize the image to have pixel values in the range [0, 1].
    
    Parameters:
        img (np.array): Input image array.
        
    Returns:
        np.array: Normalized image array with values in the range [0, 1].
    """
    # Calculate min and max pixel values
    img_min = np.min(img)
    img_max = np.max(img)
    
    # Avoid division by zero
    if img_max - img_min == 0:
        return np.zeros_like(img)

    # Normalize to 0-1 range
    normalized_img = (img - img_min) / (img_max - img_min)
    
    return normalized_img

___ 
### Read In Test Image

In [None]:
img = cv2.imread(str(PATH_TO_DATA_FOLDER) + "/" + img_file)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if normalize_img:
    img = img / 255.0

In [None]:
display_image(img, "Original Image")

In [None]:
img.dtype

___
### Convert ```img``` to Standard Chromaticity Space

$$ (\hat{r}, \hat{g}) = \bigg(\frac{R}{(R + G + B)}, \frac{G}{(R + G + B)}\bigg) $$

In [None]:
def convert_img_to_rg_chromaticity(img:np.array, epsilon:float = 1e-10) -> np.array:
    """  
    Generates rg chromaticity image from a RGB image
    """
    # Compute the sum of pixels
    rbg_sum = img.sum(axis = 2, keepdims = True)

    # Divide each channel by pixel sum
    standard_chromaticity = img / (rbg_sum + epsilon)

    # Extract the r and g channels
    r_channel = standard_chromaticity[:, :, 0]
    g_channel = standard_chromaticity[:, :, 1]

    # Add a blue channel of zeros
    chromaticity_image = np.stack((r_channel, g_channel, np.zeros_like(r_channel)), axis = 2)
    
    return chromaticity_image

chromaticity_image = convert_img_to_rg_chromaticity(img)
display_image(chromaticity_image, "Standard Chromaticity")

___
### Convert ```img``` to Log Chromaticity Space

$$ \log(\hat{r}, \hat{g}) = \bigg[\log\bigg(\frac{R}{(R + G + B)}\bigg), \log\bigg(\frac{G}{(R + G + B)}\bigg)\bigg]$$

$$ \log(\hat{r}, \hat{g}) = \bigg[\bigg(\log(R) - \log(R + G + B)\bigg), \bigg(\log(G) - \log(R + G + B)\bigg)\bigg]$$

In [None]:
def convert_img_to_log_rg_chromaticity(img:np.array, epsilon:float = 1e-10, equalize_hist:bool = False) -> np.array:
    """  
    Generates log_rg chromaticity image from a RGB image
    Uses the the mean of the 3 pixel channels.
    """
    # Compute the log of the img
    log_rgb = np.log(img + epsilon) # Epsilon to avoid taking the log of zero

    # Compute log chromaticity of each channel, use mean rather sum to normalize
    # If you change this to exactly replicate the above equation you end up with a black image.
    # If you then histEqual, you end up with essentially the exact same image as produced here using np.mean (as below)
    log_chromaticity = log_rgb - log_rgb.mean(axis = 2, keepdims = True)  # Division becomes subtraction in log space

    # Extract the r and g channels
    log_r = log_chromaticity[:, :, 0]
    log_g = log_chromaticity[:, :, 1]

    if equalize_hist:
        
        # Rescale log_r and log_g to the range [0, 255] for equalization
        log_r_rescaled = ((log_r - log_r.min()) / (log_r.max() - log_r.min()) * 255).astype(np.uint8)
        log_g_rescaled = ((log_g - log_g.min()) / (log_g.max() - log_g.min()) * 255).astype(np.uint8)

        # Apply histogram equalization
        log_r_equalized = cv2.equalizeHist(log_r_rescaled)
        log_g_equalized = cv2.equalizeHist(log_g_rescaled)

        # Normalize equalized values back to [0, 1]
        log_r = log_r_equalized / 255.0
        log_g = log_g_equalized / 255.0

    # Create rg image
    log_chromaticity_image = np.stack((log_r, log_g, np.zeros_like(log_r)), axis = 2)

    return log_chromaticity_image

log_chromaticity_image = convert_img_to_log_rg_chromaticity(img, epsilon, equalize_hist = False)
display_image(log_chromaticity_image, "Log Chromaticity")


___
### Estimate Illuminate Spectral Direction  
This requires lit and shadow pixels

___ 
### Read In CSV of Lit & Shadowed Pixels

In [None]:
annotations_df = pd.read_csv(str(PATH_TO_ANNOTATIONS_CSV) + "/" + annotations_csv_file, index_col = "filename")
annotations_df.head()


### Get Pixel Coordinates  
This will extract the lit and shadow pixel coordinates from the annotations csv

In [None]:
def get_lit_shadow_pixel_coordinates(image_annotations:pd.DataFrame, image_file_name:str) -> tuple:
    """  
    Extracts the lit and shadow pixel locations from the annoations df.
    Note: filename has to be set as the index for annotation df
    """
    # Extract annotations for image of interest
    image_annotations = annotations_df.loc[image_file_name]

    # Extract lit and shadow pixels
    lit_pixels = [
        (image_annotations[f"lit_row{i}"], image_annotations[f"lit_col{i}"])
        for i in range(1, 7)
    ]
    shadow_pixels = [
        (image_annotations[f"shad_row{i}"], image_annotations[f"shad_col{i}"])
        for i in range(1, 7)
    ]
    return lit_pixels, shadow_pixels

lit_pixels, shadow_pixels = get_lit_shadow_pixel_coordinates(annotations_df, img_file)

print(f"Lit pixels: {lit_pixels}")
print(f"Shadow Pixels: {shadow_pixels}")


### Inspect the Annotations

In [None]:
def inspect_annotations(img:np.array, lit_pixels:list, shadow_pixels:list) -> None:
    """  
    Displays the image with lit and shadow pixels highlighted
    """
    fig, ax = plt.subplots(figsize = (7, 12))
    ax.imshow(img)

    # Add circles around lit pixels
    for (y, x) in lit_pixels:
        circle = patches.Circle((x, y), radius=5, edgecolor='yellow', facecolor='none', linewidth=1, label='Lit Pixel')
        ax.add_patch(circle)

    # Add circles arond the shadow pixels
    for (y, x) in shadow_pixels:
        circle = patches.Circle((x, y), radius=5, edgecolor='red', facecolor='none', linewidth=1, label='Shadow Pixel')
        ax.add_patch(circle)

    # Add the legend
    handles = [
        patches.Patch(color='yellow', label='Lit Pixel'),
        patches.Patch(color='red', label='Shadow Pixel')
    ]
    plt.legend(handles = handles)
    plt.axis('off')
    plt.show()

inspect_annotations(img, lit_pixels, shadow_pixels)


In [None]:
print(lit_pixels)
print(shadow_pixels)

___
### Compute the ISD  
This is from 4.1 of Lightbrush:  

4.1 Base Constraints and Linear System Representation
4.1.1 Illumination Spectral Direction. The only required user input
is an estimate of the illumination spectral direction. The ISD
defines the log chromaticity space for the log-chromaticity clustering
step. While automated schemes for detecting the ISD are
possible, we opt to allow simple user control. Lightbrush includes a
lit-dark pair tool that consists of two linked squares (Figure 4(a)).
The user places one square in a lit area of a material and the other
square in a shadowed area of the same material. The system uses
the average value under each block to estimate the ISD. If the user
specifies multiple lit-dark pairs, they can tell Lightbrush to average
the estimated ISDs or use a local adaptation scheme to deal with
changing illumination conditions within the scene.

### Check Lit ~ Shadow Pixel Difference

Ensure that the difference is greater than $0.3$ accorss all channels

In [None]:
# Covert Image to log space
log_rgb = np.log(img + epsilon)

# Get the lit and shadow pixels
# Iterate over the pairs
lit_pixel_values = []
shadow_pixel_values = []
for (lit_pix, shadow_pix) in zip(lit_pixels, shadow_pixels):

    # Get the pixel values
    lit_pixel = log_rgb[lit_pix[0], lit_pix[1]]
    shadow_pixel = log_rgb[shadow_pix[0], shadow_pix[1]]
    lit_pixel_values.append(lit_pixel)
    shadow_pixel_values.append(shadow_pixel)

lit_pixel_values = np.array(lit_pixel_values)
shadow_pixel_values = np.array(shadow_pixel_values)

assert np.all((lit_pixel_values - shadow_pixel_values) > 0.3)
print("Assertion Passed")

### Compute the ISD ~ Per Pair and Mean

In [None]:
# Define Neutral ISD
neutral_isd = np.array([0.577, 0.577, 0.577]) 

# Covert Image to log space
log_rgb = np.log(img + epsilon)

# Compute the isd for each lit-shadow pair
isd_list = []

# Iterate over the pairs
for (lit_pix, shadow_pix) in zip(lit_pixels, shadow_pixels):

    # Get the pixel values
    lit_pixel = log_rgb[lit_pix[0], lit_pix[1]]
    shadow_pixel = log_rgb[shadow_pix[0], shadow_pix[1]]

    # Compute and store the ISD
    isd = lit_pixel - shadow_pixel
    isd = isd / np.linalg.norm(isd)
    isd_list.append(isd)

print(f"Illuminate Spectral Direction for Each Pair:")
for isd in isd_list:
    print(f"\t{isd}")

# Compute the mean ISD, Bruce mentions this in LightBrush
mean_isd = np.mean(np.array(isd_list), axis = 0)
print()
print(f"Mean ISD:\n\t {mean_isd}")
print()
print(f'Dot product of ISD and neutral ISD: \n\t {np.dot(mean_isd, neutral_isd)}')

### Inspect ISD

In [None]:
def plot_isd_vector_3D(mean_isd:np.array) -> None:
    """  
    Displays the mean ISD vector
    """
    # Set up 3D fig
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # Plot the ISD vector - This is through the origin, idk for sure about this
    ax.quiver(0, 0, 0, mean_isd[0], mean_isd[1], mean_isd[2], color='r', arrow_length_ratio=0.3)

    # Set limits for the axes
    ax.set_xlim([-2, 2])
    ax.set_ylim([-2, 2])
    ax.set_zlim([-2, 2])

    # Set labels
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z');

plot_isd_vector_3D(mean_isd)

### Inspect A Plane Orthogonal to ISD

In [None]:
def plot_plane_orthogonal_to_isd(mean_isd:np.array, img:np.array = None, projection:np.array = None) -> None:
    """
    Displays the plane orthogonal to the ISD vector.
    Can also display projected points
    """
    # Set up the figure and 3D axes
    fig = plt.figure(figsize = (10, 10))
    ax = fig.add_subplot(111, projection = '3d')

    # Plot the vector using mean_isd
    ax.quiver(0, 0, 0, mean_isd[0], mean_isd[1], mean_isd[2], color='r', arrow_length_ratio=0.3, label = "ISD")

    # Create a grid for the plane centered at the origin, again not sure about the origin
    # But it needs to be anchored somewhere, no?
    xx, yy = np.meshgrid(np.linspace(-2, 2, 10), np.linspace(-2, 2, 10))

    # Compute the corresponding z values for the plane
    # Plane equation: mean_isd[0]*x + mean_isd[1]*y + mean_isd[2]*z = 0
    # Solve for z: z = -(mean_isd[0]*x + mean_isd[1]*y) / mean_isd[2]
    zz = -(mean_isd[0] * xx + mean_isd[1] * yy) / mean_isd[2]

    # Plot the plane
    ax.plot_surface(xx, yy, zz, alpha=0.5, color='cyan', label = "Plane Orthogonal to ISD")

    # Add the image points
    if img is not None:

            # Plot the orginal points
            ax.scatter(img[:, :, 0], img[:, :, 1], img[:, :, 2], color='b', alpha=0.1, s=1, label='Original Points')
    
    # Set the limits
    ax.set_xlim([-6, 6])
    ax.set_ylim([-6, 6])
    ax.set_zlim([-6, 6])

    # Set labels
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')

    # Add a legend
    ax.legend()

    # Show the plot
    plt.show()

plot_plane_orthogonal_to_isd(mean_isd)

### Show Image Points In Relation to ISD 

In [None]:
plot_plane_orthogonal_to_isd(mean_isd, img = log_rgb)

### Project Log RGB onto Plane Orthogonal to ISD  
$$\log(chromaticityIMG)_{i,j} = \log(IMG)_{i,j} - \bigg(\frac{\vec{ISD} \cdot \log(IMG)_{i,j}}{\vec{ISD} \cdot \vec{ISD}}\bigg) \cdot \vec{ISD}$$  

This equation projects the pixels to the plane orothogonal to ISD, i.e., 2D chromaticity

In [60]:
# Because we computed the mean, confirm it is still a unit vector
isd = isd / np.linalg.norm(isd)

# Compute the dot product of ISD with ISD
dot_isd_isd = np.dot(isd, isd)

# Chromaticity image container
img_shape = log_rgb.shape
log_chromaticity_image = np.zeros(img_shape)


# Compute above equation; this can obviously be vectorized.
# It is done this way for clarity
for i in range(log_rgb.shape[0]):
    for j in range(log_rgb.shape[1]):

        # Compute the dot product of each pixel and isd
        dot_log_rbg_isd = np.dot(log_rgb[i, j], isd)

        # Compute the parallel component
        parallel_component = (dot_log_rbg_isd / dot_isd_isd) * isd

        # Compute the orthogonal component
        orthogonal_component = log_rgb[i, j] - parallel_component

        # Verify orthogonality: dot product of orthogonal component with ISD
        orthogonality_check = np.dot(orthogonal_component, isd)

        # Update new image
        log_chromaticity_image[i, j] = orthogonal_component

        print(f"log_rgb[{i}, {j}]: {log_rgb[i, j]}")
        print(f"Parallel component: {parallel_component}")
        print(f"Orthogonal component: {orthogonal_component}")
        print(f"Dot product (orthogonal component, ISD): {orthogonality_check}\n")

log_rgb[266, 1197]: [4.72738782 4.76217393 4.81218436]
Parallel component: [5.0313744  4.51847025 4.72147057]
Orthogonal component: [-0.30398658  0.24370368  0.09071378]
Dot product (orthogonal component, ISD): 1.5511914925962174e-15

log_rgb[266, 1198]: [4.46590812 4.59511985 4.69134788]
Parallel component: [4.83603563 4.34304454 4.53816355]
Orthogonal component: [-0.37012751  0.25207531  0.15318433]
Dot product (orthogonal component, ISD): 7.848794147809187e-16

log_rgb[266, 1199]: [4.39444915 4.52178858 4.61512052]
Parallel component: [4.758314   4.27324595 4.46522913]
Orthogonal component: [-0.36386484  0.24854263  0.14989139]
Dot product (orthogonal component, ISD): 6.182758056454415e-16

log_rgb[266, 1200]: [4.40671925 4.48863637 4.58496748]
Parallel component: [4.74127032 4.25793972 4.44923524]
Orthogonal component: [-0.33455107  0.23069665  0.13573224]
Dot product (orthogonal component, ISD): -2.086531828577411e-16

log_rgb[266, 1201]: [4.51085951 4.56434819 4.68213123]
Paralle

### Generate Projection Matrix

In [None]:
# Get the dims of the image
H, W, _ = log_rgb.shape

# Container for projection vector that matches image dims
projection_matrix = np.zeros((H, W, 3))

# Multiply the scalar projection with each channel of the ISD vector
for i in range(3):  # Channels (R, G, B)
    for x in range(H):  # Height
        for y in range(W):  # Width

            # Multiply the scalar projection by the i-th ISD component
            projection_matrix[x, y, i] = projection[x, y] * isd[i]

# Display results
print(f"Projection vector shape: {projection_matrix.shape}")
print(f"Projection vector example (first pixel): {projection_matrix[0, 0]}")


### Subtract Projection Matrix

In [None]:
projected_log_rgb = log_rgb - projection_matrix

### Inspect Results

In [None]:
display_image(projected_log_rgb, "Log RBG Projected")

In [None]:
plot_plane_orthogonal_to_isd(mean_isd, projected_log_rgb)

### Compute Log Chromoaticity on ```log_rgb_prjected```  

The steps above do not actually end up with a chromaticity space.  
So I tried to do it here right as I was leaving.  
THIS DOES NOT WORK

In [None]:
# Does not currently work
log_rbg_projected_chroma = convert_img_to_log_rg_chromaticity(projected_log_rgb, True)

In [None]:
display_image(log_rbg_projected_chroma, "Log Chromaticity Projected")

___
### This is the other way to do it

In [None]:
# Covert Image to log space
log_rgb = np.log(img + epsilon)

# Unpack the tuples of pixel coordinates, i.e., create tuple of x's and tuple of y's
# Done to index using numpy
lit_x, lit_y = zip(*lit_pixels) 
shadow_x, shadow_y = zip(*shadow_pixels)

# Extract the lit and shadow pixel values
lit_values = log_rgb[lit_x, lit_y]
shadow_values = log_rgb[shadow_x, shadow_y]

print(f"Lit Values: {lit_values}")
print(f"Shadow Values: {shadow_values}")
print(f"Lit value - shawdow values: {lit_values - shadow_values}")

# Calculate the mean of the log rbg values for lit and shadow pixels
mean_lit = np.mean(lit_values, axis = 0)
print("Mean Lit", mean_lit)
mean_shadow = np.mean(shadow_values, axis = 0)

# Use equation #4 from Bruce's RoadVision Paper
isd = mean_lit - mean_shadow
isd = isd / np.linalg.norm(isd)

# Project out the ISD from the log RGB to create illumination invariant log chromaticity
projection = (log_rgb @ isd)[:, :, np.newaxis] * isd[np.newaxis, np.newaxis, :]
log_chromaticity = log_rgb - projection

# Extract r and g channels of the log chromaticity
log_r = log_chromaticity[:, :, 0]
log_g = log_chromaticity[:, :, 1]

if equalize_hist:
    
    # Rescale log_r and log_g to the range [0, 255] for equalization
    log_r_rescaled = ((log_r - log_r.min()) / (log_r.max() - log_r.min()) * 255).astype(np.uint8)
    log_g_rescaled = ((log_g - log_g.min()) / (log_g.max() - log_g.min()) * 255).astype(np.uint8)

    # Apply histogram equalization
    log_r_equalized = cv2.equalizeHist(log_r_rescaled)
    log_g_equalized = cv2.equalizeHist(log_g_rescaled)

    # Normalize equalized values back to [0, 1]
    log_r = log_r_equalized / 255.0
    log_g = log_g_equalized / 255.0

# Create a visualization for the log chromaticity (using r, g channels)
log_chromaticity_with_isd_image = np.stack((log_r, log_g, np.zeros_like(log_r)), axis=2)

# Display the log chromaticity image
plt.imshow(log_chromaticity_with_isd_image)
plt.title("Log Chromaticity with Estimated ISD")
plt.axis('off')
plt.show()


In [None]:
fig, axes = plt.subplots(2,2, figsize = (15, 10))
ax = axes.ravel()
ax[0].imshow(img)
ax[0].axis("off")
ax[0].set_title("Original")
ax[1].imshow(chromaticity_image)
ax[1].axis("off")
ax[1].set_title("Standard Chromoaticity")
ax[2].imshow(log_chromaticity_image)
ax[2].axis("off")
ax[2].set_title("Log Chromoaticity")
ax[3].imshow(projected_log_rgb)
ax[3].axis("off")
ax[3].set_title("Projected Log RBG")

In [None]:
projected_log_rgb