Skip to content

Commit

Permalink
Add tests and clean things up (#100)
Browse files Browse the repository at this point in the history
* remove unused methods

* remove array file

* add some typing tests

* add yaml tests

* update model

* rm unused imports

* add meta tests

* add validate metric

* add more base tests

* swap order of test matrix?
  • Loading branch information
lilyminium committed Mar 28, 2024
1 parent 7ee92ac commit 0a3a9c7
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 145 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/gh-ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ jobs:
os: [macOS-latest, ubuntu-latest]
python-version: ["3.9", "3.10", "3.11", "3.12"]
pydantic-version: ["1", "2"]
include-rdkit: [true, false]
include-openeye: [true, false]
include-dgl: [true, false]
include-rdkit: [false, true]
include-openeye: [false, true]
include-dgl: [false, true]
exclude:
- include-rdkit: false
include-openeye: false
Expand Down
23 changes: 0 additions & 23 deletions openff/nagl/_base/array.py

This file was deleted.

124 changes: 5 additions & 119 deletions openff/nagl/_base/base.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
import enum
import hashlib
import inspect
import pathlib
import json
import yaml
from typing import Any, ClassVar, Dict, List, Optional, Type, no_type_check

import numpy as np
from openff.units import unit

from ..utils._utils import round_floats

try:
from pydantic.v1 import BaseModel
from pydantic.v1.errors import DictError
except ImportError:
from pydantic import BaseModel
from pydantic.errors import DictError

class MutableModel(BaseModel):
"""
Expand All @@ -38,12 +32,6 @@ class Config:
pathlib.Path: str,
}

_hash_fields: ClassVar[Optional[List[str]]] = None
_float_fields: ClassVar[List[str]] = []
_float_decimals: ClassVar[int] = 8
_hash_int: Optional[int] = None
_hash_str: Optional[str] = None

def __init__(self, *args, **kwargs):
self.__pre_init__(*args, **kwargs)
super(MutableModel, self).__init__(*args, **kwargs)
Expand All @@ -55,127 +43,25 @@ def __pre_init__(self, *args, **kwargs):
def __post_init__(self, *args, **kwargs):
pass

@classmethod
def _get_properties(cls) -> Dict[str, property]:
return dict(
inspect.getmembers(cls, predicate=lambda x: isinstance(x, property))
)

@no_type_check
def __setattr__(self, attr, value):
try:
super().__setattr__(attr, value)
except ValueError as e:
properties = self._get_properties()
if attr in properties:
if properties[attr].fset is not None:
return properties[attr].fset(self, value)
raise e

def _clsname(self):
return type(self).__name__

def _clear_hash(self):
self._hash_int = None
self._hash_str = None

def _set_attr(self, attrname, value):
self.__dict__[attrname] = value
self._clear_hash()

def __hash__(self):
if self._hash_int is None:
mash = self.get_hash()
self._hash_int = int(mash, 16)
return self._hash_int

def get_hash(self) -> str:
"""Returns string hash of the object"""
if self._hash_str is None:
dumped = self.dumps(decimals=self._float_decimals)
mash = hashlib.sha1()
mash.update(dumped.encode("utf-8"))
self._hash_str = mash.hexdigest()
return self._hash_str

# def __eq__(self, other):
# return hash(self) == hash(other)

def hash_dict(self) -> Dict[str, Any]:
"""Create dictionary from hash fields and sort alphabetically"""
if self._hash_fields:
hashdct = self.dict(include=set(self._hash_fields))
else:
hashdct = self.dict()
data = {k: hashdct[k] for k in sorted(hashdct)}
return data

def dumps(self, decimals: Optional[int] = None):
"""Serialize object to a JSON formatted string
Unlike json(), this method only includes hashable fields,
sorts them alphabetically, and optionally rounds floats.
"""
data = self.hash_dict()
dump = self.__config__.json_dumps
if decimals is not None:
for field in self._float_fields:
if field in data:
data[field] = round_floats(data[field], decimals=decimals)
with np.printoptions(precision=16):
return dump(data, default=self.__json_encoder__)
return dump(data, default=self.__json_encoder__)

def _round(self, obj):
return round_floats(obj, decimals=self._float_decimals)

def to_json(self):
return self.json(
sort_keys=True,
indent=2,
separators=(",", ": "),
)

@classmethod
def _from_dict(cls, **kwargs):
dct = {k: kwargs[k] for k in kwargs if k in cls.__fields__}
return cls(**dct)

@classmethod
def validate(cls: Type["MutableModel"], value: Any) -> "MutableModel":
if isinstance(value, dict):
return cls(**value)
elif isinstance(value, cls):
return value
elif cls.__config__.orm_mode:
return cls.from_orm(value)
elif cls.__custom_root_type__:
return cls.parse_obj(value)
else:
try:
value_as_dict = dict(value)
except (TypeError, ValueError) as e:
raise DictError() from e
return cls(**value_as_dict)

@classmethod
def from_json(cls, string_or_file):
try:
with open(string_or_file, "r") as f:
string_or_file = f.read()
except (OSError, FileNotFoundError):
pass
return cls.parse_raw(string_or_file)

def copy(self, *, include=None, exclude=None, update=None, deep: bool = False):
obj = super().copy(include=include, exclude=exclude, update=update, deep=deep)
obj.__post_init__()
return obj

def _replace_from_mapping(self, attr_name, mapping_values={}):
current_value = getattr(self, attr_name)
if current_value in mapping_values:
self._set_attr(attr_name, mapping_values[current_value])
try:
validator = cls.model_validate_json
except AttributeError:
validator = cls.parse_raw
return validator(string_or_file)

def to_yaml(self, filename):
data = json.loads(self.json())
Expand Down
157 changes: 157 additions & 0 deletions openff/nagl/tests/_base/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from openff.nagl._base.base import MutableModel
from openff.units import unit
import numpy as np
import json
import textwrap

try:
from pydantic.v1 import Field, validator
except ImportError:
from pydantic import Field, validator

class TestMutableModel:

class Model(MutableModel):
int_type: int
float_type: float
list_type: list
np_array_type: np.ndarray
tuple_type: tuple
unit_type: unit.Quantity

@validator("np_array_type", pre=True)
def _validate_np_array_type(cls, v):
return np.asarray(v)

@validator("unit_type", pre=True)
def _validate_unit_type(cls, v):
if not isinstance(v, unit.Quantity):
return unit.Quantity.from_tuple(v)
return v


def test_init(self):
model = self.Model(int_type=1, float_type=1.0, list_type=[1, 2, 3], np_array_type=np.array([1, 2, 3]), tuple_type=(1, 2, 3), unit_type=unit.Quantity(1.0, "angstrom"))
assert model.int_type == 1
assert model.float_type == 1.0
assert model.list_type == [1, 2, 3]
assert np.array_equal(model.np_array_type, np.array([1, 2, 3]))
assert model.tuple_type == (1, 2, 3)
assert model.unit_type == unit.Quantity(1.0, "angstrom")

def test_to_json(self):
arr = np.arange(10).reshape(2, 5)
model = self.Model(int_type=1, float_type=1.0, list_type=[1, 2, 3], np_array_type=arr, tuple_type=(1, 2, 3), unit_type=unit.Quantity(1.0, "angstrom"))
json_dict = json.loads(model.to_json())
expected = {
"int_type": 1,
"float_type": 1.0,
"list_type": [1, 2, 3],
"np_array_type": [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9] ],
"tuple_type": [1, 2, 3],
"unit_type": [1.0, [["angstrom", 1]]]
}
assert json_dict == expected

def test_from_json_string(self):
input_text = """
{
"int_type": 4,
"float_type": 10.0,
"list_type": [1, 2, 3],
"np_array_type": [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9] ],
"tuple_type": [1, 2, 3],
"unit_type": [1.0, [["angstrom", 1]]]
}
"""
model = self.Model.from_json(input_text)
assert model.int_type == 4
assert model.float_type == 10.0
assert model.list_type == [1, 2, 3]
arr = np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
assert np.array_equal(model.np_array_type, arr)
assert model.tuple_type == (1, 2, 3)
assert model.unit_type == unit.Quantity(1.0, "angstrom")

def test_from_json_file(self, tmp_path):
input_text = """
{
"int_type": 4,
"float_type": 10.0,
"list_type": [1, 2, 3],
"np_array_type": [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9] ],
"tuple_type": [1, 2, 3],
"unit_type": [1.0, [["angstrom", 1]]]
}
"""
file_path = tmp_path / "test.json"
with open(file_path, "w") as f:
f.write(input_text)
model = self.Model.from_json(file_path)
assert model.int_type == 4
assert model.float_type == 10.0
assert model.list_type == [1, 2, 3]
arr = np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
assert np.array_equal(model.np_array_type, arr)
assert model.tuple_type == (1, 2, 3)
assert model.unit_type == unit.Quantity(1.0, "angstrom")

def test_to_yaml(self, tmp_path):
model = self.Model(int_type=1, float_type=1.0, list_type=[1, 2, 3], np_array_type=np.array([1, 2, 3]), tuple_type=(1, 2, 3), unit_type=unit.Quantity(1.0, "angstrom"))
file_path = tmp_path / "test.yaml"
model.to_yaml(file_path)
with open(file_path, "r") as f:
yaml_text = f.read()
expected = textwrap.dedent("""
float_type: 1.0
int_type: 1
list_type:
- 1
- 2
- 3
np_array_type:
- 1
- 2
- 3
tuple_type:
- 1
- 2
- 3
unit_type:
- 1.0
- - - angstrom
- 1
""")
assert yaml_text.strip() == expected.strip()

def test_from_yaml(self, tmp_path):
input_text = textwrap.dedent("""
float_type: 1.0
int_type: 1
list_type:
- 1
- 2
- 3
np_array_type:
- 1
- 2
- 3
tuple_type:
- 1
- 2
- 3
unit_type:
- 1.0
- - - angstrom
- 1
""")
file_path = tmp_path / "test.yaml"
with open(file_path, "w") as f:
f.write(input_text)
model = self.Model.from_yaml(file_path)
assert model.int_type == 1
assert model.float_type == 1.0
assert model.list_type == [1, 2, 3]
assert np.array_equal(model.np_array_type, np.array([1, 2, 3]))
assert model.tuple_type == (1, 2, 3)
assert model.unit_type == unit.Quantity(1.0, "angstrom")

0 comments on commit 0a3a9c7

Please sign in to comment.