Skip to content

Commit

Permalink
Merge pull request #642 from DanielYang59/type
Browse files Browse the repository at this point in the history
Add type annotations and reformat docstring to Google style
  • Loading branch information
shyuep committed Apr 11, 2024
2 parents dbaba61 + 9e8992a commit a11cd9f
Show file tree
Hide file tree
Showing 54 changed files with 597 additions and 367 deletions.
2 changes: 2 additions & 0 deletions monty/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
useful design patterns such as singleton and cached_class, and many more.
"""

from __future__ import annotations

__author__ = "Shyue Ping Ong"
__copyright__ = "Copyright 2014, The Materials Virtual Lab"
__version__ = "2024.3.31"
Expand Down
27 changes: 15 additions & 12 deletions monty/bisect.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
lists.
"""

from __future__ import annotations

import bisect as bs
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Optional

__author__ = "Matteo Giantomassi"
__copyright__ = "Copyright 2013, The Materials Virtual Lab"
Expand All @@ -17,44 +23,41 @@
__date__ = "11/09/14"


def index(a, x, atol=None):
def index(a: list[float], x: float, atol: Optional[float] = None) -> int:
"""Locate the leftmost value exactly equal to x."""
i = bs.bisect_left(a, x)
if i != len(a):
if atol is None:
if a[i] == x:
return i
else:
if abs(a[i] - x) < atol:
return i
elif abs(a[i] - x) < atol:
return i
raise ValueError


def find_lt(a, x):
def find_lt(a: list[float], x: float) -> int:
"""Find rightmost value less than x."""
i = bs.bisect_left(a, x)
if i:
if i := bs.bisect_left(a, x):
return i - 1
raise ValueError


def find_le(a, x):
def find_le(a: list[float], x: float) -> int:
"""Find rightmost value less than or equal to x."""
i = bs.bisect_right(a, x)
if i:
if i := bs.bisect_right(a, x):
return i - 1
raise ValueError


def find_gt(a, x):
def find_gt(a: list[float], x: float) -> int:
"""Find leftmost value greater than x."""
i = bs.bisect_right(a, x)
if i != len(a):
return i
raise ValueError


def find_ge(a, x):
def find_ge(a: list[float], x: float) -> int:
"""Find leftmost item greater than or equal to x."""
i = bs.bisect_left(a, x)
if i != len(a):
Expand Down
105 changes: 59 additions & 46 deletions monty/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,18 @@
Useful collection classes, e.g., tree, frozendict, etc.
"""

from __future__ import annotations

import collections
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Any, Iterable

from typing_extensions import Self


def tree():
def tree() -> collections.defaultdict:
"""
A tree object, which is effectively a recursive defaultdict that
adds tree as members.
Expand All @@ -26,44 +34,48 @@ class frozendict(dict):
violates PEP8 to be consistent with standard Python's "frozenset" naming.
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
"""
:param args: Passthrough arguments for standard dict.
:param kwargs: Passthrough keyword arguments for standard dict.
Args:
args: Passthrough arguments for standard dict.
kwargs: Passthrough keyword arguments for standard dict.
"""
dict.__init__(self, *args, **kwargs)

def __setitem__(self, key, val):
raise KeyError(f"Cannot overwrite existing key: {key!s}")
def __setitem__(self, key: Any, val: Any) -> None:
raise KeyError(f"Cannot overwrite existing key: {str(key)}")

def update(self, *args, **kwargs):
def update(self, *args, **kwargs) -> None:
"""
:param args: Passthrough arguments for standard dict.
:param kwargs: Passthrough keyword arguments for standard dict.
Args:
args: Passthrough arguments for standard dict.
kwargs: Passthrough keyword arguments for standard dict.
"""
raise KeyError(f"Cannot update a {self.__class__.__name__}")


class Namespace(dict):
"""A dictionary that does not permit to redefine its keys."""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
"""
:param args: Passthrough arguments for standard dict.
:param kwargs: Passthrough keyword arguments for standard dict.
Args:
args: Passthrough arguments for standard dict.
kwargs: Passthrough keyword arguments for standard dict.
"""
self.update(*args, **kwargs)

def __setitem__(self, key, val):
def __setitem__(self, key: Any, val: Any) -> None:
if key in self:
raise KeyError(f"Cannot overwrite existent key: {key!s}")

dict.__setitem__(self, key, val)

def update(self, *args, **kwargs):
def update(self, *args, **kwargs) -> None:
"""
:param args: Passthrough arguments for standard dict.
:param kwargs: Passthrough keyword arguments for standard dict.
Args:
args: Passthrough arguments for standard dict.
kwargs: Passthrough keyword arguments for standard dict.
"""
for k, v in dict(*args, **kwargs).items():
self[k] = v
Expand All @@ -74,24 +86,26 @@ class AttrDict(dict):
Allows to access dict keys as obj.foo in addition
to the traditional way obj['foo']"
Example:
Examples:
>>> d = AttrDict(foo=1, bar=2)
>>> assert d["foo"] == d.foo
>>> d.bar = "hello"
>>> assert d.bar == "hello"
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
"""
:param args: Passthrough arguments for standard dict.
:param kwargs: Passthrough keyword arguments for standard dict.
Args:
args: Passthrough arguments for standard dict.
kwargs: Passthrough keyword arguments for standard dict.
"""
super().__init__(*args, **kwargs)
self.__dict__ = self

