## ViCTORIA training

In [None]:
import sys
sys.path.insert(1, "lib/")
sys.path.insert(1, "model/")
from ViCTORIA_network import ViCTORIA_Network, coeff_determination
from dataset_utils import read_many_hdf5

import pickle
import tensorflow as tf
import matplotlib.pyplot as plt

import platform
if platform.system() == "Darwin":
    %config InlineBackend.figure_format="retina"  # For high DPI display

if tf.test.gpu_device_name(): 
  print(f"Default GPU Device: {tf.test.gpu_device_name()}")

We load our training dataset.

In [None]:
directory = "E:/IA/Deep_ViCTORIA/Datasets/SE_ResNet/"
positions_train, scores_train = read_many_hdf5(397182, directory, "_train")

We retrieve the model's optimal hyperparameters from the `choose_hyperparameters` notebook.

In [None]:
input_file = open("model/hyperparameters.pickle", "rb")
hyperparams = pickle.load(input_file)
input_file.close()

We create a model with these hyperparameters.

In [None]:
model = ViCTORIA_Network(filters=hyperparams["filters"], nb_blocks=hyperparams["nb_blocks"])
model.build((1, 8, 8, 15))
model.compile(loss='mean_absolute_error', optimizer="adam", metrics=[coeff_determination])
model.summary()

It's time to train! You can take a coffee or two (maybe more).

In [None]:
nb_epochs = 20
history = model.fit(positions_train, scores_train, verbose=1, epochs=nb_epochs)

We save the model's weights.

In [None]:
model.save_weights("model/weights/weights")

In [None]:
def plot_history(history, model, path=None):
  epochs = range(1, nb_epochs + 1)

  _, axs = plt.subplots(1, 2, figsize=(20, 5))
  loss = history.history["loss"]
  score = history.history["coeff_determination"]

  axs[0].plot(epochs, loss, "r-.", label=f"{model.nb_blocks} SE-ResNet blocks")
  axs[0].set_xlabel("Epoch")
  axs[0].set_ylabel("(Mean Absolute Error)")
  axs[0].set_title('Training loss')

  axs[1].plot(epochs, score, "g-.", label=f"{model.nb_blocks} SE-ResNet blocks")
  axs[1].set_xlabel("Epoch")
  axs[1].set_ylabel("($R^2$)")
  axs[1].set_title('Training score')
  
  if path:
    plt.savefig(path)

We plot the loss and accuracy of each epoch.

In [None]:
plot_history(history, model, path="results/ViCTORIA_history.pdf")