# Color-Levels Model

This model intends to two additional concepts that were not present in the color model.

* Overal image level parameters should be trained independently of the color curves so that the model is not specifically trained to translate from a color at a specific level, stauration, or intensity. (Although different "looks" may have different saturation curves at different levels.)
* The model could be trained in such a way that the core color correction layers of a model could be frozen and aesthetic and color temperature parameters of an image can be left unfrozen to differentiate different styles or looks for an image. The parameters of these layers can be blended or merged to apply different types of color correction to an image.

The challenge with both of these approaches is to isolate the fundamental parameters into specific layers while the aesthetic parameters are trained independently.

In [None]:
import math
import numpy as np
import pyexr
import rawpy

import tensorflow as tf

# Prevent TensorFlow from allocating all GPU memory.
gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

from PIL import Image
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from cnn_demosaic import exposure_model
from cnn_demosaic import transform

## Configure common functions

In [None]:
def display_image(image):
    plt.imshow(image, vmin=0, vmax=1, cmap="gray")
    plt.show()


def to_color_arr(img_arr):
    """Flatten the 2D image to an array."""
    return img_arr.reshape((img_arr.size // 3, 3))


def random_sampling(target_arr, n_samples=1000):
    """Returns an array containing n_samples from a color swatch."""
    rand_index = np.random.randint(0, target_arr.shape[0], n_samples)
    return target_arr[rand_index]

In [None]:
def get_raw_properties(raw_path):
    with rawpy.imread(raw_path) as raw_img:
        camera_whitebalance = raw_img.camera_whitebalance
        daylight_whitebalance = np.asarray(raw_img.daylight_whitebalance)
        rgb_xyz_matrix = raw_img.rgb_xyz_matrix[:3]

    w_a = math.fsum(daylight_whitebalance) / math.fsum(camera_whitebalance)
    cam_whitebalance = np.asarray(camera_whitebalance)[:3] * w_a
    return cam_whitebalance, daylight_whitebalance, rgb_xyz_matrix


def get_wb_params(paths):
    for path in paths:
        cam_whitebalance, _, _ = get_raw_properties(path)
        yield cam_whitebalance


# Usign the predict method is *much* slower than just applying the dot product
# from the array.


def apply_model(img_arr, model):
    r_orig, c_orig = img_arr.shape[:2]

    output_arr = model.predict(to_color_arr(img_arr))
    return output_arr.reshape((r_orig, c_orig, 3))


@tf.function()
def tf_rgb_luma(rgb_ch):
    # Convert the image to a luma channel using the following ratios:
    # 0.299R + 0.587G + 0.114B
    return rgb_ch[:, 0] * 0.299 + rgb_ch[:, 1] * 0.587 + rgb_ch[:, 2] * 0.114

In [None]:
srgb_to_xyz = np.array(
    [
        [0.4124564, 0.3575761, 0.1804375],
        [0.2126729, 0.7151522, 0.0721750],
        [0.0193339, 0.1191920, 0.9503041],
    ],
    dtype=np.float32,
)

# This roughly matches the sRGB D65 matrix shown here:
# http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html
xyz_to_srgb = np.linalg.inv(srgb_to_xyz)

xyz_to_srgb_tf = tf.constant(xyz_to_srgb)

## Set up dataset paths

In [None]:
DATASET_PREFIX = "/media/jake/Media/datasets/fuji_raw/xe2/125_FUJI"

TRAINING_PATHS = [
    ("DSCF5752_card.exr", "DSCF5752_card.png"),
    ("DSCF5759_card.exr", "DSCF5759_card.png"),
    ("DSCF5760_card.exr", "DSCF5760_card.png"),
    ("DSCF5761_card.exr", "DSCF5761_card.png"),
    ("DSCF5731_card.exr", "DSCF5731_card_srgb.png"),  # Outdoor, partly cloudy.
    ("DSCF5782.exr", "DSCF5782.JPG"),
    ("DSCF5783.exr", "DSCF5783.JPG"),
    ("DSCF5796.exr", "DSCF5796.JPG"),
]

## Set up datasets for a model including WB params

In [None]:
def sample_pairs(a_arr, b_arr, n_samples=100000):
    """Returns an array containing n_samples from a color swatch."""
    assert len(a_arr) == len(b_arr)
    rand_index = np.random.randint(0, a_arr.shape[0], n_samples)
    return a_arr[rand_index], b_arr[rand_index]


def load_dataset(feature_image_path, srgb_image_path):
    feat_img_arr = pyexr.read(f"{DATASET_PREFIX}/{feature_image_path}")[:, :, :3]
    srgb_img_arr = np.asarray(Image.open(f"{DATASET_PREFIX}/{srgb_image_path}"))[:, :, :3] / 255

    feat_rgb_samples, targ_rgb_samples = sample_pairs(
        to_color_arr(feat_img_arr), to_color_arr(srgb_img_arr), 128 * 128 * 128
    )

    # TODO(): Augment data by changing the quality of the exposure and with different images.
    dataset = tf.data.Dataset.from_tensor_slices((feat_rgb_samples, targ_rgb_samples))
    dataset = dataset.batch(128 * 128).map(lambda i, j: (i, tf_rgb_luma(j)))
    return dataset

In [None]:
levels_datasets = []

for i in range(len(TRAINING_PATHS)):
    levels_datasets.append(load_dataset(*TRAINING_PATHS[i]))

levels_datasets_combined = tf.data.Dataset.sample_from_datasets(levels_datasets)

In [None]:
for i, j in levels_datasets_combined.take(1):
    print(i.shape)
    print(j.shape)

## Define Correction Model

In [None]:
class LogNormalizationLayer(layers.Layer):
    def __init__(self, axis=1, epsilon=1e-6, **kwargs):
        super(LogNormalizationLayer, self).__init__(**kwargs)
        self.axis = axis
        self.epsilon = epsilon

    def build(self, input_shape):
        # print(f"LogNormalizationLayer input shape: {input_shape}")
        self.input_shape = input_shape

    def call(self, inputs):
        log_inputs = tf.math.log(tf.cast(inputs, dtype=tf.float32) + self.epsilon)

        max_log_input = tf.math.reduce_max(log_inputs, axis=self.axis, keepdims=True)
        normalized_output = log_inputs / (max_log_input + self.epsilon)

        return normalized_output

    def get_config(self):
        config = super(LogNormalizationLayer, self).get_config()
        config.update(
            {
                "axis": self.axis,
                "epsilon": self.epsilon,
            }
        )
        return config

    def compute_output_shape(self, input_shape):
        return input_shape


@tf.function()
def levels_fn(img_arr, in_min, in_max):
    dyn_range = in_max - in_min
    dyn_range = dyn_range if dyn_range != 0.0 else 0.00001

    scale_ratio = 1.0 / dyn_range
    return (img_arr - in_min) * scale_ratio


class LevelsLayer(layers.Layer):
    def __init__(self, color_ch=3):
        super(LevelsLayer, self).__init__()
        self.color_ch = color_ch

    def build(self, input_shape):
        # print(f"LevelsLayer input shape: {input_shape}")
        self.input_shape = input_shape

    def call(self, inputs):
        # Parameters list must be iterated / mapped to apply gamma exponent to
        # each sub array in the first dimension.

        @tf.function()
        def apply_levels(parameters):
            img_arr, weights = parameters
            adjusted_image = levels_fn(img_arr, weights[0], weights[1])

            # Dimensions of the input and output must match.
            return [adjusted_image, weights]

        return tf.map_fn(apply_levels, inputs)[0]

    def compute_output_shape(self, input_shape):
        return input_shape[0]


class GammaAdjLayer(layers.Layer):
    def __init__(self, color_ch=3):
        super(GammaAdjLayer, self).__init__()
        self.color_ch = color_ch

    def build(self, input_shape):
        # print(f"GammaAdjLayer input shape: {input_shape}")
        self.input_shape = input_shape

    def call(self, inputs):
        # Parameters list must be iterated / mapped to apply gamma exponent to
        # each sub array in the first dimension.

        @tf.function()
        def apply_pow(parameters):
            img_arr, exp = parameters
            # Dimensions of the input and output must match.
            return [tf.pow(img_arr, exp), exp]

        return tf.map_fn(apply_pow, inputs)[0]

    def compute_output_shape(self, input_shape):
        return input_shape[0]


@tf.function()
def s_curve_fn(img_arr, offset, contrast, slope):
    output = (img_arr + offset) * contrast
    return 1 / (1 + tf.math.exp(slope * output))


class ColorCurveAdjLayer(layers.Layer):
    def __init__(self, color_ch=3):
        super(ColorCurveAdjLayer, self).__init__()
        self.color_ch = color_ch

    def build(self, input_shape):
        # print(f"ColorCurveAdjLayer input shape: {input_shape}")
        self.input_shape = input_shape

    def call(self, inputs):
        @tf.function()
        def apply_curve(parameters):
            rgb_arr, weights = parameters
            # weights should contain 3 parameters.
            output = s_curve_fn(rgb_arr, weights[0], weights[1], weights[2])
            return [output, weights]

        return tf.map_fn(apply_curve, inputs)[0]

    def compute_output_shape(self, input_shape):
        return input_shape[0]


class LumaLayer(layers.Layer):
    def __init__(self):
        super(LumaLayer, self).__init__()

    def build(self, input_shape):
        self.input_shape = input_shape

    def call(self, rgb_ch):
        output_shape = (self.input_shape[1],)
        y_ch = tf.map_fn(
            tf_rgb_luma,
            rgb_ch,
            fn_output_signature=tf.TensorSpec(shape=output_shape, dtype=tf.float32),
        )
        return y_ch

    def compute_output_shape(self, input_shape):
        return (None, input_shape[1])


# Must disable jit_compile to avoid XLA graph compilation errors. But guess what...
# Disabling jit compilation here has no effect. Disabling it globally for the model
# however does.
@tf.function(jit_compile=False)
def tf_histogram(inputs, bins=16):
    return tf.histogram_fixed_width(inputs, [0.0, 1.0], bins)


class HistogramLayer(layers.Layer):
    def __init__(self, bins=16):
        super(HistogramLayer, self).__init__()
        self.bins = bins

    def build(self, input_shape):
        self.input_shape = input_shape

    def call(self, hist_input):
        output = tf.map_fn(
            lambda i: tf_histogram(i, self.bins),
            hist_input,
            fn_output_signature=tf.TensorSpec(shape=(self.bins,), dtype=tf.int32),
        )
        return output

    def get_config(self):
        config = super(HistogramLayer, self).get_config()
        config.update(
            {
                "bins": self.bins,
            }
        )
        return config

    def compute_output_shape(self, input_shape):
        return (None, self.bins)

### Model Strcuture

The model inputs are RGB image pixel data and outputs are RGB pixel data with corrected luminance levels. For the purpose of this model, we don't want to pass in information about the target image levels, we want a generalized approach, which means that for any image we pass in, we get the "typical image processing" version of that image. The result may not be a LUT - ie. different input images may map to different curves based on the exposure. Right now, here's the theoretical process.

* Input - RGB image array
* Create variable containing luminance histogram
* Sub Model - Compute Processing Parameters
    - Input luminance histogram (16 buckets)
    - Dense Layers
    - Output parameters: (gamma, midpoint, contrast, slope)
* Gamma layer - (gamma)
* S-Curve Layer - (midpoint, contrast, slope)
* Convert output to luma channel
* Output luma channel

Train against luma channel image training data.

In [None]:
def create_test_model(exp_model, hist_bins=32):
    """This model is the training harness that allows the parameters model
    to be trained on image data. It also generates nice grayscale images."""
    rgb_input = keras.Input(shape=(None, 3), name="rgb_input")

    rgb_layers = layers.Identity()(rgb_input)

    # Get the Y channel histogram
    y_ch = LumaLayer()(rgb_input)
    y_hist = HistogramLayer(bins=hist_bins)(y_ch)

    # Levels model: determine image processing parameters.
    # if levels_model is None:
    #     levels_model = correction_parameters_model(hist_size=hist_bins, dense_layers=dense_layers)

    levels_weights, gamma_weights, curve_weights = exp_model(y_hist)

    # Apply the predicted levels and curves to the output.
    rgb_layers = GammaAdjLayer()([rgb_layers, gamma_weights])
    rgb_layers = LevelsLayer()([rgb_layers, levels_weights])
    rgb_layers = ColorCurveAdjLayer()([rgb_layers, curve_weights])

    # Convert the output to luma for training comparison.
    rgb_layers = LumaLayer()(rgb_layers)

    model = keras.Model(inputs=[rgb_input], outputs=rgb_layers)

    return model

In [None]:
exp_model = exposure_model.exposure_params_model()
test_model = create_test_model(exp_model)
test_model.compile(optimizer="adam", loss="mse", jit_compile=False)
test_model.build((None, 3))
test_model.summary()

In [None]:
v = test_model.predict(np.random.rand(1, 128 * 128, 3))

In [None]:
history = test_model.fit(levels_datasets_combined.batch(32), epochs=2)

In [None]:
results = []

In [None]:
from IPython.display import clear_output

best_model = None
best_exp_model = None
best_val_loss = float("inf")
epochs = 10
num_runs = 20

hyperparams = []

for i_l in [16]:
    for i_b in [32]:
        hyperparams.append([i_l, i_b])


for i_l, i_b in hyperparams:
    for i in range(num_runs):
        print(f"Training model {i + 1}/{num_runs}: hist_bins: [{i_b}] dense_layers: [{i_l}]")

        exp_model = exposure_model.exposure_params_model(hist_size=i_b, dense_layers=i_l)
        model = create_test_model(exp_model, hist_bins=i_b)

        model.compile(optimizer="adam", loss="mse", jit_compile=False)
        model.build((None, 3))

        history = model.fit(levels_datasets_combined.batch(32), epochs=epochs, verbose=0)

        # Get the validation loss. (change this when using validation data)
        train_loss = history.history["loss"][-1]
        print(f"Finished model {i + 1}/{num_runs}: {train_loss}")

        val_loss = train_loss

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = model
            best_exp_model = exp_model

        results.append(
            {
                "hist_bins": i_b,
                "dense_layers": i_l,
                "train_loss": train_loss,
            }
        )
    clear_output(wait=True)

print(f"Best loss: {best_val_loss}")

In [None]:
best_model.layers[2].get_config()

In [None]:

# Perform additional training on the best model to see if additional performance can be squeezed out of it
best_model.fit(levels_datasets_combined.batch(32), epochs=10)

In [None]:
import pandas as pd

results_df = pd.DataFrame(results).groupby(["hist_bins", "dense_layers"])

plt.figure(figsize=(10, 6))

for (hist_bins, dense_layers), group in results_df:
    plt.scatter(
        group["train_loss"], group["dense_layers"], label=f"B={hist_bins}, L={dense_layers}"
    )

# Set the x-axis to log scale
plt.gca().set_xscale("log")

plt.xlabel("Training Loss")
plt.ylabel("Layers")
plt.legend()
plt.show()

In [None]:

def apply_correction_model_predict(img_arr, model):
    img_r, img_c = 256, 256

    img_resized = tf.image.resize(img_arr, [img_r, img_c])
    display_image(img_resized)
    display(img_resized.shape)

    img_linear_arr = tf.reshape(img_resized, (1, img_r * img_c, 3))
    display(img_linear_arr.shape)

    output_arr = model.predict(img_linear_arr)
    display(output_arr.shape)

    # Currently the output we're looking at is a 1 channel image.
    return output_arr.reshape((img_r, img_c))

In [None]:
test_exr_img_path = f"{DATASET_PREFIX}/DSCF5782.exr"
test_exr_img_arr = pyexr.read(test_exr_img_path)[:, :, :3]

output_img_predict = apply_correction_model_predict(test_exr_img_arr, best_model)

In [None]:
display_image(output_img_predict)

In [None]:
test_exr_img_path = f"{DATASET_PREFIX}/DSCF5796.exr"
test_exr_img_arr = pyexr.read(test_exr_img_path)[:, :, :3]

output_img_predict = apply_correction_model_predict(test_exr_img_arr, best_model)

In [None]:
display_image(output_img_predict)

In [None]:
pyexr.write("test_output_levels_1.exr", output_img_predict)

In [None]:
best_exp_model.save_weights("exp_model_32_16_0.weights.h5")