In [None]:
import time
import random

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras.models import load_model

In [None]:
from imagegenerator import HDR2SDRImageGenerator
from multicsvreader import MultiCSVReader
from lutmaker import predict, write_lut_fast

In [None]:
tf.config.list_physical_devices('GPU')

## Create the data generator

In [None]:
# image_map = []
# imapf = open("../pyfilemap.txt", "r")

# # s07e01 = 1091612160
# # s07e02 = 1075912320
# # s07e03 = 828664320

# hdr = None
# for line in imapf:
#     if hdr is None:
#         hdr = line.strip()
#     else:
#         image_map.append((hdr, line.strip()))
#         hdr = None

# imapf.close()

# random.shuffle(image_map)

In [None]:
# image_gen = HDR2SDRImageGenerator(
#     image_map,
#     image_size=(1920, 1080),
#     batch_size=2048,
#     crop=(0, 0, 132, 132),
#     buffer_size=4
# )

In [None]:
csv_types = {
    'hr': np.uint16,
    'hg': np.uint16,
    'hb': np.uint16,
    'sr': np.uint8,
    'sg': np.uint8,
    'sb': np.uint8,
}

csv_files = [
    's07e01_1', 
    's07e02_1', 
    's07e03_1', 
    's07e04_1', 
#     's07e05_1', 
]
csv_files = [f"../data/{s}.xz" for s in csv_files]


csv_file_sizes = [1091612160, 1075912320, 828664320, 1058580480]

image_gen = MultiCSVReader(
    csv_list=csv_files,
    csv_sizes=csv_file_sizes,
    batch_size=2048,
    csv_dtypes=csv_types,
    x_cols=['hr', 'hg', 'hb'],
    y_cols=['sr', 'sg', 'sb']
)

## Train the model

In [None]:
def timestamp():
    return str(int(time.time()))


def build_model():
    model = keras.Sequential([
        layers.Dense(32, activation=tf.nn.relu, input_shape=[3]),
        layers.Dense(32, activation=tf.nn.relu),
        layers.Dense(3)
    ])
    
    optimizer = tf.keras.optimizers.Nadam(learning_rate=0.001)
    
    # loss: mean_squared_error or mean_absolute_error
    model.compile(loss='mean_absolute_error',
                  optimizer=optimizer,
                  metrics=['mean_absolute_error', 'mean_squared_error'])
    return model

In [None]:
model = build_model()
display(model.summary())

In [None]:
cp_filepath = "../checkpoints/weights-" + timestamp() + "-{loss:.2f}.hdf5"
checkpoint = keras.callbacks.ModelCheckpoint(cp_filepath, monitor='loss')

model.fit(image_gen, epochs=1, callbacks=[checkpoint], shuffle=False)

## Test and output LUT

In [None]:
predict(model, 65535, 61937, 771) # The yellow of the CW logo - should be (133, 132, 81)

In [None]:
write_lut_fast("../generated_lut.cube", model, 65)