From 35b5131e2f8b670fdddd052f1468c186aebe720e Mon Sep 17 00:00:00 2001 From: Yingze Wang Date: Sun, 5 Jun 2022 04:12:34 +0800 Subject: [PATCH 01/17] add(CI/CD): unittest workflows --- .github/workflows/ut.yml | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/.github/workflows/ut.yml b/.github/workflows/ut.yml index e69de29bb..e219fdcad 100644 --- a/.github/workflows/ut.yml +++ b/.github/workflows/ut.yml @@ -0,0 +1,34 @@ +name: DMFF's python tests. + +on: + push: + pull_request: + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.8] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install Dependencies + run: | + source $CONDA/bin/activate + $CONDA/bin/conda update -n base -c defaults conda + conda install pip + conda update pip + conda install numpy openmm pytest -c conda-forge + pip install jax jax_md + - name: Install DMFF + run: | + source $CONDA/bin/activate + pip install . + - name: Run Tests + run: | + source $CONDA/bin/activate + pytest -vs tests/ From 4af913cdad7d713fba1065df552464d70daf38ef Mon Sep 17 00:00:00 2001 From: Yingze Wang Date: Sun, 5 Jun 2022 04:12:50 +0800 Subject: [PATCH 02/17] add(requirements): dependencies list --- requirements.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/requirements.txt b/requirements.txt index e69de29bb..7360452fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -0,0 +1,3 @@ +numpy>=1.18 +jax>=0.3.7 +jax-md>=0.1.28 From f5eb2d56ea6706edc6761f5bcf34c76b51aa9c1f Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Sun, 5 Jun 2022 11:05:04 +0800 Subject: [PATCH 03/17] add test_utils.py --- tests/test_utils.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 tests/test_utils.py diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 000000000..699a0ba2e --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,32 @@ +import jax.numpy as jnp +import pytest +from dmff.utils import regularize_pairs, pair_buffer_scales +import numpy.testing as npt + +class TestUtils: + + @pytest.fixture(scope='class', name='pairs') + def test_init_pair(self): + + pairs = jnp.array([ + [0, 0], + [1, 0], + [0, 1], + [1, 1] + ], dtype=int) + yield pairs + + def test_regularize_pairs(self, pairs): + + ans = regularize_pairs(pairs) + npt.assert_array_equal(ans, jnp.array([ + [-1, -2], + [0, -2], + [0, 1], + [0, -1] + ])) + + def test_pair_buffer_scales(self, pairs): + + ans = pair_buffer_scales(pairs) + npt.assert_array_equal(ans, jnp.array([0, 0, 1, 0])) \ No newline at end of file From fa27e1d4328cd3a936f5149147be486e4957d535 Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Sun, 5 Jun 2022 11:10:49 +0800 Subject: [PATCH 04/17] fix: modified unit test results --- tests/test_admp/test_sptial.py | 4 ++-- tests/test_common/test_nblist.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_admp/test_sptial.py b/tests/test_admp/test_sptial.py index 055f179ae..f0a3e93e9 100644 --- a/tests/test_admp/test_sptial.py +++ b/tests/test_admp/test_sptial.py @@ -38,7 +38,7 @@ class TestSpatial: def test_build_quasi_internal(self, r1, r2, dr, norm_dr, expected): local_frames = build_quasi_internal(r1, r2, dr, norm_dr) - npt.assert_allclose(local_frames, expected) + npt.assert_allclose(local_frames, expected, rtol=1e-5) @pytest.mark.parametrize( "drvecs, box, box_inv, expected", @@ -138,6 +138,6 @@ def test_generate_construct_local_frames( ) assert construct_local_frame_fn npt.assert_allclose( - construct_local_frame_fn(positions, box), expected_local_frames, rtol=1e-6 + construct_local_frame_fn(positions, box), expected_local_frames, rtol=1e-5 ) diff --git a/tests/test_common/test_nblist.py b/tests/test_common/test_nblist.py index a8a04724b..3da3eac86 100644 --- a/tests/test_common/test_nblist.py +++ b/tests/test_common/test_nblist.py @@ -33,18 +33,18 @@ def test_update(self, nblist): def test_pairs(self, nblist): pairs = nblist.pairs - assert pairs.shape == (15, 2) + assert pairs.shape == (18, 2) def test_pair_mask(self, nblist): pair, mask = nblist.pair_mask - assert mask.shape == (15, ) + assert mask.shape == (18, ) def test_dr(self, nblist): dr = nblist.dr - assert dr.shape == (15, 3) + assert dr.shape == (18, 3) def test_distance(self, nblist): - assert nblist.distance.shape == (15, ) + assert nblist.distance.shape == (18, ) From ba16e3bf99e64ea8a5f75d5a23ded480188fc626 Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Sun, 5 Jun 2022 12:30:11 +0800 Subject: [PATCH 05/17] add `r` to avoid latex being recognized as an escape character --- dmff/admp/pairwise.py | 2 +- dmff/admp/pme.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dmff/admp/pairwise.py b/dmff/admp/pairwise.py index d8b7b80d2..1cb396a4c 100755 --- a/dmff/admp/pairwise.py +++ b/dmff/admp/pairwise.py @@ -130,7 +130,7 @@ def TT_damping_qq_kernel(dr, m, bi, bj, qi, qj): @vmap @jit_condition(static_argnums=()) def slater_disp_damping_kernel(dr, m, bi, bj, c6i, c6j, c8i, c8j, c10i, c10j): - ''' + r''' Slater-ISA type damping for dispersion: f(x) = -e^{-x} * \sum_{k} x^k/k! x = Br - \frac{2*(Br)^2 + 3Br}{(Br)^2 + 3*Br + 3} diff --git a/dmff/admp/pme.py b/dmff/admp/pme.py index 2cde47a1c..06ab88c16 100755 --- a/dmff/admp/pme.py +++ b/dmff/admp/pme.py @@ -332,7 +332,7 @@ def energy_pme(positions, box, pairs, @jit_condition(static_argnums=(3)) def calc_e_perm(dr, mscales, kappa, lmax=2): - ''' + r''' This function calculates the ePermCoefs at once ePermCoefs is basically the interaction tensor between permanent multipole components Everything should be done in the so called quasi-internal (qi) frame @@ -453,7 +453,7 @@ def trim_val_infty(x): @jit_condition(static_argnums=(7)) def calc_e_ind(dr, thole1, thole2, dmp, pscales, dscales, kappa, lmax=2): - ''' + r''' This function calculates the eUindCoefs at once ## compute the Thole damping factors for energies eUindCoefs is basically the interaction tensor between permanent multipole components and induced dipoles From 57b65d498732e62b945965bc09be9a6264a4eeca Mon Sep 17 00:00:00 2001 From: Yingze Wang Date: Sun, 5 Jun 2022 16:24:17 +0800 Subject: [PATCH 06/17] fix(ut): wrong number in test_nblist --- tests/test_common/test_nblist.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_common/test_nblist.py b/tests/test_common/test_nblist.py index 3da3eac86..a8a04724b 100644 --- a/tests/test_common/test_nblist.py +++ b/tests/test_common/test_nblist.py @@ -33,18 +33,18 @@ def test_update(self, nblist): def test_pairs(self, nblist): pairs = nblist.pairs - assert pairs.shape == (18, 2) + assert pairs.shape == (15, 2) def test_pair_mask(self, nblist): pair, mask = nblist.pair_mask - assert mask.shape == (18, ) + assert mask.shape == (15, ) def test_dr(self, nblist): dr = nblist.dr - assert dr.shape == (18, 3) + assert dr.shape == (15, 3) def test_distance(self, nblist): - assert nblist.distance.shape == (18, ) + assert nblist.distance.shape == (15, ) From 2b04192e37b3bb7ade94dc8309028fb079364341 Mon Sep 17 00:00:00 2001 From: Yingze Wang Date: Sun, 5 Jun 2022 16:35:48 +0800 Subject: [PATCH 07/17] refine(ut): code prettify in test_nblist --- tests/test_common/test_nblist.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/tests/test_common/test_nblist.py b/tests/test_common/test_nblist.py index a8a04724b..dd00f6d5b 100644 --- a/tests/test_common/test_nblist.py +++ b/tests/test_common/test_nblist.py @@ -6,13 +6,14 @@ class TestNeighborList: @pytest.fixture(scope="class", name='nblist') def test_nblist_init(self): - positions = jnp.array([[12.434, 3.404, 1.540], - [13.030, 2.664, 1.322], - [12.312, 3.814, 0.660], - [14.216, 1.424, 1.103], - [14.246, 1.144, 2.054], - [15.155, 1.542, 0.910]]) - + positions = jnp.array([ + [12.434, 3.404, 1.540], + [13.030, 2.664, 1.322], + [12.312, 3.814, 0.660], + [14.216, 1.424, 1.103], + [14.246, 1.144, 2.054], + [15.155, 1.542, 0.910] + ]) box = jnp.array([31.289, 31.289, 31.289]) r_cutoff = 4.0 nbobj = NeighborList(box, r_cutoff) @@ -21,13 +22,14 @@ def test_nblist_init(self): def test_update(self, nblist): - positions = jnp.array([[12.434, 3.404, 1.540], - [13.030, 2.664, 1.322], - [12.312, 3.814, 0.660], - [14.216, 1.424, 1.103], - [14.246, 1.144, 2.054], - [15.155, 1.542, 0.910]]) - + positions = jnp.array([ + [12.434, 3.404, 1.540], + [13.030, 2.664, 1.322], + [12.312, 3.814, 0.660], + [14.216, 1.424, 1.103], + [14.246, 1.144, 2.054], + [15.155, 1.542, 0.910] + ]) nblist.update(positions) def test_pairs(self, nblist): From 234d1dd45998abe29c4ffc507743a1a502cd72b1 Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Sun, 5 Jun 2022 21:43:22 +0800 Subject: [PATCH 08/17] Chore: clean admp module up --- dmff/admp/mbpol_intra.py | 44 ---- dmff/admp/multipole.py | 5 +- dmff/admp/pairwise.py | 103 --------- dmff/admp/parser.py | 477 --------------------------------------- dmff/admp/pme.py | 83 ------- dmff/admp/recip.py | 90 -------- 6 files changed, 2 insertions(+), 800 deletions(-) delete mode 100644 dmff/admp/parser.py diff --git a/dmff/admp/mbpol_intra.py b/dmff/admp/mbpol_intra.py index 7b1ffff56..dc69f32e3 100755 --- a/dmff/admp/mbpol_intra.py +++ b/dmff/admp/mbpol_intra.py @@ -6,12 +6,7 @@ from dmff.settings import DO_JIT from dmff.utils import jit_condition from dmff.admp.spatial import v_pbc_shift -from dmff.admp.pme import ADMPPmeForce -from dmff.admp.parser import * from jax import vmap -import time -#from admp.multipole import convert_cart2harm -#from jax_md import partition, space #const f5z = 0.999677885 @@ -488,42 +483,3 @@ def onebody_kernel(x1, x2, x3, Va, Vb, efac): e1 *= cm1_kcalmol e1 *= cal2joule # conver cal 2 j return e1 - - -def validation(pdb): - xml = 'mpidwater.xml' - pdbinfo = read_pdb(pdb) - serials = pdbinfo['serials'] - names = pdbinfo['names'] - resNames = pdbinfo['resNames'] - resSeqs = pdbinfo['resSeqs'] - positions = pdbinfo['positions'] - box = pdbinfo['box'] # a, b, c, α, β, γ - charges = pdbinfo['charges'] - positions = jnp.asarray(positions) - lx, ly, lz, _, _, _ = box - box = jnp.eye(3)*jnp.array([lx, ly, lz]) - - mScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0]) - pScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0]) - dScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0]) - - rc = 4 # in Angstrom - ethresh = 1e-4 - - n_atoms = len(serials) - - # compute intra - - - grad_E1 = value_and_grad(onebodyenergy,argnums=(0)) - ene, force = grad_E1(positions, box) - print(ene,force) - return - - -# below is the validation code -if __name__ == '__main__': - validation(sys.argv[1]) - - diff --git a/dmff/admp/multipole.py b/dmff/admp/multipole.py index 7f957a823..64085e83d 100644 --- a/dmff/admp/multipole.py +++ b/dmff/admp/multipole.py @@ -1,4 +1,3 @@ -import sys import jax.numpy as jnp from jax import vmap from dmff.utils import jit_condition @@ -48,7 +47,7 @@ def convert_cart2harm(Theta, lmax): n * (l+1)^2, stores the spherical multipoles ''' if lmax > 2: - sys.exit('l > 2 (beyond quadrupole) not supported') + raise ValueError('l > 2 (beyond quadrupole) not supported') Q_mono = Theta[0:1] @@ -90,7 +89,7 @@ def convert_harm2cart(Q, lmax): ''' if lmax > 2: - sys.exit('l > 2 (beyond quadrupole) not supported') + raise ValueError('l > 2 (beyond quadrupole) not supported') T_mono = Q[0:1] diff --git a/dmff/admp/pairwise.py b/dmff/admp/pairwise.py index 1cb396a4c..8a3431a6f 100755 --- a/dmff/admp/pairwise.py +++ b/dmff/admp/pairwise.py @@ -167,106 +167,3 @@ def slater_sr_kernel(dr, m, ai, aj, bi, bj): P = 1/3 * br2 + br + 1 return a * P * jnp.exp(-br) * m - -def validation(pdb): - xml = 'mpidwater.xml' - pdbinfo = read_pdb(pdb) - serials = pdbinfo['serials'] - names = pdbinfo['names'] - resNames = pdbinfo['resNames'] - resSeqs = pdbinfo['resSeqs'] - positions = pdbinfo['positions'] - box = pdbinfo['box'] # a, b, c, α, β, γ - charges = pdbinfo['charges'] - positions = jnp.asarray(positions) - lx, ly, lz, _, _, _ = box - box = jnp.eye(3)*jnp.array([lx, ly, lz]) - - mScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0]) - pScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0]) - dScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0]) - - rc = 4 # in Angstrom - ethresh = 1e-4 - - n_atoms = len(serials) - - atomTemplate, residueTemplate = read_xml(xml) - atomDicts, residueDicts = init_residues(serials, names, resNames, resSeqs, positions, charges, atomTemplate, residueTemplate) - - covalent_map = assemble_covalent(residueDicts, n_atoms) - displacement_fn, shift_fn = space.periodic_general(box, fractional_coordinates=False) - neighbor_list_fn = partition.neighbor_list(displacement_fn, box, rc, 0, format=partition.OrderedSparse) - nbr = neighbor_list_fn.allocate(positions) - pairs = nbr.idx.T - - pmax = 10 - kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box) - kappa = 0.657065221219616 - - # construct the C list - c_list = np.zeros((3, n_atoms)) - a_list = np.zeros(n_atoms) - q_list = np.zeros(n_atoms) - b_list = np.zeros(n_atoms) - nmol=int(n_atoms/3) - for i in range(nmol): - a = i*3 - b = i*3+1 - c = i*3+2 - # dispersion coeff - c_list[0][a]=37.199677405 - c_list[0][b]=7.6111103 - c_list[0][c]=7.6111103 - c_list[1][a]=85.26810658 - c_list[1][b]=11.90220148 - c_list[1][c]=11.90220148 - c_list[2][a]=134.44874488 - c_list[2][b]=15.05074749 - c_list[2][c]=15.05074749 - # q - q_list[a] = -0.741706 - q_list[b] = 0.370853 - q_list[c] = 0.370853 - # b, Bohr^-1 - b_list[a] = 2.00095977 - b_list[b] = 1.999519942 - b_list[c] = 1.999519942 - # a, Hartree - a_list[a] = 458.3777 - a_list[b] = 0.0317 - a_list[c] = 0.0317 - - - c_list = jnp.array(c_list) - -# @partial(vmap, in_axes=(0, 0, 0, 0), out_axes=(0)) -# @jit_condition(static_argnums=()) -# def disp6_pme_real_kernel(dr, m, ci, cj): -# # unpack static arguments -# kappa = static_args['kappa'] -# # calculate distance -# dr2 = dr ** 2 -# dr6 = dr2 ** 3 -# # do calculation -# x2 = kappa**2 * dr2 -# exp_x2 = jnp.exp(-x2) -# x4 = x2 * x2 -# g = (1 + x2 + 0.5*x4) * exp_x2 -# return (m + g - 1) * ci * cj / dr6 - -# static_args = {'kappa': kappa} -# disp6_pme_real = generate_pairwise_interaction(disp6_pme_real_kernel, covalent_map, static_args) -# print(disp6_pme_real(positions, box, pairs, mScales, c_list[0, :])) - - TT_damping_qq_c6 = generate_pairwise_interaction(TT_damping_qq_c6_kernel, covalent_map, static_args={}) - - TT_damping_qq_c6(positions, box, pairs, mScales, a_list, b_list, q_list, c_list[0]) - print('ok') - print(TT_damping_qq_c6(positions, box, pairs, mScales, a_list, b_list, q_list, c_list[0])) - return - - -# below is the validation code -if __name__ == '__main__': - validation(sys.argv[1]) diff --git a/dmff/admp/parser.py b/dmff/admp/parser.py deleted file mode 100644 index 9ba07239b..000000000 --- a/dmff/admp/parser.py +++ /dev/null @@ -1,477 +0,0 @@ -# jichen: deprecated -from xml.dom import minidom -import numpy as np -import warnings -from collections import defaultdict - -def read_atom_line(line_full): - """ - Read atom line from pdb format - HETATM 1 H14 ORTE 0 6.301 0.693 1.919 1.00 0.00 H - - 1-6 7-11 13-16 17 18-20 22 23-26 27 28-30 31-38 39-46 47-54 55-60 61-66 67-72 73-76 77-78 79-80 - ATOM serial name altLoc resName chainID resSeq iCode _ x y z occupancy tempFactor _ segID element charge - """ - - line = line_full.rstrip("\n") - type_atm = line[0:6] - if type_atm == "ATOM " or type_atm == "HETATM": - - # Roy - serial = line[7:12].strip() - - name = line[12:16].strip() - - altLoc = line[16] - resName = line[17:21] - chainID = line[21] # Not used - - resSeq = int(line[22:26].split()[0]) # sequence identifier - iCode = line[26] # insertion code, not used - - # atomic coordinates - try: - coord = np.array( - [float(line[30:38]), float(line[38:46]), float(line[46:54])], - dtype=np.float64, - ) - except ValueError: - raise ValueError("Invalid or missing coordinate(s)") - - # occupancy & B factor - try: - occupancy = float(line[54:60]) - except ValueError: - occupancy = None # Rather than arbitrary zero or one - - if occupancy is not None and occupancy < 0: - warnings.warn("Negative occupancy in one or more atoms") - - try: - bfactor = float(line[60:66]) - except ValueError: - # The PDB use a default of zero if the data is missing - bfactor = 0.0 - - segid = line[72:76] # not used - element = line[76:78].strip().upper() - charge = line[79:81] - - else: - raise ValueError("Only ATOM and HETATM supported") - - return ( - type_atm, - serial, - name, - altLoc, - resName.strip(), - chainID, - resSeq, - iCode, - coord, - occupancy, - bfactor, - segid, - element, - charge, - ) - -def read_pdb(file): - """Read PDB files.""" - fileobj = open(file, 'r') - orig = np.identity(3) - trans = np.zeros(3) - serials = [] - names = [] - altLocs = [] - resNames = [] - chainIDs = [] - resSeqs = [] - iCodes = [] - positions = [] - occupancies = [] - tempFactors = [] - segId = [] - elements = [] - charges = [] - cell = None - pbc = None - cellpar = [] - conects = {} - # make sure that only one frame is read - continue_read_atoms_flag = True - # serial starts at 1 and we need to discard it and just keep align with positions - id = 0 - - for line in fileobj.readlines(): - if line.startswith('CRYST1'): - cellpar = [float(line[6:15]), # a - float(line[15:24]), # b - float(line[24:33]), # c - float(line[33:40]), # alpha - float(line[40:47]), # beta - float(line[47:54])] # gamma - - for c in range(3): - if line.startswith('ORIGX' + '123'[c]): - orig[c] = [float(line[10:20]), - float(line[20:30]), - float(line[30:40])] - trans[c] = float(line[45:55]) - - if ( - line.startswith("ATOM") - or line.startswith("HETATM") - and continue_read_atoms_flag - ): - # Atom name is arbitrary and does not necessarily - # contain the element symbol. The specification - # requires the element symbol to be in columns 77+78. - # Fall back to Atom name for files that do not follow - # the spec, e.g. packmol. - - # line_info = type_atm, serial, name, altLoc, resName, chainID, resSeq, iCode, coord, occupancy, tempFactor, segid, element, charge - line_info = read_atom_line(line) - - # serials.append(int(line_info[1])) - serials.append(id) - id += 1 - names.append(line_info[2]) - resNames.append(line_info[4]) - resSeqs.append(line_info[6]) - position = np.dot(orig, line_info[8]) + trans - positions.append(position) - if line_info[9] is not None: - occupancies.append(line_info[9]) - tempFactors.append(line_info[10]) - elements.append(line_info[-2]) - charges.append(line_info[-1] or 0) - - if line.startswith("END"): - # End of configuration reached - # According to the latest PDB file format (v3.30), - # this line should start with 'ENDMDL' (not 'END'), - # but in this way PDB trajectories from e.g. CP2K - # are supported (also VMD supports this format). - continue_read_atoms_flag = False - pass - - if line.startswith("CONECT"): - l = line.split() - center_atom_idx = int(l[1]) - bonded_atom_idx = [int(i) for i in l[2:]] - - conects[center_atom_idx] = bonded_atom_idx - fileobj.close() - - return {'serials': serials, - 'names': names, - 'resNames': resNames, - 'resSeqs': resSeqs, - 'positions': np.vstack(positions), - 'charges': charges, - 'connects': conects, - 'box': cellpar} - -def set_axis_type(atoms): - ZThenX = 0 - Bisector = 1 - ZBisect = 2 - ThreeFold = 3 - Zonly = 4 - NoAxisType = 5 - LastAxisTypeIndex = 6 - kStrings = ['kz', 'kx', 'ky'] - - for atom in atoms: - atomType = atom['type'] - kIndices = [atomType] - - for kString in kStrings: - - if kString in atom and atom[kString] != '': - kIndices.append(atom[kString]) - atom['axis_indices'] = kIndices - - # set axis type - - kIndicesLen = len(kIndices) - - if (kIndicesLen > 3): - ky = kIndices[3] - kyNegative = False - if ky.startswith('-'): - ky = kIndices[3] = ky[1:] - kyNegative = True - else: - ky = "" - - if (kIndicesLen > 2): - kx = kIndices[2] - kxNegative = False - if kx.startswith('-'): - kx = kIndices[2] = kx[1:] - kxNegative = True - else: - kx = "" - - if (kIndicesLen > 1): - kz = kIndices[1] - kzNegative = False - if kz.startswith('-'): - kz = kIndices[1] = kz[1:] - kzNegative = True - else: - kz = "" - - while(len(kIndices) < 4): - kIndices.append("") - - axisType = ZThenX - if (not kz): - axisType = NoAxisType - if (kz and not kx): - axisType = Zonly - if (kz and kzNegative or kx and kxNegative): - axisType = Bisector - if (kx and kxNegative and ky and kyNegative): - axisType = ZBisect - if (kz and kzNegative and kx and kxNegative and ky and kyNegative): - axisType = ThreeFold - - atom['axisType'] = axisType - - return atoms - -def read_xml(fileobj): - - fileobj = minidom.parse(fileobj) - - multipoles = fileobj.getElementsByTagName("Multipole") - - residueTemplates = [] - atomTemplates = [] - - for r in fileobj.getElementsByTagName('Residue'): - - resName = r.getAttribute("name") - residueTemplate = {'resName': resName, 'atoms': [], } - - - for a in r.getElementsByTagName('Atom'): - atomName = a.getAttribute('name') - atomType = a.getAttribute('type') - atomTemplate = {'name': atomName, 'type': atomType} - - residueTemplate['atoms'].append(atomTemplate) - atomTemplates.append(atomTemplate) - - topo = defaultdict(list) - for b in r.getElementsByTagName('Bond'): - - from_ = b.getAttribute('from') - to_ = b.getAttribute('to') - topo[from_].append(to_) - # topo[to_].append(from_) - - residueTemplate['topo'] = dict(topo) - residueTemplates.append(residueTemplate) - - for i, multipole in enumerate(multipoles): - - multiDict = { - "c0": float(multipole.getAttribute("c0")), - "dX": float(multipole.getAttribute("dX")), - "dY": float(multipole.getAttribute("dY")), - "dZ": float(multipole.getAttribute("dZ")), - "qXX": float(multipole.getAttribute("qXX")), - "qXY": float(multipole.getAttribute("qXY")), - "qYY": float(multipole.getAttribute("qYY")), - "qXZ": float(multipole.getAttribute("qXZ")), - "qYZ": float(multipole.getAttribute("qYZ")), - "qZZ": float(multipole.getAttribute("qZZ")), - "oXXX": float(multipole.getAttribute("oXXX")), - "oXXY": float(multipole.getAttribute("oXXY")), - "oXYY": float(multipole.getAttribute("oXYY")), - "oYYY": float(multipole.getAttribute("oYYY")), - "oXXZ": float(multipole.getAttribute("oXXZ")), - "oXYZ": float(multipole.getAttribute("oXYZ")), - "oYYZ": float(multipole.getAttribute("oYYZ")), - "oXZZ": float(multipole.getAttribute("oXZZ")), - "oYZZ": float(multipole.getAttribute("oYZZ")), - "oZZZ": float(multipole.getAttribute("oZZZ")), - "kx": multipole.getAttribute("kx"), - "kz": multipole.getAttribute("kz"), - "ky": multipole.getAttribute("ky") - } - - for template in atomTemplates: - if template['type'] == multipole.getAttribute("type"): - template.update(multiDict) - - - for p in fileobj.getElementsByTagName('Polarize'): - - pxx = p.getAttribute('polarizabilityXX') - pyy = p.getAttribute('polarizabilityYY') - pzz = p.getAttribute('polarizabilityZZ') - thole = p.getAttribute('thole') - polarDict = {'polarizabilityXX': pxx, 'polarizabilityYY': pyy, 'polarizabilityZZ':pzz, 'thole': thole} - - for template in atomTemplates: - if template['type'] == p.getAttribute('type'): - template.update(polarDict) - - set_axis_type(atomTemplates) - - return atomTemplates, residueTemplates - -class Atom: - - def __init__(self, serial, name, resName, resSeq, position, charge, ) -> None: - self.serial = serial - self.name = name - self.position = position - self.charge = charge - self.resName = resName - self.charge = charge - self.linkAtom = [] - self.resSeq = resSeq - - def __eq__(self, o): - return o.serial == self.serial - - def link(self, atom): - if atom not in self.linkAtom: - self.linkAtom.append(atom) - if self not in atom.linkAtom: - atom.linkAtom.append(self) - - def __repr__(self) -> str: - return f'< Atom{self.serial}: {self.name} >' - -class Residue: - - def __init__(self, resName, resSeq) -> None: - self.resName = resName - self.resSeq = resSeq - self.atoms = {} - self.topo = [] - self.covalent_map = {} - - def add(self, serial, atom): - self.atoms[serial] = atom - - def __next__(self): - return next(self.atoms) - - def __getitem__(self, name): - for atom in self.atoms.values(): - if atom.name == name: - return atom - - def __repr__(self) -> str: - return f'< Residue{self.resSeq}: {self.resName} >' - -def init_residues(serials, names, resNames, resSeqs, positions, charges, atomTemplates, residueTemplates): - - residueDicts = {} - atomDicts = {} - - for name, seq in zip(resNames, resSeqs): - if seq not in residueDicts: - residueDicts[seq] = Residue(name, seq) - - - # build up residue - for serial, name, resName, resSeq, position, charge in zip(serials, names, resNames, resSeqs, positions, charges): - - atom = Atom(serial, name, resName, resSeq, position, charge) - - for a in atomTemplates: - if name == a['name']: - for k, v in a.items(): - setattr(atom, k, v) - atomDicts[serial] = atom - - residueDicts[resSeq].add(atom.serial, atom) - - - # build up topo - for residue in residueDicts.values(): - - for residueTemplate in residueTemplates: - if residueTemplate['resName'] == residue.resName: - template = residueTemplate - - for c, p in template['topo'].items(): - ctemp = template['atoms'][int(c)] - catom = residue[ctemp['name']] - - for pp in p: - ptemp = template['atoms'][int(pp)] - patom = residue[ptemp['name']] - catom.link(patom) - - # build up axis indices - for residue in residueDicts.values(): - - for atom in residue.atoms.values(): - indices = [index if index != '' else -1 for index in atom.axis_indices[1: ]] - - for patom in residue.atoms.values(): - if patom.serial == atom.serial: - continue - for i in range(len(indices)): - if indices[i] == patom.type: - indices[i] = patom.serial - break - - atom.axis_indices = indices - - - # build up covalent map in residue - for i in residue.atoms.values(): - visited = [i.serial] - residue.covalent_map[i.serial] = {} - for j in i.linkAtom: - residue.covalent_map[i.serial][j.serial] = 1 - visited.append(j.serial) - for k in j.linkAtom: - if k.serial not in visited: - residue.covalent_map[i.serial][k.serial] = 2 - visited.append(k.serial) - else: - continue - for l in k.linkAtom: - if l.serial not in visited: - residue.covalent_map[i.serial][l.serial] = 3 - visited.append(l.serial) - else: - continue - for m in l.linkAtom: - if m.serial not in visited: - residue.covalent_map[i.serial][m.serial] = 4 - visited.append(m.serial) - else: - continue - - return atomDicts, residueDicts - -def assemble_covalent(residueDicts, natoms): - - covalents = [c.covalent_map for c in residueDicts.values()] - - covalent_map = np.zeros((natoms, natoms), dtype=int) - - for covalent in covalents: - - for c, p in covalent.items(): - - for pp, dr in p.items(): - - covalent_map[c][pp] = dr - - return covalent_map - \ No newline at end of file diff --git a/dmff/admp/pme.py b/dmff/admp/pme.py index 06ab88c16..23b3a7d6c 100755 --- a/dmff/admp/pme.py +++ b/dmff/admp/pme.py @@ -18,13 +18,6 @@ from dmff.admp.recip import generate_pme_recip, Ck_1 -# for debugging use only -# from jax_md import partition, space -# from admp.parser import * - -# from jax.config import config -# config.update("jax_enable_x64", True) - # Functions that are related to electrostatic pme class ADMPPmeForce: @@ -853,79 +846,3 @@ def pol_penalty(U_ind, pol): pol_pi = trim_val_0(pol) # pol_pi = pol/(jnp.exp((-pol+1e-08)*1e10)+1) + 1e-08/(jnp.exp((pol-1e-08)*1e10)+1) return jnp.sum(0.5/pol_pi*(U_ind**2).T) * DIELECTRIC - - -def validation(pdb): - xml = 'mpidwater.xml' - pdbinfo = read_pdb(pdb) - serials = pdbinfo['serials'] - names = pdbinfo['names'] - resNames = pdbinfo['resNames'] - resSeqs = pdbinfo['resSeqs'] - positions = pdbinfo['positions'] - box = pdbinfo['box'] # a, b, c, α, β, γ - charges = pdbinfo['charges'] - positions = jnp.asarray(positions) - lx, ly, lz, _, _, _ = box - box = jnp.eye(3)*jnp.array([lx, ly, lz]) - - mScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0]) - pScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0]) - dScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0]) - - rc = 4 # in Angstrom - ethresh = 1e-4 - - n_atoms = len(serials) - - atomTemplate, residueTemplate = read_xml(xml) - atomDicts, residueDicts = init_residues(serials, names, resNames, resSeqs, positions, charges, atomTemplate, residueTemplate) - - Q = np.vstack( - [(atom.c0, atom.dX*10, atom.dY*10, atom.dZ*10, atom.qXX*300, atom.qYY*300, atom.qZZ*300, atom.qXY*300, atom.qXZ*300, atom.qYZ*300) for atom in atomDicts.values()] - ) - Q = jnp.array(Q) - Q_local = convert_cart2harm(Q, 2) - axis_type = np.array( - [atom.axisType for atom in atomDicts.values()] - ) - axis_indices = np.vstack( - [atom.axis_indices for atom in atomDicts.values()] - ) - covalent_map = assemble_covalent(residueDicts, n_atoms) - - - displacement_fn, shift_fn = space.periodic_general(box, fractional_coordinates=False) - neighbor_list_fn = partition.neighbor_list(displacement_fn, box, rc, 0, format=partition.OrderedSparse) - nbr = neighbor_list_fn.allocate(positions) - pairs = nbr.idx.T - # pairs = pairs[pairs[:, 0] < pairs[:, 1]] - - lmax = 2 - - - # Finish data preparation - # ------------------------------------------------------------------------------------- - # kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box) - # # for debugging - # kappa = 0.657065221219616 - # construct_local_frames_fn = generate_construct_local_frames(axis_type, axis_indices) - # energy_force_pme = value_and_grad(energy_pme) - # e, f = energy_force_pme(positions, box, pairs, Q_local, mScales, pScales, dScales, covalent_map, construct_local_frames_fn, kappa, K1, K2, K3, lmax) - # print('ok') - # e, f = energy_force_pme(positions, box, pairs, Q_local, mScales, pScales, dScales, covalent_map, construct_local_frames_fn, kappa, K1, K2, K3, lmax) - # print(e) - - pme_force = ADMPPmeForce(box, axis_type, axis_indices, covalent_map, rc, ethresh, lmax) - pme_force.update_env('kappa', 0.657065221219616) - - E, F = pme_force.get_forces(positions, box, pairs, Q_local, mScales) - print('ok') - E, F = pme_force.get_forces(positions, box, pairs, Q_local, mScales) - print(E) - return - - -# below is the validation code -if __name__ == '__main__': - validation(sys.argv[1]) diff --git a/dmff/admp/recip.py b/dmff/admp/recip.py index d0932586f..ed18eff0e 100755 --- a/dmff/admp/recip.py +++ b/dmff/admp/recip.py @@ -458,93 +458,3 @@ def Ck_10(ksq, kappa, V): exp_x2 = jnp.exp(-x2) f = (15 - 6*x2 + 4*x4 - 8*x6)*exp_x2 + 8*x7*sqrt_pi*jsp.special.erfc(x) return sqrt_pi*jnp.pi/2/V*kappa**7 * f / 1260 - - -# def validation(pdb): -# jnp.set_printoptions(precision=32, suppress=True) -# xml = 'mpidwater.xml' -# pdbinfo = read_pdb(pdb) -# serials = pdbinfo['serials'] -# names = pdbinfo['names'] -# resNames = pdbinfo['resNames'] -# resSeqs = pdbinfo['resSeqs'] -# positions = pdbinfo['positions'] -# box = pdbinfo['box'] # a, b, c, α, β, γ -# charges = pdbinfo['charges'] -# positions = jnp.asarray(positions) -# lx, ly, lz, _, _, _ = box -# box = jnp.eye(3)*jnp.array([lx, ly, lz]) - -# rc = 4 # in Angstrom -# ethresh = 1e-4 - -# n_atoms = len(serials) - -# atomTemplate, residueTemplate = read_xml(xml) -# atomDicts, residueDicts = init_residues(serials, names, resNames, resSeqs, positions, charges, atomTemplate, residueTemplate) - -# Q = np.vstack( -# [(atom.c0, atom.dX*10, atom.dY*10, atom.dZ*10, atom.qXX*300, atom.qYY*300, atom.qZZ*300, atom.qXY*300, atom.qXZ*300, atom.qYZ*300) for atom in atomDicts.values()] -# ) -# Q = jnp.array(Q) -# Q_local = convert_cart2harm(Q, 2) -# axis_type = np.array( -# [atom.axisType for atom in atomDicts.values()] -# ) -# axis_indices = np.vstack( -# [atom.axis_indices for atom in atomDicts.values()] -# ) -# covalent_map = assemble_covalent(residueDicts, n_atoms) - - -# # invoke pme energy calculator -# # energy_pme(positions, box, Q_local, axis_type, axis_indices, nbr.idx, 2) -# lmax = 2 -# kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box) -# # for debugging -# kappa = 0.657065221219616 -# construct_local_frames_fn = generate_construct_local_frames(axis_type, axis_indices) -# local_frames = construct_local_frames_fn(positions, box) -# Q_global = rot_local2global(Q_local, local_frames, lmax) - -# pme_order = 6 -# energy_force_pme_recip = value_and_grad(generate_pme_recip(Ck_1, kappa, False, pme_order, K1, K2, K3, lmax)) -# energy_force_pme_recip(positions, box, Q_global) -# print('ok') -# E, F = energy_force_pme_recip(positions, box, Q_global) -# print(E) - -# # construct the C list -# c_list = np.zeros((3,n_atoms)) -# nmol=int(n_atoms/3) -# for i in range(nmol): -# a = i*3 -# b = i*3+1 -# c = i*3+2 -# c_list[0][a]=37.19677405 -# c_list[0][b]=7.6111103 -# c_list[0][c]=7.6111103 -# c_list[1][a]=85.26810658 -# c_list[1][b]=11.90220148 -# c_list[1][c]=11.90220148 -# c_list[2][a]=134.44874488 -# c_list[2][b]=15.05074749 -# c_list[2][c]=15.05074749 -# energy_force_d6_recip = value_and_grad(generate_pme_recip(Ck_6, kappa, True, pme_order, K1, K2, K3, 0)) -# energy_force_d8_recip = value_and_grad(generate_pme_recip(Ck_8, kappa, True, pme_order, K1, K2, K3, 0)) -# energy_force_d10_recip = value_and_grad(generate_pme_recip(Ck_10, kappa, True, pme_order, K1, K2, K3, 0)) -# E6, F6 = energy_force_d6_recip(positions, box, c_list[0, :, jnp.newaxis]) -# E8, F6 = energy_force_d8_recip(positions, box, c_list[1, :, jnp.newaxis]) -# E10, F10 = energy_force_d10_recip(positions, box, c_list[2, :, jnp.newaxis]) -# print('ok') -# E6, F6 = energy_force_d6_recip(positions, box, c_list[0, :, jnp.newaxis]) -# E8, F6 = energy_force_d8_recip(positions, box, c_list[1, :, jnp.newaxis]) -# E10, F10 = energy_force_d10_recip(positions, box, c_list[2, :, jnp.newaxis]) -# print(E6, E8, E10) -# print(E6 + E8 + E10) -# return - - -# # validation code -# if __name__ == '__main__': -# validation(sys.argv[1]) From 38ff0e56ca4da9adb4f6ce0b22e93e77e4df1f68 Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Sun, 5 Jun 2022 21:53:41 +0800 Subject: [PATCH 09/17] chore: clean up classical and api.py --- dmff/admp/disp_pme.py | 15 +- dmff/admp/mbpol_intra.py | 24 +- dmff/admp/multipole.py | 5 +- dmff/admp/pairwise.py | 8 +- dmff/admp/pme.py | 28 +- dmff/admp/recip.py | 14 +- dmff/api.py | 548 ++++++++++++++++++++------------------- dmff/classical/inter.py | 11 +- dmff/classical/intra.py | 6 +- dmff/common/nblist.py | 6 +- 10 files changed, 334 insertions(+), 331 deletions(-) diff --git a/dmff/admp/disp_pme.py b/dmff/admp/disp_pme.py index 11c30eab0..722bf22fb 100755 --- a/dmff/admp/disp_pme.py +++ b/dmff/admp/disp_pme.py @@ -1,11 +1,14 @@ +from functools import partial + import jax.numpy as jnp -from jax import vmap, value_and_grad -from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales -from dmff.admp.spatial import pbc_shift +from dmff.admp.pairwise import (distribute_dispcoeff, distribute_scalar, + distribute_v3) from dmff.admp.pme import setup_ewald_parameters -from dmff.admp.recip import generate_pme_recip, Ck_6, Ck_8, Ck_10 -from dmff.admp.pairwise import distribute_scalar, distribute_v3, distribute_dispcoeff -from functools import partial +from dmff.admp.recip import Ck_6, Ck_8, Ck_10, generate_pme_recip +from dmff.admp.spatial import pbc_shift +from dmff.utils import jit_condition, pair_buffer_scales, regularize_pairs +from jax import value_and_grad, vmap + class ADMPDispPmeForce: ''' diff --git a/dmff/admp/mbpol_intra.py b/dmff/admp/mbpol_intra.py index dc69f32e3..5fa70e83e 100755 --- a/dmff/admp/mbpol_intra.py +++ b/dmff/admp/mbpol_intra.py @@ -1,11 +1,7 @@ - -import sys -import numpy as np import jax.numpy as jnp -from jax import grad, value_and_grad -from dmff.settings import DO_JIT -from dmff.utils import jit_condition +import numpy as np from dmff.admp.spatial import v_pbc_shift +from dmff.utils import jit_condition from jax import vmap #const @@ -13,14 +9,14 @@ fbasis = 0.15860145369897 fcore = -1.6351695982132 frest = 1.0 -reoh = 0.958649; -thetae = 104.3475; -b1 = 2.0; -roh = 0.9519607159623009; -alphaoh = 2.587949757553683; -deohA = 42290.92019288289; -phh1A = 16.94879431193463; -phh2 = 12.66426998162947; +reoh = 0.958649 +thetae = 104.3475 +b1 = 2.0 +roh = 0.9519607159623009 +alphaoh = 2.587949757553683 +deohA = 42290.92019288289 +phh1A = 16.94879431193463 +phh2 = 12.66426998162947 c5zA = jnp.array([4.2278462684916e+04, 4.5859382909906e-02, 9.4804986183058e+03, 7.5485566680955e+02, 1.9865052511496e+03, 4.3768071560862e+02, diff --git a/dmff/admp/multipole.py b/dmff/admp/multipole.py index 64085e83d..e863f9e72 100644 --- a/dmff/admp/multipole.py +++ b/dmff/admp/multipole.py @@ -1,7 +1,8 @@ +from functools import partial + import jax.numpy as jnp -from jax import vmap from dmff.utils import jit_condition -from functools import partial +from jax import vmap # This module deals with the transformations and rotations of multipoles diff --git a/dmff/admp/pairwise.py b/dmff/admp/pairwise.py index 8a3431a6f..4fd31fe10 100755 --- a/dmff/admp/pairwise.py +++ b/dmff/admp/pairwise.py @@ -1,9 +1,9 @@ -import sys -from jax import vmap +from functools import partial + import jax.numpy as jnp -from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales from dmff.admp.spatial import v_pbc_shift -from functools import partial +from dmff.utils import jit_condition, pair_buffer_scales, regularize_pairs +from jax import vmap DIELECTRIC = 1389.35455846 diff --git a/dmff/admp/pme.py b/dmff/admp/pme.py index 23b3a7d6c..e60e6702b 100755 --- a/dmff/admp/pme.py +++ b/dmff/admp/pme.py @@ -1,22 +1,22 @@ -import sys -import numpy as np -import jax +from functools import partial + import jax.numpy as jnp -from jax import grad, value_and_grad, vmap, jit -from jax.scipy.special import erf +import numpy as np +from dmff.admp.multipole import (C1_c2h, convert_cart2harm, rot_global2local, + rot_ind_global2local, rot_local2global) +from dmff.admp.pairwise import (distribute_multipoles, distribute_scalar, + distribute_v3) +from dmff.admp.settings import MAX_N_POL, POL_CONV +from dmff.admp.spatial import (build_quasi_internal, + generate_construct_local_frames, v_pbc_shift) from dmff.settings import DO_JIT -from dmff.admp.settings import POL_CONV, MAX_N_POL -from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales -from dmff.admp.multipole import C1_c2h, convert_cart2harm -from dmff.admp.multipole import rot_ind_global2local, rot_global2local, rot_local2global -from dmff.admp.spatial import v_pbc_shift, generate_construct_local_frames, build_quasi_internal -from dmff.admp.pairwise import distribute_scalar, distribute_v3, distribute_multipoles -from functools import partial +from dmff.utils import jit_condition, pair_buffer_scales, regularize_pairs +from jax import grad, jit, value_and_grad, vmap DIELECTRIC = 1389.35455846 DEFAULT_THOLE_WIDTH = 5.0 - -from dmff.admp.recip import generate_pme_recip, Ck_1 +import jax +from dmff.admp.recip import Ck_1, generate_pme_recip # Functions that are related to electrostatic pme diff --git a/dmff/admp/recip.py b/dmff/admp/recip.py index ed18eff0e..3adb4e700 100755 --- a/dmff/admp/recip.py +++ b/dmff/admp/recip.py @@ -1,18 +1,10 @@ -import numpy as np import jax.numpy as jnp import jax.scipy as jsp -from jax import jit -from dmff.settings import DO_JIT +import numpy as np from dmff.admp.pme import DIELECTRIC - -# for debug -# from admp.parser import * -# from admp.multipole import * -# from admp.spatial import * -# from admp.pme import * -# from jax.config import config -# config.update("jax_enable_x64", True) +from dmff.settings import DO_JIT +from jax import jit sqrt_pi = 1.7724538509055159 diff --git a/dmff/api.py b/dmff/api.py index 3431843a1..ea3d2dd06 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -1,44 +1,48 @@ -#!/usr/bin/env python +import itertools +import linecache +import sys +import xml.etree.ElementTree as ET +from collections import defaultdict +from copy import deepcopy + +import jax.numpy as jnp +import numpy as np import openmm as mm import openmm.app as app -import openmm.unit as unit import openmm.app.element as elem -import numpy as np -import jax.numpy as jnp -from collections import defaultdict -import xml.etree.ElementTree as ET +import openmm.unit as unit +from jax import grad +from jax_md import partition, space from dmff.utils import isinstance_jnp + from .admp.disp_pme import ADMPDispPmeForce from .admp.multipole import convert_cart2harm, convert_harm2cart -from .admp.pairwise import TT_damping_qq_c6_kernel, generate_pairwise_interaction -from .admp.pairwise import slater_disp_damping_kernel, slater_sr_kernel, TT_damping_qq_kernel -from .admp.pme import ADMPPmeForce -from .classical.intra import ( - HarmonicBondJaxForce, - HarmonicAngleJaxForce, - PeriodicTorsionJaxForce, +from .admp.pairwise import ( + TT_damping_qq_c6_kernel, + TT_damping_qq_kernel, + generate_pairwise_interaction, + slater_disp_damping_kernel, + slater_sr_kernel, ) -from jax_md import space, partition -from jax import grad -import linecache -import itertools +from .admp.pme import ADMPPmeForce from .classical.inter import ( - CoulombPMEForce, - LennardJonesForce, CoulNoCutoffForce, + CoulombPMEForce, CoulReactionFieldForce, + LennardJonesForce, +) +from .classical.intra import ( + HarmonicAngleJaxForce, + HarmonicBondJaxForce, + PeriodicTorsionJaxForce, ) -import sys -from copy import deepcopy class XMLNodeInfo: - @staticmethod - def to_str(value)->str: - """ convert value to string if it can - """ + def to_str(value) -> str: + """convert value to string if it can""" if isinstance(value, str): return value elif isinstance(value, (jnp.ndarray, np.ndarray)): @@ -52,54 +56,48 @@ def to_str(value)->str: return str(value) class XMLElementInfo: - def __init__(self, name): self.name = name self.attributes = {} - + def addAttribute(self, key, value): self.attributes[key] = XMLNodeInfo.to_str(value) - + def __repr__(self): return f'<{self.name} {" ".join([f"{k}={v}" for k, v in self.attributes.items()])}>' - + def __getitem__(self, name): return self.attributes[name] - def __init__(self, name): self.name = name self.attributes = {} self.elements = [] - + def __getitem__(self, name): if isinstance(name, str): return self.attributes[name] elif isinstance(name, int): return self.elements[name] - def addAttribute(self, key, value): self.attributes[key] = XMLNodeInfo.to_str(value) - def addElement(self, name, info): element = self.XMLElementInfo(name) for k, v in info.items(): element.addAttribute(k, v) self.elements.append(element) - def modResidue(self, residue, atom, key, value): pass def __repr__(self): # tricy string formatting left = f'<{self.name} {" ".join([f"{k}={v}" for k, v in self.attributes.items()])}> \n\t' - right = f'<\\{self.name}>' - content = '\n\t'.join([repr(e) for e in self.elements]) - return left + content + '\n' + right - + right = f"<\\{self.name}>" + content = "\n\t".join([repr(e) for e in self.elements]) + return left + content + "\n" + right def get_line_context(file_path, line_number): @@ -129,13 +127,14 @@ def build_covalent_map(data, max_neighbor): def findAtomTypeTexts(attribs, num): typetxt = [] - for n in range(1, num+1): - for key in ["type%i"%n, "class%i"%n]: + for n in range(1, num + 1): + for key in ["type%i" % n, "class%i" % n]: if key in attribs: typetxt.append((key, attribs[key])) break return typetxt + class ADMPDispGenerator: def __init__(self, hamiltonian): self.ff = hamiltonian @@ -205,8 +204,9 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): if "ethresh" in args: self.ethresh = args["ethresh"] - Force_DispPME = ADMPDispPmeForce(box, covalent_map, rc, self.ethresh, - self.pmax, lpme=self.lpme) + Force_DispPME = ADMPDispPmeForce( + box, covalent_map, rc, self.ethresh, self.pmax, lpme=self.lpme + ) self.disp_pme_force = Force_DispPME pot_fn_lr = Force_DispPME.get_energy pot_fn_sr = generate_pairwise_interaction( @@ -239,36 +239,41 @@ def getJaxPotential(self): def renderXML(self): # generate xml force field file - finfo = XMLNodeInfo('ADMPDispForce') - finfo.addAttribute('mScale12', self.params["mScales"][0]) - finfo.addAttribute('mScale13', self.params["mScales"][1]) - finfo.addAttribute('mScale14', self.params["mScales"][2]) - finfo.addAttribute('mScale15', self.params["mScales"][3]) - finfo.addAttribute('mScale16', self.params["mScales"][4]) - + finfo = XMLNodeInfo("ADMPDispForce") + finfo.addAttribute("mScale12", self.params["mScales"][0]) + finfo.addAttribute("mScale13", self.params["mScales"][1]) + finfo.addAttribute("mScale14", self.params["mScales"][2]) + finfo.addAttribute("mScale15", self.params["mScales"][3]) + finfo.addAttribute("mScale16", self.params["mScales"][4]) + for i in range(len(self.types)): - ainfo = {'type': self.types[i], 'A': self.params["A"][i], 'B': self.params["B"][i], 'Q': self.params["Q"][i], 'C6': self.params["C6"][i], 'C8': self.params["C8"][i], 'C10': self.params["C10"][i]} - finfo.addElement('Atom', ainfo) - + ainfo = { + "type": self.types[i], + "A": self.params["A"][i], + "B": self.params["B"][i], + "Q": self.params["Q"][i], + "C6": self.params["C6"][i], + "C8": self.params["C8"][i], + "C10": self.params["C10"][i], + } + finfo.addElement("Atom", ainfo) + return finfo + # register all parsers app.forcefield.parsers["ADMPDispForce"] = ADMPDispGenerator.parseElement class ADMPDispPmeGenerator: - r''' + r""" This one computes the undamped C6/C8/C10 interactions u = \sum_{ij} c6/r^6 + c8/r^8 + c10/r^10 - ''' + """ def __init__(self, hamiltonian): self.ff = hamiltonian - self.params = { - "C6": [], - "C8": [], - "C10": [] - } + self.params = {"C6": [], "C8": [], "C10": []} self._jaxPotential = None self.types = [] self.ethresh = 5e-4 @@ -297,8 +302,7 @@ def parseElement(element, hamiltonian): generator.params[k] = jnp.array(generator.params[k]) generator.types = np.array(generator.types) - def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, - args): + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): methodMap = { app.CutoffPeriodic: "CutoffPeriodic", app.NoCutoff: "NoCutoff", @@ -328,17 +332,18 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, rc = nonbondedCutoff.value_in_unit(unit.angstrom) # get calculator - if 'ethresh' in args: - self.ethresh = args['ethresh'] + if "ethresh" in args: + self.ethresh = args["ethresh"] - disp_force = ADMPDispPmeForce(box, covalent_map, rc, self.ethresh, - self.pmax, self.lpme) + disp_force = ADMPDispPmeForce( + box, covalent_map, rc, self.ethresh, self.pmax, self.lpme + ) self.disp_force = disp_force pot_fn_lr = disp_force.get_energy def potential_fn(positions, box, pairs, params): mScales = params["mScales"] - C6_list = params["C6"][map_atomtype] * 1e6 # to kj/mol * A**6 + C6_list = params["C6"][map_atomtype] * 1e6 # to kj/mol * A**6 C8_list = params["C8"][map_atomtype] * 1e8 C10_list = params["C10"][map_atomtype] * 1e10 c6_list = jnp.sqrt(C6_list) @@ -346,7 +351,7 @@ def potential_fn(positions, box, pairs, params): c10_list = jnp.sqrt(C10_list) c_list = jnp.vstack((c6_list, c8_list, c10_list)) E_lr = pot_fn_lr(positions, box, pairs, c_list.T, mScales) - return - E_lr + return -E_lr self._jaxPotential = potential_fn # self._top_data = data @@ -358,14 +363,17 @@ def renderXML(self): # generate xml force field file pass + # register all parsers app.forcefield.parsers["ADMPDispPmeForce"] = ADMPDispPmeGenerator.parseElement + class QqTtDampingGenerator: - r''' + r""" This one calculates the tang-tonnies damping of charge-charge interaction E = \sum_ij exp(-B*r)*(1+B*r)*q_i*q_j/r - ''' + """ + def __init__(self, hamiltonian): self.ff = hamiltonian self.params = { @@ -398,8 +406,7 @@ def parseElement(element, hamiltonian): generator.types = np.array(generator.types) # on working - def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, - args): + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): n_atoms = len(data.atoms) # build index map @@ -411,13 +418,13 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, # build covalent map covalent_map = build_covalent_map(data, 6) - pot_fn_sr = generate_pairwise_interaction(TT_damping_qq_kernel, - covalent_map, - static_args={}) + pot_fn_sr = generate_pairwise_interaction( + TT_damping_qq_kernel, covalent_map, static_args={} + ) def potential_fn(positions, box, pairs, params): mScales = params["mScales"] - b_list = params["B"][map_atomtype] / 10 # convert to A^-1 + b_list = params["B"][map_atomtype] / 10 # convert to A^-1 q_list = params["Q"][map_atomtype] E_sr = pot_fn_sr(positions, box, pairs, mScales, b_list, q_list) @@ -433,17 +440,19 @@ def renderXML(self): # generate xml force field file pass + # register all parsers app.forcefield.parsers["QqTtDampingForce"] = QqTtDampingGenerator.parseElement class SlaterDampingGenerator: - r''' + r""" This one computes the slater-type damping function for c6/c8/c10 dispersion E = \sum_ij (f6-1)*c6/r6 + (f8-1)*c8/r8 + (f10-1)*c10/r10 fn = f_tt(x, n) x = br - (2*br2 + 3*br) / (br2 + 3*br + 3) - ''' + """ + def __init__(self, hamiltonian): self.ff = hamiltonian self.params = { @@ -479,8 +488,7 @@ def parseElement(element, hamiltonian): generator.params[k] = jnp.array(generator.params[k]) generator.types = np.array(generator.types) - def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, - args): + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): n_atoms = len(data.atoms) # build index map @@ -493,22 +501,24 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, covalent_map = build_covalent_map(data, 6) # WORKING - pot_fn_sr = generate_pairwise_interaction(slater_disp_damping_kernel, - covalent_map, - static_args={}) + pot_fn_sr = generate_pairwise_interaction( + slater_disp_damping_kernel, covalent_map, static_args={} + ) def potential_fn(positions, box, pairs, params): mScales = params["mScales"] - b_list = params["B"][map_atomtype] / 10 # convert to A^-1 - c6_list = jnp.sqrt(params["C6"][map_atomtype] * 1e6) # to kj/mol * A**6 + b_list = params["B"][map_atomtype] / 10 # convert to A^-1 + c6_list = jnp.sqrt(params["C6"][map_atomtype] * 1e6) # to kj/mol * A**6 c8_list = jnp.sqrt(params["C8"][map_atomtype] * 1e8) c10_list = jnp.sqrt(params["C10"][map_atomtype] * 1e10) - E_sr = pot_fn_sr(positions, box, pairs, mScales, b_list, c6_list, c8_list, c10_list) + E_sr = pot_fn_sr( + positions, box, pairs, mScales, b_list, c6_list, c8_list, c10_list + ) return E_sr self._jaxPotential = potential_fn # self._top_data = data - + def getJaxPotential(self): return self._jaxPotential @@ -516,21 +526,22 @@ def renderXML(self): # generate xml force field file pass + app.forcefield.parsers["SlaterDampingForce"] = SlaterDampingGenerator.parseElement class SlaterExGenerator: - r''' + r""" This one computes the Slater-ISA type exchange interaction u = \sum_ij A * (1/3*(Br)^2 + Br + 1) - ''' + """ def __init__(self, hamiltonian): self.ff = hamiltonian self.params = { - "A": [], - "B": [], - } + "A": [], + "B": [], + } self._jaxPotential = None self.types = [] @@ -556,8 +567,7 @@ def parseElement(element, hamiltonian): generator.params[k] = jnp.array(generator.params[k]) generator.types = np.array(generator.types) - def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, - args): + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): n_atoms = len(data.atoms) # build index map @@ -569,14 +579,14 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, # build covalent map covalent_map = build_covalent_map(data, 6) - pot_fn_sr = generate_pairwise_interaction(slater_sr_kernel, - covalent_map, - static_args={}) + pot_fn_sr = generate_pairwise_interaction( + slater_sr_kernel, covalent_map, static_args={} + ) def potential_fn(positions, box, pairs, params): mScales = params["mScales"] a_list = params["A"][map_atomtype] - b_list = params["B"][map_atomtype] / 10 # nm^-1 to A^-1 + b_list = params["B"][map_atomtype] / 10 # nm^-1 to A^-1 return pot_fn_sr(positions, box, pairs, mScales, a_list, b_list) @@ -590,6 +600,7 @@ def renderXML(self): # generate xml force field file pass + app.forcefield.parsers["SlaterExForce"] = SlaterExGenerator.parseElement @@ -598,16 +609,23 @@ def renderXML(self): class SlaterSrEsGenerator(SlaterExGenerator): def __init__(self): super().__init__(self) + + class SlaterSrPolGenerator(SlaterExGenerator): def __init__(self): super().__init__(self) + + class SlaterSrDispGenerator(SlaterExGenerator): def __init__(self): super().__init__(self) + + class SlaterDhfGenerator(SlaterExGenerator): def __init__(self): super().__init__(self) + # register all parsers app.forcefield.parsers["SlaterSrEsForce"] = SlaterSrEsGenerator.parseElement app.forcefield.parsers["SlaterSrPolForce"] = SlaterSrPolGenerator.parseElement @@ -679,33 +697,33 @@ def registerAtomType(self, atom: dict): @staticmethod def parseElement(element, hamiltonian): - r""" parse admp related parameters in XML file - - example: - - - - - - - - - - - - - - + r"""parse admp related parameters in XML file + + example: + + + + + + + + + + + + + + """ generator = ADMPPmeGenerator(hamiltonian) @@ -1055,7 +1073,7 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): self.ethresh, self.lmax, self.lpol, - lpme=self.lpme + lpme=self.lpme, ) self.pme_force = pme_force @@ -1091,39 +1109,50 @@ def getJaxPotential(self): def renderXML(self): # - - finfo = XMLNodeInfo('ADMPPmeForce') - finfo.addAttribute('lmax', str(self.lmax)) + + finfo = XMLNodeInfo("ADMPPmeForce") + finfo.addAttribute("lmax", str(self.lmax)) outputparams = deepcopy(self.params) - mScales = outputparams.pop('mScales') - pScales = outputparams.pop('pScales') - dScales = outputparams.pop('dScales') + mScales = outputparams.pop("mScales") + pScales = outputparams.pop("pScales") + dScales = outputparams.pop("dScales") for i in range(len(mScales)): - finfo.addAttribute(f'mScale1{i+2}', str(mScales[i])) + finfo.addAttribute(f"mScale1{i+2}", str(mScales[i])) for i in range(len(pScales)): - finfo.addAttribute(f'pScale{i+1}', str(pScales[i])) + finfo.addAttribute(f"pScale{i+1}", str(pScales[i])) for i in range(len(dScales)): - finfo.addAttribute(f'dScale{i+1}', str(dScales[i])) - - Q = outputparams['Q_local'] + finfo.addAttribute(f"dScale{i+1}", str(dScales[i])) + + Q = outputparams["Q_local"] Q_global = convert_harm2cart(Q, self.lmax) - + # for atom in range(self.n_atoms): - info = {'type': self.map_atomtype[atom]} - info.update({ktype:self.kStrings[ktype][atom] for ktype in ['kz', 'kx', 'ky']}) - for i, key in enumerate(['c0', 'dX', 'dY', 'dZ', 'qXX', 'qXY', 'qXZ', 'qYY', 'qYZ', 'qZZ']): + info = {"type": self.map_atomtype[atom]} + info.update( + {ktype: self.kStrings[ktype][atom] for ktype in ["kz", "kx", "ky"]} + ) + for i, key in enumerate( + ["c0", "dX", "dY", "dZ", "qXX", "qXY", "qXZ", "qYY", "qYZ", "qZZ"] + ): info[key] = "%.8f" % Q_global[atom][i] - finfo.addElement('Atom', info) - + finfo.addElement("Atom", info) + # for t in range(len(self.types)): - info = { - 'type': self.types[t] - } - info.update({p: "%.8f" % self.params['pol'][t] for p in ['polarizabilityXX', 'polarizabilityYY', 'polarizabilityZZ']}) - finfo.addElement('Polarize', info) - + info = {"type": self.types[t]} + info.update( + { + p: "%.8f" % self.params["pol"][t] + for p in [ + "polarizabilityXX", + "polarizabilityYY", + "polarizabilityZZ", + ] + } + ) + finfo.addElement("Polarize", info) + return finfo @@ -1150,14 +1179,14 @@ def registerBondType(self, bond): def parseElement(element, hamiltonian): r"""parse section in XML file - - example: - - - - - <\HarmonicBondForce> - + + example: + + + + + <\HarmonicBondForce> + """ generator = HarmonicBondJaxGenerator(hamiltonian) hamiltonian.registerGenerator(generator) @@ -1216,12 +1245,11 @@ def renderXML(self): binfo[k1] = v1 binfo[k2] = v2 for key in self.params.keys(): - binfo[key] = "%.8f"%self.params[key][ntype] + binfo[key] = "%.8f" % self.params[key][ntype] finfo.addElement("Bond", binfo) return finfo - # register all parsers app.forcefield.parsers["HarmonicBondForce"] = HarmonicBondJaxGenerator.parseElement @@ -1241,13 +1269,13 @@ def registerAngleType(self, angle): @staticmethod def parseElement(element, hamiltonian): - r""" parse section in XML file + r"""parse section in XML file - example: - - - - <\HarmonicAngleForce> + example: + + + + <\HarmonicAngleForce> """ generator = HarmonicAngleJaxGenerator(hamiltonian) @@ -1310,9 +1338,15 @@ def renderXML(self): finfo = XMLNodeInfo("HarmonicAngleForce") for i, type in enumerate(self.types): t1, t2, t3 = type - ainfo = {'type1': t1, 'type2': t2, 'type3': t3, 'k': self.params['k'][i], 'angle': self.params['angle'][i]} - finfo.addElement('Angle', ainfo) - + ainfo = { + "type1": t1, + "type2": t2, + "type3": t3, + "k": self.params["k"][i], + "angle": self.params["angle"][i], + } + finfo.addElement("Angle", ainfo) + return finfo @@ -1320,7 +1354,6 @@ def renderXML(self): app.forcefield.parsers["HarmonicAngleForce"] = HarmonicAngleJaxGenerator.parseElement - def _matchImproper(data, torsion, generator): type1 = data.atomType[data.atoms[torsion[0]]] type2 = data.atomType[data.atoms[torsion[1]]] @@ -1523,14 +1556,14 @@ def registerImproperTorsion(self, parameters, ordering="default"): @staticmethod def parseElement(element, ff): - """ parse section in XML file - - example: - - - - - + """parse section in XML file + + example: + + + + + """ existing = [f for f in ff._forces if isinstance(f, PeriodicTorsionJaxGenerator)] @@ -1867,50 +1900,58 @@ def getJaxPotential(self): def renderXML(self): params = self.params # generate xml force field file - finfo = XMLNodeInfo('PeriodicTorsionForce') + finfo = XMLNodeInfo("PeriodicTorsionForce") for i in range(len(self.proper)): proper = self.proper[i] - - finfo.addElement('Proper', - {'type1': proper.types1, 'type2': proper.types2, - 'type3': proper.types3, 'type4': proper.types4, - 'periodicity1': proper.periodicity[0], - 'phase1': params['psi1_p'][i], - 'k1': params['k1_p'][i], - 'periodicity2': proper.periodicity[1], - 'phase2': params['psi2_p'][i], - 'k2': params['k2_p'][i], - 'periodicity3': proper.periodicity[2], - 'phase3': params['psi3_p'][i], - 'k3': params['k3_p'][i], - 'periodicity4': proper.periodicity[3], - 'phase4': params['psi4_p'][i], - 'k4': params['k4_p'][i], - } + + finfo.addElement( + "Proper", + { + "type1": proper.types1, + "type2": proper.types2, + "type3": proper.types3, + "type4": proper.types4, + "periodicity1": proper.periodicity[0], + "phase1": params["psi1_p"][i], + "k1": params["k1_p"][i], + "periodicity2": proper.periodicity[1], + "phase2": params["psi2_p"][i], + "k2": params["k2_p"][i], + "periodicity3": proper.periodicity[2], + "phase3": params["psi3_p"][i], + "k3": params["k3_p"][i], + "periodicity4": proper.periodicity[3], + "phase4": params["psi4_p"][i], + "k4": params["k4_p"][i], + }, ) - + for i in range(len(self.improper)): - + improper = self.improper[i] - - finfo.addElement('Improper', - {'type1': improper.types1, 'type2': improper.types2, - 'type3': improper.types3, 'type4': improper.types4, - 'periodicity1': improper.periodicity[0], - 'phase1': params['psi1_i'][i], - 'k1': params['k1_i'][i], - 'periodicity2': proper.periodicity[1], - 'phase2': params['psi2_i'][i], - 'k2': params['k2_i'][i], - 'periodicity3': proper.periodicity[2], - 'phase3': params['psi3_i'][i], - 'k3': params['k3_i'][i], - 'periodicity4': proper.periodicity[3], - 'phase4': params['psi4_i'][i], - 'k4': params['k4_i'][i], - } + + finfo.addElement( + "Improper", + { + "type1": improper.types1, + "type2": improper.types2, + "type3": improper.types3, + "type4": improper.types4, + "periodicity1": improper.periodicity[0], + "phase1": params["psi1_i"][i], + "k1": params["k1_i"][i], + "periodicity2": proper.periodicity[1], + "phase2": params["psi2_i"][i], + "k2": params["k2_i"][i], + "periodicity3": proper.periodicity[2], + "phase3": params["psi3_i"][i], + "k3": params["k3_i"][i], + "periodicity4": proper.periodicity[3], + "phase4": params["psi4_i"][i], + "k4": params["k4_i"][i], + }, ) - + return finfo @@ -1957,9 +1998,9 @@ def registerAtom(self, atom): @staticmethod def parseElement(element, ff): """parse section in XML file - + example: - + @@ -1999,9 +2040,9 @@ def parseElement(element, ff): generator.useAttributeFromResidue.append(eprm) for atom in element.findall("Atom"): generator.registerAtom(atom.attrib) - + generator.n_atoms = len(element.findall("Atom")) - + # jax it! for k in generator.params.keys(): generator.params[k] = jnp.array(generator.params[k]) @@ -2034,7 +2075,6 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): mscales_lj = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) # mscale for LJ mscales_lj = mscales_lj.at[2].set(self.params["lj14scale"][0]) - # Coulomb: only support PME for now # set PBC if nonbondedMethod not in [app.NoCutoff, app.CutoffNonPeriodic]: @@ -2086,7 +2126,6 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): map_nbfix = [] # implement it later map_nbfix = np.array(map_nbfix, dtype=int).reshape((-1, 2)) - colv_map = build_covalent_map(data, 6) @@ -2105,11 +2144,11 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): else: r_switch = r_cut ifSwitch = False - + map_lj = jnp.array(map_lj) map_nbfix = jnp.array(map_nbfix) - map_charge = jnp.array(map_charge) - + map_charge = jnp.array(map_charge) + ljforce = LennardJonesForce( r_switch, r_cut, @@ -2126,7 +2165,9 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): # do not use PME if nonbondedMethod in [app.CutoffPeriodic, app.CutoffNonPeriodic]: # use Reaction Field - coulforce = CoulReactionFieldForce(r_cut, map_charge, colv_map, isPBC=ifPBC) + coulforce = CoulReactionFieldForce( + r_cut, map_charge, colv_map, isPBC=ifPBC + ) if nonbondedMethod is app.NoCutoff: # use NoCutoff coulforce = CoulNoCutoffForce(map_charge, colv_map) @@ -2136,12 +2177,12 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): coulenergy = coulforce.generate_get_energy() def potential_fn(positions, box, pairs, params): - + # check whether args passed into potential_fn are jnp.array and differentiable # note this check will be optimized away by jit # it is jit-compatiable isinstance_jnp(positions, box, params) - + ljE = ljenergy( positions, box, @@ -2150,7 +2191,7 @@ def potential_fn(positions, box, pairs, params): params["sigma"], params["epsfix"], params["sigfix"], - mscales_lj + mscales_lj, ) coulE = coulenergy(positions, box, pairs, params["charge"], mscales_coul) @@ -2162,23 +2203,27 @@ def getJaxPotential(self): return self._jaxPotential def renderXML(self): - + # - finfo = XMLNodeInfo('NonbondedForce') - finfo.addAttribute('coulomb14scale', str(self.coulomb14scale)) - finfo.addAttribute('lj14scale', str(self.lj14scale)) - + finfo = XMLNodeInfo("NonbondedForce") + finfo.addAttribute("coulomb14scale", str(self.coulomb14scale)) + finfo.addAttribute("lj14scale", str(self.lj14scale)) + for atom in range(self.n_atoms): - info = {'type': self.types[atom], 'charge': self.params['charge'][atom], 'sigma': self.params['sigma'][atom], 'epsilon': self.params['epsilon'][atom]} - finfo.addElement('Atom', info) - + info = { + "type": self.types[atom], + "charge": self.params["charge"][atom], + "sigma": self.params["sigma"][atom], + "epsilon": self.params["epsilon"][atom], + } + finfo.addElement("Atom", info) + return finfo app.forcefield.parsers["NonbondedForce"] = NonbondJaxGenerator.parseElement - class Hamiltonian(app.forcefield.ForceField): def __init__(self, *xmlnames): super().__init__(*xmlnames) @@ -2192,13 +2237,13 @@ def createPotential( topology, nonbondedMethod=app.NoCutoff, nonbondedCutoff=1.0 * unit.nanometer, - **args + **args, ): system = self.createSystem( topology, nonbondedMethod=nonbondedMethod, nonbondedCutoff=nonbondedCutoff, - **args + **args, ) # load_constraints_from_system_if_needed # create potentials @@ -2226,30 +2271,3 @@ def render(self, filename): tree = ET.ElementTree(root) tree.write(filename) - - -if __name__ == "__main__": - H = Hamiltonian("forcefield.xml") - generator = H.getGenerators()[0] - app.Topology.loadBondDefinitions("residues.xml") - pdb = app.PDBFile("../water1024.pdb") - rc = 4.0 - potentials = H.createPotential(pdb.topology, nonbondedCutoff=rc * unit.angstrom) - pot_disp = potentials[0] - - positions = jnp.array(pdb.positions._value) * 10 - a, b, c = pdb.topology.getPeriodicBoxVectors() - box = jnp.array([a._value, b._value, c._value]) * 10 - - # neighbor list - displacement_fn, shift_fn = space.periodic_general( - box, fractional_coordinates=False - ) - neighbor_list_fn = partition.neighbor_list( - displacement_fn, box, rc, 0, format=partition.OrderedSparse - ) - nbr = neighbor_list_fn.allocate(positions) - pairs = nbr.idx.T - - param_grad = grad(pot_disp, argnums=3)(positions, box, pairs, generator.params) - print(param_grad) diff --git a/dmff/classical/inter.py b/dmff/classical/inter.py index 849b781df..847aee9b2 100644 --- a/dmff/classical/inter.py +++ b/dmff/classical/inter.py @@ -1,13 +1,10 @@ -from dmff.utils import pair_buffer_scales, regularize_pairs import jax.numpy as jnp -from dmff.admp.pme import energy_pme, setup_ewald_parameters -from dmff.admp.recip import generate_pme_recip -from dmff.admp.spatial import v_pbc_shift import numpy as np -import jax.numpy as jnp +from dmff.admp.pme import DIELECTRIC, energy_pme, setup_ewald_parameters +from dmff.admp.recip import Ck_1, generate_pme_recip +from dmff.admp.spatial import v_pbc_shift +from dmff.utils import pair_buffer_scales, regularize_pairs from jax import grad -from dmff.admp.recip import generate_pme_recip, Ck_1 -from dmff.admp.pme import DIELECTRIC ONE_4PI_EPS0 = DIELECTRIC * 0.1 diff --git a/dmff/classical/intra.py b/dmff/classical/intra.py index 48327d961..24a79562f 100644 --- a/dmff/classical/intra.py +++ b/dmff/classical/intra.py @@ -1,9 +1,5 @@ -import sys -import numpy as np -import jax import jax.numpy as jnp -from jax import grad, value_and_grad, vmap, jit -from jax.scipy.special import erf +from jax import value_and_grad, vmap def distance(p1v, p2v): diff --git a/dmff/common/nblist.py b/dmff/common/nblist.py index 74adcd967..631441f61 100644 --- a/dmff/common/nblist.py +++ b/dmff/common/nblist.py @@ -1,7 +1,7 @@ -from jax_md import space, partition import jax.numpy as jnp from dmff.utils import regularize_pairs -import jax.numpy as jnp +from jax_md import partition, space + class NeighborList: @@ -92,4 +92,4 @@ def distance(self): jnp.ndarray: (nPairs, ) """ - return jnp.linalg.norm(self.dr, axis=1) \ No newline at end of file + return jnp.linalg.norm(self.dr, axis=1) From 8a8694685cb9e707a2c26e8f830e7261cce85b92 Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Sun, 5 Jun 2022 22:01:09 +0800 Subject: [PATCH 10/17] fix(ut): withdraw last commit about fix wrong number in test_nblist --- tests/test_common/test_nblist.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_common/test_nblist.py b/tests/test_common/test_nblist.py index a8a04724b..3da3eac86 100644 --- a/tests/test_common/test_nblist.py +++ b/tests/test_common/test_nblist.py @@ -33,18 +33,18 @@ def test_update(self, nblist): def test_pairs(self, nblist): pairs = nblist.pairs - assert pairs.shape == (15, 2) + assert pairs.shape == (18, 2) def test_pair_mask(self, nblist): pair, mask = nblist.pair_mask - assert mask.shape == (15, ) + assert mask.shape == (18, ) def test_dr(self, nblist): dr = nblist.dr - assert dr.shape == (15, 3) + assert dr.shape == (18, 3) def test_distance(self, nblist): - assert nblist.distance.shape == (15, ) + assert nblist.distance.shape == (18, ) From 7c71961cfd48f295d2d888b8daa27eb7301bbe09 Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Thu, 9 Jun 2022 18:54:27 +0800 Subject: [PATCH 11/17] update mkdocs.yml and a simple test of api generator --- docs/refs/common/nblist.md | 3 +++ mkdocs.yml | 14 +++++++++----- 2 files changed, 12 insertions(+), 5 deletions(-) create mode 100644 docs/refs/common/nblist.md diff --git a/docs/refs/common/nblist.md b/docs/refs/common/nblist.md new file mode 100644 index 000000000..475594727 --- /dev/null +++ b/docs/refs/common/nblist.md @@ -0,0 +1,3 @@ +# Reference + +::: dmff.common.nblist.NeighborList \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 6833919a7..80743097a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -4,19 +4,20 @@ nav: - User Guide: - 1. Introduction: user_guide/introduction.md - 2. Installation: user_guide/installation.md - - 3. Compute energy and force: user_guide/compute.md - - 4. Auto-diff: user_guide/auto_diff.md - - 5. Theory: user_guide/theory.md - - 6. couple with MD: user_guide/couple.md - - 7. Introduction to force field xml file: user_guide/xml_spec.md + - 3. Theory: user_guide/theory.md + - 4. Get Started: user_guide/tutorial.md + - 5. HOW TO - xml file: user_guide/xml_spec.md - Developer Guide: - 1. Introduction: dev_guide/introduction.md - 2. Architecture: dev_guide/arch.md - 3. Convention: dev_guide/convention.md + - 4. Write Docs: dev_guide/write_docs.md + - Modules: - ADMP: - Introduction: admp/readme.md + - Theory: user_guide/multipole_pme.md - Frontends: admp/frontend.md - About: about.md @@ -26,6 +27,9 @@ markdown_extensions: - pymdownx.arithmatex: generic: true +plugins: +- mkdocstrings: + extra_javascript: - javascripts/mathjax.js - https://polyfill.io/v3/polyfill.min.js?features=es6 From 9cd3b5307ff884ea72a07e2cf734b399d35dca79 Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Sat, 11 Jun 2022 18:29:30 +0800 Subject: [PATCH 12/17] feat: auto gen docs refs --- docs/gen_ref_pages.py | 38 ++++++++++++++++++++++++++++++++++++++ docs/refs/common/nblist.md | 3 --- mkdocs.yml | 20 +++++++++++++------- 3 files changed, 51 insertions(+), 10 deletions(-) create mode 100644 docs/gen_ref_pages.py delete mode 100644 docs/refs/common/nblist.md diff --git a/docs/gen_ref_pages.py b/docs/gen_ref_pages.py new file mode 100644 index 000000000..77fbb0a3c --- /dev/null +++ b/docs/gen_ref_pages.py @@ -0,0 +1,38 @@ +"""Generate the code reference pages.""" + +from pathlib import Path + +import mkdocs_gen_files + +nav = mkdocs_gen_files.Nav() + +for path in sorted(Path("dmff").rglob("*.py")): # + + module_path = path.relative_to('dmff').with_suffix("") # + + doc_path = path.relative_to('dmff').with_suffix(".md") # + + full_doc_path = Path("refs", doc_path) # + + parts = list(module_path.parts) + + if parts[-1] == "__init__": # + continue + elif parts[-1] == "__main__": + continue + + nav[parts] = doc_path.as_posix() + print(full_doc_path) + with mkdocs_gen_files.open(full_doc_path, "w") as fd: # + + identifier = ".".join(parts) # + + print("::: dmff." + identifier, file=fd) # + + + mkdocs_gen_files.set_edit_path(full_doc_path, path) # + +with mkdocs_gen_files.open("refs/SUMMARY.md", "w") as nav_file: # + + nav_file.writelines(nav.build_literate_nav()) # + diff --git a/docs/refs/common/nblist.md b/docs/refs/common/nblist.md deleted file mode 100644 index 475594727..000000000 --- a/docs/refs/common/nblist.md +++ /dev/null @@ -1,3 +0,0 @@ -# Reference - -::: dmff.common.nblist.NeighborList \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 80743097a..86da979ea 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,6 +1,7 @@ site_name: DMFF nav: - - Home: index.md + - Home: + - Introduction: index.md - User Guide: - 1. Introduction: user_guide/introduction.md - 2. Installation: user_guide/installation.md @@ -14,12 +15,10 @@ nav: - 3. Convention: dev_guide/convention.md - 4. Write Docs: dev_guide/write_docs.md - - Modules: - - ADMP: - - Introduction: admp/readme.md - - Theory: user_guide/multipole_pme.md - - Frontends: admp/frontend.md - - About: about.md + - Modules: refs/ + + - About: + - Maintainer: about.md theme: readthedocs @@ -28,8 +27,15 @@ markdown_extensions: generic: true plugins: +- search +- gen-files: + scripts: + - docs/gen_ref_pages.py +- literate-nav: + nav_file: SUMMARY.md - mkdocstrings: + extra_javascript: - javascripts/mathjax.js - https://polyfill.io/v3/polyfill.min.js?features=es6 From f4fec8f00cc2be41a906f8f777e3e1398c3fd79a Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Sat, 11 Jun 2022 18:46:21 +0800 Subject: [PATCH 13/17] fix: typo in requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f3a674b4a..343447502 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ numpy>=1.18 jax>=0.3.7 jax-md>=0.1.28 -mkdocs=1.3.0 +mkdocs>=1.3.0 mkdocs-autorefs>=0.4.1 mkdocs-gen-files>=0.3.4 mkdocs-literate-nav>=0.4.1 From f37673356b3c2cd2bdb111c20a9dab840016d798 Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Tue, 14 Jun 2022 17:04:49 +0800 Subject: [PATCH 14/17] feat: update docs configs --- dmff/admp/mbpol_intra.py | 7 +------ docs/assets/js/mathjax.js | 16 ++++++++++++++++ docs/dev_guide/profile.md | 2 -- docs/user_guide/xml_spec.md | 2 +- mkdocs.yml | 20 ++++++++++---------- 5 files changed, 28 insertions(+), 19 deletions(-) create mode 100644 docs/assets/js/mathjax.js delete mode 100644 docs/dev_guide/profile.md diff --git a/dmff/admp/mbpol_intra.py b/dmff/admp/mbpol_intra.py index 8452cf6e9..3a9966e13 100755 --- a/dmff/admp/mbpol_intra.py +++ b/dmff/admp/mbpol_intra.py @@ -1,11 +1,10 @@ -import sys + import numpy as np import jax.numpy as jnp import numpy as np from dmff.admp.spatial import v_pbc_shift from dmff.utils import jit_condition from jax import vmap -import time #const f5z = 0.999677885 @@ -484,7 +483,3 @@ def onebody_kernel(x1, x2, x3, Va, Vb, efac): e1 *= cm1_kcalmol e1 *= cal2joule # conver cal 2 j return e1 -<<<<<<< HEAD - -======= ->>>>>>> cicd diff --git a/docs/assets/js/mathjax.js b/docs/assets/js/mathjax.js new file mode 100644 index 000000000..e648674dd --- /dev/null +++ b/docs/assets/js/mathjax.js @@ -0,0 +1,16 @@ +window.MathJax = { + tex: { + inlineMath: [["\\(", "\\)"]], + displayMath: [["\\[", "\\]"]], + processEscapes: true, + processEnvironments: true + }, + options: { + ignoreHtmlClass: ".*|", + processHtmlClass: "arithmatex" + } + }; + + document$.subscribe(() => { + MathJax.typesetPromise() + }) \ No newline at end of file diff --git a/docs/dev_guide/profile.md b/docs/dev_guide/profile.md deleted file mode 100644 index 14a981b77..000000000 --- a/docs/dev_guide/profile.md +++ /dev/null @@ -1,2 +0,0 @@ -# How to profile - diff --git a/docs/user_guide/xml_spec.md b/docs/user_guide/xml_spec.md index 5304e3141..7386c6345 100644 --- a/docs/user_guide/xml_spec.md +++ b/docs/user_guide/xml_spec.md @@ -151,7 +151,7 @@ The `` node of the residue part defines all the atoms involved in the resi - <\displaylines{Bond atomName1="CA" atomName2="HA"/> + diff --git a/mkdocs.yml b/mkdocs.yml index fe232417a..8dee9e9eb 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -5,22 +5,22 @@ nav: - User Guide: - 1. Introduction: user_guide/introduction.md - 2. Installation: user_guide/installation.md - - 3. Compute energy and force: user_guide/compute.md - - 4. Auto-diff: user_guide/auto_diff.md - - 5. Theory: user_guide/theory.md - - 6. couple with MD: user_guide/couple.md - - 7. Introduction to force field xml file: user_guide/xml_spec.md - + - 3. Basic usage: user_guide/usage.md + - 4. XML format force field: user_guide/xml_spec.md + - 5. Theory: user_guide/theory.md + - Developer Guide: - 1. Introduction: dev_guide/introduction.md - - 2. Architecture: dev_guide/arch.md - - 3. Convention: dev_guide/convention.md + - 2. Software architecture: dev_guide/arch.md + - 3. Coding conventions: dev_guide/convention.md + - 4. Document Writing: dev_guide/write_docs.md + - Modules: - ADMP: - Introduction: admp/readme.md - Frontends: admp/frontend.md - - References: refs/ + - API: refs/ - About: - Maintainer: about.md @@ -42,6 +42,6 @@ plugins: extra_javascript: - - javascripts/mathjax.js + - assets/js/mathjax.js - https://polyfill.io/v3/polyfill.min.js?features=es6 - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js From 5f35e79605f8ff81f5005813a9ec681b3b7afbd6 Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Tue, 14 Jun 2022 17:14:19 +0800 Subject: [PATCH 15/17] fix: fix test_nblist bug --- dmff/api.py | 1 + tests/test_common/test_nblist.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/dmff/api.py b/dmff/api.py index baa68bcb5..dfd7fa6de 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -11,6 +11,7 @@ import openmm as mm import openmm.app as app import openmm.app.element as elem +import openmm.unit as unit from dmff.admp.disp_pme import ADMPDispPmeForce from dmff.admp.multipole import convert_cart2harm, convert_harm2cart diff --git a/tests/test_common/test_nblist.py b/tests/test_common/test_nblist.py index 1ee86d030..d5f709103 100644 --- a/tests/test_common/test_nblist.py +++ b/tests/test_common/test_nblist.py @@ -37,17 +37,17 @@ def test_update(self, nblist): def test_pairs(self, nblist): pairs = nblist.pairs - assert pairs.shape == (18, 2) + assert pairs.shape == (15, 2) def test_pair_mask(self, nblist): pair, mask = nblist.pair_mask - assert mask.shape == (18, ) + assert mask.shape == (15, ) def test_dr(self, nblist): dr = nblist.dr - assert dr.shape == (18, 3) + assert dr.shape == (15, 3) def test_distance(self, nblist): From 2495943bb97fbdc7f4a1d50175106e9227209a00 Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Tue, 14 Jun 2022 17:16:04 +0800 Subject: [PATCH 16/17] update: license in docs --- docs/about.md | 3 --- docs/license.md | 3 +++ mkdocs.yml | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) delete mode 100644 docs/about.md create mode 100644 docs/license.md diff --git a/docs/about.md b/docs/about.md deleted file mode 100644 index 835deeb6f..000000000 --- a/docs/about.md +++ /dev/null @@ -1,3 +0,0 @@ -# Lisense - -# Contributor \ No newline at end of file diff --git a/docs/license.md b/docs/license.md new file mode 100644 index 000000000..d2af63711 --- /dev/null +++ b/docs/license.md @@ -0,0 +1,3 @@ +# Lisense + +The project DeePMD-kit is licensed under [GNU LGPLv3.0](https://github.com/deepmodeling/deepmd-kit/blob/master/LICENSE). \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 8dee9e9eb..8a751b209 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -23,7 +23,7 @@ nav: - API: refs/ - About: - - Maintainer: about.md + - License: license.md theme: readthedocs From d3854c4780068b47ad8f3ce3df3d8c9c45d50afc Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Tue, 14 Jun 2022 17:22:56 +0800 Subject: [PATCH 17/17] fix: import missing in api.py (may caused by formatter) --- dmff/admp/recip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dmff/admp/recip.py b/dmff/admp/recip.py index 0867872ae..3987f23d8 100755 --- a/dmff/admp/recip.py +++ b/dmff/admp/recip.py @@ -1,4 +1,4 @@ - +import numpy as np import jax.numpy as jnp import jax.scipy as jsp from jax import jit