diff --git a/foyer/forcefield.py b/foyer/forcefield.py index 69c2f3e1..e055ddbf 100755 --- a/foyer/forcefield.py +++ b/foyer/forcefield.py @@ -522,21 +522,19 @@ def __init__( if validation: for ff_file_name in preprocessed_files: Validator(ff_file_name, debug) - try: - super(Forcefield, self).__init__(*preprocessed_files) - finally: - for ff_file_name in preprocessed_files: - os.remove(ff_file_name) + super(Forcefield, self).__init__(*preprocessed_files) - if isinstance(forcefield_files, str): - self._version = self._parse_version_number(forcefield_files) - self._name = self._parse_name(forcefield_files) - elif isinstance(forcefield_files, list): + if len(preprocessed_files) == 1: + self._version = self._parse_version_number(preprocessed_files[0]) + self._name = self._parse_name(preprocessed_files[0]) + elif len(preprocessed_files) > 1: self._version = [ - self._parse_version_number(f) for f in forcefield_files + self._parse_version_number(f) for f in preprocessed_files ] - self._name = [self._parse_name(f) for f in forcefield_files] + self._name = [self._parse_name(f) for f in preprocessed_files] + for fp in preprocessed_files: + os.remove(fp) self.parser = smarts.SMARTS(self.non_element_types) self._system_data = None diff --git a/foyer/tests/test_forcefield.py b/foyer/tests/test_forcefield.py index 222e2820..171f38a1 100644 --- a/foyer/tests/test_forcefield.py +++ b/foyer/tests/test_forcefield.py @@ -2,13 +2,14 @@ import glob import itertools as it import os +from typing import List import parmed as pmd import pytest from lxml import etree as ET from pkg_resources import resource_filename -from foyer import Forcefield +from foyer import Forcefield, forcefields from foyer.exceptions import FoyerError, ValidationWarning from foyer.forcefield import ( _check_independent_residues, @@ -639,6 +640,30 @@ def test_load_metadata(self): assert lj_ff.version == ["0.4.1", "4.8.2"] assert lj_ff.name == ["LJ", "JL"] + def test_load_metadata_single_xml(self): + from_xml_ff = Forcefield(forcefield_files=get_fn("lj.xml")) + assert from_xml_ff.version == "0.4.1" + assert from_xml_ff.name == "LJ" + + def test_load_metadata_list_xml(self): + from_xml_ff = Forcefield( + forcefield_files=[get_fn("lj.xml"), get_fn("lj2.xml")] + ) + assert isinstance(from_xml_ff.version, List) + assert isinstance(from_xml_ff.name, List) + assert all([x in from_xml_ff.version for x in ["0.4.1", "4.8.2"]]) + assert all([x in from_xml_ff.name for x in ["JL", "LJ"]]) + + def test_load_metadata_from_internal_forcefield_plugin_loader(self): + from_xml_ff = forcefields.load_OPLSAA() + assert from_xml_ff.version == "0.0.1" + assert from_xml_ff.name == "OPLS-AA" + + def test_load_metadata_from_internal_name(self): + from_xml_ff = Forcefield(name="oplsaa") + assert from_xml_ff.version == "0.0.1" + assert from_xml_ff.name == "OPLS-AA" + @pytest.mark.skipif(not has_mbuild, reason="mbuild is not installed") def test_no_overlap_residue_atom_overlap(self): import mbuild as mb