diff --git a/dmff/api.py b/dmff/api.py index dc72e49ca..7c2e11a7d 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -200,6 +200,132 @@ def getJaxPotential(self): # register all parsers # app.forcefield.parsers["ADMPDispForce"] = ADMPDispGenerator.parseElement +class ADMPDispGenerator: + def __init__(self, ff): + + self.name = "ADMPDispForce" + self.ff = ff + self.fftree = ff.fftree + self.paramtree = ff.paramtree + + # default params + self._jaxPotential = None + self.types = [] + self.ethresh = 5e-4 + self.pmax = 10 + + def extract(self): + + mScales = [self.fftree.get_attribs(f'{self.name}', f'mScale1{i}')[0] for i in range(2, 7)] + mScales.append(1.0) + self.paramtree[self.name] = {} + self.paramtree[self.name]['mScales'] = jnp.array(mScales) + + ABQC = self.fftree.get_attribs(f'{self.name}/Atom', ['A', 'B', 'Q', 'C6', 'C8', 'C10']) + + ABQC = np.array(ABQC) + A = ABQC[:, 0] + B = ABQC[:, 1] + Q = ABQC[:, 2] + C6 = ABQC[:, 3] + C8 = ABQC[:, 4] + C10 = ABQC[:, 5] + + self.paramtree[self.name]['A'] = jnp.array(A) + self.paramtree[self.name]['B'] = jnp.array(B) + self.paramtree[self.name]['Q'] = jnp.array(Q) + self.paramtree[self.name]['C6'] = jnp.array(C6) + self.paramtree[self.name]['C8'] = jnp.array(C8) + self.paramtree[self.name]['C10'] = jnp.array(C10) + + atomTypes = self.fftree.get_attribs(f'{self.name}/Atom', f'type') + self.atomTypes = np.array(atomTypes, dtype=int).astype(str) + + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, + args): + + methodMap = { + app.CutoffPeriodic: "CutoffPeriodic", + app.NoCutoff: "NoCutoff", + app.PME: "PME", + } + if nonbondedMethod not in methodMap: + raise ValueError("Illegal nonbonded method for ADMPDispForce") + if nonbondedMethod is app.CutoffPeriodic: + self.lpme = False + else: + self.lpme = True + + n_atoms = len(data.atoms) + # build index map + map_atomtype = np.zeros(n_atoms, dtype=int) + for i in range(n_atoms): + atype = data.atomType[data.atoms[i]] + map_atomtype[i] = np.where(self.atomTypes == atype)[0][0] + self.map_atomtype = map_atomtype + # build covalent map + covalent_map = build_covalent_map(data, 6) + # here box is only used to setup ewald parameters, no need to be differentiable + a, b, c = system.getDefaultPeriodicBoxVectors() + box = jnp.array([a._value, b._value, c._value]) * 10 + # get the admp calculator + rc = nonbondedCutoff.value_in_unit(unit.angstrom) + + # get calculator + if "ethresh" in args: + self.ethresh = args["ethresh"] + + Force_DispPME = ADMPDispPmeForce(box, + covalent_map, + rc, + self.ethresh, + self.pmax, + lpme=self.lpme) + self.disp_pme_force = Force_DispPME + pot_fn_lr = Force_DispPME.get_energy + pot_fn_sr = generate_pairwise_interaction(TT_damping_qq_c6_kernel, + covalent_map, + static_args={}) + + def potential_fn(positions, box, pairs, params): + mScales = params["mScales"] + a_list = (params["A"][map_atomtype] / 2625.5 + ) # kj/mol to au, as expected by TT_damping kernel + b_list = params["B"][map_atomtype] * 0.0529177249 # nm^-1 to au + q_list = params["Q"][map_atomtype] + c6_list = jnp.sqrt(params["C6"][map_atomtype] * 1e6) + c8_list = jnp.sqrt(params["C8"][map_atomtype] * 1e8) + c10_list = jnp.sqrt(params["C10"][map_atomtype] * 1e10) + c_list = jnp.vstack((c6_list, c8_list, c10_list)) + + E_sr = pot_fn_sr(positions, box, pairs, mScales, a_list, b_list, + q_list, c_list[0]) + E_lr = pot_fn_lr(positions, box, pairs, c_list.T, mScales) + return E_sr - E_lr + + self._jaxPotential = potential_fn + # self._top_data = data + + def overwrite(self): + + self.fftree.set_attrib(f'{self.name}', 'mScale12', [self.paramtree[self.name]['mScales'][0]]) + self.fftree.set_attrib(f'{self.name}', 'mScale13', [self.paramtree[self.name]['mScales'][1]]) + self.fftree.set_attrib(f'{self.name}', 'mScale14', [self.paramtree[self.name]['mScales'][2]]) + self.fftree.set_attrib(f'{self.name}', 'mScale15', [self.paramtree[self.name]['mScales'][3]]) + self.fftree.set_attrib(f'{self.name}', 'mScale16', [self.paramtree[self.name]['mScales'][4]]) + + self.fftree.set_attrib(f'{self.name}/Atom', 'A', [self.paramtree[self.name]['A']]) + self.fftree.set_attrib(f'{self.name}/Atom', 'B', [self.paramtree[self.name]['B']]) + self.fftree.set_attrib(f'{self.name}/Atom', 'Q', [self.paramtree[self.name]['Q']]) + self.fftree.set_attrib(f'{self.name}/Atom', 'C6', [self.paramtree[self.name]['C6']]) + self.fftree.set_attrib(f'{self.name}/Atom', 'C8', [self.paramtree[self.name]['C8']]) + self.fftree.set_attrib(f'{self.name}/Atom', 'C10', [self.paramtree[self.name]['C10']]) + + + def getJaxPotential(self): + return self._jaxPotential + +jaxGenerators['ADMPDispForce'] = ADMPDispGenerator class ADMPDispPmeGenerator: r""" @@ -299,6 +425,114 @@ def getJaxPotential(self): # register all parsers # app.forcefield.parsers["ADMPDispPmeForce"] = ADMPDispPmeGenerator.parseElement +class ADMPDispPmeGenerator: + r""" + This one computes the undamped C6/C8/C10 interactions + u = \sum_{ij} c6/r^6 + c8/r^8 + c10/r^10 + """ + def __init__(self, ff): + self.ff = ff + self.fftree = ff.fftree + self.paramtree = ff.paramtree + + self.params = {"C6": [], "C8": [], "C10": []} + self._jaxPotential = None + self.atomTypes = None + self.ethresh = 5e-4 + self.pmax = 10 + self.name = "ADMPDispPmeForce" + + def extract(self): + + mScales = [self.fftree.get_attribs(f'{self.name}', f'mScale1{i}')[0] for i in range(2, 7)] + mScales.append(1.0) + + self.paramtree[self.name] = {} + self.paramtree[self.name]['mScales'] = jnp.array(mScales) + + C6 = self.fftree.get_attribs(f'{self.name}/Atom', f'C6') + C8 = self.fftree.get_attribs(f'{self.name}/Atom', f'C8') + C10 = self.fftree.get_attribs(f'{self.name}/Atom', f'C10') + + self.paramtree[self.name]['C6'] = jnp.array(C6) + self.paramtree[self.name]['C8'] = jnp.array(C8) + self.paramtree[self.name]['C10'] = jnp.array(C10) + + atomTypes = self.fftree.get_attribs(f'{self.name}/Atom', f'type') + self.atomTypes = np.array(atomTypes, dtype=int).astype(str) + + def overwrite(self): + + self.fftree.set_attrib(f'{self.name}', 'mScale12', [self.paramtree[self.name]['mScales'][0]]) + self.fftree.set_attrib(f'{self.name}', 'mScale13', [self.paramtree[self.name]['mScales'][1]]) + self.fftree.set_attrib(f'{self.name}', 'mScale14', [self.paramtree[self.name]['mScales'][2]]) + self.fftree.set_attrib(f'{self.name}', 'mScale15', [self.paramtree[self.name]['mScales'][3]]) + self.fftree.set_attrib(f'{self.name}', 'mScale16', [self.paramtree[self.name]['mScales'][4]]) + + self.fftree.set_attrib(f'{self.name}/Atom', 'C6', self.paramtree[self.name]['C6']) + self.fftree.set_attrib(f'{self.name}/Atom', 'C8', self.paramtree[self.name]['C8']) + self.fftree.set_attrib(f'{self.name}/Atom', 'C10', self.paramtree[self.name]['C10']) + + + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, + args): + methodMap = { + app.CutoffPeriodic: "CutoffPeriodic", + app.NoCutoff: "NoCutoff", + app.PME: "PME", + } + if nonbondedMethod not in methodMap: + raise ValueError("Illegal nonbonded method for ADMPDispPmeForce") + if nonbondedMethod is app.CutoffPeriodic: + self.lpme = False + else: + self.lpme = True + + n_atoms = len(data.atoms) + # build index map + map_atomtype = np.zeros(n_atoms, dtype=int) + for i in range(n_atoms): + atype = data.atomType[data.atoms[i]] + map_atomtype[i] = np.where(self.atomTypes == atype)[0][0] + self.map_atomtype = map_atomtype + # build covalent map + covalent_map = build_covalent_map(data, 6) + + # here box is only used to setup ewald parameters, no need to be differentiable + a, b, c = system.getDefaultPeriodicBoxVectors() + box = jnp.array([a._value, b._value, c._value]) * 10 + # get the admp calculator + rc = nonbondedCutoff.value_in_unit(unit.angstrom) + + # get calculator + if "ethresh" in args: + self.ethresh = args["ethresh"] + + disp_force = ADMPDispPmeForce(box, covalent_map, rc, self.ethresh, + self.pmax, self.lpme) + self.disp_force = disp_force + pot_fn_lr = disp_force.get_energy + + def potential_fn(positions, box, pairs, params): + mScales = params["mScales"] + C6_list = params["C6"][map_atomtype] * 1e6 # to kj/mol * A**6 + C8_list = params["C8"][map_atomtype] * 1e8 + C10_list = params["C10"][map_atomtype] * 1e10 + c6_list = jnp.sqrt(C6_list) + c8_list = jnp.sqrt(C8_list) + c10_list = jnp.sqrt(C10_list) + c_list = jnp.vstack((c6_list, c8_list, c10_list)) + E_lr = pot_fn_lr(positions, box, pairs, c_list.T, mScales) + return -E_lr + + self._jaxPotential = potential_fn + # self._top_data = data + + def getJaxPotential(self): + return self._jaxPotential + +jaxGenerators['ADMPDispPmeForce'] = ADMPDispPmeGenerator + class QqTtDampingGenerator: r""" @@ -1025,9 +1259,454 @@ def potential_fn(positions, box, pairs, params): def getJaxPotential(self): return self._jaxPotential +class ADMPPmeGenerator: + + def __init__(self, ff): -# app.forcefield.parsers["ADMPPmeForce"] = ADMPPmeGenerator.parseElement + self.name = 'ADMPPmeForce' + self.ff = ff + self.fftree = ff.fftree + self.paramtree = ff.paramtree + + # default params + self._jaxPotential = None + self.types = [] + self.ethresh = 5e-4 + self.step_pol = None + self.lpol = False + self.ref_dip = "" + + def extract(self): + + self.lmax = self.fftree.get_attribs(f'{self.name}', 'lmax')[0] # return [lmax] + + mScales = [self.fftree.get_attribs(f'{self.name}', f'mScale1{i}')[0] for i in range(2, 7)] + pScales = [self.fftree.get_attribs(f'{self.name}', f'pScale1{i}')[0] for i in range(2, 7)] + dScales = [self.fftree.get_attribs(f'{self.name}', f'dScale1{i}')[0] for i in range(2, 7)] + + # make sure the last digit is 1.0 + mScales.append(1.0) + pScales.append(1.0) + dScales.append(1.0) + + self.paramtree[self.name] = {} + self.paramtree[self.name]['mScales'] = jnp.array(mScales) + self.paramtree[self.name]['pScales'] = jnp.array(pScales) + self.paramtree[self.name]['dScales'] = jnp.array(dScales) + + # check if polarize + polarize = self.fftree.get_nodes(f'{self.name}/Polarize') + if polarize: + self.lpol = True + else: + self.lpol = False + + atomTypes = self.fftree.get_attribs(f'{self.name}/Atom', 'type') + self.atomTypes = np.array(atomTypes, dtype=int).astype(str) + kx = self.fftree.get_attribs(f'{self.name}/Atom', 'kx') + ky = self.fftree.get_attribs(f'{self.name}/Atom', 'ky') + kz = self.fftree.get_attribs(f'{self.name}/Atom', 'kz') + + kx = [ 0 if kx_ is None else int(kx_) for kx_ in kx ] + ky = [ 0 if ky_ is None else int(ky_) for ky_ in ky ] + kz = [ 0 if kz_ is None else int(kz_) for kz_ in kz ] + + # invoke by `self.kStrings["kz"][itype]` + self.kStrings = {} + self.kStrings['kx'] = kx + self.kStrings['ky'] = ky + self.kStrings['kz'] = kz + + c0 = self.fftree.get_attribs(f'{self.name}/Atom', 'c0') + dX = self.fftree.get_attribs(f'{self.name}/Atom', 'dX') + dY = self.fftree.get_attribs(f'{self.name}/Atom', 'dY') + dZ = self.fftree.get_attribs(f'{self.name}/Atom', 'dZ') + qXX = self.fftree.get_attribs(f'{self.name}/Atom', 'qXX') + qYY = self.fftree.get_attribs(f'{self.name}/Atom', 'qYY') + qZZ = self.fftree.get_attribs(f'{self.name}/Atom', 'qZZ') + qXY = self.fftree.get_attribs(f'{self.name}/Atom', 'qXY') + qYZ = self.fftree.get_attribs(f'{self.name}/Atom', 'qYZ') + + # assume that polarize tag match the per atom type + polarizabilityXX = self.fftree.get_attribs(f'{self.name}/Polarize', 'polarizabilityXX') + polarizabilityYY = self.fftree.get_attribs(f'{self.name}/Polarize', 'polarizabilityYY') + polarizabilityZZ = self.fftree.get_attribs(f'{self.name}/Polarize', 'polarizabilityZZ') + thole = self.fftree.get_attribs(f'{self.name}/Polarize', 'thole') + + n_atoms = len(atomTypes) + assert n_atoms == len(polarizabilityXX), "Number of polarizabilityXX does not match number of atoms!" + + # map atom multipole moments + if self.lmax == 0: + n_mtps = 1 + elif self.lmax == 1: + n_mtps = 4 + elif self.lmax == 2: + n_mtps = 10 + Q = np.zeros((n_atoms, n_mtps)) + + # TDDO: unit conversion + Q[:, 0] = c0 + if self.lmax >= 1: + Q[:, 1] = dX + Q[:, 2] = dY + Q[:, 3] = dZ + Q[:, 1:4] *= 10 + if self.lmax >= 2: + Q[:, 4] = qXX + Q[:, 5] = qYY + Q[:, 6] = qZZ + Q[:, 7] = qXY + Q[:, 8] = qYZ + Q[:, 4:9] *= 300 + + # add all differentiable params to self.params + Q_local = convert_cart2harm(Q, self.lmax) + self.paramtree[self.name]["Q_local"] = Q_local + + if self.lpol: + pol = jnp.vstack(( + polarizabilityXX, + polarizabilityYY, + polarizabilityZZ, + )).T + pol = 1000 * jnp.mean(pol, axis=1) + tholes = jnp.array(thole) + self.paramtree[self.name]["pol"] = pol + self.paramtree[self.name]["tholes"] = tholes + else: + pol = None + tholes = None + + def overwrite(self): + + self.fftree.set_attrib(f'{self.name}', 'mScale12', [self.paramtree[self.name]['mScales'][0]]) + self.fftree.set_attrib(f'{self.name}', 'mScale13', [self.paramtree[self.name]['mScales'][1]]) + self.fftree.set_attrib(f'{self.name}', 'mScale14', [self.paramtree[self.name]['mScales'][2]]) + self.fftree.set_attrib(f'{self.name}', 'mScale15', [self.paramtree[self.name]['mScales'][3]]) + self.fftree.set_attrib(f'{self.name}', 'mScale16', [self.paramtree[self.name]['mScales'][4]]) + + self.fftree.set_attrib(f'{self.name}', 'pScale12', [self.paramtree[self.name]['pScales'][0]]) + self.fftree.set_attrib(f'{self.name}', 'pScale13', [self.paramtree[self.name]['pScales'][1]]) + self.fftree.set_attrib(f'{self.name}', 'pScale14', [self.paramtree[self.name]['pScales'][2]]) + self.fftree.set_attrib(f'{self.name}', 'pScale15', [self.paramtree[self.name]['pScales'][3]]) + self.fftree.set_attrib(f'{self.name}', 'pScale16', [self.paramtree[self.name]['pScales'][4]]) + + self.fftree.set_attrib(f'{self.name}', 'dScale12', [self.paramtree[self.name]['dScales'][0]]) + self.fftree.set_attrib(f'{self.name}', 'dScale13', [self.paramtree[self.name]['dScales'][1]]) + self.fftree.set_attrib(f'{self.name}', 'dScale14', [self.paramtree[self.name]['dScales'][2]]) + self.fftree.set_attrib(f'{self.name}', 'dScale15', [self.paramtree[self.name]['dScales'][3]]) + self.fftree.set_attrib(f'{self.name}', 'dScale16', [self.paramtree[self.name]['dScales'][4]]) + + Q_global = convert_harm2cart(self.paramtree[self.name]['Q_local'], self.lmax) + + + self.fftree.set_attrib(f'{self.name}/Atom', 'c0', Q_global[:, 0]) + self.fftree.set_attrib(f'{self.name}/Atom', 'dX', Q_global[:, 1]) + self.fftree.set_attrib(f'{self.name}/Atom', 'dY', Q_global[:, 2]) + self.fftree.set_attrib(f'{self.name}/Atom', 'dZ', Q_global[:, 3]) + self.fftree.set_attrib(f'{self.name}/Atom', 'qXX', Q_global[:, 4]) + self.fftree.set_attrib(f'{self.name}/Atom', 'qYY', Q_global[:, 5]) + self.fftree.set_attrib(f'{self.name}/Atom', 'qZZ', Q_global[:, 6]) + self.fftree.set_attrib(f'{self.name}/Atom', 'qXY', Q_global[:, 7]) + self.fftree.set_attrib(f'{self.name}/Atom', 'qYZ', Q_global[:, 8]) + self.fftree.set_attrib(f'{self.name}/Atom', 'qYZ', Q_global[:, 9]) + + if self.lpol: + # self.paramtree[self.name]['pol']: every element is the mean value of XX YY ZZ + # get the number of polarize element + n_pol = len(self.paramtree[self.name]['pol']) + self.fftree.set_attrib(f'{self.name}/Polarize', 'polarizabilityXX', [self.paramtree[self.name]['pol'][0]] * n_pol) + self.fftree.set_attrib(f'{self.name}/Polarize', 'polarizabilityYY', [self.paramtree[self.name]['pol'][1]] * n_pol) + self.fftree.set_attrib(f'{self.name}/Polarize', 'polarizabilityZZ', [self.paramtree[self.name]['pol'][2]] * n_pol) + self.fftree.set_attrib(f'{self.name}/Polarize', 'thole', self.paramtree[self.name]['tholes']) + + + + + + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, + args): + + methodMap = { + app.CutoffPeriodic: "CutoffPeriodic", + app.NoCutoff: "NoCutoff", + app.PME: "PME", + } + if nonbondedMethod not in methodMap: + raise ValueError("Illegal nonbonded method for ADMPPmeForce") + if nonbondedMethod is app.CutoffPeriodic: + self.lpme = False + else: + self.lpme = True + + n_atoms = len(data.atoms) + map_atomtype = np.zeros(n_atoms, dtype=int) + + for i in range(n_atoms): + atype = data.atomType[data.atoms[i]] # convert str to int to match atomTypes + map_atomtype[i] = np.where(self.atomTypes == atype)[0][0] + self.map_atomtype = map_atomtype + + # here box is only used to setup ewald parameters, no need to be differentiable + a, b, c = system.getDefaultPeriodicBoxVectors() + box = jnp.array([a._value, b._value, c._value]) * 10 + + # get the admp calculator + rc = nonbondedCutoff.value_in_unit(unit.angstrom) + + # build covalent map + covalent_map = build_covalent_map(data, 6) + # build intra-molecule axis + # the following code is the direct transplant of forcefield.py in openmm 7.4.0 + + if self.lmax > 0: + + # setting up axis_indices and axis_type + ZThenX = 0 + Bisector = 1 + ZBisect = 2 + ThreeFold = 3 + ZOnly = 4 # typo fix + NoAxisType = 5 + LastAxisTypeIndex = 6 + + self.axis_types = [] + self.axis_indices = [] + for i_atom in range(n_atoms): + atom = data.atoms[i_atom] + t = data.atomType[atom] + # if t is in type list? + if t in self.atomTypes: + itypes = np.where(self.atomTypes == t)[0] + hit = 0 + # try to assign multipole parameters via only 1-2 connected atoms + for itype in itypes: + if hit != 0: + break + kz = int(self.kStrings["kz"][itype]) + kx = int(self.kStrings["kx"][itype]) + ky = int(self.kStrings["ky"][itype]) + neighbors = np.where(covalent_map[i_atom] == 1)[0] + zaxis = -1 + xaxis = -1 + yaxis = -1 + for z_index in neighbors: + if hit != 0: + break + z_type = int(data.atomType[data.atoms[z_index]]) + if z_type == abs( + kz + ): # find the z atom, start searching for x + for x_index in neighbors: + if x_index == z_index or hit != 0: + continue + x_type = int( + data.atomType[data.atoms[x_index]]) + if x_type == abs( + kx + ): # find the x atom, start searching for y + if ky == 0: + zaxis = z_index + xaxis = x_index + # cannot ditinguish x and z? use the smaller index for z, and the larger index for x + if x_type == z_type and xaxis < zaxis: + swap = z_axis + z_axis = x_axis + x_axis = swap + # otherwise, try to see if we can find an even smaller index for x? + else: + for x_index in neighbors: + x_type1 = int( + data.atomType[ + data. + atoms[x_index]]) + if (x_type1 == abs(kx) and + x_index != z_index + and + x_index < xaxis): + xaxis = x_index + hit = 1 # hit, finish matching + matched_itype = itype + else: + for y_index in neighbors: + if (y_index == z_index + or y_index == x_index + or hit != 0): + continue + y_type = int(data.atomType[ + data.atoms[y_index]]) + if y_type == abs(ky): + zaxis = z_index + xaxis = x_index + yaxis = y_index + hit = 2 + matched_itype = itype + # assign multipole parameters via 1-2 and 1-3 connected atoms + for itype in itypes: + if hit != 0: + break + kz = int(self.kStrings["kz"][itype]) + kx = int(self.kStrings["kx"][itype]) + ky = int(self.kStrings["ky"][itype]) + neighbors_1st = np.where(covalent_map[i_atom] == 1)[0] + neighbors_2nd = np.where(covalent_map[i_atom] == 2)[0] + zaxis = -1 + xaxis = -1 + yaxis = -1 + for z_index in neighbors_1st: + if hit != 0: + break + z_type = int(data.atomType[data.atoms[z_index]]) + if z_type == abs(kz): + for x_index in neighbors_2nd: + if x_index == z_index or hit != 0: + continue + x_type = int( + data.atomType[data.atoms[x_index]]) + # we ask x to be in 2'nd neighbor, and x is z's neighbor + if (x_type == abs(kx) + and covalent_map[z_index, + x_index] == 1): + if ky == 0: + zaxis = z_index + xaxis = x_index + # select smallest x index + for x_index in neighbors_2nd: + x_type1 = int(data.atomType[ + data.atoms[x_index]]) + if (x_type1 == abs(kx) + and x_index != z_index + and + covalent_map[x_index, + z_index] + == 1 + and x_index < xaxis): + xaxis = x_index + hit = 3 + matched_itype = itype + else: + for y_index in neighbors_2nd: + if (y_index == z_index + or y_index == x_index + or hit != 0): + continue + y_type = int(data.atomType[ + data.atoms[y_index]]) + if (y_type == abs(ky) and + covalent_map[y_index, + z_index] + == 1): + zaxis = z_index + xaxis = x_index + yaxis = y_index + hit = 4 + matched_itype = itype + # assign multipole parameters via only a z-defining atom + for itype in itypes: + if hit != 0: + break + kz = int(self.kStrings["kz"][itype]) + kx = int(self.kStrings["kx"][itype]) + zaxis = -1 + xaxis = -1 + yaxis = -1 + neighbors = np.where(covalent_map[i_atom] == 1)[0] + for z_index in neighbors: + if hit != 0: + break + z_type = int(data.atomType[data.atoms[z_index]]) + if kx == 0 and z_type == abs(kz): + zaxis = z_index + hit = 5 + matched_itype = itype + # assign multipole parameters via no connected atoms + for itype in itypes: + if hit != 0: + break + kz = int(self.kStrings["kz"][itype]) + zaxis = -1 + xaxis = -1 + yaxis = -1 + if kz == 0: + hit = 6 + matched_itype = itype + # add particle if there was a hit + if hit != 0: + map_atomtype[i_atom] = matched_itype + self.axis_indices.append([zaxis, xaxis, yaxis]) + + kz = int(self.kStrings["kz"][matched_itype]) + kx = int(self.kStrings["kx"][matched_itype]) + ky = int(self.kStrings["ky"][matched_itype]) + axisType = ZThenX + if kz == 0: + axisType = NoAxisType + if kz != 0 and kx == 0: + axisType = ZOnly + if kz < 0 or kx < 0: + axisType = Bisector + if kx < 0 and ky < 0: + axisType = ZBisect + if kz < 0 and kx < 0 and ky < 0: + axisType = ThreeFold + self.axis_types.append(axisType) + + else: + sys.exit("Atom %d not matched in forcefield!" % i_atom) + + else: + sys.exit("Atom %d not matched in forcefield!" % i_atom) + self.axis_indices = np.array(self.axis_indices) + self.axis_types = np.array(self.axis_types) + else: + self.axis_types = None + self.axis_indices = None + + if "ethresh" in args: + self.ethresh = args["ethresh"] + if "step_pol" in args: + self.step_pol = args["step_pol"] + + pme_force = ADMPPmeForce(box, self.axis_types, self.axis_indices, + covalent_map, rc, self.ethresh, self.lmax, + self.lpol, self.lpme, self.step_pol) + self.pme_force = pme_force + + def potential_fn(positions, box, pairs, params): + params = params['ADMPPmeForce'] + mScales = params["mScales"] + Q_local = params["Q_local"][map_atomtype] + if self.lpol: + pScales = params["pScales"] + dScales = params["dScales"] + pol = params["pol"][map_atomtype] + tholes = params["tholes"][map_atomtype] + + return pme_force.get_energy( + positions, + box, + pairs, + Q_local, + pol, + tholes, + mScales, + pScales, + dScales, + pme_force.U_ind, + ) + else: + return pme_force.get_energy(positions, box, pairs, Q_local, + mScales) + + self._jaxPotential = potential_fn + + def getJaxPotential(self): + return self._jaxPotential + +# app.forcefield.parsers["ADMPPmeForce"] = ADMPPmeGenerator.parseElement +jaxGenerators["ADMPPmeForce"] = ADMPPmeGenerator class Potential: def __init__(self): @@ -1051,10 +1730,11 @@ def getPotentialFunc(self, names=[]): raise DMFFException("No DMFF function in this potential object.") def totalPE(positions, box, pairs, params): - totale = sum([ + totale_list = [ self.dmff_potentials[k](positions, box, pairs, params) for k in self.dmff_potentials.keys() if (len(names) == 0 or k in names) - ]) + ] + totale = jnp.sum(jnp.array(totale_list)) return totale return totalPE @@ -1085,6 +1765,9 @@ def __init__(self, *xmlnames): for jaxGen in self._jaxGenerators: self._forces.append(jaxGen) + def getGenerators(self): + return self._jaxGenerators + def extractParameterTree(self): # load Force info for jaxgen in self._jaxGenerators: diff --git a/dmff/fftree.py b/dmff/fftree.py index a03094d89..be63defa1 100644 --- a/dmff/fftree.py +++ b/dmff/fftree.py @@ -1,9 +1,10 @@ import xml.etree.ElementTree as ET import xml.dom.minidom from dmff.utils import convertStr2Float, DMFFException -from typing import List +from typing import Dict, List, Union, TypeVar from itertools import permutations +value = TypeVar('value') # generic type: interpreted as either a number or str class SelectError(BaseException): pass @@ -54,7 +55,25 @@ def __init__(self, tag, **attrs): super().__init__(tag, **attrs) - def get_nodes(self, parser): + def get_nodes(self, parser:str)->List[Node]: + """ + get all nodes of a certain path + + Examples + -------- + >>> fftree.get_nodes('HarmonicBondForce/Bond') + >>> [, , ...] + + Parameters + ---------- + parser : str + a path to locate nodes + + Returns + ------- + List[Node] + a list of Node + """ steps = parser.split("/") val = self for nstep, step in enumerate(steps): @@ -69,7 +88,29 @@ def get_nodes(self, parser): val = val[0] return val - def get_attribs(self, parser, attrname): + def get_attribs(self, parser:str, attrname:Union[str, List[str]])->List[Union[value, List[value]]]: + """ + get all values of attributes of nodes which nodes matching certain path + + Examples: + --------- + >>> fftree.get_attribs('HarmonicBondForce/Bond', 'k') + >>> [2.0, 2.0, 2.0, 2.0, 2.0, ...] + >>> fftree.get_attribs('HarmonicBondForce/Bond', ['k', 'r0']) + >>> [[2.0, 1.53], [2.0, 1.53], ...] + + Parameters + ---------- + parser : str + a path to locate nodes + attrname : _type_ + attribute name or a list of attribute names of a node + + Returns + ------- + List[Union[float, str]] + a list of values of attributes + """ sel = self.get_nodes(parser) if isinstance(attrname, list): ret = [] @@ -78,17 +119,54 @@ def get_attribs(self, parser, attrname): ret.append(vals) return ret else: - attrs = [convertStr2Float(n.attrs[attrname]) for n in sel] + attrs = [convertStr2Float(n.attrs[attrname]) if attrname in n.attrs else None for n in sel] return attrs - def set_node(self, parser, values): + def set_node(self, parser:str, values:List[Dict[str, value]])->None: + """ + set attributes of nodes which nodes matching certain path + + Parameters + ---------- + parser : str + path to locate nodes + values : List[Dict[str, value]] + a list of Dict[str, value], where value is any type can be convert to str of a number. + + Examples + -------- + >>> fftree.set_node('HarmonicBondForce/Bond', + [{'k': 2.0, 'r0': 1.53}, + {'k': 2.0, 'r0': 1.53}]) + """ nodes = self.get_nodes(parser) for nit in range(len(values)): for key in values[nit]: nodes[nit].attrs[key] = f"{values[nit][key]}" - def set_attrib(self, parser, attrname, values): - valdicts = [{attrname: i} for i in values] + def set_attrib(self, parser:str, attrname:str, values:Union[value, List[value]])->None: + """ + set ONE Attribute of nodes which nodes matching certain path + + Parameters + ---------- + parser : str + path to locate nodes + attrname : str + attribute name + values : Union[float, str, List[float, str]] + attribute value or a list of attribute values of a node + + Examples + -------- + >>> fftree.set_attrib('HarmonicBondForce/Bond', 'k', 2.0) + >>> fftree.set_attrib('HarmonicBondForce/Bond', 'k', [2.0, 2.0, 2.0, 2.0, 2.0]) + + """ + if len(values) == 0: + valdicts = [{attrname: values}] + else: + valdicts = [{attrname: i} for i in values] self.set_node(parser, valdicts) diff --git a/docs/dev_guide/convention.md b/docs/dev_guide/convention.md index d4d18f455..bf114dbb2 100644 --- a/docs/dev_guide/convention.md +++ b/docs/dev_guide/convention.md @@ -23,4 +23,365 @@ Under the `dmff`, there are several files and sub-directory: ## 3.2 Programming Style -TBA \ No newline at end of file +In the project DMFF, the programming style is followed numpy docstring. A proper docstring of methods and classes would be generated pretty api doc. In order to make other developers understand the methods you provide and further maintain them, you'd better add type annotations to the externally provided APIs. Here is the doc of [python typing](https://docs.python.org/3/library/typing.html). + +Here is an brife exampl from [napoleon](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_numpy.html). + +```python +# -*- coding: utf-8 -*- +"""Example NumPy style docstrings. + +This module demonstrates documentation as specified by the `NumPy +Documentation HOWTO`_. Docstrings may extend over multiple lines. Sections +are created with a section header followed by an underline of equal length. + +Example +------- +Examples can be given using either the ``Example`` or ``Examples`` +sections. Sections support any reStructuredText formatting, including +literal blocks:: + + $ python example_numpy.py + + +Section breaks are created with two blank lines. Section breaks are also +implicitly created anytime a new section starts. Section bodies *may* be +indented: + +Notes +----- + This is an example of an indented section. It's like any other section, + but the body is indented to help it stand out from surrounding text. + +If a section is indented, then a section break is created by +resuming unindented text. + +Attributes +---------- +module_level_variable1 : int + Module level variables may be documented in either the ``Attributes`` + section of the module docstring, or in an inline docstring immediately + following the variable. + + Either form is acceptable, but the two should not be mixed. Choose + one convention to document module level variables and be consistent + with it. + + +.. _NumPy Documentation HOWTO: + https://github.com/numpy/numpy/blob/master/doc/HOWTO_DOCUMENT.rst.txt + +""" + +module_level_variable1 = 12345 + +module_level_variable2 = 98765 +"""int: Module level variable documented inline. + +The docstring may span multiple lines. The type may optionally be specified +on the first line, separated by a colon. +""" + + +def function_with_types_in_docstring(param1, param2): + """Example function with types documented in the docstring. + + `PEP 484`_ type annotations are supported. If attribute, parameter, and + return types are annotated according to `PEP 484`_, they do not need to be + included in the docstring: + + Parameters + ---------- + param1 : int + The first parameter. + param2 : str + The second parameter. + + Returns + ------- + bool + True if successful, False otherwise. + + .. _PEP 484: + https://www.python.org/dev/peps/pep-0484/ + + """ + + +def function_with_pep484_type_annotations(param1: int, param2: str) -> bool: + """Example function with PEP 484 type annotations. + + The return type must be duplicated in the docstring to comply + with the NumPy docstring style. + + Parameters + ---------- + param1 + The first parameter. + param2 + The second parameter. + + Returns + ------- + bool + True if successful, False otherwise. + + """ + + +def module_level_function(param1, param2=None, *args, **kwargs): + """This is an example of a module level function. + + Function parameters should be documented in the ``Parameters`` section. + The name of each parameter is required. The type and description of each + parameter is optional, but should be included if not obvious. + + If \*args or \*\*kwargs are accepted, + they should be listed as ``*args`` and ``**kwargs``. + + The format for a parameter is:: + + name : type + description + + The description may span multiple lines. Following lines + should be indented to match the first line of the description. + The ": type" is optional. + + Multiple paragraphs are supported in parameter + descriptions. + + Parameters + ---------- + param1 : int + The first parameter. + param2 : :obj:`str`, optional + The second parameter. + *args + Variable length argument list. + **kwargs + Arbitrary keyword arguments. + + Returns + ------- + bool + True if successful, False otherwise. + + The return type is not optional. The ``Returns`` section may span + multiple lines and paragraphs. Following lines should be indented to + match the first line of the description. + + The ``Returns`` section supports any reStructuredText formatting, + including literal blocks:: + + { + 'param1': param1, + 'param2': param2 + } + + Raises + ------ + AttributeError + The ``Raises`` section is a list of all exceptions + that are relevant to the interface. + ValueError + If `param2` is equal to `param1`. + + """ + if param1 == param2: + raise ValueError('param1 may not be equal to param2') + return True + + +def example_generator(n): + """Generators have a ``Yields`` section instead of a ``Returns`` section. + + Parameters + ---------- + n : int + The upper limit of the range to generate, from 0 to `n` - 1. + + Yields + ------ + int + The next number in the range of 0 to `n` - 1. + + Examples + -------- + Examples should be written in doctest format, and should illustrate how + to use the function. + + >>> print([i for i in example_generator(4)]) + [0, 1, 2, 3] + + """ + for i in range(n): + yield i + + +class ExampleError(Exception): + """Exceptions are documented in the same way as classes. + + The __init__ method may be documented in either the class level + docstring, or as a docstring on the __init__ method itself. + + Either form is acceptable, but the two should not be mixed. Choose one + convention to document the __init__ method and be consistent with it. + + Note + ---- + Do not include the `self` parameter in the ``Parameters`` section. + + Parameters + ---------- + msg : str + Human readable string describing the exception. + code : :obj:`int`, optional + Numeric error code. + + Attributes + ---------- + msg : str + Human readable string describing the exception. + code : int + Numeric error code. + + """ + + def __init__(self, msg, code): + self.msg = msg + self.code = code + + +class ExampleClass(object): + """The summary line for a class docstring should fit on one line. + + If the class has public attributes, they may be documented here + in an ``Attributes`` section and follow the same formatting as a + function's ``Args`` section. Alternatively, attributes may be documented + inline with the attribute's declaration (see __init__ method below). + + Properties created with the ``@property`` decorator should be documented + in the property's getter method. + + Attributes + ---------- + attr1 : str + Description of `attr1`. + attr2 : :obj:`int`, optional + Description of `attr2`. + + """ + + def __init__(self, param1, param2, param3): + """Example of docstring on the __init__ method. + + The __init__ method may be documented in either the class level + docstring, or as a docstring on the __init__ method itself. + + Either form is acceptable, but the two should not be mixed. Choose one + convention to document the __init__ method and be consistent with it. + + Note + ---- + Do not include the `self` parameter in the ``Parameters`` section. + + Parameters + ---------- + param1 : str + Description of `param1`. + param2 : :obj:`list` of :obj:`str` + Description of `param2`. Multiple + lines are supported. + param3 : :obj:`int`, optional + Description of `param3`. + + """ + self.attr1 = param1 + self.attr2 = param2 + self.attr3 = param3 #: Doc comment *inline* with attribute + + #: list of str: Doc comment *before* attribute, with type specified + self.attr4 = ["attr4"] + + self.attr5 = None + """str: Docstring *after* attribute, with type specified.""" + + @property + def readonly_property(self): + """str: Properties should be documented in their getter method.""" + return "readonly_property" + + @property + def readwrite_property(self): + """:obj:`list` of :obj:`str`: Properties with both a getter and setter + should only be documented in their getter method. + + If the setter method contains notable behavior, it should be + mentioned here. + """ + return ["readwrite_property"] + + @readwrite_property.setter + def readwrite_property(self, value): + value + + def example_method(self, param1, param2): + """Class methods are similar to regular functions. + + Note + ---- + Do not include the `self` parameter in the ``Parameters`` section. + + Parameters + ---------- + param1 + The first parameter. + param2 + The second parameter. + + Returns + ------- + bool + True if successful, False otherwise. + + """ + return True + + def __special__(self): + """By default special members with docstrings are not included. + + Special members are any methods or attributes that start with and + end with a double underscore. Any special member with a docstring + will be included in the output, if + ``napoleon_include_special_with_doc`` is set to True. + + This behavior can be enabled by changing the following setting in + Sphinx's conf.py:: + + napoleon_include_special_with_doc = True + + """ + pass + + def __special_without_docstring__(self): + pass + + def _private(self): + """By default private members are not included. + + Private members are any methods or attributes that start with an + underscore and are *not* special. By default they are not included + in the output. + + This behavior can be changed such that private members *are* included + by changing the following setting in Sphinx's conf.py:: + + napoleon_include_private_with_doc = True + + """ + pass + + def _private_without_docstring(self): + pass +``` \ No newline at end of file diff --git a/tests/data/methane_water_modified.xml b/tests/data/methane_water_modified.xml new file mode 100644 index 000000000..95a84f28e --- /dev/null +++ b/tests/data/methane_water_modified.xml @@ -0,0 +1,59 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/data/tip3p.xml b/tests/data/tip3p.xml new file mode 100644 index 000000000..77de2c6cb --- /dev/null +++ b/tests/data/tip3p.xml @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/test_admp/test_compute.py b/tests/test_admp/test_compute.py index e6ea1f3bb..79c060dee 100644 --- a/tests/test_admp/test_compute.py +++ b/tests/test_admp/test_compute.py @@ -25,9 +25,27 @@ def test_init(self): rc = 4.0 H = Hamiltonian('tests/data/admp.xml') pdb = app.PDBFile('tests/data/water_dimer.pdb') - H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5) + potential = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5) + generators = H.getGenerators() - yield H.getGenerators() + yield generators + + def test_ADMPPmeForce(self, generators): + + rc = 4.0 + pdb = app.PDBFile('tests/data/water_dimer.pdb') + positions = np.array(pdb.positions._value) * 10 + a, b, c = pdb.topology.getPeriodicBoxVectors() + box = np.array([a._value, b._value, c._value]) * 10 + # neighbor list + nblist = NeighborList(box, rc) + nblist.allocate(positions) + pairs = nblist.pairs + + gen = generators[1] + pot = gen.getJaxPotential() + energy = pot(positions, box, pairs, gen.paramtree) + def test_ADMPPmeForce_jit(self, generators): @@ -41,8 +59,8 @@ def test_ADMPPmeForce_jit(self, generators): nblist = NeighborList(box, rc) nblist.allocate(positions) pairs = nblist.pairs - - pot_pme = gen.getJaxPotential() - j_pot_pme = jit(value_and_grad(pot_pme)) - - E_pme, F_pme = j_pot_pme(positions, box, pairs, gen.params) \ No newline at end of file + + gen = generators[1] + pot = gen.getJaxPotential() + j_pot_pme = jit(value_and_grad(pot)) + energy = j_pot_pme(positions, box, pairs, gen.paramtree) diff --git a/tests/test_api.py b/tests/test_api.py index f576de147..31786910d 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -31,7 +31,7 @@ def test_init(self): def test_ADMPDispForce_parseXML(self, generators): gen = generators[0] - params = gen.params + params = gen.paramtree['ADMPDispForce'] npt.assert_allclose(params['mScales'], [0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) npt.assert_allclose(params['A'], [1203470.743, 83.2283563]) @@ -40,18 +40,32 @@ def test_ADMPDispForce_parseXML(self, generators): def test_ADMPDispForce_renderXML(self, generators): gen = generators[0] - xml = gen.renderXML() + params = gen.paramtree['ADMPDispForce'] + gen.overwrite() - assert xml.name == 'ADMPDispForce' - npt.assert_allclose(float(xml[0]['type']), 380) - npt.assert_allclose(float(xml[0]['A']), 1203470.743) - npt.assert_allclose(float(xml[1]['B']), 37.78544799) - npt.assert_allclose(float(xml[1]['Q']), 0.370853) + assert gen.name == 'ADMPDispForce' + npt.assert_allclose(params['mScales'], [0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) + npt.assert_allclose(params['A'], [1203470.743, 83.2283563]) + npt.assert_allclose(params['B'], [37.81265679, 37.78544799]) def test_ADMPPmeForce_parseXML(self, generators): gen = generators[1] - params = gen.params + tree = gen.paramtree['ADMPPmeForce'] + + npt.assert_allclose(tree['mScales'], [0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) + npt.assert_allclose(tree['pScales'], [0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) + npt.assert_allclose(tree['dScales'], [0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) + # Q_local is already converted to local frame + # npt.assert_allclose(tree['Q_local'][0][:4], [-1.0614, 0.0, 0.0, -0.023671684]) + npt.assert_allclose(tree['pol'], [0.88000005, 0]) + npt.assert_allclose(tree['tholes'], [8., 0.]) + + def test_ADMPPmeForce_renderXML(self, generators): + + gen = generators[1] + params = gen.paramtree['ADMPPmeForce'] + gen.overwrite() npt.assert_allclose(params['mScales'], [0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) npt.assert_allclose(params['pScales'], [0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) @@ -61,21 +75,6 @@ def test_ADMPPmeForce_parseXML(self, generators): npt.assert_allclose(params['pol'], [0.88000005, 0]) npt.assert_allclose(params['tholes'], [8., 0.]) - def test_ADMPPmeForce_renderXML(self, generators): - - gen = generators[1] - xml = gen.renderXML() - - assert xml.name == 'ADMPPmeForce' - assert xml.attributes['lmax'] == '2' - assert xml.attributes['mScale12'] == '0.0' - assert xml.attributes['mScale15'] == '1.0' - assert xml.elements[0].name == 'Atom' - assert xml.elements[0].attributes['qXZ'] == '-0.07141020' - assert xml.elements[2].name == 'Polarize' - assert xml.elements[2].attributes['polarizabilityXX'][:6] == '0.8800' - assert xml[3]['type'] == '381' - class TestClassicalAPI: """ Test classical forcefield generators @@ -99,64 +98,63 @@ def test_init(self): yield H.getGenerators() - def test_NonBond_parseXML(self, generators): + def test_Nonbond_parseXML(self, generators): gen = generators[0] - params = gen.params + params = gen.paramtree['NonbondedForce'] npt.assert_allclose(params['sigma'], [1.0, 1.0, -1.0, -1.0]) def test_NonBond_renderXML(self, generators): gen = generators[0] - xml = gen.renderXML() + params = gen.paramtree['NonbondedForce'] + gen.overwrite() - assert xml.name == 'NonbondedForce' - assert xml.attributes['lj14scale'] == '0.5' - assert xml[0]['type'] == 'n1' - assert xml[1]['sigma'] == '1.0' + assert gen.name == 'NonbondedForce' + # npt.assert_allclose(params['type'], ['n1', 'n2', 'n3', 'n4']) + # type is asigned to the generator themself as a member variable + npt.assert_allclose(params['charge'], [1.0, -1.0, 1.0, -1.0]) + npt.assert_allclose(params['sigma'], [1.0, 1.0, -1.0, -1.0]) + npt.assert_allclose(params['epsilon'], [0.0, 0.0, 0.0, 0.0]) + npt.assert_allclose(params['lj14scale'], [0.5]) + def test_HarmonicAngle_parseXML(self, generators): gen = generators[1] - params = gen.params + params = gen.paramtree['HarmonicAngleForce'] npt.assert_allclose(params['k'], 836.8) npt.assert_allclose(params['angle'], 1.8242181341844732) def test_HarmonicAngle_renderXML(self, generators): gen = generators[1] - xml = gen.renderXML() + params = gen.paramtree['HarmonicAngleForce'] + gen.overwrite() - assert xml.name == 'HarmonicAngleForce' - assert xml[0]['type1'] == 'n1' - assert xml[0]['type2'] == 'n2' - assert xml[0]['type3'] == 'n3' - assert xml[0]['angle'][:7] == '1.82421' - assert xml[0]['k'] == '836.8' + assert gen.name == 'HarmonicAngleForce' + npt.assert_allclose(params['angle'], [1.8242181341844732] * 2) + npt.assert_allclose(params['k'], [836.8] * 2) def test_PeriodicTorsion_parseXML(self, generators): gen = generators[2] - params = gen.params - npt.assert_allclose(params['psi1_p'], 0) - npt.assert_allclose(params['k1_p'], 2.092) + params = gen.paramtree['PeriodicTorsionForce'] + npt.assert_allclose(params['prop_phase']['1'], [0]) + npt.assert_allclose(params['prop_k']['1'], [2.092]) def test_PeriodicTorsion_renderXML(self, generators): gen = generators[2] - xml = gen.renderXML() - assert xml.name == 'PeriodicTorsionForce' - assert xml[0].name == 'Proper' - assert xml[0]['type1'] == 'n1' - assert xml[1].name == 'Improper' - assert xml[1]['type1'] == 'n1' + params = gen.paramtree['PeriodicTorsionForce'] + gen.overwrite() + assert gen.name == 'PeriodicTorsionForce' + npt.assert_allclose(params['prop_phase']['1'], [0]) + npt.assert_allclose(params['prop_k']['1'], [2.092]) def test_parse_multiple_files(self): pdb = app.PDBFile("tests/data/methane_water.pdb") - h = Hamiltonian("tests/data/methane.xml", "tip3p.xml") + h = Hamiltonian("tests/data/methane.xml", "tests/data/tip3p.xml") potentials = h.createPotential(pdb.topology) - npt.assert_allclose( - h.getGenerators()[-1].params["charge"], - [-0.1068, 0.0267, 0.0267, 0.0267, 0.0267, -0.834, 0.417] - ) + diff --git a/tests/test_classical/test_lj.py b/tests/test_classical/test_lj.py index 7bb8ba7eb..3f560b0d6 100644 --- a/tests/test_classical/test_lj.py +++ b/tests/test_classical/test_lj.py @@ -74,7 +74,7 @@ def test_lj_params_check(self): pairs = np.array(pairs, dtype=int) ljE = potential.getPotentialFunc() with pytest.raises(TypeError): - energy = ljE(pos, box, pairs, h.getGenerators()[0].params) + energy = ljE(pos, box, pairs, h.getGenerators()[0].paramtree) energy = jax.jit(ljE)(pos, box, pairs, h.paramtree) # jit will optimized away type check force = jax.grad(jax.jit(ljE))(pos, box, pairs, h.paramtree) \ No newline at end of file