Copyright 2021-2024 @ Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd

This code is a part of MindSPONGE:
MindSpore Simulation Package tOwards Next Generation molecular modelling.

MindSPONGE is open-source software based on the AI-framework:
MindSpore (https://www.mindspore.cn/)

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

See the License for the specific language governing permissions and
limitations under the License.


Training discriminator.

配置环境, 导入必要的包

In [1]:
import sys
import numpy as np
import time
import h5py
import mindspore as ms
from mindspore import nn
from mindspore import Tensor
from mindspore import dataset as ds
from mindspore.ops import functional as F
from mindspore.train import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore import context

import sys
import os
os.environ['MINDSPONGE_HOME']='/home/mindspore/work/summerschool/mindscience/MindSPONGE/src'
os.environ['GLOG_v']=str(4)
path = os.getenv('MINDSPONGE_HOME')
if path:
    sys.path.insert(0, path)
sys.path.append('../..')
data_dir = './data'

from cybertron.model import MolCT
from cybertron.embedding import MolEmbedding
from cybertron.readout import AtomwiseReadout
from cybertron.cybertron import Cybertron
from cybertron.train import TransformerLR
from cybertron.train import TrainMonitor
from cybertron.train import WithAdversarialLossCell
from cybertron.train import CrossEntropyLoss
from cybertron.train.metric import MAE, RMSE, Loss

定义二分类loss

In [2]:
class BCELossForDiscriminator(nn.Cell):
    def __init__(self, reduction: str = 'mean'):
        super().__init__()

        self.cross_entropy = nn.BCEWithLogitsLoss(reduction)

    def construct(self, pos_pred: Tensor, neg_pred: Tensor):
        """calculate cross entropy loss function

        Args:
            pos_pred (Tensor):  Positive samples
            neg_pred (Tensor):  Negative samples

        Returns:
            loss (Tensor):      Loss function with same shape of samples

        """
        pos_loss = self.cross_entropy(pos_pred, F.ones_like(pos_pred))
        neg_loss = self.cross_entropy(neg_pred, F.zeros_like(neg_pred))

        return pos_loss + neg_loss, None, None, None

In [3]:
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")

加载数据集, 其中正样本为训练模型用的数据集, 负样本为模拟轨迹中的构象
这样的选择是为了识别训练数据集没有见过的分子构象

In [4]:
ori_train_set = data_dir + '/dataset/data_normed_trainset_83197_64_64.npz'
ori_train_set = np.load(ori_train_set)
Z = ori_train_set['atom_type']
pos_data = ori_train_set['coordinate']

atom_type = Tensor(Z,ms.int32)
num_atom = int(atom_type.shape[-1])

data_num = pos_data.shape[0]
traj_file = data_dir + '/traj/PES_4-100000-800K-bias-NORMAL2.h5md'
traj = h5py.File(traj_file)['particles/trajectory0/position/value']
neg_data = np.array(traj,dtype=np.float32)[np.random.choice(traj.shape[0],data_num,replace=False)]

设置模型超参数

In [5]:
dim_feature = 128
activation = 'silu'

emb = MolEmbedding(
    dim_node=dim_feature,
    emb_dis=True,
    emb_bond=False,
    cutoff=1,
    cutoff_fn='smooth',
    rbf_fn='log_gaussian',
    activation=activation,
    length_unit='nm',
)

mod = MolCT(
    dim_feature=dim_feature,
    dim_edge_emb=emb.dim_edge,
    n_interaction=3,
    n_heads=8,
    activation=activation,
)

readout = AtomwiseReadout(
    dim_output=1,
    dim_node_rep=dim_feature,
    activation=activation,
)

net = Cybertron(embedding=emb, model=mod, readout=readout, atom_type=atom_type, length_unit='nm')

打印模型情况

In [6]:
net.print_info()
tot_params = 0
for i,param in enumerate(net.trainable_params()):
    tot_params += param.size
    print(i,param.name,param.shape)
print('Total parameters: ',tot_params)

Cybertron Engine, Ride-on!
--------------------------------------------------------------------------------
    Using fixed atom type index:
       Atom 0      : 8
       Atom 1      : 6
       Atom 2      : 6
       Atom 3      : 6
       Atom 4      : 6
       Atom 5      : 6
       Atom 6      : 6
       Atom 7      : 1
       Atom 8      : 1
       Atom 9      : 1
       Atom 10     : 1
       Atom 11     : 1
       Atom 12     : 1
       Atom 13     : 1
       Atom 14     : 1
--------------------------------------------------------------------------------
    Graph Embedding: MolEmbedding
--------------------------------------------------------------------------------
       Length unit: nm
       Atom embedding size: 64
       Cutoff distance: 1.0 nm
       Cutoff function: SmoothCutoff
       Radical basis functions: LogGaussianBasis
          Minimum distance: 0.04 nm
          Maximum distance: 1.0 nm
          Reference distance: 1.0 nm
          Log Gaussian begin: -3.218876

划分数据集, 包装loss+网络和evaluation+网络

In [7]:
n_epoch = 40
repeat_time = 1
batch_size = 64
val_size = 32

idx = np.random.choice(np.arange(data_num), val_size, replace=False)
train_idx = np.setdiff1d(np.arange(data_num), idx)

ds_train_dsc = ds.NumpySlicesDataset({'pos':pos_data[train_idx],'neg':neg_data[train_idx]},shuffle=True)
ds_val_dsc = ds.NumpySlicesDataset({'pos':pos_data[idx],'neg':neg_data[idx]},shuffle=True)
ds_train_dsc = ds_train_dsc.batch(batch_size)
ds_train_dsc = ds_train_dsc.repeat(repeat_time)
ds_val_dsc = ds_val_dsc.batch(batch_size)
ds_val_dsc = ds_val_dsc.repeat(repeat_time)
loss_fn = BCELossForDiscriminator()
loss_network = WithAdversarialLossCell(net,loss_fn)
eval_network = WithAdversarialLossCell(net,loss_fn)

设置学习率和优化器, 包装训练模型

In [8]:
lr = TransformerLR(learning_rate=.3, warmup_steps=8000, dimension=128) # smaller
optim = nn.Adam(params=net.trainable_params(), learning_rate=lr)
eval_loss = 'EvalLoss'
neg_loss = 'NegLoss'
model = Model(loss_network,eval_network=eval_network, optimizer=optim,metrics={eval_loss:Loss()}) 
model.eval(ds_val_dsc,dataset_sink_mode=False)

[ERROR] CORE(3881892,7f115b98c740,python):2024-08-19-16:40:36.196.866 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_3881892/368653846.py]
[ERROR] CORE(3881892,7f115b98c740,python):2024-08-19-16:40:36.196.889 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_3881892/368653846.py]
[ERROR] CORE(3881892,7f115b98c740,python):2024-08-19-16:40:36.196.920 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_3881892/368653846.py]
[ERROR] CORE(3881892,7f115b98c740,python):2024-08-19-16:40:36.196.932 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_3881892/368653846.py]
[ERROR] CORE(3881892,7f115b98c740,python):2024-08-19-16:40:36.243.493 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_3881892/368653846.py]
[ERROR] CORE(3881892,7f115b98c740,python):2024-08-19-16:40:36.243

