# Joint denoising and demosaicing

install tensorflow:

make pip virtual environment,
```
python3 -m venv tensorflow
cd tensorflow
source tensorflow/bin/activate
```
then install
```
python3 -m pip install tensorflow[and-cuda]
pip install jupyter matplotlib
ipython kernel install --user --name=venv
jupyter notebook this-notebook.ipynb
```
and then select the venv kernel

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import mixed_precision
from tensorflow.python.client import device_lib
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
import os
import struct

rng = np.random.default_rng(666)

# Check if the GPU is detected
local_device_protos = device_lib.list_local_devices()
[x.name for x in local_device_protos]

In [None]:
TILE_SIZE = 128
TILES_PER_IMAGE = 1000
NUM_TRAINING_IMG = 7

# probably not much more to gain for the training data we currently have
EPOCHS_COUNT = 100

# this is going to be rggb in planes. 5th channel is noise estimation
INPUT_CHANNELS_COUNT = 5

MODEL_NAME = str(EPOCHS_COUNT)

In [None]:
def read_reference_pfm(filename):
    decoded = []
    with open(filename, 'rb') as pfm_file:

        line1, line2, line3 = (pfm_file.readline().decode('latin-1').strip() for _ in range(3))
        assert line1 in ('PF', 'Pf')
        
        channels = 3 if "PF" in line1 else 1
        width, height = (int(s) for s in line2.split())
        scale_endianess = float(line3)
        bigendian = scale_endianess > 0
        scale = abs(scale_endianess)

        buffer = pfm_file.read()
        samples = width * height * channels
        assert len(buffer) == samples * 4
        
        fmt = f'{"<>"[bigendian]}{samples}f'
        decoded = struct.unpack(fmt, buffer)
    # make sure extent is multiple of 2
    decoded = np.reshape(np.array(decoded), (height, width, 3))
    wd = (width//2)*2
    ht = (height//2)*2
    decoded = decoded[:ht,:wd,:]
    image_tensor = tf.constant(decoded, dtype=np.float16)
    image_tensor = tf.reshape(image_tensor, [ht, wd, 3])
    return image_tensor

def display_img(img):
    plt.imshow(img.numpy().astype(np.float32)[:,:,:3])
    plt.axis('off')
    plt.show()

In [None]:
def generate_input_tiles(img, n, ox, oy, flip, noise_a, noise_b):
    res = []
    for i in range(n):
        # do some augmentation shenannigans: flip, add noise, mosaic, add noise estimation as channel
        b = img[oy[i]:oy[i]+TILE_SIZE,ox[i]:ox[i]+TILE_SIZE,:]
        if flip[i] == 1:
            b = np.flip(b, 0)
        if flip[i] == 2:
            b = np.flip(b, 1)
        if flip[i] == 3:
            b = np.flip(b, (0,1))
        # cut into mosaic planes
        wd = TILE_SIZE
        ht = TILE_SIZE
        red    = np.reshape(b[0:ht:2,0:wd:2,0], (ht//2,wd//2))
        green0 = np.reshape(b[1:ht:2,0:wd:2,1], (ht//2,wd//2))
        green1 = np.reshape(b[0:ht:2,1:wd:2,1], (ht//2,wd//2))
        blue   = np.reshape(b[1:ht:2,1:wd:2,2], (ht//2,wd//2))
        # compute noise channel and simulate additive gaussian/poissonian noise
        noise  = np.sqrt(noise_a[i] + noise_b[i] * green0)
        red    = red    + np.sqrt(np.maximum(noise_a[i] + red   *noise_b[i], 0.0))*np.reshape(rng.normal(0, 1, (ht//2) * (wd//2)),  (ht//2,wd//2))
        green0 = green0 + np.sqrt(np.maximum(noise_a[i] + green0*noise_b[i], 0.0))*np.reshape(rng.normal(0, 1, (ht//2) * (wd//2)),  (ht//2,wd//2))
        green1 = green1 + np.sqrt(np.maximum(noise_a[i] + green1*noise_b[i], 0.0))*np.reshape(rng.normal(0, 1, (ht//2) * (wd//2)),  (ht//2,wd//2))
        blue   = blue   + np.sqrt(np.maximum(noise_a[i] + blue  *noise_b[i], 0.0))*np.reshape(rng.normal(0, 1, (ht//2) * (wd//2)),  (ht//2,wd//2))
        b = np.stack((red,green0,green1,blue,noise),axis=2)
        deg = tf.constant(b, dtype=np.float16)
        deg = tf.reshape(deg, [ht//2, wd//2, 5])
        res.append(deg)
    return res

def generate_output_tiles(img, n, ox, oy, flip):
    return [
        img[oy[i]:oy[i]+TILE_SIZE,ox[i]:ox[i]+TILE_SIZE,:] if flip[i] == 0 else
        np.flip(img[oy[i]:oy[i]+TILE_SIZE,ox[i]:ox[i]+TILE_SIZE,:], 0 if flip[i] == 1 else (1 if flip[i] == 2 else (0,1)))
        for i in range(n)
    ]

def display_tile_grid(tiles, lines_count=4, columns_count=2, size=2):
    fig, axes = plt.subplots(lines_count, columns_count*len(tiles), figsize=(size*columns_count*len(tiles), size*lines_count))

    for i in range(lines_count):
        for j in range(columns_count):
            for k in range(len(tiles)):
                ax = axes[i, j*len(tiles) + k]
                ax.imshow(tiles[k][i*columns_count + j].numpy().astype(np.float32)[:,:,:3], interpolation='nearest')
                ax.axis('off')

    fig.tight_layout()
    plt.show()

In [None]:
images = []
for i in range(NUM_TRAINING_IMG):
    folder = 'data/img_' + str(i).zfill(4)
    img_output = read_reference_pfm(folder + '.pfm')

    images.append(img_output)

In [None]:
display_img(images[0])


In [None]:
tiles_input = []
tiles_expected = []

for i in range(NUM_TRAINING_IMG):
    
    img_expected = images[i]
    n = TILES_PER_IMAGE
    ox = rng.integers(0, np.shape(img_expected)[1]-TILE_SIZE, n)
    oy = rng.integers(0, np.shape(img_expected)[0]-TILE_SIZE, n)
    flip = rng.integers(0, 4, n)
    # noise_a = rng.uniform(0, 1000.0/65535.0, n)
    # noise_b = rng.uniform(0, 20.0/65535.0, n)
    noise_a = rng.exponential(100.0/65535.0, n)
    noise_b = rng.exponential(2.0/65535.0, n)


    tiles_input += generate_input_tiles(img_expected, n, ox, oy, flip, noise_a, noise_b)
    tiles_expected += generate_output_tiles(img_expected, n, ox, oy, flip)

tiles_input, tiles_expected = tf.convert_to_tensor(tiles_input), tf.convert_to_tensor(tiles_expected)

display_tile_grid([tiles_input, tiles_expected], size=3)

In [None]:
# Validation
# TODO make sure this does *not* overlap with the training input..
validation_tiles_input = tiles_input[-10:]
validation_tiles_expected = tiles_expected[-10:]


validation_tiles_input, validation_tiles_expected = tf.convert_to_tensor(validation_tiles_input), tf.convert_to_tensor(validation_tiles_expected)


## CNN

In [None]:
mixed_precision.set_global_policy('mixed_float16')

In [None]:
input_shape = (None, None, INPUT_CHANNELS_COUNT)

inputs = tf.keras.Input(shape=input_shape)

x_128 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)

x_64 = layers.MaxPooling2D((2, 2))(x_128)
x_64 = layers.Conv2D(43, (3, 3), activation='relu', padding='same')(x_64)

x_32 = layers.MaxPooling2D((2, 2))(x_64)
x_32 = layers.Conv2D(57, (3, 3), activation='relu', padding='same')(x_32)

x_16 = layers.MaxPooling2D((2, 2))(x_32)
x_16 = layers.Conv2D(76, (3, 3), activation='relu', padding='same')(x_16)

x_8 = layers.MaxPooling2D((2, 2))(x_16)
x_8 = layers.Conv2D(101, (3, 3), activation='relu', padding='same')(x_8)

x_4 = layers.MaxPooling2D((2, 2))(x_8)
x_4 = layers.Conv2D(101, (3, 3), activation='relu', padding='same')(x_4)

x = layers.UpSampling2D(size=(2, 2),interpolation='nearest')(x_4)
x = layers.Concatenate()([x, x_8]) # Skip connection
# x = layers.Conv2D(101, (3, 3), activation='relu', padding='same')(x)
x = layers.Conv2D(101, (3, 3), activation='relu', padding='same')(x)

x = layers.UpSampling2D(size=(2, 2),interpolation='nearest')(x)
x = layers.Concatenate()([x, x_16]) # Skip connection
# x = layers.Conv2D(76, (3, 3), activation='relu', padding='same')(x)
x = layers.Conv2D(76, (3, 3), activation='relu', padding='same')(x)

x = layers.UpSampling2D(size=(2, 2),interpolation='nearest')(x)
x = layers.Concatenate()([x, x_32]) # Skip connection
# x = layers.Conv2D(57, (3, 3), activation='relu', padding='same')(x)
x = layers.Conv2D(57, (3, 3), activation='relu', padding='same')(x)

x = layers.UpSampling2D(size=(2, 2),interpolation='nearest')(x)
x = layers.Concatenate()([x, x_64]) # Skip connection
# x = layers.Conv2D(43, (3, 3), activation='relu', padding='same')(x)
x = layers.Conv2D(43, (3, 3), activation='relu', padding='same')(x)

x = layers.UpSampling2D(size=(2, 2),interpolation='nearest')(x)
x = layers.Concatenate()([x, x_128]) # Skip connection
# x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
# the number of output channels here makes the next layer insanely expensive. can we do with less?
x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)

x = layers.Concatenate()([x, inputs]) # Skip connection
x = layers.UpSampling2D(size=(2, 2),interpolation='nearest')(x)
# x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
# XXX why is this float32?
x = layers.Conv2D(3, (3, 3), activation='relu', padding='same', dtype='float32')(x)

outputs = x

# Create the model
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam', loss='mean_absolute_error')

model.summary()

In [None]:
# Train the model
hist = model.fit(tiles_input, tiles_expected, epochs=EPOCHS_COUNT, validation_data=(validation_tiles_input, validation_tiles_expected))


training_loss = hist.history['loss']
validation_loss = hist.history['val_loss']

In [None]:
# write raw f16 coefficients of the model into a file.
# probably in the future also write some information about training data/loss/network configuration?
model.get_weights()
with open(MODEL_NAME+'.dat', 'wb') as f:
    for a in model.get_weights():
        f.write(a.astype('float16').tobytes())


In [None]:
xs = range(10, len(training_loss))
plt.plot(xs, training_loss[10:], label = 'Training loss')
plt.plot(xs, validation_loss[10:], label = 'Validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
# plt.yscale('log')
plt.legend()
plt.show()

In [None]:
# Make predictions on the training data
predictions = tf.convert_to_tensor(model.predict(tiles_input))

display_tile_grid([tiles_input, predictions, tiles_expected], size=3)

## Evaluation

In [None]:
# Make predictions on the evaluation data
test_predictions = tf.convert_to_tensor(model.predict(validation_tiles_input))

display_tile_grid([validation_tiles_input, test_predictions, validation_tiles_expected], lines_count=4, size=3)