Copyright 2021-2024 @ Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd

This code is a part of MindSPONGE:
MindSpore Simulation Package tOwards Next Generation molecular modelling.

MindSPONGE is open-source software based on the AI-framework:
MindSpore (https://www.mindspore.cn/)

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

See the License for the specific language governing permissions and
limitations under the License.

Training forcefield model with Cybertron.

配置环境, 导入必要的包

In [1]:
import sys
import os
os.environ['MINDSPONGE_HOME']='/home/mindspore/work/summerschool/mindscience/MindSPONGE/src'
os.environ['GLOG_v']=str(4)
sys.path.append('../..')


import time
import numpy as np
import mindspore as ms
from mindspore import nn
from mindspore import Tensor
from mindspore import dataset as ds
from mindspore import context
from mindspore.train import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig

In [2]:
from cybertron import Cybertron
from cybertron.model import MolCT
from cybertron.embedding import MolEmbedding
from cybertron.readout import AtomwiseReadout
from cybertron.train import MolWithLossCell, MolWithEvalCell
from cybertron.train.lr import TransformerLR
from cybertron.train.loss import MSELoss
from cybertron.train.metric import MAE, RMSE, Loss
from cybertron.train.callback import TrainMonitor

In [3]:
data_dir = './data'
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")

加载数据集

In [4]:
sys_name = data_dir + '/dataset/data_normed_'

train_file = sys_name + 'trainset_83197_64_64.npz'
valid_file = sys_name + 'validset_128.npz'

train_data = np.load(train_file)
valid_data = np.load(valid_file)

atom_type = Tensor(train_data['atom_type'], ms.int32)
scale = train_data['scale']
shift = train_data['shift']

设置训练网络参数

In [5]:
dim_feature = 128
activation = 'silu'

emb = MolEmbedding(
    dim_node=dim_feature,
    emb_dis=True,
    emb_bond=False,
    cutoff=1,
    cutoff_fn='smooth',
    rbf_fn='log_gaussian',
    activation=activation,
    length_unit='nm',
)

mod = MolCT(
    dim_feature=dim_feature,
    dim_edge_emb=emb.dim_edge,
    n_interaction=3,
    n_heads=8,
    activation=activation,
)

readout = AtomwiseReadout(
    dim_output=1,
    dim_node_rep=dim_feature,
    activation=activation,
)

net = Cybertron(embedding=emb, model=mod, readout=readout, atom_type=atom_type, length_unit='nm')
_ = net.set_scaleshift(scale=scale, shift=shift)

conf_dir = data_dir + '/conf'
_ = net.save_configure('configure_MolCT' + '.yaml' , conf_dir)

打印网络参数情况

In [6]:
tot_params = 0
for i, param in enumerate(net.trainable_params()):
    tot_params += param.size
    print(i, param.name, param.shape)
print('Total parameters: ', tot_params)

net.print_info()

0 embedding.atom_embedding.embedding_table (64, 128)
1 model.filter_net.linear.weight (128, 64)
2 model.filter_net.linear.bias (128,)
3 model.filter_net.residual.nonlinear.mlp.0.weight (128, 128)
4 model.filter_net.residual.nonlinear.mlp.0.bias (128,)
5 model.filter_net.residual.nonlinear.mlp.1.weight (128, 128)
6 model.filter_net.residual.nonlinear.mlp.1.bias (128,)
7 model.interaction.0.positional_embedding.norm.gamma (128,)
8 model.interaction.0.positional_embedding.norm.beta (128,)
9 model.interaction.0.positional_embedding.x2q.weight (128, 128)
10 model.interaction.0.positional_embedding.x2k.weight (128, 128)
11 model.interaction.0.positional_embedding.x2v.weight (128, 128)
12 model.interaction.0.multi_head_attention.output.weight (128, 128)
13 model.interaction.1.positional_embedding.norm.gamma (128,)
14 model.interaction.1.positional_embedding.norm.beta (128,)
15 model.interaction.1.positional_embedding.x2q.weight (128, 128)
16 model.interaction.1.positional_embedding.x2k.weight

训练数据集分batch, 包装训练网络和loss function

In [7]:
N_EPOCH = 10
REPEAT_TIME = 1
BATCH_SIZE = 32

ds_train = ds.NumpySlicesDataset(
    {'coordinate': train_data['coordinate'],
        'energy': train_data['label'],
        'force': train_data['force'],
        }, shuffle=True)

data_keys = ds_train.column_names
ds_train = ds_train.batch(BATCH_SIZE)
ds_train = ds_train.repeat(REPEAT_TIME)

force_dis = train_data['avg_force_dis']
loss_network = MolWithLossCell(data_keys=data_keys,
                                network=net,
                                loss_fn=[MSELoss(), MSELoss(force_dis=force_dis)],
                                calc_force=True,
                                loss_weights=[1, 100],
                                )
loss_network.print_info()

Cell wrapper: MolWithLossCell
    Input arguments:
       Argument 0: coordinate
    Labels, loss function and weights:
       Label 0: energy,  loss: MSELoss,  weight: 0.00990099.
       Label 1: force,  loss: MSELoss,  weight: 0.990099.
    Calculate force using automatic differentiation: True
--------------------------------------------------------------------------------


验证数据集分batch, 包装训练网络和evaluation loss function

In [8]:
ds_valid = ds.NumpySlicesDataset(
    {'coordinate': valid_data['coordinate'],
        'energy': valid_data['label'],
        'force': valid_data['force'],
        }, shuffle=True)
data_keys = ds_valid.column_names
ds_valid = ds_valid.batch(128)
ds_valid = ds_valid.repeat(1)

eval_network = MolWithEvalCell(data_keys=data_keys,
                                network=net,
                                loss_fn=[MSELoss(), MSELoss(force_dis=force_dis)],
                                calc_force=True,
                                loss_weights=[1, 100],
                                normed_evaldata=True
                                )
eval_network.print_info()

Cell wrapper: MolWithEvalCell
    Input arguments:
       Argument 0: coordinate
    Labels, loss function and weights:
       Label 0: energy,  loss: MSELoss,  weight: 0.00990099.
       Label 1: force,  loss: MSELoss,  weight: 0.990099.
    Calculate force using automatic differentiation: True
    Using normalized dataset: True
--------------------------------------------------------------------------------


设置学习率和优化器, 包装训练模型, 设置输出模型情况,和checkpoint文件的保存

In [9]:
lr = TransformerLR(learning_rate=1., warmup_steps=4000, dimension=dim_feature)
optim = nn.Adam(params=net.trainable_params(), learning_rate=lr)
energy_mae = 'EnergyMAE'
forces_mae = 'ForcesMAE'
forces_rmse = 'ForcesRMSE'
eval_loss = 'EvalLoss'
model = Model(loss_network, eval_network=eval_network, optimizer=optim,
                metrics={eval_loss: Loss(), energy_mae: MAE(0), forces_mae: MAE(1),
                        forces_rmse: RMSE(1)})

ckpt_name = 'cybertron-' + net.model_name.lower()
ckpt_dir = data_dir + '/ckpt'
record_cb = TrainMonitor(model, ckpt_name, per_epoch=1, avg_steps=32,
                            directory=ckpt_dir, eval_dataset=ds_valid, best_ckpt_metrics=forces_rmse)

config_ck = CheckpointConfig(save_checkpoint_steps=32, keep_checkpoint_max=64)
ckpoint_cb = ModelCheckpoint(prefix=ckpt_name, directory=ckpt_dir, config=config_ck)

开始训练

In [10]:
print("Start training ...")
beg_time = time.time()
model.train(N_EPOCH, ds_train, callbacks=[
            record_cb, ckpoint_cb], dataset_sink_mode=False)
end_time = time.time()
used_time = end_time - beg_time
m, s = divmod(used_time, 60)
h, m = divmod(m, 60)
print("Training Fininshed!")
print("Training Time: %02d:%02d:%02d" % (h, m, s))



Start training ...
Epoch: 1, Step: 128, Learning_rate: 4.4371976e-05, Last_Loss: 0.74542654, Avg_loss: 0.9044692646712065, EvalLoss: 0.6588590741157532, EnergyMAE: 48.16619110107422, ForcesMAE: 694.361328125, ForcesRMSE: 1757.4310797297287
Epoch: 2, Step: 256, Learning_rate: 8.909333e-05, Last_Loss: 0.18310946, Avg_loss: 0.20136635564267635, EvalLoss: 0.18085232377052307, EnergyMAE: 42.73646926879883, ForcesMAE: 372.437109375, ForcesRMSE: 912.9966045939054
Epoch: 3, Step: 384, Learning_rate: 0.0001338147, Last_Loss: 0.14448962, Avg_loss: 0.11873087473213673, EvalLoss: 0.11840511858463287, EnergyMAE: 15.34036636352539, ForcesMAE: 308.76402994791664, ForcesRMSE: 746.9594812750332
Epoch: 4, Step: 512, Learning_rate: 0.00017853607, Last_Loss: 0.082165144, Avg_loss: 0.09984835772775114, EvalLoss: 0.10890037566423416, EnergyMAE: 76.10836791992188, ForcesMAE: 284.77265625, ForcesRMSE: 676.0275142329638
Epoch: 5, Step: 640, Learning_rate: 0.0002232574, Last_Loss: 0.08014881, Avg_loss: 0.082336