In [1]:
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import time
import os
import tensorflow as tf
from kgcnn.data.utils import save_pickle_file, load_pickle_file
from datetime import timedelta
from tensorflow_addons import optimizers
from kgcnn.data.transform.scaler.standard import StandardLabelScaler
from kgcnn.data.transform.scaler.molecule import QMGraphLabelScaler
import kgcnn.training.scheduler
from kgcnn.training.history import save_history_score, load_history_list
from kgcnn.metrics.metrics import ScaledMeanAbsoluteError, ScaledRootMeanSquaredError
from sklearn.model_selection import KFold
from kgcnn.training.hyper import HyperParameter
from kgcnn.data.serial import deserialize as deserialize_dataset
from kgcnn.model.serial import deserialize as deserialize_model
from kgcnn.utils.plots import plot_train_test_loss, plot_predict_true
from kgcnn.utils.devices import set_devices_gpu

2024-03-03 02:22:17.749549: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-03 02:22:17.924131: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.

TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/add

In [2]:
class GNNTrainConfig:
    def __init__(self, 
                 hyper="hyper/hyper_mp_e_form.py", 
                 category= None, 
                 model=None, 
                 dataset=None, 
                 make= None, 
                 gpu=None,
                 fold=None,
                 seed=42):
        self.hyper = hyper
        self.category = category
        self.model = model
        self.dataset = dataset
        self.make = make
        self.gpu = gpu 
        self.seed = seed
        self.fold = fold

    def to_dict(self):
        return vars(self)

# Usage:
config = GNNTrainConfig(hyper="training/hyper/hyper_mp_e_form.py", model='DenseGNN', make='make_model_asu', dataset='MatProjectEFormDataset', seed=42)
print("Input of argparse:", config.to_dict())
args = config.to_dict()
args


# python training/train_crystal.py --dataset MatProjectGapDataset --model GATv2 --category GATv2 --hyper training/hyper/hyper_mp_gap.py
# python training/train_crystal.py --dataset MatProjectEFormDataset --model GIN --category GIN --hyper training/hyper/hyper_mp_e_form.py

Input of argparse: {'hyper': 'training/hyper/hyper_mp_e_form.py', 'category': None, 'model': 'DenseGNN', 'dataset': 'MatProjectEFormDataset', 'make': 'make_model_asu', 'gpu': None, 'seed': 42, 'fold': None}


{'hyper': 'training/hyper/hyper_mp_e_form.py',
 'category': None,
 'model': 'DenseGNN',
 'dataset': 'MatProjectEFormDataset',
 'make': 'make_model_asu',
 'gpu': None,
 'seed': 42,
 'fold': None}

In [None]:
# Set seed.
np.random.seed(args["seed"])
tf.random.set_seed(args["seed"])
tf.keras.utils.set_random_seed(args["seed"])

# Assigning GPU.
set_devices_gpu(args["gpu"])

# A class `HyperParameter` is used to expose and verify hyperparameter.
# The hyperparameter is stored as a dictionary with sectiomegnet 'model', 'dataset' and 'training'.
hyper = HyperParameter(
    hyper_info=args["hyper"], hyper_category=args["category"],
    model_name=args["model"], model_class=args["make"], dataset_class=args["dataset"])
hyper.verify()

# Loading a specific per-defined dataset from a module in kgcnn.data.datasets.
# Those sub-classed classes are named after the dataset like e.g. `MatProjectEFormDataset`
# If no name is given, a general `CrystalDataset` is constructed.
# However, the construction then must be fully defined in the data section of the hyperparameter,
# including all methods to run on the dataset. Information required in hyperparameter are for example 'file_path',
# 'data_directory' etc.
# Making a custom training script rather than configuring the dataset via hyperparameter can be
# more convenient.
dataset = deserialize_dataset(hyper["dataset"])

