In [None]:
import math
import numpy as np
import pyexr
import rawpy

from PIL import Image
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from cnn_demosaic import transform

In [None]:
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

# Prevent TensorFlow from allocating all GPU memory.
gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

## Configure common functions

In [None]:

def display_image(image):
    plt.imshow(image, vmin=0, vmax=1)
    plt.show()

def to_color_arr(img_arr):
    """Flatten the 2D image to an array."""
    return img_arr.reshape((img_arr.size // 3, 3))

def random_sampling(target_arr, n_samples=1000):
    """Returns an array containing n_samples from a color swatch."""
    rand_index = np.random.randint(0, target_arr.shape[0], n_samples)
    return target_arr[rand_index]

In [None]:

def get_raw_properties(raw_path):
    with rawpy.imread(raw_path) as raw_img:
        camera_whitebalance = raw_img.camera_whitebalance
        daylight_whitebalance = np.asarray(raw_img.daylight_whitebalance)
        rgb_xyz_matrix = raw_img.rgb_xyz_matrix[:3]

    w_a = math.fsum(daylight_whitebalance) / math.fsum(camera_whitebalance)
    cam_whitebalance = np.asarray(camera_whitebalance)[:3] * w_a
    return cam_whitebalance, daylight_whitebalance, rgb_xyz_matrix

def get_wb_params(paths):
    for path in paths:
        cam_whitebalance, _, _ = get_raw_properties(path)
        yield cam_whitebalance

# Usign the predict method is *much* slower than just applying the dot product
# from the array.

def apply_model(img_arr, model):
    r_orig, c_orig = img_arr.shape[:2]
    
    output_arr = model.predict(to_color_arr(img_arr))
    return output_arr.reshape((r_orig, c_orig, 3))

In [None]:
srgb_to_xyz = np.array(
    [
        [0.4124564, 0.3575761, 0.1804375],
        [0.2126729, 0.7151522, 0.0721750],
        [0.0193339, 0.1191920, 0.9503041],
    ],
    dtype=np.float32,
)

# This roughly matches the sRGB D65 matrix shown here:
# http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html
xyz_to_srgb = np.linalg.inv(srgb_to_xyz)

## Set up datasets

In [None]:
# The path to the camera processed color card
SRGB_PATH = '../training/DSCF5652_card_srgb.exr'
# The path to the demosaiced but not color corrected color card
EXR_PATH = '../training/DSCF5652_card.exr'
# The path to the raw image
RAW_PATH = '../training/DSCF5652.RAF'

TEST_SRGB_PATH = '../training/DSCF5731_card_srgb.exr'
TEST_EXR_PATH = '../training/DSCF5731_card.exr'
TEST_RAW_PATH = '../training/DSCF5731.RAF'

TEST_FULL_EXR_PATH = '../training/DSCF5731.exr'

RAW_PATHS = [
    '/media/jake/Media/datasets/fuji_raw/xe2/125_FUJI/DSCF5652.RAF',
    '/media/jake/Media/datasets/fuji_raw/xe2/125_FUJI/DSCF5711.RAF',
    '/media/jake/Media/datasets/fuji_raw/xe2/125_FUJI/DSCF5731.RAF']

TRAINING_PATHS = [
    ('../training/DSCF5652_card.exr', '../training/DSCF5652_card_srgb.exr'),
    ('../training/DSCF5711_card.exr', '../training/DSCF5711_card_srgb.exr'),
    ('../training/DSCF5731_card.exr', '../training/DSCF5731_card_srgb.exr'),   
]

In [None]:
srgb_img_arr = pyexr.read(SRGB_PATH)[:,:,:3]
exr_img_arr = pyexr.read(EXR_PATH)[:,:,:3]

test_exr_img_arr = pyexr.read(TEST_EXR_PATH)[:,:,:3]
test_srgb_img_arr = pyexr.read(TEST_SRGB_PATH)[:,:,:3]

In [None]:
X_indoor = to_color_arr(exr_img_arr)
y_indoor = to_color_arr(srgb_img_arr)

X_outdoor = to_color_arr(test_exr_img_arr)
y_outdoor = to_color_arr(test_srgb_img_arr)

In [None]:
dataset_indoor = tf.data.Dataset.from_tensor_slices((X_indoor, y_indoor)).shuffle(1000)
dataset_outdoor = tf.data.Dataset.from_tensor_slices((X_indoor, y_indoor)).shuffle(1000)

dataset_combined = tf.data.Dataset.sample_from_datasets(
    [dataset_indoor, dataset_outdoor], weights=[0.5, 0.5]).batch(32)

train_dataset, test_dataset = tf.keras.utils.split_dataset(dataset_combined, left_size=0.8)

In [None]:

class CameraCorrectionLayer(layers.Layer):
    def __init__(self, camera_matrix):
        super(CameraCorrectionLayer, self).__init__()
        self.camera_matrix = tf.convert_to_tensor(np.linalg.inv(camera_matrix))

    def build(self, input_shape):
        return
        # if input_shape[1] != 3:
        #     raise ValueError(f'input shape is not valid for this layer: {input_shape}')
    def call(self, inputs):
        return tf.tensordot(inputs, self.camera_matrix, 1)

class XyzToSrgbLayer(layers.Layer):
    def __init__(self, color_ch=3):
        super(XyzToSrgbLayer, self).__init__()
        self.color_ch = color_ch

    def build(self, input_shape):
        return
        # if input_shape[1] != 3:
        #     raise ValueError(f'input shape is not valid for this layer: {input_shape}')
    def call(self, inputs):
        return tf.tensordot(inputs, xyz_to_srgb, 1)

class ColorTransformLayer(layers.Layer):
    def __init__(self, color_ch=3):
        super(ColorTransformLayer, self).__init__()
        self.color_ch = color_ch

    def build(self, input_shape):
        self.w = self.add_weight(shape=(3,3),
                                 initializer='random_normal',
                                 trainable=True)
    def call(self, inputs):
        return tf.tensordot(inputs, self.w, 1)


## Define model

In [None]:

# Create a model using the color correction layer for the fuji XE2.
# This model will only fit one ligthing scenario.

corrected_model = Sequential([
    keras.Input(shape=(3,), batch_size=32),
    # CameraCorrectionLayer(xe2_rgb_matrix),
    layers.Dense(3, activation='elu'),
    layers.Dense(3, activation='sigmoid'),
    ColorTransformLayer(),
    XyzToSrgbLayer(),
])

corrected_model.compile(optimizer='adam', loss='mse')
corrected_model.build()
corrected_model.summary()

In [None]:
corrected_model.fit(train_dataset, epochs=8, validation_data=test_dataset)

## Evaluate output

In [None]:
TEST_RAW_PATH = '../training/DSCF5731.RAF'
TEST_EXR_PATH = '../training/DSCF5731.exr'
pt_cloudy_wb, _, xe2_rgb_matrix = get_raw(TEST_RAW_PATH)
test_exr_img_arr = pyexr.read(TEST_EXR_PATH)[:,:,:3]

display_image(apply_model(test_exr_img_arr * pt_cloudy_wb, corrected_model))

In [None]:
display_image(apply_model(test_exr_img_arr * pt_cloudy_wb, corrected_model))

## Set up datasets for a model including WB params

In [None]:
def sample_pairs(a_arr, b_arr, n_samples=100000):
    """Returns an array containing n_samples from a color swatch."""
    assert len(a_arr) == len(b_arr)
    rand_index = np.random.randint(0, a_arr.shape[0], n_samples)
    return a_arr[rand_index], b_arr[rand_index]


def load_dataset(feature_image_path, target_image_path, wb_array):
    feat_img_arr = pyexr.read(feature_image_path)[:,:,:3]
    targ_img_arr = pyexr.read(feature_image_path)[:,:,:3]

    feat_rgb_samples, targ_rgb_samples = sample_pairs(to_color_arr(feat_img_arr), to_color_arr(targ_img_arr))
    feat_rgb_with_wb = [np.concatenate((rgb, wb_array)) for rgb in feat_rgb_samples]

    return tf.data.Dataset.from_tensor_slices((feat_rgb_with_wb, targ_rgb_samples))


wb_params = list(get_wb_params(RAW_PATHS))

wb_datasets = []

for i in range(len(TRAINING_PATHS)):
    wb_datasets.append(load_dataset(*TRAINING_PATHS[i], wb_params[i]))

wb_datasets_combined = tf.data.Dataset.sample_from_datasets(wb_datasets, weights=[0.4, 0.3, 0.3])


## Define Correction Model

In [None]:

def create_model():
    # Merged input contains RGB values for the pixel as [0:3] and white balance
    # values for the pixel as [3:]. The purpose of this approach is to create
    # a general model which can color correct for images of any light source,
    # however it doesn't seem to be working as intended, even though it is
    # doing *something*.
    merged_input = keras.Input(shape=(6,), name='merged_input')
    rgb_input = layers.Lambda(lambda x: x[:, 0:3])(merged_input)
    wb_input = layers.Lambda(lambda x: x[:, 3:])(merged_input)

    # Apply a set of weights to the white balance. The first white balance weights
    # will be used on the camera rgb input. The second will be used on the 
    # converted RGB input.
    wb_transform_1 = ColorTransformLayer()(wb_input)
    wb_transform_1 = layers.Dense(3)(wb_transform_1)
    wb_transform_1 = layers.Dense(3)(wb_transform_1)
    
    wb_transform_2 = ColorTransformLayer()(wb_input)
    wb_transform_2 = layers.Dense(3)(wb_transform_2)
    wb_transform_2 = layers.Dense(3)(wb_transform_2)

    rgb_layers = layers.Identity()(rgb_input)

    # Apply a transformed white balance.
    rgb_layers = layers.Multiply()([rgb_layers, wb_transform_1])

    # Apply some gamma and color curves to the input.
    # Using multiple curves or applying curves after color correction
    # does not improve performance. Increasing the dimensions may or
    # may not improve performance.
    rgb_layers = layers.Dense(3)(rgb_layers)
    rgb_layers = layers.Dense(6, activation='elu', name='gamma_1')(rgb_layers)
    rgb_layers = layers.Dense(6)(rgb_layers)
    rgb_layers = layers.Dense(6, activation='sigmoid', name='s_curve_1')(rgb_layers)
    rgb_layers = layers.Dense(3)(rgb_layers)
    
    # Apply the color transformation to the RGB image.
    rgb_layers = ColorTransformLayer()(rgb_layers)
   
    # Apply a second transformation of the white balance to the output.
    # This has a minimal but measurable improvement.
    rgb_layers = layers.Multiply()([rgb_layers, wb_transform_2])

    model = keras.Model(inputs=merged_input, outputs=rgb_layers)

    return model

In [None]:
color_correction_model = create_model()
color_correction_model.compile(optimizer='adam', loss='mse')
color_correction_model.build((None, 6))
color_correction_model.summary()

In [None]:
color_correction_model.fit(wb_datasets_combined.batch(32), epochs=20)

In [None]:
color_correction_model.save_weights(f'color_correction_model_0_1.weights.h5')

In [None]:

# This is too slow to be practical. It must be running operations on single
# pixels instead of the image as a batch.

def apply_correction_model(img_arr, model, wb_matrix):
    r_orig, c_orig = img_arr.shape[:2]

    img_linear_arr = to_color_arr(img_arr)

    process_arr = np.array([np.concatenate((r, wb_matrix)) for r in img_linear_arr])
    display(process_arr.shape)
    
    output_arr = model.predict(process_arr)
    return output_arr.reshape((r_orig, c_orig, 3))

In [None]:

output_img = apply_correction_model(test_exr_img_arr, color_correction_model, pt_cloudy_wb)

In [None]:
display_image(output_img)

In [None]:
apply_correction_model(exr_img_arr, color_correction_model, indoor_wb)

In [None]:
display_image(test_srgb_img_arr)

In [None]:

# For some reason DarkTable is having problems with this EXR. Gimp will produce
# usable results when using the eyedropper tool on an image to adjust color
# curves, but this is really what we're trying to avoid doing.

pyexr.write(
    f"./corrected_output_img.exr",
    np.asarray(output_img),
    precision=pyexr.HALF,
)

## Conclusions

This approach above is still not working well. I suspect that the model is not
using the white balancing layers, which is why it is producing the muddy
results, which seem to be too warm or too cool depending on which lighting
scenario the images were taken with.

### Alternative approaches

* Create a color correction model which we co-train to produce a white balance
  transform, using as input either:
    - the white / black point
    - the white balance type
    - color correction weights
* Create a bunch of color correction sample images using color temperature controlled panel lights.