In [1]:
import tensorflow as tf
import numpy as np

from neural_networks.tf_model_construction import train_model, test_model

In [2]:
binary_class_model_params = {
    "data_path": "D:/Datasets/one_step_retrosynth_ai_output/",  # Input data folder path.
    "fold_ind": 1,                                              # Index of the current data fold.
    "fp_type": "hsfp",                                          # Which fingerprint dataset to use.
    "oversample": False,                                        # Usage of SMOTE oversampling.
    "num_layers": 3,                                            # Total number of layers in the network.
    "random_seed": 101,                                         # Random seed used for reproducibility purposes.
    "classification_type": "binary",                            # Type of classification problem.

    "input_size": 1024,                                         # Input layer size.
    "input_activation": tf.nn.relu,                             # Input layer activation function.
    "input_init": tf.initializers.glorot_normal(),              # Input layer weight initialization function.
    "input_dropout": 0.2,                                       # Input layer dropout value.

    "hidden_size": 1024,                                        # Hidden layer size.
    "hidden_activation": tf.nn.relu,                            # Hidden layer activation function.
    "hidden_init": tf.initializers.glorot_normal(),             # Hidden layer weight initialization function.
    "hidden_dropout": 0.2,                                      # Hidden layer dropout value.

    "output_size": 1,                                           # Output layer size.
    "output_activation": tf.nn.sigmoid,                         # Output layer activation function.
    "output_init": tf.initializers.glorot_normal(),             # Output layer weight initialization function.

    "learning_rate": 0.001,                                     # Adam optimizer learning rate.

    "max_num_epochs": 200,                                      # Maximum number of epochs.
    "batch_size": 32,                                           # Batch size.
    "early_stopping_interval": 10,                              # Number of epochs for early stopping detection.

    # Path to a folder to save the TensorFlow summaries of the trained network.
    "model_log_folder": "neural_networks/constructed_models/binary_classification/",
}

In [3]:
num_folds = 5
avg_tr_loss, avg_val_loss, avg_tst_loss, avg_val_acc, avg_tst_acc = [], [], [], [], []

for i in range(1, num_folds+1):
    binary_class_model_params["fold_ind"] = i

    _, temp_tr_loss, temp_val_loss, temp_val_accuracy = train_model(binary_class_model_params, verbose=False)
    test_labels, test_output, temp_tst_loss, temp_tst_accuracy = test_model(binary_class_model_params, verbose=False)
    
    avg_tr_loss.append(temp_tr_loss)
    avg_val_loss.append(temp_val_loss)
    avg_tst_loss.append(temp_tst_loss)
    avg_val_acc.append(temp_val_accuracy)
    avg_tst_acc.append(temp_tst_accuracy)

print("Average training loss: {}".format(np.average(avg_tr_loss)))
print("Average validation loss: {}".format(np.average(avg_val_loss)))
print("Average training loss: {}".format(np.average(avg_tst_loss)))
print("Average validation accuracy: {}".format(np.average(avg_val_acc)))
print("Average test loss: {}".format(np.average(avg_tst_acc)))

INFO:tensorflow:Restoring parameters from neural_networks/constructed_models/binary_classification/fold_1/checkpoints/-40
INFO:tensorflow:Restoring parameters from neural_networks/constructed_models/binary_classification/fold_2/checkpoints/-24
INFO:tensorflow:Restoring parameters from neural_networks/constructed_models/binary_classification/fold_3/checkpoints/-34
INFO:tensorflow:Restoring parameters from neural_networks/constructed_models/binary_classification/fold_4/checkpoints/-10
INFO:tensorflow:Restoring parameters from neural_networks/constructed_models/binary_classification/fold_5/checkpoints/-14
Average training loss: 0.3980202376842499
Average validation loss: 0.39461448788642883
Average training loss: 0.39838963747024536
Average validation accuracy: 0.9188070297241211
Average test loss: 0.9119033813476562
