# This script is used to train the MPNN_model for dGbyG.

In [2]:
import os
import torch
import numpy as np
import pandas as pd

from dGbyG.data import TrainDataset

from dGbyG.model import MPNN_model, Trainer

# 1. Load the training data

In [5]:
# Read the training data as pd.DataFrame
train_data_path = '../data/TrainingData.csv'
TrainingData_df = pd.read_csv(train_data_path)
equation = TrainingData_df.loc[:, 'reaction'].to_numpy()
standard_dG_prime = TrainingData_df.loc[:, 'standard_dg_prime'].to_numpy()

# compute the weight for each data point
mean_std = TrainingData_df.loc[:,'std'].mean()
Scale = []
for n, sem in zip(TrainingData_df.loc[:,'n'], TrainingData_df.loc[:,'SEM']):
    if np.isnan(sem):
        scale = mean_std
    else:
        scale = (sem**2 + mean_std**2/n)**0.5
    Scale.append(scale)
Scale = np.array(Scale)
weight = 1/np.array(Scale)/np.median(Scale)

# 2. Set super parameters

In [None]:
TrainSet = TrainDataset(equations=equation, dGs=standard_dG_prime, weights=weight)

# four super parameters for the model
atom_feature_size = TrainSet[0].x.size(1)
bond_feature_size = TrainSet[0].edge_attr.size(1)
embedding_dim = 256
num_layers = 2

# 3. Set the directory path to save the model weights

In [19]:
model_weights_folder = f'../models/mpnn_A{atom_feature_size}_B{bond_feature_size}_E{embedding_dim}_L{num_layers}'

if not os.path.isdir(model_weights_folder):
    os.makedirs(model_weights_folder)

# 4. Train the network

In [20]:
# Train the network N times (here N=100).
N=5
for n in range(N):
    dG = standard_dG_prime + np.random.randn(standard_dG_prime.shape[0]) * Scale
    TrainSet = TrainDataset(equations=equation, dGs=dG, weights=weight)

    network = MPNN_model(atom_dim=atom_feature_size, bond_dim=bond_feature_size, emb_dim=embedding_dim, num_layer=num_layers)
    trainer = Trainer()
    trainer.network = network

    loss_history, Result_df, i = trainer.train(TrainSet, epochs=9000, lr=1e-4, weight_decay=1e-6)
    
    torch.save(trainer.network.state_dict(), os.path.join(model_weights_folder, str(n)+'.pt'))
    print(f'{n} done')

print('All done')

train on: cuda:0
2025-09-11 00:46:13 start preparing data
2025-09-11 00:46:14 start training
2025-09-11 00:46:45 training have done
0 done
train on: cuda:0
2025-09-11 00:46:46 start preparing data
2025-09-11 00:46:47 start training
2025-09-11 00:47:16 training have done
1 done
train on: cuda:0
2025-09-11 00:47:17 start preparing data
2025-09-11 00:47:19 start training
2025-09-11 00:47:49 training have done
2 done
train on: cuda:0
2025-09-11 00:47:50 start preparing data
2025-09-11 00:47:51 start training
2025-09-11 00:48:21 training have done
3 done
train on: cuda:0
2025-09-11 00:48:22 start preparing data
2025-09-11 00:48:23 start training
2025-09-11 00:48:53 training have done
4 done
All done