# Check if dataset has the required properties for model input. This includes a quick shape comparison.
# The name of the keras `Input` layer of the model is directly connected to property of the dataset.
# Example 'edge_indices' or 'node_attributes'. This couples the keras model to the dataset.
dataset.assert_valid_model_input(hyper["model"]["config"]["inputs"])

# Filter the dataset for invalid graphs. At the moment invalid graphs are graphs which do not have the property set,
# which is required by the model's input layers, or if a tensor-like property has zero length.
dataset.clean(hyper["model"]["config"]["inputs"])
data_length = len(dataset)  # Length of the cleaned dataset.

# Train on graph labels. Must be defined by the dataset.
labels = np.array(dataset.obtain_property("graph_labels"))
label_names = dataset.label_names
label_units = dataset.label_units
if len(labels.shape) <= 1:
    labels = np.expand_dims(labels, axis=-1)

# Training on multiple targets for regression.
multi_target_indices = hyper["training"]["multi_target_indices"] if "multi_target_indices" in hyper[
    "training"] else None
if multi_target_indices is not None:
    labels = labels[:, multi_target_indices]
    if label_names is not None:
        label_names = [label_names[i] for i in multi_target_indices]
    if label_units is not None:
        label_units = [label_units[i] for i in multi_target_indices]
print("Labels '%s' in '%s' have shape '%s'." % (label_names, label_units, labels.shape))

# Make output directory
filepath = hyper.results_file_path()
postfix_file = hyper["info"]["postfix_file"]

# For Crystals, also the atomic number is required to properly pre-scale extensive quantities like total energy.
atoms = dataset.obtain_property("node_number")

# Cross-validation via random KFold split form `sklearn.model_selection`.
kf = KFold(**hyper["training"]["cross_validation"]["config"])

# Training on splits. Since training on crystal datasets can be expensive, there is a 'execute_splits' parameter to not
# train on all splits for testing.
execute_folds = args["fold"]
if "execute_folds" in hyper["training"]:
    execute_folds = hyper["training"]["execute_folds"]
model, hist, x_test, y_test, scaler, atoms_test = None, None, None, None, None, None
train_test_indices = [
    (train_index, test_index) for train_index, test_index in kf.split(X=np.zeros((data_length, 1)), y=labels)]

