In [None]:
import math
import numpy as np
import pyexr
import rawpy

from PIL import Image
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from cnn_demosaic import transform

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

# Prevent TensorFlow from allocating all GPU memory.
gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [None]:
# The path to the camera processed color card
SRGB_PATH = "../training/DSCF5652_card_srgb.exr"
# The path to the demosaiced but not color corrected color card
EXR_PATH = "../training/DSCF5652_card.exr"
# The path to the raw image
RAW_PATH = "../training/DSCF5652.RAF"

In [None]:
srgb_img_arr = pyexr.read(SRGB_PATH)[:, :, :3]
exr_img_arr = pyexr.read(EXR_PATH)[:, :, :3]

In [None]:
# The pixel format and values should be floating point triad in range 0.0...1.0
srgb_img_arr[0, 0]


## Set up functional helpers

In [None]:
def display_image(image):
    plt.imshow(image, vmin=0, vmax=1)
    plt.show()


def to_color_arr(img_arr):
    """Flatten the 2D image to an array."""
    return img_arr.reshape((img_arr.shape[0] * img_arr.shape[1], 3))


def random_sampling(target_arr, n_samples=1000):
    """Returns an array containing n_samples from a color swatch."""
    rand_index = np.random.randint(0, target_arr.shape[0], n_samples)
    return target_arr[rand_index]


def plot_color_3d(col_arr):
    r, g, b = col_arr[:, 0].flatten(), col_arr[:, 1].flatten(), col_arr[:, 2].flatten()

    fig = plt.figure()
    axis = fig.add_subplot(1, 1, 1, projection="3d")

    axis.scatter(r, g, b, c=col_arr, marker="o")
    axis.set_xlabel("Red")
    axis.set_ylabel("Green")
    axis.set_zlabel("Blue")
    plt.show()


def plot_color_scope(img_arr):
    x_points = []
    c_points = []

    for i in range(0, img_arr.shape[0]):
        for j in range(img_arr.shape[1]):
            x_points.append(j)
            c_points.append(img_arr[i, j])

    c_points = np.asarray(c_points)
    x_points = np.asarray(x_points)

    norm_c = c_points.sum(axis=1) / 3

    r_vals = [(1, 0, 0, c * 0.01) for c in c_points[:, 0]]
    g_vals = [(0, 1, 0, c * 0.01) for c in c_points[:, 1]]
    b_vals = [(0, 0, 1, c * 0.01) for c in c_points[:, 2]]

    fig = plt.figure()
    axis = fig.add_subplot(1, 1, 1)

    # Values are normalized on the red axis.
    axis.scatter(x_points, c_points[:, 0] / c_points[:, 0], c=r_vals, s=3)
    axis.scatter(x_points, c_points[:, 1] / c_points[:, 0], c=g_vals, s=3)
    axis.scatter(x_points, c_points[:, 2] / c_points[:, 0], c=b_vals, s=3)

    plt.show()


## Find a color conversion matrix using least squares

In [None]:
# This finds the diagonal matrix which best matches the color transformation
# between two swatches. Note: This does not find a transformation which
# matches multiple color chips.


def find_diagonal_transformation(raw_img, target_img):
    raw_reshaped = raw_img.reshape(-1, 3)
    target_reshaped = target_img.reshape(-1, 3)

    # Solve for diagonal matrix only
    diag_matrix = np.zeros(3)
    for i in range(3):
        # Use only the corresponding channel for regression
        diag_matrix[i] = np.linalg.lstsq(
            raw_reshaped[:, i : i + 1], target_reshaped[:, i], rcond=None
        )[0][0]
    # Create a diagonal transformation matrix
    transform = np.diag(diag_matrix)
    return transform


def apply_color_transformation(img, transformation_matrix):
    # Reshape image to (n_pixels, 3)
    img_reshaped = img.reshape(-1, 3)
    # Apply transformation
    transformed = np.dot(img_reshaped, transformation_matrix)
    # Reshape back to original image shape
    transformed_img = transformed.reshape(img.shape)
    # Clip values to valid range [0, 255] or [0, 1] depending on your image format
    transformed_img = np.clip(transformed_img, 0, 255 if img.dtype == np.uint8 else 1.0)
    return transformed_img.astype(img.dtype)

In [None]:
# Extract sRGB and camera raw color chips from the image.

srgb_chip_0 = srgb_img_arr[:160, :160]
srgb_chip_1 = srgb_img_arr[:160, 160:320]
srgb_chip_2 = srgb_img_arr[:160, 320:480]

raw_chip_0 = exr_img_arr[:160, :160]
raw_chip_1 = exr_img_arr[:160, 160:320]
raw_chip_2 = exr_img_arr[:160, 320:480]

# An isolated color swatch should have consistent levels across the x axis.
plot_color_scope(raw_chip_1)
plot_color_scope(srgb_chip_1)

