# Joint denoising and demosaicing

## install pytorch

make pip virtual environment,
```
python3 -m venv torch
cd torch
source torch/bin/activate
```
then install
```
python3 -m pip install torch-cuda-installer
torch-cuda-installer --torch
pip install jupyter matplotlib
ipython kernel install --user --name=venv
jupyter notebook this-notebook.ipynb
```
and then select the venv kernel

## prepare training data

make a subdirectory `data/` and copy a bunch of `img_XXXX.pfm` training images. these should be noise free and free of demosaicing artifacts. i usually use highres raws and export at 1080p.

export as linear rec2020 pfm, set colour matrix in the `colour` module to `rec2020` for all images. this makes sure the raw camera rgb values will be passed on to the output.

set `NUM_TRAINING_IMG=X` to the number of training images below.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
import os
import struct

rng = np.random.default_rng(666)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)

device

In [None]:
TILE_SIZE = 128
TILES_PER_IMAGE = 1000
NUM_TRAINING_IMG = 12

# probably not much more to gain for the training data we currently have
EPOCHS_COUNT = 50

# this is going to be rggb in planes. 5th channel is noise estimation
INPUT_CHANNELS_COUNT = 5

MODEL_NAME = str(EPOCHS_COUNT)

# actually maybe this is a bad idea and might make the weights depend on the matrix:
# colour matrix camera to rec2020 for my telephone:
c2rec2020 = np.array(
[[2.792677, -0.134533, 0.263296],
 [-0.110118, 0.991432, 0.071795],
 [0.117527, -0.650657, 2.678170]], dtype=np.float16)
rgb2yuv = np.array(
[[0.299, 0.587, 0.114],
 [-0.14713, -0.28886, 0.436],
 [0.615, -0.51499, -0.10001]], dtype=np.float16)
M = rgb2yuv @ c2rec2020