num_folds = len(train_test_indices)
splits_done = 0
time_list = []
train_indices_all, test_indices_all = [], []
for current_fold, (train_index, test_index) in enumerate(train_test_indices):
    test_indices_all.append(test_index)
    train_indices_all.append(train_index)
    print(train_indices_all[0].shape)

    # Only do execute_splits out of the k-folds of cross-validation.
    if execute_folds:
        if current_fold not in execute_folds:
            continue
    print("Running training on fold: %s" % current_fold)

    # Make the model for current split using model kwargs from hyperparameter.
    # They are always updated on top of the models default kwargs.
    model = deserialize_model(hyper["model"])

    # First select training and test graphs from indices, then convert them into tensorflow tensor
    # representation. Which property of the dataset and whether the tensor will be ragged is retrieved from the
    # kwargs of the keras `Input` layers ('name' and 'ragged').
    x_train, y_train = dataset[train_index].tensor(hyper["model"]["config"]["inputs"]), labels[train_index]
    x_test, y_test = dataset[test_index].tensor(hyper["model"]["config"]["inputs"]), labels[test_index]
    # Also keep the same information for atomic numbers of the structures.
    atoms_test = [atoms[i] for i in test_index]
    atoms_train = [atoms[i] for i in train_index]

    # Normalize training and test targets via a sklearn `StandardScaler`. No other scaler are used at the moment.
    # Scaler is applied to target if 'scaler' appears in hyperparameter. Only use for regression.
    if "scaler" in hyper["training"]:
        print("Using StandardScaler.")
        if hyper["training"]["scaler"]["class_name"] == "QMGraphLabelScaler":
            scaler = QMGraphLabelScaler(**hyper["training"]["scaler"]["config"])
        else:
            scaler = StandardLabelScaler(**hyper["training"]["scaler"]["config"])

        y_train = scaler.fit_transform(y=y_train, atomic_number=atoms_train)
        y_test = scaler.transform(y=y_test, atomic_number=atoms_test)
        scaler_scale = scaler.get_scaling()

        # If scaler was used we add rescaled standard metrics to compile, since otherwise the keras history will not
        # directly log the original target values, but the scaled ones.
        mae_metric = ScaledMeanAbsoluteError(scaler_scale.shape, name="scaled_mean_absolute_error")
        rms_metric = ScaledRootMeanSquaredError(scaler_scale.shape, name="scaled_root_mean_squared_error")
        if scaler_scale is not None:
            mae_metric.set_scale(scaler_scale)
            rms_metric.set_scale(scaler_scale)
        metrics = [mae_metric, rms_metric]

        # Save scaler to file
        scaler.save(os.path.join(filepath, f"scaler{postfix_file}_fold_{current_fold}"))

    else:
        print("TRAINING: Not using StandardScaler for regression.")
        metrics = None

    # Compile model with optimizer and loss
    model.compile(**hyper.compile(loss="mean_absolute_error", metrics=metrics))
    print(model.summary())

    # Start and time training
    start = time.time()
    hist = model.fit(x_train, y_train,
                     validation_data=(x_test, y_test),
                     **hyper.fit())
    stop = time.time()
    print("Print Time for training: ", str(timedelta(seconds=stop - start)))
    time_list.append(str(timedelta(seconds=stop - start)))
    # Get loss from history
    save_pickle_file(hist.history, os.path.join(filepath, f"history{postfix_file}_fold_{current_fold}.pickle"))

    # Plot prediction
    predicted_y = model.predict(x_test)
    true_y = y_test

    if scaler:
        predicted_y = scaler.inverse_transform(y=predicted_y, atomic_number=atoms_test)
        true_y = scaler.inverse_transform(y=true_y, atomic_number=atoms_test)

    plot_predict_true(predicted_y, true_y,
                      filepath=filepath, data_unit=label_units,
                      model_name=hyper.model_name, dataset_name=hyper.dataset_class, target_names=label_names,
                      file_name=f"predict{postfix_file}_fold_{current_fold}.png", show_fig=False)

    # Save keras-model to output-folder.
    model.save(os.path.join(filepath, f"model{postfix_file}_fold_{current_fold}"))

    splits_done = splits_done + 1

history_list = load_history_list(os.path.join(filepath, f"history{postfix_file}_fold_(i).pickle"), num_folds)

# Plot training- and test-loss vs epochs for all splits.
data_unit = hyper["data"]["data_unit"] if "data_unit" in hyper["data"] else ""
plot_train_test_loss(history_list, loss_name=None, val_loss_name=None,
                     model_name=hyper.model_name, data_unit=data_unit, dataset_name=hyper.dataset_class,
                     filepath=filepath, file_name=f"loss{postfix_file}.png")

# Save original data indices of the splits.
np.savez(os.path.join(filepath, f"{hyper.model_name}_test_indices_{postfix_file}.npz"), *test_indices_all)
np.savez(os.path.join(filepath, f"{hyper.model_name}_train_indices_{postfix_file}.npz"), *train_indices_all)

# Save hyperparameter again, which were used for this fit.
hyper.save(os.path.join(filepath, f"{hyper.model_name}_hyper{postfix_file}.json"))

# Save score of fit result for as text file.
save_history_score(history_list, loss_name=None, val_loss_name=None,
                   model_name=hyper.model_name, data_unit=data_unit, dataset_name=hyper.dataset_class,
                   model_class=hyper.model_class, multi_target_indices=multi_target_indices,
                   execute_folds=execute_folds,seed=args["seed"],
                   filepath=filepath, file_name=f"score{postfix_file}.yaml", time_list=time_list)



## TEST

In [4]:
# hyper_mp_jdft2d
# 39.43 cgcnn
# 39.53, gogn  44-->43.1--->42.3(GIN+NMPN)  45(GIN+AttentiveFP)
# 44 GIN   
# 47.68 schnet
# 50.57 megnet
# 59.2 Nmpn
# 60.31 MEGAN

