Skip to content

Commit

Permalink
Include detailed tests for forcefield metadata (#435)
Browse files Browse the repository at this point in the history
* Include detailed tests for forcefield metadata

Previously, the tests in test_forcefield.py were assumed to test that
forcefield metadata like `version`, `name`, etc. were properly loaded
into the `Forcefield` object.

This is only partially true, as there are cases where the files either
go out of scope before that information is gathered, etc. This PR
includes additional tests to ensure that this information is not being
lost prematurely. Big thanks to: @mattwthompson, @ahy3nz for discovering
this issue, creating MWE's for some of these cases, and taking my
investigation further and most likely pinpointing the issue.

* Change ff load logic to make options consistent

When a forcefield file is provided by using the internal name:
`foyer.Forcefield(name='oplsaa)`, the final file processing would happen
in a separate part of the logic and use variables that were only
accessed if the user used `forcefield_files` in the `Forcefield`
constructor method.

Now the same final processing occurs for all input types to load the
forcefields. All tests currently pass with this fix in place.

* Remove named temp files after final parsing
  • Loading branch information
justinGilmer committed Jun 29, 2021
1 parent 3b91082 commit fc44f80
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 12 deletions.
20 changes: 9 additions & 11 deletions foyer/forcefield.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,21 +522,19 @@ def __init__(
if validation:
for ff_file_name in preprocessed_files:
Validator(ff_file_name, debug)
try:
super(Forcefield, self).__init__(*preprocessed_files)
finally:
for ff_file_name in preprocessed_files:
os.remove(ff_file_name)
super(Forcefield, self).__init__(*preprocessed_files)

if isinstance(forcefield_files, str):
self._version = self._parse_version_number(forcefield_files)
self._name = self._parse_name(forcefield_files)
elif isinstance(forcefield_files, list):
if len(preprocessed_files) == 1:
self._version = self._parse_version_number(preprocessed_files[0])
self._name = self._parse_name(preprocessed_files[0])
elif len(preprocessed_files) > 1:
self._version = [
self._parse_version_number(f) for f in forcefield_files
self._parse_version_number(f) for f in preprocessed_files
]
self._name = [self._parse_name(f) for f in forcefield_files]
self._name = [self._parse_name(f) for f in preprocessed_files]

for fp in preprocessed_files:
os.remove(fp)
self.parser = smarts.SMARTS(self.non_element_types)
self._system_data = None

Expand Down
27 changes: 26 additions & 1 deletion foyer/tests/test_forcefield.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import glob
import itertools as it
import os
from typing import List

import parmed as pmd
import pytest
from lxml import etree as ET
from pkg_resources import resource_filename

from foyer import Forcefield
from foyer import Forcefield, forcefields
from foyer.exceptions import FoyerError, ValidationWarning
from foyer.forcefield import (
_check_independent_residues,
Expand Down Expand Up @@ -639,6 +640,30 @@ def test_load_metadata(self):
assert lj_ff.version == ["0.4.1", "4.8.2"]
assert lj_ff.name == ["LJ", "JL"]

def test_load_metadata_single_xml(self):
from_xml_ff = Forcefield(forcefield_files=get_fn("lj.xml"))
assert from_xml_ff.version == "0.4.1"
assert from_xml_ff.name == "LJ"

def test_load_metadata_list_xml(self):
from_xml_ff = Forcefield(
forcefield_files=[get_fn("lj.xml"), get_fn("lj2.xml")]
)
assert isinstance(from_xml_ff.version, List)
assert isinstance(from_xml_ff.name, List)
assert all([x in from_xml_ff.version for x in ["0.4.1", "4.8.2"]])
assert all([x in from_xml_ff.name for x in ["JL", "LJ"]])

def test_load_metadata_from_internal_forcefield_plugin_loader(self):
from_xml_ff = forcefields.load_OPLSAA()
assert from_xml_ff.version == "0.0.1"
assert from_xml_ff.name == "OPLS-AA"

def test_load_metadata_from_internal_name(self):
from_xml_ff = Forcefield(name="oplsaa")
assert from_xml_ff.version == "0.0.1"
assert from_xml_ff.name == "OPLS-AA"

@pytest.mark.skipif(not has_mbuild, reason="mbuild is not installed")
def test_no_overlap_residue_atom_overlap(self):
import mbuild as mb
Expand Down

0 comments on commit fc44f80

Please sign in to comment.