In [None]:
# Sample the tile so that there are a uniform number of pixels in each array.
raw_samples = random_sampling(to_color_arr(raw_chip_0))
target_samples = random_sampling(to_color_arr(srgb_chip_0))

# Find the matrix which transforms the raw pixels to target pixels.
transform_arr = find_diagonal_transformation(raw_samples, target_samples)

# Apply the transformation and display the transformed raw tile next to the
# original sRGB camera-processed tile.
t_chip_0 = apply_color_transformation(raw_chip_0, transform_arr)
plot_color_scope(t_chip_0)
display_image(t_chip_0)
display_image(srgb_chip_0)

# Display the transformation matrix, which shows that it is only performing
# single channel mixing.
display(transform_arr)

In [None]:
# Perform the same operations above for chip_1
raw_samples = random_sampling(to_color_arr(raw_chip_1))
target_samples = random_sampling(to_color_arr(srgb_chip_1))

transform_arr = find_diagonal_transformation(raw_samples, target_samples)

t_chip_1 = apply_color_transformation(raw_chip_1, transform_arr)
plot_color_scope(t_chip_1)
display_image(t_chip_1)
display_image(srgb_chip_1)

# The resulting transform for this tile is very different from the
# transformation for another tile of the same image.
display(transform_arr)

In [None]:
# Perform the same operations above for chip_2
raw_samples = random_sampling(to_color_arr(raw_chip_2))
target_samples = random_sampling(to_color_arr(srgb_chip_2))

transform_arr = find_diagonal_transformation(raw_samples, target_samples)

display(transform_arr)
t_chip_2 = apply_color_transformation(raw_chip_2, transform_arr)
plot_color_scope(t_chip_2)
display_image(t_chip_2)

# The resulting transform for this tile is very different from the
# transformation for another tile of the same image.
display_image(srgb_chip_2)

In [None]:
# Sample all three images. Find a diagonal transformation that best fits all three samples.
raw_samples = np.concatenate(
    (
        random_sampling(to_color_arr(raw_chip_0)),
        random_sampling(to_color_arr(raw_chip_1)),
        random_sampling(to_color_arr(raw_chip_2)),
    )
)
target_samples = np.concatenate(
    (
        random_sampling(to_color_arr(srgb_chip_0)),
        random_sampling(to_color_arr(srgb_chip_1)),
        random_sampling(to_color_arr(srgb_chip_2)),
    )
)

transform_arr = find_diagonal_transformation(raw_samples, target_samples)
t_chip_2 = apply_color_transformation(raw_chip_2, transform_arr)

# Display the results for chip 2. The diagonal transformation matrix no longer
# produces a result which visually matches the sRGB image.
plot_color_scope(t_chip_2)
display_image(t_chip_2)
display_image(srgb_chip_2)
display(transform_arr)


## Create a color conversion model using gradient descent

In [None]:
# This is a layer with a trainiable 3x3 weights matrix which applies
# this as the dot product to the input pixel.


class ColorTransformLayer(layers.Layer):
    def __init__(self, color_ch=3):
        super(ColorTransformLayer, self).__init__()
        self.color_ch = color_ch

    def build(self, input_shape):
        self.w = self.add_weight(shape=(input_shape), initializer="random_normal", trainable=True)

    def call(self, inputs):
        return tf.tensordot(inputs, self.w, 1)

In [None]:
# The model just consists of this, as that is all we want to do for this
# experiment.

model = Sequential(
    [
        ColorTransformLayer(),
    ]
)

In [None]:
model.compile(optimizer="adam", loss="mse")
model.build((3, 3))
model.summary()

In [None]:
X_data = to_color_arr(exr_img_arr)
y_data = to_color_arr(srgb_img_arr)

dataset = tf.data.Dataset.from_tensor_slices((X_data, y_data)).shuffle(50000).batch(32)
train_dataset, test_dataset = tf.keras.utils.split_dataset(dataset, left_size=0.8)

In [None]:
for i in dataset.take(1):
    display(i)

In [None]:
model.fit(train_dataset, epochs=8, validation_data=test_dataset)

In [None]:
model.layers[0].weights

In [None]:
# Using the predict method is *much* slower than just applying the dot product
# from the array.


def apply_model(img_arr, model):
    r_orig, c_orig = img_arr.shape[:2]

    output_arr = model.predict(to_color_arr(img_arr))
    return output_arr.reshape((r_orig, c_orig, 3))

In [None]:
transform_arr = model.layers[0].weights

t_chip_0 = apply_color_transformation(raw_chip_0, transform_arr)

# The scopes should be proportional, or at close if the training was successful.
plot_color_scope(t_chip_0)
plot_color_scope(srgb_chip_0)
display_image(t_chip_0)
display_image(srgb_chip_0)