In [None]:
def read_reference_pfm(filename):
    decoded = []
    with open(filename, 'rb') as pfm_file:

        line1, line2, line3 = (pfm_file.readline().decode('latin-1').strip() for _ in range(3))
        assert line1 in ('PF', 'Pf')
        
        channels = 3 if "PF" in line1 else 1
        width, height = (int(s) for s in line2.split())
        scale_endianess = float(line3)
        bigendian = scale_endianess > 0
        scale = abs(scale_endianess)

        buffer = pfm_file.read()
        samples = width * height * channels
        assert len(buffer) == samples * 4
        
        fmt = f'{"<>"[bigendian]}{samples}f'
        decoded = struct.unpack(fmt, buffer)
    # make sure extent is multiple of 2
    decoded = np.reshape(np.array(decoded), (height, width, 3))
    wd = (width//2)*2
    ht = (height//2)*2
    decoded = decoded[:ht,:wd,:]
    image = decoded.astype(np.float16)
    image = np.reshape(image, (ht, wd, 3))
    return image

def display_img(img):
    plt.imshow(img.astype(np.float32)[:,:,:3])
    plt.axis('off')
    plt.show()

# fwd map from network output to image
N2I = np.reshape([[[12*((TILE_SIZE//2)*(j//2) + (i//2)) + c + 3*((j%2)*2+(i%2)) for c in range(3)] for i in range(TILE_SIZE)] for j in range(TILE_SIZE)], (TILE_SIZE*TILE_SIZE*3))
I2N = np.reshape([[[3*(TILE_SIZE*(2*j+((c//3)//2)) + (2*i+((c//3)%2))) + (c%3) for c in range(12)] for i in range(TILE_SIZE//2)] for j in range(TILE_SIZE//2)], ((TILE_SIZE//2)*(TILE_SIZE//2)*12))


In [None]:
def generate_input_tiles(img, n, ox, oy, flip, noise_a, noise_b):
    res = []
    for i in range(n):
        # do some augmentation shenannigans: flip, add noise, mosaic, add noise estimation as channel
        b = img[oy[i]:oy[i]+TILE_SIZE,ox[i]:ox[i]+TILE_SIZE,:]
        if flip[i] == 1:
            b = np.flip(b, 0)
        if flip[i] == 2:
            b = np.flip(b, 1)
        if flip[i] == 3:
            b = np.flip(b, (0,1))
        # cut into mosaic planes
        wd = TILE_SIZE
        ht = TILE_SIZE
        # FIXME: argh i swapped the two greens here. remember to do the same in glsl later or change it in both places!
        red    = np.reshape(b[0:ht:2,0:wd:2,0], (ht//2,wd//2))
        green0 = np.reshape(b[1:ht:2,0:wd:2,1], (ht//2,wd//2))
        green1 = np.reshape(b[0:ht:2,1:wd:2,1], (ht//2,wd//2))
        blue   = np.reshape(b[1:ht:2,1:wd:2,2], (ht//2,wd//2))
        # compute noise channel and simulate additive gaussian/poissonian noise
        noise  = np.sqrt(noise_a[i] + noise_b[i] * green0)
        red    = red    + np.sqrt(np.maximum(noise_a[i] + red   *noise_b[i], 0.0))*np.reshape(rng.normal(0, 1, (ht//2) * (wd//2)),  (ht//2,wd//2))
        green0 = green0 + np.sqrt(np.maximum(noise_a[i] + green0*noise_b[i], 0.0))*np.reshape(rng.normal(0, 1, (ht//2) * (wd//2)),  (ht//2,wd//2))
        green1 = green1 + np.sqrt(np.maximum(noise_a[i] + green1*noise_b[i], 0.0))*np.reshape(rng.normal(0, 1, (ht//2) * (wd//2)),  (ht//2,wd//2))
        blue   = blue   + np.sqrt(np.maximum(noise_a[i] + blue  *noise_b[i], 0.0))*np.reshape(rng.normal(0, 1, (ht//2) * (wd//2)),  (ht//2,wd//2))
        b = np.stack((red,green0,green1,blue,noise),axis=2)
        deg = b.astype(np.float16)
        deg = np.reshape(deg, (ht//2, wd//2, 5))
        res.append(torch.permute(torch.reshape(torch.from_numpy(deg),
                                 (1,TILE_SIZE//2,TILE_SIZE//2,INPUT_CHANNELS_COUNT)), (0,3,1,2)))
    return res

def generate_output_tiles(img, n, ox, oy, flip):
    return [
        torch.permute(torch.reshape(
        torch.from_numpy(np.reshape(np.reshape(  # sort from network to image order
        img[oy[i]:oy[i]+TILE_SIZE,ox[i]:ox[i]+TILE_SIZE,:] if flip[i] == 0 else
        np.flip(img[oy[i]:oy[i]+TILE_SIZE,ox[i]:ox[i]+TILE_SIZE,:], 0 if flip[i] == 1 else (1 if flip[i] == 2 else (0,1))),
        (TILE_SIZE * TILE_SIZE * 3))[I2N], (TILE_SIZE//2, TILE_SIZE//2, 12)).astype(np.float16)),
        (1,TILE_SIZE//2,TILE_SIZE//2,12)), (0,3,1,2))
        for i in range(n)
    ]

def display_tile_grid(tiles, lines_count=4, columns_count=2, size=2):
    fig, axes = plt.subplots(lines_count, columns_count*len(tiles), figsize=(size*columns_count*len(tiles), size*lines_count))

    for i in range(lines_count):
        for j in range(columns_count):
            for k in range(len(tiles)):
                ax = axes[i, j*len(tiles) + k]
                ax.imshow(
                    np.reshape(np.reshape(np.reshape(
                    tiles[k][i*columns_count + j].permute(0,2,3,1).detach().cpu().numpy().astype(np.float32),
                    (TILE_SIZE*TILE_SIZE*3))[N2I],
                    (TILE_SIZE*TILE_SIZE,3)) @ c2rec2020.T,           
                    (TILE_SIZE,TILE_SIZE,3))
                    [:,:,:3]
                    if (k % len(tiles)) != 0 else
                    tiles[k][i*columns_count + j].permute(0,2,3,1).numpy().astype(np.float32)[0,:,:,:3],
                    interpolation='nearest')
                ax.axis('off')

    fig.tight_layout()
    plt.show()

In [None]:
images = []
for i in range(NUM_TRAINING_IMG):
    folder = 'data/img_' + str(i).zfill(4)
    img_output = read_reference_pfm(folder + '.pfm')

    images.append(img_output)

In [None]:
display_img(images[0])

In [None]:
tiles_input = []
tiles_expected = []

for i in range(NUM_TRAINING_IMG):
    
    img_expected = images[i]
    n = TILES_PER_IMAGE
    ox = rng.integers(0, np.shape(img_expected)[1]-TILE_SIZE, n)
    oy = rng.integers(0, np.shape(img_expected)[0]-TILE_SIZE, n)
    flip = rng.integers(0, 4, n)
    # noise_a = rng.uniform(0, 1000.0/65535.0, n)
    # noise_b = rng.uniform(0, 20.0/65535.0, n)
    noise_a = rng.exponential(100.0/65535.0, n)
    noise_b = rng.exponential(2.0/65535.0, n)

    tiles_input += generate_input_tiles(img_expected, n, ox, oy, flip, noise_a, noise_b)
    tiles_expected += generate_output_tiles(img_expected, n, ox, oy, flip)

display_tile_grid([tiles_input, tiles_expected], size=3)

In [None]:
# Validation
# TODO make sure this does *not* overlap with the training input..
validation_tiles_input = tiles_input[-10:]
validation_tiles_expected = tiles_expected[-10:]

## CNN

In [None]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        # pm = 'reflect' # not implemented you suckers
        pm = 'zeros'
        self.enc0 = nn.Conv2d(INPUT_CHANNELS_COUNT, 32, 3, padding='same', padding_mode=pm)
        self.enc1 = nn.Conv2d(32, 43, 3, padding='same', padding_mode=pm)
        self.enc2 = nn.Conv2d(43, 57, 3, padding='same', padding_mode=pm)
        self.enc3 = nn.Conv2d(57, 76, 3, padding='same', padding_mode=pm)
        self.enc4 = nn.Conv2d(76, 101, 3, padding='same', padding_mode=pm)
        self.enc5 = nn.Conv2d(101, 101, 3, padding='same', padding_mode=pm)
        
        self.dec0 = nn.Conv2d(101+101, 101, 3, padding='same', padding_mode=pm)
        self.dec1 = nn.Conv2d(101+76, 76, 3, padding='same', padding_mode=pm)
        self.dec2 = nn.Conv2d(76+57, 57, 3, padding='same', padding_mode=pm)
        self.dec3 = nn.Conv2d(57+43, 43, 3, padding='same', padding_mode=pm)
        # XXX and this last layer might want to get +INPUT_CHANNELS_COUNT
        self.dec4 = nn.Conv2d(43+32, 12, 3, padding='same', padding_mode=pm)
        
        self.pool = nn.MaxPool2d(2, 2)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, x):
        x_128 = F.relu(self.enc0(x))
        x_64  = F.relu(self.enc1(self.pool(x_128)))
        x_32  = F.relu(self.enc2(self.pool(x_64)))
        x_16  = F.relu(self.enc3(self.pool(x_32)))
        x_8   = F.relu(self.enc4(self.pool(x_16)))
        x_4   = F.relu(self.enc5(self.pool(x_8)))
        
        x     = F.relu(self.dec0(torch.cat([self.upsample(x_4), x_8],   1)))
        x     = F.relu(self.dec1(torch.cat([self.upsample(x),   x_16],  1)))
        x     = F.relu(self.dec2(torch.cat([self.upsample(x),   x_32],  1)))
        x     = F.relu(self.dec3(torch.cat([self.upsample(x),   x_64],  1)))
        # TODO and maybe the input image
        x     = F.relu(self.dec4(torch.cat([self.upsample(x),   x_128], 1)))
        
        return x

model = Net()
model.compile()

In [None]:
# train the model. re-run this cell to train it some more.
import torch.optim as optim

# increase until GPU oom
batch_size = 300

cmat = torch.from_numpy(M.T).cuda() # transpose because we'll multiply it from the left

class ColourLoss(nn.Module):
    def __init__(self):
        super(ColourLoss, self).__init__()

    def forward(self, output, target):
        # average blocks of 4 colours by summing them:
        o3 = torch.sum(torch.reshape(output, (batch_size*(TILE_SIZE//2)*(TILE_SIZE//2),4,3)), 1)        
        t3 = torch.sum(torch.reshape(target, (batch_size*(TILE_SIZE//2)*(TILE_SIZE//2),4,3)), 1)
        # apply linear transform to get yuv:
        o3 = torch.matmul(o3, cmat)
        t3 = torch.matmul(t3, cmat)
        # okay this doesn't work. probably need colour matrix here + Lab and do L1 L and L2 ab or some such
        # TODO: something red - green and blue - green and compare that with uh, MSE?
        return torch.sum(
            torch.abs(o3[:,1] - t3[:,1])).div(4*t3.size()[0]) + torch.sum(
            torch.abs(o3[:,2] - t3[:,2])).div(4*t3.size()[0])



training_loss = []
optimiser = optim.Adam(model.parameters())
criterion0 = torch.nn.L1Loss()
criterion1 = ColourLoss()
# criterion = torch.nn.MSELoss()

for epoch in range(EPOCHS_COUNT):
    running_loss = 0.0
    for i in range((len(tiles_input)+batch_size-1)//batch_size):
        inputs  = torch.cat(tiles_input[batch_size*i:batch_size*(i+1)]).cuda()
        targets = torch.cat(tiles_expected[batch_size*i:batch_size*(i+1)]).cuda()

        # zero the parameter gradients
        optimiser.zero_grad()
        
        # forward + backward + optimise
        with torch.autocast(device_type=device, dtype=torch.float16):
            outputs = model(inputs)
            loss = criterion0(outputs, targets) + criterion1(outputs, targets)
        loss.backward()
        optimiser.step()

        running_loss += loss.item()

    print(f'[{epoch + 1}] loss: {running_loss:.3f}')
    training_loss.append(running_loss)

print('finished training')

In [None]:
# write raw f16 coefficients of the model into a file.
# probably in the future also write some information about training data/loss/network configuration? like a hash?
with open(MODEL_NAME+'.dat', 'wb') as f:
    for param in model.parameters():
        p = param.data.detach().cpu().numpy().astype('float16')
        # print(type(param), param.size())
        print(np.shape(p))
        f.write(p.tobytes())

In [None]:
xs = range(10, len(training_loss))
plt.plot(xs, training_loss[10:], label = 'training loss')
plt.xlabel('epochs')
plt.ylabel('loss')
# plt.yscale('log')
plt.legend()
plt.show()

In [None]:
# import gc
# del model
# gc.collect()
# torch.cuda.empty_cache()

# make predictions on the training data
offset=4000
predictions = [model(tiles_input[offset+i].cuda()) for i in range(8)]
display_tile_grid([tiles_input[offset:], predictions, tiles_expected[offset:]], size=3)

offset=10000 # some crazy aliasing
predictions = [model(tiles_input[offset+i].cuda()) for i in range(8)]
display_tile_grid([tiles_input[offset:], predictions, tiles_expected[offset:]], size=3)

## Evaluation

In [None]:
# TODO validation data

# Make predictions on the evaluation data
# test_predictions = tf.convert_to_tensor(model.predict(validation_tiles_input))
# display_tile_grid([validation_tiles_input, test_predictions, validation_tiles_expected], lines_count=4, size=3)