{'EvalLoss': 18.43325424194336}

保存网络设置, 设置模型checkpoint保存情况和模型训练情况

In [9]:
conf_dir = data_dir + '/conf'
net.save_configure('configure_discr' + '.yaml' , conf_dir)
outname = 'discr_' + net.model_name
ckpt_dir = data_dir + '/ckpt'
record_cb = TrainMonitor(model, outname, per_epoch=1, avg_steps=32, directory=ckpt_dir, eval_dataset=ds_val_dsc, best_ckpt_metrics=eval_loss)
config_ck = CheckpointConfig(save_checkpoint_steps=8, keep_checkpoint_max=8)
ckpoint_cb = ModelCheckpoint(prefix=outname, directory=ckpt_dir, config=config_ck)

开始训练

In [10]:
print("Start training ...")
beg_time = time.time()
model.train(n_epoch,ds_train_dsc,callbacks=[record_cb, ckpoint_cb],dataset_sink_mode=False)
end_time = time.time()
used_time = end_time - beg_time
m, s = divmod(used_time, 60)
h, m = divmod(m, 60)
print ("Training Fininshed!")
print ("Training Time: %02d:%02d:%02d" % (h, m, s))



Start training ...


[ERROR] CORE(3881892,7f115b98c740,python):2024-08-19-16:40:52.956.439 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_3881892/368653846.py]
[ERROR] CORE(3881892,7f115b98c740,python):2024-08-19-16:40:52.956.455 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_3881892/368653846.py]
[ERROR] CORE(3881892,7f115b98c740,python):2024-08-19-16:40:52.956.478 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_3881892/368653846.py]
[ERROR] CORE(3881892,7f115b98c740,python):2024-08-19-16:40:52.956.493 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_3881892/368653846.py]
[ERROR] CORE(3881892,7f115b98c740,python):2024-08-19-16:40:52.956.510 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_3881892/368653846.py]
[ERROR] CORE(3881892,7f115b98c740,python):2024-08-19-16:40:52.959

Epoch: 1, Step: 64, Learning_rate: 2.3346504e-06, Last_Loss: 9.793301, Avg_loss: 13.351794136895073, EvalLoss: 9.605378150939941
Epoch: 2, Step: 128, Learning_rate: 4.706359e-06, Last_Loss: 1.5816145, Avg_loss: 1.7781564129723444, EvalLoss: 1.59995436668396
Epoch: 3, Step: 192, Learning_rate: 7.078067e-06, Last_Loss: 1.3639992, Avg_loss: 1.4122117947018336, EvalLoss: 1.3273067474365234
Epoch: 4, Step: 256, Learning_rate: 9.4497755e-06, Last_Loss: 1.0592549, Avg_loss: 1.1430923825218564, EvalLoss: 1.0671653747558594
Epoch: 5, Step: 320, Learning_rate: 1.1821484e-05, Last_Loss: 0.90319586, Avg_loss: 0.9496476006886315, EvalLoss: 0.864309549331665
Epoch: 6, Step: 384, Learning_rate: 1.4193192e-05, Last_Loss: 0.7442018, Avg_loss: 0.8159498430433727, EvalLoss: 0.7321667671203613
Epoch: 7, Step: 448, Learning_rate: 1.65649e-05, Last_Loss: 0.6547539, Avg_loss: 0.7479302731771318, EvalLoss: 0.6558916568756104
Epoch: 8, Step: 512, Learning_rate: 1.8936607e-05, Last_Loss: 0.6546364, Avg_loss: 0.