In [None]:
import math
import numpy as np
import pyexr
import rawpy

from PIL import Image
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from cnn_demosaic import transform

In [None]:
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

# Prevent TensorFlow from allocating all GPU memory.
gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

## Configure common functions

In [None]:
def display_image(image):
    plt.imshow(image, vmin=0, vmax=1)
    plt.show()


def to_color_arr(img_arr):
    """Flatten the 2D image to an array."""
    return img_arr.reshape((img_arr.size // 3, 3))


def random_sampling(target_arr, n_samples=1000):
    """Returns an array containing n_samples from a color swatch."""
    rand_index = np.random.randint(0, target_arr.shape[0], n_samples)
    return target_arr[rand_index]

In [None]:
def get_raw_properties(raw_path):
    with rawpy.imread(raw_path) as raw_img:
        camera_whitebalance = raw_img.camera_whitebalance
        daylight_whitebalance = np.asarray(raw_img.daylight_whitebalance)
        rgb_xyz_matrix = raw_img.rgb_xyz_matrix[:3]

    w_a = math.fsum(daylight_whitebalance) / math.fsum(camera_whitebalance)
    cam_whitebalance = np.asarray(camera_whitebalance)[:3] * w_a
    return cam_whitebalance, daylight_whitebalance, rgb_xyz_matrix


def get_wb_params(paths):
    for path in paths:
        cam_whitebalance, _, _ = get_raw_properties(path)
        yield cam_whitebalance


# Usign the predict method is *much* slower than just applying the dot product
# from the array.


def apply_model(img_arr, model):
    r_orig, c_orig = img_arr.shape[:2]

    output_arr = model.predict(to_color_arr(img_arr), batch_size=8192)
    return output_arr.reshape((r_orig, c_orig, 3))

In [None]:
srgb_to_xyz = np.array(
    [
        [0.4124564, 0.3575761, 0.1804375],
        [0.2126729, 0.7151522, 0.0721750],
        [0.0193339, 0.1191920, 0.9503041],
    ],
    dtype=np.float32,
)

# This roughly matches the sRGB D65 matrix shown here:
# http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html
xyz_to_srgb = np.linalg.inv(srgb_to_xyz)

## Set up datasets

### Creating datasets

Images are used as raw training inputs for images. The dataset is defined as input and output image pairs with a RAW image containing the as-shot parameters for the image. The image pairs correspond to the floating point demosaiced image (X) and the sRGB camera color corrected output image (y). Depending on the model, the white balance information from the camera may be used as an input. There are a number of considerations for the image chosen and their preparation. Each pixel in the output image should correspond to each pixel in the input image. The model will be evaluated on the ability to predict the output value of the pixel based on the input value, using color transformations. This means that lens and geometric transformations on the output image will move pixels from their original location. If using a color card, one approach is to slice the image into tiles so that all the pixels in one tile correspond to all the pixels in another tile of the same color, and to trim the image into this essential shape.

An easier approach is to disable lens correction in the camera if possible, or use a manual lens which will does not provide information to the camera for image processing. Adjustments such as pincussion and barrel distortion correction will not be saved to the image, and each input pixel should correspond to each output pixel. One caveat to this is that the output EXR image is not yet cropped identically to the output JPG from the camera. Therefore manual alignment and cropping is currently required.

Each input and output image is then sampled (100k pixel pairs) from each image. Distribution of color and lighting in the image affects the result. Using a wide variety of image and lighting situations when training a model will produce output with dull colors. Using a very similar set of images will produce a model that generates more accurate, vivid colors. Experimentation is important. Levels in the image are not currently normalized, and this could have a big impact on the output.  


In [None]:
DATASET_PREFIX = "/media/jake/Media/datasets/fuji_raw/xe2/125_FUJI"

# Cherrypick training images to batch similar results. These should really be
# clustered by white balance parameters.

RAW_PATHS = [
    "DSCF5752.RAF",
    "DSCF5759.RAF",
    # "DSCF5760.RAF",
    "DSCF5761.RAF",
    # "DSCF5731.RAF",
    # "DSCF5782.RAF",
    # "DSCF5783.RAF",
    "DSCF5796.RAF",
]

TRAINING_PATHS = [
    ("DSCF5752_card.exr", "DSCF5752_card.png"),
    ("DSCF5759_card.exr", "DSCF5759_card.png"),
    # ("DSCF5760_card.exr", "DSCF5760_card.png"),
    ("DSCF5761_card.exr", "DSCF5761_card.png"),
    # ("DSCF5731_card.exr", "DSCF5731_card_srgb.png"),  # Outdoor, partly cloudy.
    # ("DSCF5782.exr", "DSCF5782.JPG"),
    # ("DSCF5783.exr", "DSCF5783.JPG"),
    ("DSCF5796.exr", "DSCF5796.JPG"),
]

In [None]:
from cnn_demosaic import exposure_model
from cnn_demosaic import exposure

exp_model = exposure_model.create_exposure_model('./params_model_32_16_0.weights.h5')
exp = exposure.Exposure(exp_model)

In [None]:
def sample_pairs(a_arr, b_arr, n_samples=128*128):
    """Returns an array containing n_samples from a color swatch."""
    assert len(a_arr) == len(b_arr)
    rand_index = np.random.randint(0, a_arr.shape[0], n_samples)
    return a_arr[rand_index], b_arr[rand_index]


def load_dataset(feature_image_path, srgb_image_path, wb_array=None):
    feat_img_arr = pyexr.read(f"{DATASET_PREFIX}/{feature_image_path}")[:, :, :3]
    srgb_img_arr = np.asarray(Image.open(f"{DATASET_PREFIX}/{srgb_image_path}"))[:, :, :3] / 255

    # Apply the levels model to the image.
    proc_img_arr = np.asarray(exp.process(feat_img_arr))
    
    feat_rgb_samples, targ_rgb_samples = sample_pairs(
        to_color_arr(proc_img_arr), to_color_arr(srgb_img_arr)
    )

    if wb_array is not None:
        wb_matrix_full = np.full_like(feat_rgb_samples, wb_array)
        feat_rgb_samples = np.concatenate((feat_rgb_samples, wb_matrix_full), axis=1)

    # display(f'loaded dataset shape: {feat_rgb_samples.shape}')
    return tf.data.Dataset.from_tensor_slices((feat_rgb_samples, targ_rgb_samples))


In [None]:
@tf.function
def normalize_dyn_range(array_input):
    dyn_range = 0.98
    offset = (1.0 - dyn_range) / 2.0
    arr_min = tf.math.reduce_min(array_input)
    arr_max = tf.math.reduce_max(array_input)
    denom = tf.math.maximum(arr_max - arr_min, 0.001)
    return (array_input - arr_min) * dyn_range / denom + offset


@tf.function
def normalize_batches(X_inp, y_inp):
    X_rgb = normalize_dyn_range(X_inp[:, 0:3])
    X_out = tf.concat([X_rgb, X_inp[:, 3:]], 1)
    y_out = normalize_dyn_range(y_inp)
    return X_out, y_out

In [None]:
# Loads datasets combined with whitebalance parameters from RAW_PATHS and TRAINING_PATHS
full_raw_paths = [f"{DATASET_PREFIX}/{p}" for p in RAW_PATHS]
wb_params = list(get_wb_params(full_raw_paths))

In [None]:
# Loads datasets combined with whitebalance parameters from RAW_PATHS and TRAINING_PATHS
full_raw_paths = [f"{DATASET_PREFIX}/{p}" for p in RAW_PATHS]
wb_params = list(get_wb_params(full_raw_paths))

wb_datasets = []

for i in range(len(TRAINING_PATHS)):
    # DEFAULT: No normalization of batches. Produces good results for constistent levels across training images.
    # wb_datasets.append(load_dataset(*TRAINING_PATHS[i], wb_params[i]).batch(32))
    # EXPERIMENT: Normalize image input levels so the model is not performing the bulk of the adjustment.
    # Normalizing the dynamic range increases the overall loss but improves color accuracy.
    wb_datasets.append(
        load_dataset(*TRAINING_PATHS[i], wb_params[i]).batch(32).map(normalize_batches)
    )

wb_datasets_combined = tf.data.Dataset.sample_from_datasets(wb_datasets)

In [None]:
class CameraCorrectionLayer(layers.Layer):
    def __init__(self, camera_matrix):
        super(CameraCorrectionLayer, self).__init__()
        self.camera_matrix = tf.convert_to_tensor(np.linalg.inv(camera_matrix))

    def build(self, input_shape):
        return
        # if input_shape[1] != 3:
        #     raise ValueError(f'input shape is not valid for this layer: {input_shape}')

    def call(self, inputs):
        return tf.tensordot(inputs, self.camera_matrix, 1)


class XyzToSrgbLayer(layers.Layer):
    def __init__(self, color_ch=3):
        super(XyzToSrgbLayer, self).__init__()
        self.color_ch = color_ch

    def build(self, input_shape):
        return
        # if input_shape[1] != 3:
        #     raise ValueError(f'input shape is not valid for this layer: {input_shape}')

    def call(self, inputs):
        return tf.tensordot(inputs, xyz_to_srgb, 1)


class ColorTransformLayer(layers.Layer):
    def __init__(self, color_ch=3):
        super(ColorTransformLayer, self).__init__()
        self.color_ch = color_ch

    def build(self, input_shape):
        self.w = self.add_weight(shape=(3, 3), initializer="random_normal", trainable=True)

    def call(self, inputs):
        return tf.tensordot(inputs, self.w, 1)


class MultiColorTransformLayer(layers.Layer):
    def __init__(self, **kwargs):
        super(MultiColorTransformLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.w = self.add_weight(shape=(3, 3), initializer="random_normal", trainable=True)

    def call(self, inputs):
        @tf.function()
        def apply_fn(inputs):
            return tf.tensordot(inputs, self.w, 1)
        
        return tf.map_fn(
            apply_fn,
            inputs
        )

    def compute_output_shape(self, input_shape):
        return input_shape


@tf.function()
def tf_s_curve_fn(img_arr, offset, contrast, slope):
    """Applies the s-curve function in a TensorFlow context."""
    output = (img_arr + offset) * contrast
    return 1 / (1 + tf.math.exp(slope * output))


# Applies a channel independent s-curve to the image.
class ColorCurveAdjLayer(layers.Layer):
    def __init__(self, color_ch=3):
        super(ColorCurveAdjLayer, self).__init__()
        self.color_ch = color_ch

    def build(self, input_shape):
        self.input_shape = input_shape
        self.w = self.add_weight(shape=(3, 3), initializer="random_normal", trainable=True)

    def call(self, inputs):
        @tf.function()
        def apply_curve(rgb_arr):
            # weights should contain 3 parameters.
            output = tf_s_curve_fn(rgb_arr, self.w[0], self.w[1], self.w[2])
            return output

        return apply_curve(inputs)

    def compute_output_shape(self, input_shape):
        return input_shape

## Set up datasets for a model including WB params

## Define Correction Model

In [None]:
def create_model():   
    # Merged input contains RGB values for the pixel as [0:3] and white balance
    # values for the pixel as [3:]. The purpose of this approach is to create
    # a general model which can color correct for images of any light source,
    # however it doesn't seem to be working as intended, even though it is
    # doing *something*.
    merged_input = keras.Input(shape=(6,), name="merged_input")
    rgb_input = layers.Lambda(lambda x: x[:, 0:3])(merged_input)
    wb_input = layers.Lambda(lambda x: x[:, 3:])(merged_input)

    # Apply a set of weights to the white balance. The first white balance weights
    # will be used on the camera rgb input. The second will be used on the
    # converted RGB input.
    wb_transform = layers.Dense(6)(wb_input)
    wb_transform = layers.Dense(6)(wb_transform)
    wb_transform = layers.Dense(6)(wb_transform)

    # wb_transform = layers.Dense(3, activation="sigmoid")(wb_transform)
    wb_transform_1 = layers.Dense(3)(wb_transform)
    wb_transform_1 = layers.Dense(3, activation="sigmoid")(wb_transform_1)
    wb_transform_1 = ColorTransformLayer()(wb_transform_1)

    wb_transform_2 = layers.Dense(6)(wb_transform)
    wb_transform_2 = layers.Dense(6)(wb_transform_2)
    wb_transform_2 = layers.Dense(3)(wb_transform_2)
    wb_transform_2 = layers.Dense(3, activation="sigmoid")(wb_transform_2)
    wb_transform_2 = ColorTransformLayer()(wb_transform_2)    

    rgb_layers = layers.Identity()(rgb_input)

    # Apply a transformed white balance.
    rgb_layers = layers.Multiply()([rgb_layers, wb_transform_1])

    # Apply per channel color curves to the input.
    rgb_layers = ColorCurveAdjLayer()(rgb_layers)

    # Apply the color transformation to the RGB image.
    rgb_layers = ColorTransformLayer()(rgb_layers)
    rgb_layers = XyzToSrgbLayer()(rgb_layers)

    # Apply a second RGB color curve to the input.
    rgb_layers = ColorCurveAdjLayer()(rgb_layers)

    # EXP: previous best, this enabled!!
    rgb_layers = layers.Multiply()([rgb_layers, wb_transform_2])
    
    # Apply a second transformation of the white balance to the output.
    # This has a minimal but measurable improvement.
    rgb_layers = layers.Multiply()([rgb_layers, wb_transform_2])

    model = keras.Model(inputs=merged_input, outputs=rgb_layers)

    return model


In [None]:
once = True

mse_loss = tf.keras.losses.MeanSquaredError()

# I actually don't know if this loss function is useful.
def level_loss(y_true, y_pred):
    sq_diff = tf.square(y_true - y_pred)
    sum_sq_diff = tf.reduce_sum(sq_diff, axis=-1)
    error = tf.sqrt(sum_sq_diff)
    return error

def color_relative_loss(y_true, y_pred):
    sum_true = tf.reshape(tf.reduce_sum(y_true, axis=-1), (-1, 1))
    sum_pred = tf.reshape(tf.reduce_sum(y_pred, axis=-1), (-1, 1))
    weighted = tf.square(y_true/sum_true - y_pred/sum_pred)
    error = tf.reduce_sum(weighted, axis=-1)
    return error

def balanced_relative_loss(y_true, y_pred):
    return mse_loss.call(y_true, y_pred) + color_relative_loss(y_true, y_pred)/3


In [None]:
y_true_test = np.asarray([[1.0, 1.0, 1.0],
                          [0.3, 0.1, 0.2]])
y_pred_test = np.asarray([[0.5, 0.5, 0.5],
                          [0.6, 0.2, 0.4]])

display(color_relative_loss(y_true_test, y_pred_test))
display(level_loss(y_true_test, y_pred_test))


In [None]:
color_correction_model = create_model()
color_correction_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss=balanced_relative_loss, metrics=['mse'])
color_correction_model.build((None, 6))
color_correction_model.summary()

In [None]:
results = []

In [None]:
from IPython.display import clear_output

best_model = None
best_loss = float("inf")
epochs = 10
num_runs = 20

hyperparams = [1]


for h in hyperparams:
    for i in range(num_runs):
        print(f"Training model {i + 1}/{num_runs}.")

        model = create_model()
        model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss=balanced_relative_loss, metrics=['mse'])
        model.build((None, 6))

        history = model.fit(wb_datasets_combined, epochs=epochs, verbose=0)

        # Get the validation loss. (change this when using validation data)
        train_loss = history.history["loss"][-1]
        print(f"Finished model {i + 1}/{num_runs}: {train_loss}")

        if train_loss < best_loss:
            best_loss = train_loss
            best_model = model

        results.append(
            {
                "train_loss": train_loss,
            }
        )
    clear_output(wait=True)

print(f"Best loss: {best_loss}")

In [None]:
best_model.fit(wb_datasets_combined, epochs=epochs)

In [None]:
color_correction_model.save_weights(f"color_correction_model_0_7.weights.h5")

In [None]:

def apply_correction_model_predict(img_arr, model, wb_matrix):
    r_orig, c_orig = img_arr.shape[:2]

    img_linear_arr = to_color_arr(np.asarray(img_arr))

    # Creates a whitebalance array and merges this with the image.
    wb_matrix_full = np.full_like(img_linear_arr, wb_matrix)
    process_arr = np.concatenate((img_linear_arr, wb_matrix_full), axis=1)
    display(process_arr.shape)

    output_arr = model.predict(process_arr, batch_size=8192)
    return output_arr.reshape((r_orig, c_orig, 3))

In [None]:
test_exr_img_path = f"{DATASET_PREFIX}/DSCF5796.exr"
test_raw_img_path = f"{DATASET_PREFIX}/DSCF5796.RAF"
# test_exr_img_path = f"{DATASET_PREFIX}/DSCF5783.exr"
# test_raw_img_path = f"{DATASET_PREFIX}/DSCF5783.RAF"

test_exr_img_arr = exp.process(pyexr.read(test_exr_img_path)[:, :, :3])
test_wb, _, _ = get_raw_properties(test_raw_img_path)

output_img = apply_correction_model_predict(test_exr_img_arr, best_model, test_wb)

In [None]:
display_image(output_img)

In [None]:
# For some reason DarkTable is having problems with this EXR. Gimp will produce
# usable results when using the eyedropper tool on an image to adjust color
# curves, but this is really what we're trying to avoid doing.

pyexr.write(
    f"./corrected_output_img_4.exr",
    np.asarray(output_img),
    precision=pyexr.HALF,
)

## Conclusions

This approach above is still not working well. I suspect that the model is not
using the white balancing layers, which is why it is producing the muddy
results, which seem to be too warm or too cool depending on which lighting
scenario the images were taken with.

### Alternative approaches

* Create a color correction model which we co-train to produce a white balance
  transform, using as input either:
    - the white / black point
    - the white balance type
    - color correction weights
* Create a bunch of color correction sample images using color temperature controlled panel lights.