def copy(self):
def copy(self) -> Self:
"""
:return: Copy of AttrDict
Returns:
Copy of AttrDict
"""
newd = super().copy()
return self.__class__(**newd)
Expand All @@ -105,14 +119,15 @@ class FrozenAttrDict(frozendict):
to the traditional way obj['foo']
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
"""
:param args: Passthrough arguments for standard dict.
:param kwargs: Passthrough keyword arguments for standard dict.
Args:
args: Passthrough arguments for standard dict.
kwargs: Passthrough keyword arguments for standard dict.
"""
super().__init__(*args, **kwargs)

def __getattribute__(self, name):
def __getattribute__(self, name: str) -> Any:
try:
return super().__getattribute__(name)
except AttributeError:
Expand All @@ -121,7 +136,7 @@ def __getattribute__(self, name):
except KeyError as exc:
raise AttributeError(str(exc))

def __setattr__(self, name, value):
def __setattr__(self, name: str, value: Any) -> None:
raise KeyError(
f"You cannot modify attribute {name} of {self.__class__.__name__}"
)
Expand All @@ -142,32 +157,32 @@ class MongoDict:
>>> m["a"]
{'b': 1}
.. note::
Notes:
Cannot inherit from ABC collections.Mapping because otherwise
dict.keys and dict.items will pollute the namespace.
e.g MongoDict({"keys": 1}).keys would be the ABC dict method.
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
"""
:param args: Passthrough arguments for standard dict.
:param kwargs: Passthrough keyword arguments for standard dict.
Args:
args: Passthrough arguments for standard dict.
kwargs: Passthrough keyword arguments for standard dict.
"""
self.__dict__["_mongo_dict_"] = dict(*args, **kwargs)

def __repr__(self):
def __repr__(self) -> str:
return str(self)

def __str__(self):
def __str__(self) -> str:
return str(self._mongo_dict_)

def __setattr__(self, name, value):
def __setattr__(self, name: str, value: Any) -> None:
raise NotImplementedError(
f"You cannot modify attribute {name} of {self.__class__.__name__}"
)

def __getattribute__(self, name):
def __getattribute__(self, name: str) -> Any:
try:
return super().__getattribute__(name)
except AttributeError:
Expand All @@ -180,37 +195,35 @@ def __getattribute__(self, name):
except Exception as exc:
raise AttributeError(str(exc))

def __getitem__(self, slice_):
def __getitem__(self, slice_) -> Any:
return self._mongo_dict_.__getitem__(slice_)

def __iter__(self):
def __iter__(self) -> Iterable:
return iter(self._mongo_dict_)

def __len__(self):
def __len__(self) -> int:
return len(self._mongo_dict_)

def __dir__(self):
def __dir__(self) -> list:
"""
For Ipython tab completion.
See http://ipython.org/ipython-doc/dev/config/integrating.html
"""
return sorted(k for k in self._mongo_dict_ if not callable(k))


def dict2namedtuple(*args, **kwargs):
def dict2namedtuple(*args, **kwargs) -> tuple:
"""
Helper function to create a :class:`namedtuple` from a dictionary.
Example:
Helper function to create a class `namedtuple` from a dictionary.
Examples:
>>> t = dict2namedtuple(foo=1, bar="hello")
>>> assert t.foo == 1 and t.bar == "hello"
>>> t = dict2namedtuple([("foo", 1), ("bar", "hello")])
>>> assert t[0] == t.foo and t[1] == t.bar
.. warning::
Warnings:
- The order of the items in the namedtuple is not deterministic if
kwargs are used.
namedtuples, however, should always be accessed by attribute hence
Expand Down
11 changes: 3 additions & 8 deletions monty/design_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ def singleton(cls):
"""
This decorator can be used to create a singleton out of a class.
Usage::
Usage:
@singleton
class MySingleton():
Expand Down Expand Up @@ -63,10 +62,7 @@ class _decorated(klass): # type: ignore

def __new__(cls, *args, **kwargs):
"""
Pass through...
:param args:
:param kwargs:
:return:
Pass through.
"""
key = (cls, *args, *tuple(kwargs.items()))
try:
Expand Down Expand Up @@ -107,7 +103,7 @@ class NullFile:

def __new__(cls):
"""
Pass through
Pass through.
"""
return open(os.devnull, "w") # pylint: disable=R1732

Expand All @@ -121,5 +117,4 @@ class NullStream:
def write(self, *args): # pylint: disable=E0211
"""
Does nothing...
:param args:
"""
Loading

0 comments on commit a11cd9f

Please sign in to comment.