In [None]:
t_chip_1 = apply_color_transformation(raw_chip_1, transform_arr)

plot_color_scope(t_chip_1)
plot_color_scope(srgb_chip_1)
display_image(t_chip_1)
display_image(srgb_chip_1)

In [None]:
# Apply the model to the sample image.

display_image(apply_model(exr_img_arr, model))

In [None]:
display_image(apply_color_transformation(exr_img_arr, transform_arr))
display_image(srgb_img_arr)

In [None]:
# An experiment with two color correction .

dense_model = Sequential(
    [
        layers.Dense(3, activation="elu"),
        layers.Dense(3, activation="sigmoid"),
        ColorTransformLayer(),
        # XyzToSrgbLayer(),
    ]
)

dense_model.compile(optimizer="adam", loss="mse")
dense_model.build((3, 3))
dense_model.summary()

In [None]:
dense_model.fit(train_dataset, epochs=8, validation_data=test_dataset)

In [None]:
display_image(apply_model(exr_img_arr, dense_model))
display_image(srgb_img_arr)


### Conclusions

Using training data within a specific lighting scenario (ie. *indoor*) will
produce a model which properly converts the color for that scenario.


## Convert colors using the dcraw color matrix

This roughly follows the process outlined here, with some questions and experimentation:

https://www.numbercrunch.de/blog/2020/12/from-numbers-to-images-raw-image-processing-with-python/

Using the approach outlined above would be the ideal approach, however, I was never able to achieve the same results through prior experimentation.

### Outstanding Questions

* Can white balance be applied after demosaicing? Conceiveably it can, as the pixel values still represent the same thing, both before and after conversion.

* The approach above indicates that `rgb_xyz_matrix` is *the camera specific matrix that turns XYZ color into camera primaries* but given the name, it seems like this is the matrix which converts camera primaries into XYZ.

* What would happen if we applied the `rgb_xyz_matrix` and `camera_whitebalance` to the image prior to training the model - would the result be generalized? What would the model weights look like?

In [None]:
def get_raw_whitebalance(raw_path):
    with rawpy.imread(raw_path) as raw_img:
        camera_whitebalance = raw_img.camera_whitebalance
        daylight_whitebalance = raw_img.daylight_whitebalance
        rgb_xyz_matrix = raw_img.rgb_xyz_matrix

    w_a = math.fsum(daylight_whitebalance) / math.fsum(camera_whitebalance)
    cam_whitebalance = np.asarray(camera_whitebalance)[:3] * w_a
    return cam_whitebalance, daylight_whitebalance, rgb_xyz_matrix

In [None]:
with rawpy.imread(RAW_PATH) as raw_img:
    raw_img_array = raw_img.raw_image.astype(np.float32).copy()
    camera_whitebalance = raw_img.camera_whitebalance
    daylight_whitebalance = raw_img.daylight_whitebalance
    rgb_xyz_matrix = raw_img.rgb_xyz_matrix

In [None]:
camera_whitebalance, _, _ = get_raw_whitebalance(RAW_PATH)
camera_whitebalance

In [None]:
cam_matrix = rgb_xyz_matrix[:3]

In [None]:
srgb_to_xyz = np.array(
    [
        [0.4124564, 0.3575761, 0.1804375],
        [0.2126729, 0.7151522, 0.0721750],
        [0.0193339, 0.1191920, 0.9503041],
    ],
    dtype=np.float32,
)

# This roughly matches the sRGB D65 matrix shown here:
# http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html
xyz_to_srgb = np.linalg.inv(srgb_to_xyz)

In [None]:
# this assumes the cam_matrix is xyz_to_cam
srgb_to_cam = np.dot(cam_matrix, srgb_to_xyz)
cam_to_srgb = np.linalg.inv(srgb_to_cam)

In [None]:
# What is the difference between the dot product and the einsum?
srgb_img_arr = np.einsum("ij,...j", cam_to_srgb, cam_whitebalance * exr_img_arr)
srgb_img_dot = np.dot(cam_whitebalance * exr_img_arr, cam_to_srgb)

display_image(transform.normalize_arr(srgb_img_arr))
display_image(transform.normalize_arr(srgb_img_dot))

In [None]:
# Trying this as the inverse. ie. if the cam_matrix is cam_to_xyz

xyz_img_arr = np.einsum("ij,...j", cam_matrix, exr_img_arr)
xyz_img_dot = np.dot(exr_img_arr, cam_matrix)

srgb_img_arr = np.einsum("ij,...j", xyz_to_srgb, xyz_img_arr)
srgb_img_dot = np.dot(xyz_img_arr, srgb_to_xyz)
display_image(transform.normalize_arr(srgb_img_arr))
display_image(transform.normalize_arr(srgb_img_dot))

### Conclusions?

It is difficult to make a conclusion from the operation above other than there is something significant I don't understand in the operations above.