Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
125 commits
Select commit Hold shift + click to select a range
b9a0551
Merge branch 'devel' of https://github.com/deepmodeling/DMFF into devel
Ericwang6 Jun 11, 2022
caa27e3
Merge branch 'devel' of https://github.com/deepmodeling/DMFF into devel
Ericwang6 Jun 13, 2022
bcee6d6
feat(classical): offer tools to organize xml forcefield and match par…
WangXinyan940 Jun 25, 2022
713f31e
fix(tests): change API to be consistent with dmff
WangXinyan940 Jun 25, 2022
c46d141
feat(dmff): Add class Potential to keep OpenMM and DMFF potentials
WangXinyan940 Jun 25, 2022
084c19d
fix(api): auto-update parameter before writing new xml file.
WangXinyan940 Jun 27, 2022
77bed0e
refactor(api): Refactor DMFF API module to support multiple features
WangXinyan940 Jun 27, 2022
8b2b031
refactor: admp api.py using old xmlholder api
Jun 30, 2022
5e8f224
Merge branch 'api-refactor' of https://github.com/deepmodeling/DMFF i…
Jun 30, 2022
da25009
fix(fftree): make the name of API easier to be understand
WangXinyan940 Jun 30, 2022
a2b636e
fix(fftree): make the name of API easier to be understand
WangXinyan940 Jun 30, 2022
c1be361
update
Jul 2, 2022
e316c29
Merge branch 'api-refactor' of https://github.com/deepmodeling/DMFF i…
Jul 2, 2022
0a60819
update: admp related api
Jul 2, 2022
09e4b2d
fix: fix bug in admp api
Jul 4, 2022
e6346c9
doc: programming style convention about typing and numpy style docstring
Jul 4, 2022
28642bb
Merge pull request #48 from Roy-Kid/roy/xmlholder
WangXinyan940 Jul 4, 2022
0a0f70d
fix(tests): Add attribute to choose improper order
WangXinyan940 Jul 5, 2022
e75620e
fix(classical): remove double check of cutoff distance in potential f…
WangXinyan940 Jul 5, 2022
cfa4eac
Merge pull request #49 from WangXinyan940/wangxy/xmlholder
WangXinyan940 Jul 5, 2022
2956d89
fix(classical): avoid creating empty dict in impr parameters
WangXinyan940 Jul 6, 2022
4055f11
fix(inter): add 1e-12 as eps value in jnp.sqrt function
WangXinyan940 Jul 6, 2022
19d8583
Merge pull request #51 from WangXinyan940/wangxy/xmlholder
WangXinyan940 Jul 6, 2022
2c5595b
fix(inter): use more tiny eps for numerical consistent
WangXinyan940 Jul 6, 2022
fdbeeb2
Merge pull request #52 from WangXinyan940/wangxy/xmlholder
WangXinyan940 Jul 6, 2022
c417711
update: frontend docs
Jul 10, 2022
5b4d82e
fix typo: fix typo in dev_guide
Jul 11, 2022
c491266
Merge pull request #53 from Roy-Kid/roy/xmlholder
WangXinyan940 Jul 11, 2022
3a29cad
Temprary commit before test
KuangYu Aug 10, 2022
38bbe5e
Debugging new api for admp
KuangYu Aug 11, 2022
fc9ad93
peg system benchmark checked out
KuangYu Aug 11, 2022
d39df34
Update examples
KuangYu Aug 12, 2022
1d60ed3
Fix bug: was not reading the XZ component of multipole
KuangYu Aug 12, 2022
34ab999
remove jit for nbl.update
KuangYu Aug 12, 2022
ae0875c
Modify setup.py
KuangYu Aug 12, 2022
eeb1b69
Require jax version smaller than 0.3.16, which is incompatible with
KuangYu Aug 12, 2022
77ea11c
Update demo notebook in water example
KuangYu Aug 13, 2022
4b6626e
Merge pull request #56 from KuangYu/devel
KuangYu Aug 13, 2022
73d6984
Update documents and the notebook in classical FF example
KuangYu Aug 13, 2022
fcd321c
Merge branch 'devel' of github.com:deepmodeling/DMFF into devel
KuangYu Aug 13, 2022
9118aaa
Merge pull request #57 from KuangYu/devel
KuangYu Aug 13, 2022
0ce4406
Fix bug in nblist module
KuangYu Aug 23, 2022
4932660
change allocate
KuangYu Aug 23, 2022
8d5c30e
Merge pull request #58 from KuangYu/devel
KuangYu Aug 23, 2022
52d8830
fix: fix nblist jit-related bug
Aug 23, 2022
149d6ca
fix: test sequence in test_nblist.py
Sep 1, 2022
c749bab
Fix bug in Slater ISA short range interaction
KuangYu Sep 3, 2022
a675599
update: all tests pass
Sep 3, 2022
c2fac3b
clean up code
Sep 3, 2022
0e109db
Merge remote-tracking branch 'origin/feat-covmap' into feat-nbfix
WangXinyan940 Sep 3, 2022
c5ae843
Merge pull request #60 from Roy-Kid/feat-covmap
Ericwang6 Sep 4, 2022
5f26f78
feat: brutal Freud-based nblist
Sep 18, 2022
2e15a57
Bug fix in parsing internal xmls (#62)
Ericwang6 Sep 19, 2022
63ce8ae
settle down dependencies version
Sep 19, 2022
b7cb0f2
merge Roy FreudNeighborList with Yingzes
Sep 19, 2022
8d02b9a
rename `nmax` to `capacity_multiplier`(consistent with jax_md); remov…
Sep 19, 2022
0fe3c73
add freud-analysis require
Sep 19, 2022
1b019c2
Merge pull request #63 from Roy-Kid/feat-covmap
Roy-Kid Sep 19, 2022
399a524
remove overflow judgement so update in nblist can be jitted
KuangYu Sep 19, 2022
67615aa
:Merge branch 'devel' into feat-covmap
KuangYu Sep 19, 2022
39146f2
Adapt examples to new nblist api
KuangYu Sep 19, 2022
a4e27c0
Merge pull request #64 from KuangYu/feat-covmap
KuangYu Sep 19, 2022
6137072
add md_ipi in examples to run classical MD for bulk water
LanYang430 Sep 25, 2022
6dcf2ea
modified md_ipi
LanYang430 Sep 29, 2022
eefc4cc
modified md_ipi
LanYang430 Sep 29, 2022
fd799d7
new
LanYang430 Sep 29, 2022
fe9b16e
new
LanYang430 Sep 29, 2022
780568e
new_2
LanYang430 Sep 29, 2022
46b1c08
Merge pull request #65 from Humourist/feat-covmap
KuangYu Sep 29, 2022
d1fe0e5
Merge remote-tracking branch 'origin/feat-covmap' into feat-nbfix
WangXinyan940 Oct 12, 2022
26dc954
Merge pull request #59 from Roy-Kid/devel
Roy-Kid Oct 12, 2022
f17ae2b
Merge remote-tracking branch 'origin/feat-covmap' into feat-nbfix
WangXinyan940 Oct 12, 2022
9348487
Refactor: decompose api.py and generators to separate files
WangXinyan940 Oct 13, 2022
1ef38e2
Merge branch 'devel' into feat-nbfix
WangXinyan940 Oct 13, 2022
66562ab
feat(MBAR): Add differentiable MBAR impl
WangXinyan940 Oct 14, 2022
1b6473e
bug fix: record the name of matched template
WangXinyan940 Oct 19, 2022
13aa8cb
Add unit test for MBAR estimator
WangXinyan940 Oct 26, 2022
41bc501
feat: Estimate free energy of an extra state
WangXinyan940 Oct 26, 2022
c30ca0d
set specific version no. for jax in installation guide
WangXinyan940 Nov 2, 2022
e597d6b
Update MBAR Estimator API
WangXinyan940 Nov 4, 2022
97583c5
Update unit test for the latest MBAR API
WangXinyan940 Nov 4, 2022
e7f4ed7
Update github workflow to support new unit test.
WangXinyan940 Nov 4, 2022
bcd1390
Update requirement of pymbar
WangXinyan940 Nov 4, 2022
15dbdbd
fix package including problem
WangXinyan940 Nov 4, 2022
20f0d3f
Change cell vector to prevent numerical problem (ceil(12.0 / 1.0) wou…
WangXinyan940 Nov 4, 2022
dd50e16
Update settings.py
WangXinyan940 Nov 4, 2022
7609886
remove unused imports
Ericwang6 Nov 5, 2022
a1f5fbe
fix "==" in requirements
Ericwang6 Nov 5, 2022
55452aa
update (gitignore): hmtff cache
Ericwang6 Nov 6, 2022
75ecc84
update (api): docstring in createPotential
Ericwang6 Nov 6, 2022
fa5ca20
Update optax transforms for force field parameters
WangXinyan940 Nov 7, 2022
d5a1aea
Update __init__.py
WangXinyan940 Nov 7, 2022
5d77215
Merge branch 'feat-nbfix' into feat-optimizer
WangXinyan940 Nov 7, 2022
1bf8710
Update LJ jax force API in LennardJonesForce generator
WangXinyan940 Nov 7, 2022
932d4ef
Merge branch 'feat-nbfix' into feat-optimizer
WangXinyan940 Nov 7, 2022
b75682e
Let energy function in TargetState use the whole trajectory instead o…
WangXinyan940 Nov 7, 2022
585ba81
Update MBAR UT to fit API change
WangXinyan940 Nov 7, 2022
adcfcc2
Merge branch 'feat-nbfix' into feat-optimizer
WangXinyan940 Nov 7, 2022
fab194f
Remove the using of numpy to make free energy & weight differentiable
WangXinyan940 Nov 7, 2022
8358fc7
Merge branch 'feat-nbfix' into feat-optimizer
WangXinyan940 Nov 7, 2022
d1e6d79
Update __init__.py
WangXinyan940 Nov 7, 2022
6ca6664
Set precision again in dmff.mbar after importing pymbar
WangXinyan940 Nov 7, 2022
9e0bc28
Merge branch 'feat-nbfix' into feat-optimizer
WangXinyan940 Nov 7, 2022
e4e6b7a
Increase MBAR numerical stability
WangXinyan940 Nov 7, 2022
6eedcfe
Merge branch 'feat-nbfix' into feat-optimizer
WangXinyan940 Nov 7, 2022
b2eb539
Update an example of MBAR-based optimization
WangXinyan940 Nov 7, 2022
5c45204
Correct the typo in document
WangXinyan940 Nov 9, 2022
e0cc8fc
Add openmm sample state.
WangXinyan940 Nov 14, 2022
2e7adaa
Update genOptimizer to support multiple optimizers
WangXinyan940 Nov 14, 2022
f3e220b
Set default pressure to be 0 (NVT)
WangXinyan940 Nov 14, 2022
8395f68
Merge pull request #66 from deepmodeling/feat-nbfix
WangXinyan940 Nov 14, 2022
5d2aa5f
Add Gitee_mirror (#67)
AnguseZhang Nov 15, 2022
3bc59fd
Fix Mirror CI/CD (#68)
AnguseZhang Nov 15, 2022
6193968
Add tutorial_utils for demo usage (#69)
WangXinyan940 Nov 15, 2022
533981f
Hook jax force to generator for intra potentials
WangXinyan940 Nov 15, 2022
13df1e0
Remove the requirement of mdtraj
WangXinyan940 Nov 15, 2022
9fcd2fc
Let NeighborList.update return (N, 3) pairs with colv_map information
WangXinyan940 Nov 22, 2022
d0cd7a3
Update nblist.py
WangXinyan940 Nov 22, 2022
92b35ef
bugfix: correctly recognize if the LJForce card uses type or class
WangXinyan940 Nov 22, 2022
a48cd26
Merge pull request #70 from deepmodeling/bugfix-nbfix
Roy-Kid Nov 22, 2022
8acdaf8
SMIRKS-based typing scheme (#72)
Ericwang6 Nov 29, 2022
ee7ae2a
Hotfix in BCC parametrization (#73)
Ericwang6 Nov 29, 2022
f13a90a
Merge branch 'devel' of https://github.com/deepmodeling/DMFF into devel
Ericwang6 Nov 30, 2022
a55874c
hotfix: incorrect bcc matching
Ericwang6 Nov 30, 2022
46d9ef1
Save covalent map to Potential object & make energy function generato…
WangXinyan940 Dec 1, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .github/workflows/mirror_gitee.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Mirror to Gitee Repo

on: [ push, delete, create ]

# Ensures that only one mirror task will run at a time.
concurrency:
group: git-mirror

jobs:
git-mirror:
if: github.repository_owner == 'deepmodeling'
runs-on: ubuntu-latest
steps:
- uses: wearerequired/git-mirror-action@v1
env:
SSH_PRIVATE_KEY: ${{ secrets.SYNC_GITEE_PRIVATE_KEY }}
with:
source-repo: "https://github.com/deepmodeling/dmff.git"
destination-repo: "git@gitee.com:deepmodeling/DMFF.git"
3 changes: 2 additions & 1 deletion .github/workflows/ut.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ jobs:
$CONDA/bin/conda update -n base -c defaults conda
conda install pip
conda update pip
conda install numpy openmm pytest -c conda-forge
conda install numpy openmm pytest rdkit biopandas openbabel -c conda-forge
pip install jax jax_md
pip install mdtraj==1.9.7 pymbar==4.0.1
- name: Install DMFF
run: |
source $CONDA/bin/activate
Expand Down
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -781,4 +781,7 @@ FodyWeavers.xsd
*.acpype/

*/_date.py
*/_version.py
*/_version.py

# hmtff cache
*.hmtff/
26 changes: 24 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
# DMFF

[![doi:10.26434/chemrxiv-2022-2c7gv](https://img.shields.io/badge/DOI-10.26434%2Fchemrxiv--2022--2c7gv-blue)](https://doi.org/10.26434/chemrxiv-2022-2c7gv)

## About DMFF

**DMFF** (**D**ifferentiable **M**olecular **F**orce **F**ield) is a Jax-based python package that provides a full differentiable implementation of molecular force field models. This project aims to establish an extensible codebase to minimize the efforts in force field parameterization, and to ease the force and virial tensor evaluations for advanced complicated potentials (e.g., polarizable models with geometry-dependent atomic parameters). Currently, this project mainly focuses on the molecular systems such as: water, biological macromolecules (peptides, proteins, nucleic acids), organic polymers, and small organic molecules (organic electrolyte, drug-like molecules) etc. We support both the conventional point charge models (OPLS and AMBER like) and multipolar polarizable models (AMOEBA and MPID like). The entire project is backed by the XLA technique in JAX, thus can be "jitted" and run in GPU devices much more efficiently compared to normal python codes.

The behavior of organic molecular systems (e.g., protein folding, polymer structure, etc.) is often determined by a complex effect of many different types of interactions. The existing organic molecular force fields are mainly empirically fitted and their performance relies heavily on error cancellation. Therefore, the transferability and the prediction power of these force fields are insufficient. For new molecules, the parameter fitting process requires essential manual intervention and can be quite cumbersome. In order to automate the parametrization process and increase the robustness of the model, it is necessary to apply modern AI techniques in conventional force field development. This project serves for this purpose by utilizing the automatic differentiable programming technique to develop a codebase, which allows a more convenient incorporation of modern AI optimization techniques. It also helps the realization of many exciting functions including (but not limited to): hybrid machine learning/force field models and parameter optimization based on trajectory.

### License and credits

The project DMFF is licensed under [GNU LGPL v3.0](LICENSE). If you use this code in any future publications, please cite this using `Wang X, Li J, Yang L, Chen F, Wang Y, Chang J, et al. DMFF: An Open-Source Automatic
Differentiable Platform for Molecular Force Field
Development and Molecular Dynamics
Simulation. ChemRxiv. Cambridge: Cambridge Open Engage; 2022; This content is a preprint and has not been peer-reviewed.`

## User Guide

+ [1. Introduction](docs/user_guide/introduction.md)
Expand All @@ -18,9 +29,20 @@ The behavior of organic molecular systems (e.g., protein folding, polymer struct
+ [3. Coding conventions](docs/dev_guide/convention.md)
+ [4. Document writing](docs/dev_guide/write_docs.md)

## Modules
+ [1. ADMP](docs/modules/admp.md)
## Code Structure

The code is organized as follows:

+ `examples`: demos presented in Jupyter Notebook.
+ `docs`: documentation.
+ `package`: files for constructing packages or images, such as conda recipe and docker files.
+ `tests`: unit tests.
+ `dmff`: DMFF python codes
+ `dmff/admp`: source code of automatic differentiable multipolar polarizable (ADMP) force field module.
+ `dmff/classical`: source code of classical force field module.
+ `dmff/common`: source code of common functions, such as neighbor list.
+ `dmff/generators`: source code of force generators.
+ `dmff/sgnn`: source of subgragh neural network force field model.

## Support and Contribution

Expand Down
5 changes: 3 additions & 2 deletions dmff/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .settings import *
from .common.nblist import NeighborList
from .api import Hamiltonian
from .common.nblist import NeighborList, NeighborListFreud
from .api import Hamiltonian
from .generators import *
22 changes: 11 additions & 11 deletions dmff/admp/disp_pme.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ class ADMPDispPmeForce:
The so called "environment paramters" means parameters that do not need to be differentiable
'''

def __init__(self, box, covalent_map, rc, ethresh, pmax, lpme=True):
self.covalent_map = covalent_map
def __init__(self, box, rc, ethresh, pmax, lpme=True):

self.rc = rc
self.ethresh = ethresh
self.pmax = pmax
Expand All @@ -44,7 +44,7 @@ def __init__(self, box, covalent_map, rc, ethresh, pmax, lpme=True):
def generate_get_energy(self):
def get_energy(positions, box, pairs, c_list, mScales):
return energy_disp_pme(positions, box, pairs,
c_list, mScales, self.covalent_map,
c_list, mScales,
self.kappa, self.K1, self.K2, self.K3, self.pmax,
self.d6_recip, self.d8_recip, self.d10_recip, lpme=self.lpme)
return get_energy
Expand Down Expand Up @@ -78,7 +78,7 @@ def refresh_calculators(self):


def energy_disp_pme(positions, box, pairs,
c_list, mScales, covalent_map,
c_list, mScales,
kappa, K1, K2, K3, pmax,
recip_fn6, recip_fn8, recip_fn10, lpme=True):
'''
Expand All @@ -90,7 +90,7 @@ def energy_disp_pme(positions, box, pairs,
box:
3 * 3: box, axes arranged in row
pairs:
Np * 2: interacting pair indices
Np * 3: interacting pair indices and topology distance
c_list:
Na * (pmax-4)/2: atomic dispersion coefficients
mScales:
Expand All @@ -115,7 +115,7 @@ def energy_disp_pme(positions, box, pairs,
if lpme is False:
kappa = 0

ene_real = disp_pme_real(positions, box, pairs, c_list, mScales, covalent_map, kappa, pmax)
ene_real = disp_pme_real(positions, box, pairs, c_list, mScales, kappa, pmax)

if lpme:
ene_recip = recip_fn6(positions, box, c_list[:, 0, jnp.newaxis])
Expand All @@ -132,7 +132,7 @@ def energy_disp_pme(positions, box, pairs,

def disp_pme_real(positions, box, pairs,
c_list,
mScales, covalent_map,
mScales,
kappa, pmax):
'''
This function calculates the dispersion real space energy
Expand All @@ -144,7 +144,7 @@ def disp_pme_real(positions, box, pairs,
box:
3 * 3: box, axes arranged in row
pairs:
Np * 2: interacting pair indices
Np * 3: interacting pair indices and topology distance
c_list:
Na * (pmax-4)/2: atomic dispersion coefficients
mScales:
Expand All @@ -162,16 +162,16 @@ def disp_pme_real(positions, box, pairs,

# expand pairwise parameters
# pairs = pairs[pairs[:, 0] < pairs[:, 1]]
pairs = regularize_pairs(pairs)
pairs = pairs.at[:, :2].set(regularize_pairs(pairs[:, :2]))

box_inv = jnp.linalg.inv(box)

ri = distribute_v3(positions, pairs[:, 0])
rj = distribute_v3(positions, pairs[:, 1])
nbonds = covalent_map[pairs[:, 0], pairs[:, 1]]
nbonds = pairs[:, 2]
mscales = distribute_scalar(mScales, nbonds-1)

buffer_scales = pair_buffer_scales(pairs)
buffer_scales = pair_buffer_scales(pairs[:, :2])
mscales = mscales * buffer_scales

ci = distribute_dispcoeff(c_list, pairs[:, 0])
Expand Down
13 changes: 5 additions & 8 deletions dmff/admp/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def distribute_dispcoeff(c_list, index):
def distribute_matrix(multipoles,index1,index2):
return multipoles[index1,index2]

def generate_pairwise_interaction(pair_int_kernel, covalent_map, static_args):
def generate_pairwise_interaction(pair_int_kernel, static_args):
'''
This is a calculator generator for pairwise interaction

Expand All @@ -53,9 +53,6 @@ def generate_pairwise_interaction(pair_int_kernel, covalent_map, static_args):
function type (dr, m, p1i, p1j, p2i, p2j) -> energy : the vectorized kernel function,
dr is the distance, m is the topological scaling factor, p1i, p1j, p2i, p2j are pairwise parameters

covalent_map:
Na * Na, int: the covalent_map matrix that marks the topological distances between atoms

static_args:
dict: a dictionary that stores all static global parameters (such as lmax, kappa, etc)

Expand All @@ -67,13 +64,14 @@ def generate_pairwise_interaction(pair_int_kernel, covalent_map, static_args):
'''

def pair_int(positions, box, pairs, mScales, *atomic_params):
pairs = regularize_pairs(pairs)
# pairs = regularize_pairs(pairs)
pairs = pairs.at[:, :2].set(regularize_pairs(pairs[:, :2]))

ri = distribute_v3(positions, pairs[:, 0])
rj = distribute_v3(positions, pairs[:, 1])
# ri = positions[pairs[:, 0]]
# rj = positions[pairs[:, 1]]
nbonds = covalent_map[pairs[:, 0], pairs[:, 1]]
nbonds = pairs[:, 3]
mscales = distribute_scalar(mScales, nbonds-1)

buffer_scales = pair_buffer_scales(pairs)
Expand Down Expand Up @@ -114,8 +112,7 @@ def TT_damping_qq_c6_kernel(dr, m, ai, aj, bi, bj, qi, qj, ci, cj):
exp_br = jnp.exp(-br)
f = 2625.5 * a * exp_br \
+ (-2625.5) * exp_br * (1+br) * q / r \
+ exp_br*(1+br+br2/2+br3/6+br4/24+br5/120+br6/720) * c / dr**6

+ exp_br*(1+br+br2/2+br3/6+br4/24+br5/120+br6/720) * c / dr**6
return f * m


Expand Down
32 changes: 14 additions & 18 deletions dmff/admp/pme.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class ADMPPmeForce:
The so called "environment paramters" means parameters that do not need to be differentiable
'''

def __init__(self, box, axis_type, axis_indices, covalent_map, rc, ethresh, lmax, lpol=False, lpme=True, steps_pol=None):
def __init__(self, box, axis_type, axis_indices, rc, ethresh, lmax, lpol=False, lpme=True, steps_pol=None):
'''
Initialize the ADMPPmeForce calculator.

Expand All @@ -51,8 +51,6 @@ def __init__(self, box, axis_type, axis_indices, covalent_map, rc, ethresh, lmax
(3, 3) float, box size in row
axis_type:
(na,) int, types of local axis (bisector, z-then-x etc.)
covalent_map:
(na, na) int, covalent map matrix, labels the topological distances between atoms
rc:
float: cutoff distance
ethresh:
Expand Down Expand Up @@ -91,10 +89,10 @@ def __init__(self, box, axis_type, axis_indices, covalent_map, rc, ethresh, lmax
self.K2 = K2
self.K3 = K3
self.pme_order = 6
self.covalent_map = covalent_map
self.lpol = lpol
self.steps_pol = steps_pol
self.n_atoms = int(covalent_map.shape[0]) # len(axis_type)
# self.n_atoms = int(covalent_map.shape[0]) # len(axis_type)
self.n_atoms = len(axis_type)

# setup calculators
self.refresh_calculators()
Expand All @@ -107,7 +105,7 @@ def generate_get_energy(self):
def get_energy(positions, box, pairs, Q_local, mScales):
return energy_pme(positions, box, pairs,
Q_local, None, None, None,
mScales, None, None, self.covalent_map,
mScales, None, None,
self.construct_local_frames, self.pme_recip,
self.kappa, self.K1, self.K2, self.K3, self.lmax, False, lpme=self.lpme)
return get_energy
Expand All @@ -116,7 +114,7 @@ def get_energy(positions, box, pairs, Q_local, mScales):
def energy_fn(positions, box, pairs, Q_local, Uind_global, pol, tholes, mScales, pScales, dScales):
return energy_pme(positions, box, pairs,
Q_local, Uind_global, pol, tholes,
mScales, pScales, dScales, self.covalent_map,
mScales, pScales, dScales,
self.construct_local_frames, self.pme_recip,
self.kappa, self.K1, self.K2, self.K3, self.lmax, True, lpme=self.lpme)
self.energy_fn = energy_fn
Expand Down Expand Up @@ -284,7 +282,7 @@ def setup_ewald_parameters(
# @jit_condition(static_argnums=())
def energy_pme(positions, box, pairs,
Q_local, Uind_global, pol, tholes,
mScales, pScales, dScales, covalent_map,
mScales, pScales, dScales,
construct_local_frame_fn, pme_recip_fn, kappa, K1, K2, K3, lmax, lpol, lpme=True):
'''
This is the top-level wrapper for multipole PME
Expand All @@ -306,7 +304,7 @@ def energy_pme(positions, box, pairs,
(Nexcl,): multipole-multipole interaction exclusion scalings: 1-2, 1-3 ...
for permanent-permanent, permanent-induced, induced-induced interactions
pairs:
Np * 2: interacting pair indices
Np * 3: interacting pair indices and topology distance
covalent_map:
Na * Na: topological distances between atoms, if i, j are topologically distant, then covalent_map[i, j] == 0
construct_local_frame_fn:
Expand Down Expand Up @@ -353,26 +351,24 @@ def energy_pme(positions, box, pairs,

if lpol:
ene_real = pme_real(positions, box, pairs, Q_global, U_ind, pol, tholes,
mScales, pScales, dScales, covalent_map, kappa, lmax, True)
mScales, pScales, dScales, kappa, lmax, True)
else:
ene_real = pme_real(positions, box, pairs, Q_global, None, None, None,
mScales, None, None, covalent_map, kappa, lmax, False)
mScales, None, None, kappa, lmax, False)

if lpme:
ene_recip = pme_recip_fn(positions, box, Q_global_tot)
ene_self = pme_self(Q_global_tot, kappa, lmax)

if lpol:
ene_self += pol_penalty(U_ind, pol)

return ene_real + ene_recip + ene_self

else:
if lpol:
ene_self = pol_penalty(U_ind, pol)
else:
ene_self = 0.0

return ene_real + ene_self


Expand Down Expand Up @@ -747,7 +743,7 @@ def pme_real_kernel(dr, qiQI, qiQJ, qiUindI, qiUindJ, thole1, thole2, dmp, mscal
# @jit_condition(static_argnums=(7))
def pme_real(positions, box, pairs,
Q_global, Uind_global, pol, tholes,
mScales, pScales, dScales, covalent_map,
mScales, pScales, dScales,
kappa, lmax, lpol):
'''
This is the real space PME calculate function
Expand All @@ -763,7 +759,7 @@ def pme_real(positions, box, pairs,
box:
3 * 3: box, axes arranged in row
pairs:
Np * 2: interacting pair indices
Np * 3: interacting pair indices and topology distance
Q_global:
Na * (l+1)**2: harmonics multipoles of each atom, in global frame
Uind_global:
Expand All @@ -786,14 +782,14 @@ def pme_real(positions, box, pairs,
Output:
ene: pme realspace energy
'''
pairs = regularize_pairs(pairs)
buffer_scales = pair_buffer_scales(pairs)
pairs = pairs.at[:, :2].set(regularize_pairs(pairs[:, :2]))
buffer_scales = pair_buffer_scales(pairs[:, :2])
box_inv = jnp.linalg.inv(box)
r1 = distribute_v3(positions, pairs[:, 0])
r2 = distribute_v3(positions, pairs[:, 1])
Q_extendi = distribute_multipoles(Q_global, pairs[:, 0])
Q_extendj = distribute_multipoles(Q_global, pairs[:, 1])
nbonds = distribute_matrix(covalent_map,pairs[:, 0],pairs[:, 1])
nbonds = pairs[:, 2]
#nbonds = covalent_map[pairs[:, 0], pairs[:, 1]]
indices = nbonds-1
mscales = distribute_scalar(mScales, indices)
Expand Down
Loading