# GIN  Schnet  GATv2  GraphSAGE  INorp  


# e_form 
# 0.0255 cgcnn
# 0.0272 Schnet
# 0.027 megnet
# 0.0328 PAiNN
# 0.034 GraphSAGE
# 0.035 INorp
# 0.04 MEGAN
# 0.0459 HamNet
# 0.0527 GATv2
# 0.064(GIN+AttentiveFP)
# 0.069 GIN
# 0.1483 AttentiveFP
# 1.47 Unet


# gap
# 0.1797 cgcnn
# 0.20 GraphSAGE
# 0.21 megnet
# 0.22 INorp
# 0.2317 MEGAN
# 0.243 GATv2
# 0.2487 Schnet
# 0.275(GIN+NMPN)   0.288(GIN+AttentiveFP)
# 0.32 GIN
# 0.64 AttentiveFP


# MatProjectPerovskitesDataset
# 0.031 coGN
# 0.0354  GATv2
# 0.038 Schnet
# 0.039 megnet
# 0.0396 HamNet
# 0.0405 NMPN 
# 0.0408 GraphSAGE
# 0.0418 AttentiveFP
# 0.0428 MEGAN    0.044(GIN+NMPN)  0.048 GIN+AttentiveFP
# 0.0435 INorp
# 0.0508 PAiNN
# 0.0539 GIN
# 0.0536 MoGAT
#   0.08 MAT



# hyper_mp_dielectric
# 0.303 GIN  0.309  0.3000  0.297  0.30(GIN+NMPN)  0.297(GIN+AttentiveFP)
# 0.307 Schnet      0.321(Hamnet+Schnet)---0.308
# 0.327 DimeNetPP
# 0.324 megnet
# 0.3247 MEGAN
# 0.3341  NMPN
# 0.3468 GraphSAGE
# 0.351 GATv2
# 0.362  PAiNN
# 0.41 AttentiveFP
# 0.4563 INorp
# 0.50 HamNet
# 0.4597 MoGAT



# hyper_mp_phonons
# 29.19 Megnet
# 36.68  HamNet  36.97(GIN+NMPN)
# 37 NMPN
# 37.25  AttentiveFP  GIN+AttentiveFP----》40  40.54  36.5  37.08  37
# 40.8 Schnet           50(Hamnet+Schnet)---44.45 ---- 42  ---- 42.3 
# 45.53  GATv2
# 49.45 MEGAN
# 50.62 GIN
# 51.59 PAiNN
# 52.51 MoGAT
# 84 GraphSAGE
# 94 INorp


# NMPN(3), HamNet(4), GIN(3), Schnet,     AttentiveFP(1), GATv2(1)，MEGAN(1)  



# MatProjectLogKVRHDataset
# 0.057 GIN+AttentiveFP  0.0588   0.06 (GIN+NMPN)
# 0.0583 GIN     0.058(GIN+GATv2)  
# 0.0588 PAiNN
# 0.0606 Schnet
# 0.068 HamNet
# 0.067 NMPN
# 0.0687 INorp(charge)  0.0699(no charge) 0.0626  0.0625  0.0598  0.0591  0.0578 
# 0.0668 megnet
# 0.069 MEGAN
# 0.0693 GraphSAGE
# 0.0734 GATv2 0.0754(delete n_in)
# 0.1044 AttentiveFP



# MatProjectLogGVRHDataset
# 0.0804 (GIN+NMPN)   
# 0.081 GIN+AttentiveFP  ---> 0.079 
# 0.0812 GIN
# 0.0816 Schnet
# 0.0852 HamNet
# 0.0851 megnet
# 0.0868 INorp
# 0.0919 MEGAN
# 0.0937 GraphSAGE
# 0.096 GATv2
# 0.0944 NMPN
# 0.1329 AttentiveFP