diff --git a/.github/workflows/gh-ci.yaml b/.github/workflows/gh-ci.yaml index 08d8da3d..b7b82fc4 100644 --- a/.github/workflows/gh-ci.yaml +++ b/.github/workflows/gh-ci.yaml @@ -33,9 +33,9 @@ jobs: os: [macOS-latest, ubuntu-latest] python-version: ["3.9", "3.10", "3.11", "3.12"] pydantic-version: ["1", "2"] - include-rdkit: [true, false] - include-openeye: [true, false] - include-dgl: [true, false] + include-rdkit: [false, true] + include-openeye: [false, true] + include-dgl: [false, true] exclude: - include-rdkit: false include-openeye: false diff --git a/openff/nagl/_base/array.py b/openff/nagl/_base/array.py deleted file mode 100644 index 4a68f429..00000000 --- a/openff/nagl/_base/array.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import Any - -import numpy as np - - -class ArrayMeta(type): - def __getitem__(cls, T): - return type("Array", (Array,), {"__dtype__": T}) - - -class Array(np.ndarray, metaclass=ArrayMeta): - """A typeable numpy array""" - - @classmethod - def __get_validators__(cls): - yield cls.validate_type - - @classmethod - def validate_type(cls, val): - dtype = getattr(cls, "__dtype__", Any) - if dtype is Any: - dtype = None - return np.asanyarray(val, dtype=dtype) diff --git a/openff/nagl/_base/base.py b/openff/nagl/_base/base.py index 5b41c48e..de15942f 100644 --- a/openff/nagl/_base/base.py +++ b/openff/nagl/_base/base.py @@ -1,22 +1,16 @@ import enum -import hashlib -import inspect import pathlib import json import yaml -from typing import Any, ClassVar, Dict, List, Optional, Type, no_type_check import numpy as np from openff.units import unit -from ..utils._utils import round_floats try: from pydantic.v1 import BaseModel - from pydantic.v1.errors import DictError except ImportError: from pydantic import BaseModel - from pydantic.errors import DictError class MutableModel(BaseModel): """ @@ -38,12 +32,6 @@ class Config: pathlib.Path: str, } - _hash_fields: ClassVar[Optional[List[str]]] = None - _float_fields: ClassVar[List[str]] = [] - _float_decimals: ClassVar[int] = 8 - _hash_int: Optional[int] = None - _hash_str: Optional[str] = None - def __init__(self, *args, **kwargs): self.__pre_init__(*args, **kwargs) super(MutableModel, self).__init__(*args, **kwargs) @@ -55,80 +43,6 @@ def __pre_init__(self, *args, **kwargs): def __post_init__(self, *args, **kwargs): pass - @classmethod - def _get_properties(cls) -> Dict[str, property]: - return dict( - inspect.getmembers(cls, predicate=lambda x: isinstance(x, property)) - ) - - @no_type_check - def __setattr__(self, attr, value): - try: - super().__setattr__(attr, value) - except ValueError as e: - properties = self._get_properties() - if attr in properties: - if properties[attr].fset is not None: - return properties[attr].fset(self, value) - raise e - - def _clsname(self): - return type(self).__name__ - - def _clear_hash(self): - self._hash_int = None - self._hash_str = None - - def _set_attr(self, attrname, value): - self.__dict__[attrname] = value - self._clear_hash() - - def __hash__(self): - if self._hash_int is None: - mash = self.get_hash() - self._hash_int = int(mash, 16) - return self._hash_int - - def get_hash(self) -> str: - """Returns string hash of the object""" - if self._hash_str is None: - dumped = self.dumps(decimals=self._float_decimals) - mash = hashlib.sha1() - mash.update(dumped.encode("utf-8")) - self._hash_str = mash.hexdigest() - return self._hash_str - - # def __eq__(self, other): - # return hash(self) == hash(other) - - def hash_dict(self) -> Dict[str, Any]: - """Create dictionary from hash fields and sort alphabetically""" - if self._hash_fields: - hashdct = self.dict(include=set(self._hash_fields)) - else: - hashdct = self.dict() - data = {k: hashdct[k] for k in sorted(hashdct)} - return data - - def dumps(self, decimals: Optional[int] = None): - """Serialize object to a JSON formatted string - - Unlike json(), this method only includes hashable fields, - sorts them alphabetically, and optionally rounds floats. - """ - data = self.hash_dict() - dump = self.__config__.json_dumps - if decimals is not None: - for field in self._float_fields: - if field in data: - data[field] = round_floats(data[field], decimals=decimals) - with np.printoptions(precision=16): - return dump(data, default=self.__json_encoder__) - return dump(data, default=self.__json_encoder__) - - def _round(self, obj): - return round_floats(obj, decimals=self._float_decimals) - def to_json(self): return self.json( sort_keys=True, @@ -136,28 +50,6 @@ def to_json(self): separators=(",", ": "), ) - @classmethod - def _from_dict(cls, **kwargs): - dct = {k: kwargs[k] for k in kwargs if k in cls.__fields__} - return cls(**dct) - - @classmethod - def validate(cls: Type["MutableModel"], value: Any) -> "MutableModel": - if isinstance(value, dict): - return cls(**value) - elif isinstance(value, cls): - return value - elif cls.__config__.orm_mode: - return cls.from_orm(value) - elif cls.__custom_root_type__: - return cls.parse_obj(value) - else: - try: - value_as_dict = dict(value) - except (TypeError, ValueError) as e: - raise DictError() from e - return cls(**value_as_dict) - @classmethod def from_json(cls, string_or_file): try: @@ -165,17 +57,11 @@ def from_json(cls, string_or_file): string_or_file = f.read() except (OSError, FileNotFoundError): pass - return cls.parse_raw(string_or_file) - - def copy(self, *, include=None, exclude=None, update=None, deep: bool = False): - obj = super().copy(include=include, exclude=exclude, update=update, deep=deep) - obj.__post_init__() - return obj - - def _replace_from_mapping(self, attr_name, mapping_values={}): - current_value = getattr(self, attr_name) - if current_value in mapping_values: - self._set_attr(attr_name, mapping_values[current_value]) + try: + validator = cls.model_validate_json + except AttributeError: + validator = cls.parse_raw + return validator(string_or_file) def to_yaml(self, filename): data = json.loads(self.json()) diff --git a/openff/nagl/tests/_base/test_base.py b/openff/nagl/tests/_base/test_base.py new file mode 100644 index 00000000..adf34e4c --- /dev/null +++ b/openff/nagl/tests/_base/test_base.py @@ -0,0 +1,157 @@ +from openff.nagl._base.base import MutableModel +from openff.units import unit +import numpy as np +import json +import textwrap + +try: + from pydantic.v1 import Field, validator +except ImportError: + from pydantic import Field, validator + +class TestMutableModel: + + class Model(MutableModel): + int_type: int + float_type: float + list_type: list + np_array_type: np.ndarray + tuple_type: tuple + unit_type: unit.Quantity + + @validator("np_array_type", pre=True) + def _validate_np_array_type(cls, v): + return np.asarray(v) + + @validator("unit_type", pre=True) + def _validate_unit_type(cls, v): + if not isinstance(v, unit.Quantity): + return unit.Quantity.from_tuple(v) + return v + + + def test_init(self): + model = self.Model(int_type=1, float_type=1.0, list_type=[1, 2, 3], np_array_type=np.array([1, 2, 3]), tuple_type=(1, 2, 3), unit_type=unit.Quantity(1.0, "angstrom")) + assert model.int_type == 1 + assert model.float_type == 1.0 + assert model.list_type == [1, 2, 3] + assert np.array_equal(model.np_array_type, np.array([1, 2, 3])) + assert model.tuple_type == (1, 2, 3) + assert model.unit_type == unit.Quantity(1.0, "angstrom") + + def test_to_json(self): + arr = np.arange(10).reshape(2, 5) + model = self.Model(int_type=1, float_type=1.0, list_type=[1, 2, 3], np_array_type=arr, tuple_type=(1, 2, 3), unit_type=unit.Quantity(1.0, "angstrom")) + json_dict = json.loads(model.to_json()) + expected = { + "int_type": 1, + "float_type": 1.0, + "list_type": [1, 2, 3], + "np_array_type": [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9] ], + "tuple_type": [1, 2, 3], + "unit_type": [1.0, [["angstrom", 1]]] + } + assert json_dict == expected + + def test_from_json_string(self): + input_text = """ + { + "int_type": 4, + "float_type": 10.0, + "list_type": [1, 2, 3], + "np_array_type": [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9] ], + "tuple_type": [1, 2, 3], + "unit_type": [1.0, [["angstrom", 1]]] + } + """ + model = self.Model.from_json(input_text) + assert model.int_type == 4 + assert model.float_type == 10.0 + assert model.list_type == [1, 2, 3] + arr = np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + assert np.array_equal(model.np_array_type, arr) + assert model.tuple_type == (1, 2, 3) + assert model.unit_type == unit.Quantity(1.0, "angstrom") + + def test_from_json_file(self, tmp_path): + input_text = """ + { + "int_type": 4, + "float_type": 10.0, + "list_type": [1, 2, 3], + "np_array_type": [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9] ], + "tuple_type": [1, 2, 3], + "unit_type": [1.0, [["angstrom", 1]]] + } + """ + file_path = tmp_path / "test.json" + with open(file_path, "w") as f: + f.write(input_text) + model = self.Model.from_json(file_path) + assert model.int_type == 4 + assert model.float_type == 10.0 + assert model.list_type == [1, 2, 3] + arr = np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + assert np.array_equal(model.np_array_type, arr) + assert model.tuple_type == (1, 2, 3) + assert model.unit_type == unit.Quantity(1.0, "angstrom") + + def test_to_yaml(self, tmp_path): + model = self.Model(int_type=1, float_type=1.0, list_type=[1, 2, 3], np_array_type=np.array([1, 2, 3]), tuple_type=(1, 2, 3), unit_type=unit.Quantity(1.0, "angstrom")) + file_path = tmp_path / "test.yaml" + model.to_yaml(file_path) + with open(file_path, "r") as f: + yaml_text = f.read() + expected = textwrap.dedent(""" + float_type: 1.0 + int_type: 1 + list_type: + - 1 + - 2 + - 3 + np_array_type: + - 1 + - 2 + - 3 + tuple_type: + - 1 + - 2 + - 3 + unit_type: + - 1.0 + - - - angstrom + - 1 + """) + assert yaml_text.strip() == expected.strip() + + def test_from_yaml(self, tmp_path): + input_text = textwrap.dedent(""" + float_type: 1.0 + int_type: 1 + list_type: + - 1 + - 2 + - 3 + np_array_type: + - 1 + - 2 + - 3 + tuple_type: + - 1 + - 2 + - 3 + unit_type: + - 1.0 + - - - angstrom + - 1 + """) + file_path = tmp_path / "test.yaml" + with open(file_path, "w") as f: + f.write(input_text) + model = self.Model.from_yaml(file_path) + assert model.int_type == 1 + assert model.float_type == 1.0 + assert model.list_type == [1, 2, 3] + assert np.array_equal(model.np_array_type, np.array([1, 2, 3])) + assert model.tuple_type == (1, 2, 3) + assert model.unit_type == unit.Quantity(1.0, "angstrom") diff --git a/openff/nagl/tests/_base/test_metaregistry.py b/openff/nagl/tests/_base/test_metaregistry.py new file mode 100644 index 00000000..b1f6bf10 --- /dev/null +++ b/openff/nagl/tests/_base/test_metaregistry.py @@ -0,0 +1,47 @@ +import pytest + +from openff.nagl._base.metaregistry import create_registry_metaclass + +class TestMetaRegistry: + + Registry = create_registry_metaclass(ignore_case=False) + RegistryIgnoreCase = create_registry_metaclass(ignore_case=True) + + class TestClass(metaclass=Registry): + name = "TestKey" + + class TestClassIgnoreCase(metaclass=RegistryIgnoreCase): + name = "TestKey" + + def test_create_registry_metaclass(self): + assert "TestKey" in self.Registry.registry + assert self.Registry.registry["TestKey"] is self.TestClass + assert "testkey" in self.RegistryIgnoreCase.registry + assert self.RegistryIgnoreCase.registry["testkey"] is self.TestClassIgnoreCase + + def test_key_transform(self): + assert "TestKey" in self.Registry.registry + assert self.Registry._get_by_key("TestKey") is self.TestClass + with pytest.raises(KeyError): + self.Registry._get_by_key("testkey") + + def test_key_transform_ignore_case(self): + assert "testkey" in self.RegistryIgnoreCase.registry + assert "TestKey" not in self.RegistryIgnoreCase.registry + assert self.RegistryIgnoreCase._get_by_key("TestKey") is self.TestClassIgnoreCase + assert self.RegistryIgnoreCase._get_by_key("testkey") is self.TestClassIgnoreCase + + def test_get_class(self): + assert self.Registry._get_class(self.TestClass) is self.TestClass + assert self.Registry._get_class(self.TestClass()) is self.TestClass + assert self.Registry._get_class("TestKey") is self.TestClass + with pytest.raises(KeyError): + self.Registry._get_class("testkey") + + def test_get_object(self): + assert isinstance(self.Registry._get_object(self.TestClass), self.TestClass) + assert isinstance(self.Registry._get_object(self.TestClass()), self.TestClass) + assert isinstance(self.Registry._get_object("TestKey"), self.TestClass) + with pytest.raises(KeyError): + self.Registry._get_object("testkey") + diff --git a/openff/nagl/tests/training/test_loss.py b/openff/nagl/tests/training/test_loss.py index 4a7fbe9d..bc53f77c 100644 --- a/openff/nagl/tests/training/test_loss.py +++ b/openff/nagl/tests/training/test_loss.py @@ -1,7 +1,13 @@ +import typing +from typing import Dict, List import pytest +from openff.nagl.nn._containers import ReadoutModule import torch import numpy as np + +from openff.nagl.training.metrics import RMSEMetric from openff.nagl.training.loss import ( + _BaseTarget, MultipleDipoleTarget, SingleDipoleTarget, HeavyAtomReadoutTarget, @@ -9,6 +15,28 @@ MultipleESPTarget ) +class TestBaseTarget: + + class BaseTarget(_BaseTarget): + name: typing.Literal["base"] + def get_required_columns(self) -> List[str]: + return [] + + def evaluate_target(self, molecules, labels, predictions, readout_modules) -> "torch.Tensor": + return torch.tensor([0.0]) + + def test_validate_metric(self): + input_text = '{"metric": "rmse", "name": "readout", "prediction_label": "charges", "target_label": "charges"}' + target = ReadoutTarget.parse_raw(input_text) + assert isinstance(target.metric, RMSEMetric) + + def test_non_implemented_methods(self): + target = self.BaseTarget(name="base", metric="rmse", target_label="charges") + with pytest.raises(NotImplementedError): + target.compute_reference(None) + with pytest.raises(NotImplementedError): + target.report_artifact(None, None, None, None) + class TestReadoutTarget: def test_single_molecule(self, dgl_methane):