## 环境准备

从github上获取DMFF，跳转到所需分支，而后安装

In [17]:
! rm -rf DMFF
! rm -rf /opt/mamba/lib/python3.10/site-packages/dmff*
! git clone https://github.com/deepmodeling/DMFF.git
! git config --global --add safe.directory `pwd`/DMFF
! cd DMFF && git checkout wangxy/v1.0.0-devel && pip install .

Cloning into 'DMFF'...
remote: Enumerating objects: 3507, done.[K
remote: Counting objects: 100% (956/956), done.[K
remote: Compressing objects: 100% (340/340), done.[K
remote: Total 3507 (delta 633), reused 912 (delta 608), pack-reused 2551[K
Receiving objects: 100% (3507/3507), 18.81 MiB | 2.17 MiB/s, done.
Resolving deltas: 100% (2243/2243), done.
Updating files: 100% (273/273), done.
Branch 'wangxy/v1.0.0-devel' set up to track remote branch 'wangxy/v1.0.0-devel' from 'origin'.
Switched to a new branch 'wangxy/v1.0.0-devel'
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Processing /data/DMFF
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: dmff
  Building wheel for dmff (setup.py) ... [?25ldone
[?25h  Created wheel for dmff: filename=dmff-0.2.1.dev222+g8efbe63-py3-none-any.whl size=93258 sha256=f20f32e539489412cad4c9a61a0e90b4c18f0c361fba252181ecc0ca6a337222
  Stored in directory: /tmp/pip-ephem-wheel-cache-cjaq0s7q/wheels/f

安装依赖库，比较耗时，需要稍等一会儿。

In [3]:
! mamba install openmm=7.7.0 rdkit -c conda-forge -y
! pip install parmed mdtraj pymbar networkx


                  __    __    __    __
                 /  \  /  \  /  \  /  \
                /    \/    \/    \/    \
███████████████/  /██/  /██/  /██/  /████████████████████████
              /  / \   / \   / \   / \  \____
             /  /   \_/   \_/   \_/   \    o \__,
            / _/                       \_____/  `
            |/
        ███╗   ███╗ █████╗ ███╗   ███╗██████╗  █████╗
        ████╗ ████║██╔══██╗████╗ ████║██╔══██╗██╔══██╗
        ██╔████╔██║███████║██╔████╔██║██████╔╝███████║
        ██║╚██╔╝██║██╔══██║██║╚██╔╝██║██╔══██╗██╔══██║
        ██║ ╚═╝ ██║██║  ██║██║ ╚═╝ ██║██████╔╝██║  ██║
        ╚═╝     ╚═╝╚═╝  ╚═╝╚═╝     ╚═╝╚═════╝ ╚═╝  ╚═╝

        mamba (0.27.0) supported by @QuantStack

        GitHub:  https://github.com/mamba-org/mamba
        Twitter: https://twitter.com/QuantStack

█████████████████████████████████████████████████████████████


Looking for: ['openmm=7.7.0', 'rdkit']

[?25l[2K[0G[+] 0.0s
[2K[1A[2K[0G[+] 0.1s
conda-forge/linux-64 [9

拷贝示例文件到根目录

In [6]:
! cp DMFF/tests/data/bond1.xml .
! cp DMFF/tests/data/bond1.pdb .

## 引入必要的库

In [1]:
from typing import Tuple
import numpy as np
import jax.numpy as jnp
import jax
from dmff.api.topology import DMFFTopology
from dmff.api.paramset import ParamSet
from dmff.api.xmlio import XMLIO
from dmff.api.hamiltonian import _DMFFGenerators
from dmff.classical.intra import HarmonicBondJaxForce
from dmff.utils import DMFFException, isinstance_jnp

2023-09-23 18:08:32.369984: W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: /usr/lib/x86_64-linux-gnu/libcuda.so.1: file too short; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-09-23 18:08:32.370046: W external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:263] failed call to cuInit: UNKNOWN ERROR (303)


## 创建Generator类

In [2]:
class HarmonicBondGenerator:
    """
    A class for generating harmonic bond force field parameters.

    Attributes:
    -----------
    name : str
        The name of the force field.
    ffinfo : dict
        The force field information.
    key_type : str
        The type of the key.
    bond_keys : list of tuple
        The keys of the bonds.
    bond_params : list of tuple
        The parameters of the bonds.
    bond_mask : list of float
        The mask of the bonds.
    _use_smarts : bool
        Whether to use SMARTS.
    """

    def __init__(self, ffinfo: dict, paramset: ParamSet):
        """
        Initializes the HarmonicBondGenerator.

        Parameters:
        -----------
        ffinfo : dict
            The force field information.
        paramset : ParamSet
            The parameter set.
        """
        self.name = "HarmonicBondForce" # 初始化Generator所关联的势函数名称
        self.ffinfo = ffinfo # 绑定这一Generator所对应的力场文件信息
        paramset.addField(self.name) # 在参数集中注册一个Field，用于存储这一势函数相关的参数。ParamSet介绍见下文。
        self.key_type = None

        bond_keys, bond_params, bond_mask = [], [], [] # 创建bond_keys, bond_params, bond_mask三个List，每个key对应着相应位置的力场参数与mask。
        for node in self.ffinfo["Forces"][self.name]["node"]:
            attribs = node["attrib"]
            
            # 判断bond term使用"type"还是"class"进行匹配。目前仅支持基于这两个属性的参数匹配，并且不允许混搭。
            if self.key_type is None and "type1" in attribs:
                self.key_type = "type"
            elif self.key_type is None and "class1" in attribs:
                self.key_type = "class"
            elif self.key_type is not None and f"{self.key_type}1" not in attribs:
                raise ValueError("Keyword 'class' or 'type' cannot be used together.")
            else:
                raise ValueError("Cannot find key type for HarmonicBondForce.")
            key = (attribs[self.key_type + "1"], attribs[self.key_type + "2"])
            bond_keys.append(key)

            k = float(attribs["k"])
            r0 = float(attribs["length"])
            bond_params.append([k, r0])

            # when the node has mask attribute, it means that the parameter is not trainable. 
            # the gradient of this parameter will be zero.
            mask = 1.0
            if "mask" in attribs and attribs["mask"].upper() == "TRUE":
                mask = 0.0
            bond_mask.append(mask)

        self.bond_keys = bond_keys
        bond_length = jnp.array([i[1] for i in bond_params])
        bond_k = jnp.array([i[0] for i in bond_params])
        bond_mask = jnp.array(bond_mask)

        # 在ParamSet中注册参数。
        # 在Generator初始化结束后，我们可以通过ParamSet调用这些参数，不经过Generator，进而保证这些参数与Generator无关。
        # 可优化的参数与函数独立存在，不构成闭包，是可微分编程正确求导的前提。
        paramset.addParameter(bond_length, "length", field=self.name, mask=bond_mask) # register parameters to ParamSet
        paramset.addParameter(bond_k, "k", field=self.name, mask=bond_mask) # register parameters to ParamSet
        
    def getName(self) -> str:
        """
        Returns the name of the force field.

        Returns:
        --------
        str
            The name of the force field.
        """
        return self.name
    
    # 根据输入的ParamSet直接修改self.ffinfo的值。
    # self.ffinfo是解析xml力场文件后得到的dict，在保持格式约定的前提下，可以直接与xml文件互转。
    # 修改self.ffinfo中参数的值，而后我们可以直接将self.ffinfo渲染成新的力场参数文件。
    # 这一函数的入参是固定的。
    def overwrite(self, paramset: ParamSet) -> None:
        """
        Overwrites the parameter set.

        Parameters:
        -----------
        paramset : ParamSet
            The parameter set.
        """
        bond_node_indices = [i for i in range(len(self.ffinfo["Forces"][self.name]["node"])) if self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Bond"]

        bond_length = paramset[self.name]["length"]
        bond_k = paramset[self.name]["k"]
        bond_msks = paramset.mask[self.name]["length"]
        for nnode, key in enumerate(self.bond_keys):
            self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode]]["attrib"] = {}
            self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode]]["attrib"][f"{self.key_type}1"] = key[0]
            self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode]]["attrib"][f"{self.key_type}2"] = key[1]
            r0 = bond_length[nnode]
            k = bond_k[nnode]
            mask = bond_msks[nnode]
            self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode]]["attrib"]["k"] = str(k)
            self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode]]["attrib"]["length"] = str(r0)
            if mask < 0.999:
                self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode]]["attrib"]["mask"] = "true"

    # 工具函数，用于查找与选定bond匹配的key的角标
    def _find_key_index(self, key: Tuple[str, str]) -> int:
        """
        Finds the index of the key.

        Parameters:
        -----------
        key : tuple of str
            The key.

        Returns:
        --------
        int
            The index of the key.
        """
        for i, k in enumerate(self.bond_keys):
            if k[0] == key[0] and k[1] == key[1]:
                return i
            if k[0] == key[1] and k[1] == key[0]:
                return i
        return None

    # 撰写方法来创建势函数。
    # 对于不同的topdata，我们所构造的势函数是不同的。
    # Generator负责基于输入的topdata构建从能量到力场参数的求导链，这使得Generator仅与力场参数相关，与各个体系的拓扑无关。
    # 这一函数的入参是固定的。
    def createPotential(self, topdata: DMFFTopology, nonbondedMethod,
                        nonbondedCutoff, args):
        """
        Creates the potential.

        Parameters:
        -----------
        topdata : DMFFTopology
            The topology data.
        nonbondedMethod : str
            The nonbonded method.
        nonbondedCutoff : float
            The nonbonded cutoff.
        args : list
            The arguments.

        Returns:
        --------
        function
            The potential function.
        """
        # 按照HarmonicBondForce的要求遍历体系中所有的bond，进行匹配
        bond_a1, bond_a2, bond_indices = [], [], []
        for bond in topdata.bonds():
            a1, a2 = bond.atom1, bond.atom2
            i1, i2 = a1.index, a2.index
            if self.key_type == "type":
                key = (a1.meta["type"], a2.meta["type"])
            elif self.key_type == "class":
                key = (a1.meta["class"], a2.meta["class"])
            idx = self._find_key_index(key)
            if idx is None:
                continue
            bond_a1.append(i1)
            bond_a2.append(i2)
            bond_indices.append(idx)
        bond_a1 = jnp.array(bond_a1)
        bond_a2 = jnp.array(bond_a2)
        bond_indices = jnp.array(bond_indices)
        
        # 创建势函数
        harmonic_bond_force = HarmonicBondJaxForce(bond_a1, bond_a2, bond_indices)
        harmonic_bond_energy = harmonic_bond_force.generate_get_energy()
        
        # 包装成统一的potential_function函数形式，传入四个参数：positions, box, pairs, parameters。
        def potential_fn(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, params: ParamSet) -> jnp.ndarray:
            isinstance_jnp(positions, box, params)
            energy = harmonic_bond_energy(positions, box, pairs, params[self.name]["k"], params[self.name]["length"])
            return energy

        self._jaxPotential = potential_fn
        return potential_fn


## 注册Generator到DMFF，与XML文件中特定Force绑定

In [3]:
# register the generator
_DMFFGenerators["HarmonicBondForce"] = HarmonicBondGenerator

## 测试

### OpenMM计算测试体系能量

In [4]:
import openmm as mm
import openmm.app as app
import openmm.unit as unit


pdb = app.PDBFile("bond1.pdb")
ff = app.ForceField("bond1.xml")
system = ff.createSystem(pdb.topology)
integ = mm.VerletIntegrator(1e-10)
context = mm.Context(system, integ)
context.setPositions(pdb.getPositions())
energy = context.getState(getEnergy=True).getPotentialEnergy()
print("OpenMM:", energy)

OpenMM: 1389.1622953572387 kJ/mol


### DMFF计算测试体系能量

In [20]:
from dmff.operators import TemplateATypeOperator

# 体系坐标。
pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
pos = jnp.array(pos)

# DMFF格式下的体系拓扑，支持直接基于openmm的topology对象进行初始化。
dmfftop = DMFFTopology(from_top=pdb.topology)

# 盒子，在这一示例中并无用处。
box = np.eye(3) * 10.0
box = jnp.array(box)

# XML力场读写工具
xmlio = XMLIO()
xmlio.loadXML("bond1.xml")
# 将xml文件解析为Dict，命名为ffinfo
ffinfo = xmlio.parseXML()

# 根据力场文件中的residue template，基于图同构，为topology中各个atom赋予atom type，存储于Atom.meta中。
tempOP = TemplateATypeOperator(ffinfo)
top_atype = tempOP(dmfftop)
for atom in top_atype.atoms():
    print("Meta data:", atom.meta)
print()
    
# 初始化ParamSet。
# ParamSet是一个PyTree类。它类似一个字典，但被限制了深度，只有两层。
# 第一层叫做Field，按照势函数名称分类，对于这个示例，就是HarmonicBondForce。
# 第二层是势函数的各个参数，在这个示例中即为length和k。
# mask也被初始化于ParamSet中，本示例中暂不展示。
paramset = ParamSet()
# 初始化Generator。
generator = HarmonicBondGenerator(ffinfo, paramset)

# 查看ParamSet。
print(paramset.parameters)
print()

# init potential
potential = generator.createPotential(top_atype, app.NoCutoff, 1.0, {})
energy = potential(pos, box, [], paramset)
print(f"DMFF: {energy} kJ/mol")

Meta data: {'element': 'N', 'external_bond': False, 'type': 'n1', 'class': 'n1'}
Meta data: {'element': 'N', 'external_bond': False, 'type': 'n2', 'class': 'n2'}

{'HarmonicBondForce': {'length': DeviceArray([0.09572], dtype=float32), 'k': DeviceArray([462750.4], dtype=float32)}}

DMFF: 1389.1622314453125 kJ/mol


### XML力场文件更新示例

In [27]:
print(">>>> Before updating <<<<")
! cat bond1.xml

paramset["HarmonicBondForce"]["length"] = paramset["HarmonicBondForce"]["length"].at[0].set(0.1)

generator.overwrite(paramset)
xmlio.writeXML("bond_update.xml", ffinfo)
print("\n\n>>>> After updating <<<<")
! cat bond_update.xml


>>>> Before updating <<<<
<ForceField>
    <AtomTypes>
        <Type element="N" name="n1" class="n1" mass="14.01"/>
        <Type element="N" name="n2" class="n2" mass="14.01"/>
    </AtomTypes>
    <Residues>
        <Residue name="LIG">
            <Atom name="N1" type="n1"/>
            <Atom name="N2" type="n2"/>
            <Bond atomName1="N1" atomName2="N2"/>
        </Residue>
    </Residues>
    <HarmonicBondForce>
        <Bond type1="n1" type2="n2" length="0.09572" k="462750.4"/>
    </HarmonicBondForce>
</ForceField>

>>>> After updating <<<<
<?xml version="1.0" ?>
<ForceField>
   <Operators/>
   <AtomTypes>
      <Type element="N" name="n1" class="n1" mass="14.01"/>
      <Type element="N" name="n2" class="n2" mass="14.01"/>
   </AtomTypes>
   <Residues>
      <Residue name="LIG">
         <Atom name="N1" type="n1"/>
         <Atom name="N2" type="n2"/>
         <Bond atomName1="N1" atomName2="N2"/>
      </Residue>
   </Residues>
   <HarmonicBondForce>
      <Bond type1=