In [None]:
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [None]:
import gc

# gc.collect()

In [None]:
import time
import random

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras.models import load_model

In [None]:
tf.config.list_physical_devices('GPU')

In [None]:
csv_types = {
    'hr': np.uint16,
    'hg': np.uint16,
    'hb': np.uint16,
    'sr': np.uint8,
    'sg': np.uint8,
    'sb': np.uint8,
}

In [None]:
def build_model():
    model = keras.Sequential([
        layers.Dense(32, activation=tf.nn.relu, input_shape=[3]),
        layers.Dense(32, activation=tf.nn.relu),
        layers.Dense(3)
    ])
    
    optimizer = tf.keras.optimizers.Nadam(learning_rate=0.001)
    
    # loss: mean_squared_error or mean_absolute_error
    model.compile(loss='mean_absolute_error',
                  optimizer=optimizer,
                  metrics=['mean_absolute_error', 'mean_squared_error'])
    return model

In [None]:
def timestamp():
    return str(int(time.time()))

def train_model_on(model, file_path):
    
    # load x and y from csv
    csv_data = pd.read_csv(file_path, dtype=csv_types)
    x_train = csv_data[['hr', 'hg', 'hb']]
    y_train = csv_data[['sr', 'sg', 'sb']]
    shuffle_buffer_size = len(x_train) // 48
    
    # load into tensorflow
    dataset = tf.data.Dataset.from_tensor_slices((x_train.values, y_train.values))
    
    # deallocate pandas data
    csv_data = None
    x_train = None
    y_train = None
    
    # shuffle and batch
    train_dataset = dataset.shuffle(shuffle_buffer_size).batch(2048)
    
    # deallocate old tf dataset
    dataset = None
    
    # configure checkpoints - will save weights after epochs
    cp_filepath = "../checkpoints/weights-" + timestamp() + "-{loss:.2f}.hdf5"
    checkpoint = keras.callbacks.ModelCheckpoint(cp_filepath, monitor='loss')
    
    # train the model
    model.fit(train_dataset, epochs=1, callbacks=[checkpoint])
    
    # deallocate training dateset
    train_dataset = None
    gc.collect()

In [None]:
csv_files = [
    's07e01_1', 's07e01_2', 's07e01_3', 's07e01_4',
    's07e02_1', 's07e02_2', 's07e02_3', 's07e02_4',
#     's07e03_1', 's07e03_2', 's07e03_3', 's07e03_4',
#     's07e04_1', 's07e04_2', 's07e04_3', 's07e04_4',
]
csv_files = [f"../data/{s}.xz" for s in csv_files]
random.shuffle(csv_files)

TEMP_WEIGHTS_FILE = '../checkpoints/temp.hdf5'

model = build_model()
display(model.summary())

for csv_file in csv_files:
    print(f"Training on: {csv_file}")
    train_model_on(model, csv_file)
    model.save_weights(TEMP_WEIGHTS_FILE)
    model = None
    tf.keras.backend.clear_session()
    model = build_model()
    model.load_weights(TEMP_WEIGHTS_FILE)


In [None]:
def predict(hr, hg, hb):
    tf_in = tf.convert_to_tensor([[hr, hg, hb]])
    sdr_out = model.predict(tf_in)
    sr, sg, sb = sdr_out[0]
    sr = np.clip(sr, 0.0, 255.0)
    sg = np.clip(sg, 0.0, 255.0)
    sb = np.clip(sb, 0.0, 255.0)
    return sr, sg, sb

def batch_predict(lst):
    tf_in = tf.convert_to_tensor(lst)
    return batch_predict_tf(tf_in)

def batch_predict_tf(tf_in):
    sdr_out = model.predict(tf_in)
    return np.clip(sdr_out, 0.0, 255.0)

In [None]:
# usually 17, 35, or 65. 129 doesn't work in some ffmpeg builds
lut_size = 129

lut_step_size = 65535.0 / lut_size

def luti_to_hdr(i):
    return lut_step_size * i

def sdr_to_lutv(c):
    return c / 255.0

def write_lut_fast():
    lut_file = open("../generated_lut.cube", "w+")
    lut_file.write("TITLE \"HDR_2_SDR_generated_lut\"")
    lut_file.write("\n")
    lut_file.write("LUT_3D_SIZE " + str(lut_size))
    lut_file.write("\n")
    for bi in range(0, lut_size):
        for gi in range(0, lut_size):
            ril = list(range(0, lut_size))
            hdr_list = [[luti_to_hdr(ri), luti_to_hdr(gi), luti_to_hdr(bi)] for ri in ril]
            prediction_list = batch_predict(hdr_list)
            for sr, sg, sb in prediction_list:
                lr, lg, lb = sdr_to_lutv(sr), sdr_to_lutv(sg), sdr_to_lutv(sb)
                lut_file.write(f"{lr:.6f} {lg:.6f} {lb:.6f}")
                lut_file.write("\n")
    lut_file.close()

In [None]:
write_lut_fast()