diff --git a/tensorflow/Train_Model.ipynb b/tensorflow/Train_Model.ipynb index 90cc6ef3..50eaf9be 100644 --- a/tensorflow/Train_Model.ipynb +++ b/tensorflow/Train_Model.ipynb @@ -471,6 +471,7 @@ "source": [ "MSE_Loss = tf.keras.losses.MeanSquaredError()\n", "\n", + "#YUV loss to weigh in favour of luminance (2 to 1), as humans are less sensitive to chroma degradation\n", "def YUV_Error(y_true, y_pred):\n", " true_yuv = tf.image.rgb_to_yuv(y_true)\n", " pred_yuv = tf.image.rgb_to_yuv(y_pred)\n", @@ -515,6 +516,7 @@ } ], "source": [ + "#Super-convergence with clipping followed by fine tuning with Adam allows somewhat fair convergence within a few minutes\n", "model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=10000.0, clipvalue=0.00000001, momentum=0.9, decay=0.0, nesterov=True), loss=YUV_Error)\n", "model.fit(dataset_train.repeat(), epochs=1, steps_per_epoch=4096, validation_data=dataset_valid)\n", "model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss=YUV_Error)\n", @@ -606,6 +608,7 @@ } ], "source": [ + "#Show results\n", "d_pred = next(iter(dataset_valid))\n", "show_images(d_pred[1], val_range=[0, 1], scale=8)\n", "show_images(d_pred[0], val_range=[0, 1], scale=8)\n",