Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.0
hooks:
- id: ruff
args: [ --fix ]
- id: ruff-format
# numpydoc
- repo: https://github.com/Carreau/velin
Expand Down
2 changes: 1 addition & 1 deletion dargs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .dargs import Argument, Variant, ArgumentEncoder
from .dargs import Argument, ArgumentEncoder, Variant

__all__ = ["Argument", "Variant", "ArgumentEncoder"]
41 changes: 24 additions & 17 deletions dargs/dargs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
r"""
Some (ocaml) pseudo-code here to show the intended type structure::
r"""Some (ocaml) pseudo-code here to show the intended type structure.

type args = {key: str; value: data; optional: bool; doc: str} list
and data =
Expand Down Expand Up @@ -40,7 +39,11 @@
HookArgKType = Callable[["Argument", dict, List[str]], None]
HookArgVType = Callable[["Argument", Any, List[str]], None]
HookVrntType = Callable[["Variant", dict, List[str]], None]
_DUMMYHOOK = lambda a, x, p: None # for doing nothing in traversing


def _DUMMYHOOK(a, x, p):
# for doing nothing in traversing
pass


class _Flags(Enum):
Expand Down Expand Up @@ -68,19 +71,19 @@ def __str__(self) -> str:


class ArgumentKeyError(ArgumentError):
"""Error class for missing or invalid argument keys"""
"""Error class for missing or invalid argument keys."""

pass


class ArgumentTypeError(ArgumentError):
"""Error class for invalid argument data types"""
"""Error class for invalid argument data types."""

pass


class ArgumentValueError(ArgumentError):
"""Error class for missing or invalid argument values"""
"""Error class for missing or invalid argument values."""

pass

Expand Down Expand Up @@ -169,8 +172,12 @@ def __init__(
def __eq__(self, other: "Argument") -> bool:
# do not compare doc and default
# since they do not enter to the type checking
fkey = lambda f: f.name
vkey = lambda v: v.flag_name
def fkey(f):
return f.name

def vkey(v):
return v.flag_name

return (
self.name == other.name
and set(self.dtype) == set(other.dtype)
Expand Down Expand Up @@ -204,7 +211,7 @@ def __getitem__(self, key: str) -> "Argument":
return self[skey][rkey]

@property
def I(self):
def I(self): # noqa:E743
# return a dummy argument that only has self as a sub field
# can be used in indexing
return Argument("_", dict, [self])
Expand All @@ -227,7 +234,7 @@ def _reorg_dtype(self):
if (
self.optional
and self.default is not _Flags.NONE
and all([not isinstance_annotation(self.default, tt) for tt in self.dtype])
and all(not isinstance_annotation(self.default, tt) for tt in self.dtype)
):
self.dtype.add(type(self.default))
# and make it compatible with `isinstance`
Expand Down Expand Up @@ -473,7 +480,7 @@ def normalize(
do_alias: bool = True,
trim_pattern: Optional[str] = None,
):
"""Modify `argdict` so that it meets the Argument structure
"""Modify `argdict` so that it meets the Argument structure.

Normalization can add default values to optional args,
substitute alias by its standard names, and discard unnecessary
Expand Down Expand Up @@ -526,7 +533,7 @@ def normalize_value(
do_alias: bool = True,
trim_pattern: Optional[str] = None,
):
"""Modify the value so that it meets the Argument structure
"""Modify the value so that it meets the Argument structure.

Same as `normalize({self.name: value})[self.name]`.

Expand Down Expand Up @@ -611,7 +618,7 @@ def gen_doc_head(self, path: Optional[List[str]] = None, **kwargs) -> str:
if self.optional:
typesig += ", optional"
if self.default == "":
typesig += f", default: (empty string)"
typesig += ", default: (empty string)"
elif self.default is not _Flags.NONE:
typesig += f", default: ``{self.default}``"
if self.alias:
Expand Down Expand Up @@ -733,7 +740,7 @@ def set_default(self, default_tag: Union[bool, str]):
self.default_tag = default_tag

def extend_choices(self, choices: Optional[Iterable["Argument"]]):
"""Add a list of choice Arguments to the current Variant"""
"""Add a list of choice Arguments to the current Variant."""
# choices is a list of arguments
# whose name is treated as the switch tag
# we convert it into a dict for better reference
Expand All @@ -760,7 +767,7 @@ def add_choice(
*args,
**kwargs,
) -> "Argument":
"""Add a choice Argument to the current Variant"""
"""Add a choice Argument to the current Variant."""
if isinstance(tag, Argument):
newarg = tag
else:
Expand Down Expand Up @@ -835,7 +842,7 @@ def gen_doc(
if kwargs.get("make_link"):
if not kwargs.get("make_anchor"):
raise ValueError("`make_link` only works with `make_anchor` set")
fnstr, target = make_ref_pair(path + [self.flag_name], fnstr, "flag")
fnstr, target = make_ref_pair([*path, self.flag_name], fnstr, "flag")
body_list.append(target + "\n")
for choice in self.choice_dict.values():
body_list.append("")
Expand Down Expand Up @@ -986,7 +993,7 @@ def isinstance_annotation(value, dtype) -> bool:


class ArgumentEncoder(json.JSONEncoder):
"""Extended JSON Encoder to encode Argument object:
"""Extended JSON Encoder to encode Argument object.

Examples
--------
Expand Down
50 changes: 23 additions & 27 deletions dargs/sphinx.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
where `_test_argument` returns an :class:`Argument <dargs.Argument>`. A :class:`list` of :class:`Argument <dargs.Argument>` is also accepted.
"""
import sys
from typing import List
from typing import ClassVar, List

from docutils.parsers.rst import Directive
from docutils.parsers.rst.directives import unchanged
Expand All @@ -34,13 +34,13 @@


class DargsDirective(Directive):
"""dargs directive"""
"""dargs directive."""

has_content = True
option_spec = dict(
module=unchanged,
func=unchanged,
)
has_content: ClassVar[bool] = True
option_spec: ClassVar[dict] = {
"module": unchanged,
"func": unchanged,
}

def run(self):
if "module" in self.options and "func" in self.options:
Expand All @@ -58,11 +58,8 @@ def run(self):

if not hasattr(mod, attr_name):
raise self.error(
(
'Module "%s" has no attribute "%s"\n'
"Incorrect argparse :module: or :func: values?"
)
% (module_name, attr_name)
f'Module "{module_name}" has no attribute "{attr_name}"\n'
"Incorrect argparse :module: or :func: values?"
)
func = getattr(mod, attr_name)
arguments = func()
Expand All @@ -78,7 +75,7 @@ def run(self):
make_anchor=True, make_link=True, use_sphinx_domain=True
)
rsts.extend(rst.split("\n"))
self.state_machine.insert_input(rsts, "%s:%s" % (module_name, attr_name))
self.state_machine.insert_input(rsts, f"{module_name}:{attr_name}")
return []


Expand All @@ -88,17 +85,17 @@ class DargsObject(ObjectDescription):
This directive creates a signature node for an argument.
"""

option_spec = dict(
path=unchanged,
)
option_spec: ClassVar[dict] = {
"path": unchanged,
}

def handle_signature(self, sig, signode):
signode += addnodes.desc_name(sig, sig)
return sig

def add_target_and_index(self, name, sig, signode):
path = self.options["path"]
targetid = "%s:%s" % (self.objtype, path)
targetid = f"{self.objtype}:{path}"
if targetid not in self.state.document.ids:
signode["names"].append(targetid)
signode["ids"].append(targetid)
Expand All @@ -108,16 +105,15 @@ def add_target_and_index(self, name, sig, signode):
inv = self.env.domaindata["dargs"]["arguments"]
if targetid in inv:
self.state.document.reporter.warning(
'Duplicated argument "%s" described in "%s".'
% (targetid, self.env.doc2path(inv[targetid][0])),
f'Duplicated argument "{targetid}" described in "{self.env.doc2path(inv[targetid][0])}".',
line=self.lineno,
)
inv[targetid] = (self.env.docname, self.objtype)

self.indexnode["entries"].append(
(
"pair",
"%s ; %s (%s) " % (name, path, self.objtype.title()),
f"{name}; {path} ({self.objtype.title()})",
targetid,
"main",
None,
Expand All @@ -133,25 +129,25 @@ class DargsDomain(Domain):
- dargs::argument role
"""

name = "dargs"
label = "dargs"
object_types = {
name: ClassVar[str] = "dargs"
label: ClassVar[str] = "dargs"
object_types: ClassVar[dict] = {
"argument": ObjType("argument", "argument"),
}
directives = {
directives: ClassVar[dict] = {
"argument": DargsObject,
}
roles = {
roles: ClassVar[dict] = {
"argument": XRefRole(),
}

initial_data = {
initial_data: ClassVar[dict] = {
"arguments": {}, # fullname -> docname, objtype
}

def resolve_xref(self, env, fromdocname, builder, typ, target, node, contnode):
"""Resolve cross-references."""
targetid = "%s:%s" % (typ, target)
targetid = f"{typ}:{target}"
obj = self.data["arguments"].get(targetid)
if obj is None:
return None
Expand Down
1 change: 0 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
#
# Configuration file for the Sphinx documentation builder.
#
Expand Down
31 changes: 31 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,34 @@ include = ["dargs*"]

[tool.setuptools_scm]
write_to = "dargs/_version.py"

[tool.ruff.lint]
select = [
"E", # errors
"F", # pyflakes
"D", # pydocstyle
"UP", # pyupgrade
"C4", # flake8-comprehensions
"RUF", # ruff
"I", # isort
]

ignore = [
"E501", # line too long
"F841", # local variable is assigned to but never used
"E741", # ambiguous variable name
"E402", # module level import not at top of file
"D100", # TODO: missing docstring in public module
"D101", # TODO: missing docstring in public class
"D102", # TODO: missing docstring in public method
"D103", # TODO: missing docstring in public function
"D104", # TODO: missing docstring in public package
"D105", # TODO: missing docstring in magic method
"D205", # 1 blank line required between summary line and description
"D401", # TODO: first line should be in imperative mood
"D404", # TODO: first word of the docstring should not be This
]
ignore-init-module-imports = true

[tool.ruff.lint.pydocstyle]
convention = "numpy"
4 changes: 4 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import os
import sys

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
4 changes: 0 additions & 4 deletions tests/context.py
Original file line number Diff line number Diff line change
@@ -1,4 +0,0 @@
import sys, os

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import dargs
12 changes: 6 additions & 6 deletions tests/dpmdargs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .context import dargs
from dargs import dargs, Argument, Variant
from dargs import Argument, Variant, dargs

ACTIVATION_FN_DICT = {
"relu": None,
Expand Down Expand Up @@ -140,7 +139,8 @@ def descrpt_se_a_tpe_args():
doc_type_nlayer = "number of hidden layers of type embedding net"
doc_numb_aparam = "dimension of atomic parameter. if set to a value > 0, the atomic parameters are embedded."

return descrpt_se_a_args() + [
return [
*descrpt_se_a_args(),
Argument("type_nchanl", int, optional=True, default=4, doc=doc_type_nchanl),
Argument("type_nlayer", int, optional=True, default=2, doc=doc_type_nlayer),
Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam),
Expand Down Expand Up @@ -202,7 +202,7 @@ def descrpt_se_ar_args():


def descrpt_hybrid_args():
doc_list = f"A list of descriptor definitions"
doc_list = "A list of descriptor definitions"

return [
Argument(
Expand Down Expand Up @@ -243,7 +243,7 @@ def descrpt_variant_type_args():
link_se_a_3be = make_link("se_a_3be", "model/descriptor[se_a_3be]")
link_se_a_tpe = make_link("se_a_tpe", "model/descriptor[se_a_tpe]")
link_hybrid = make_link("hybrid", "model/descriptor[hybrid]")
doc_descrpt_type = f"The type of the descritpor. See explanation below. \n\n\
doc_descrpt_type = "The type of the descritpor. See explanation below. \n\n\
- `loc_frame`: Defines a local frame at each atom, and the compute the descriptor as local coordinates under this frame.\n\n\
- `se_a`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor.\n\n\
- `se_r`: Used by the smooth edition of Deep Potential. Only the distance between atoms is used to construct the descriptor.\n\n\
Expand Down Expand Up @@ -391,7 +391,7 @@ def modifier_dipole_charge():
doc_model_name = "The name of the frozen dipole model file."
doc_model_charge_map = f"The charge of the WFCC. The list length should be the same as the {make_link('sel_type', 'model/fitting_net[dipole]/sel_type')}. "
doc_sys_charge_map = f"The charge of real atoms. The list length should be the same as the {make_link('type_map', 'model/type_map')}"
doc_ewald_h = f"The grid spacing of the FFT grid. Unit is A"
doc_ewald_h = "The grid spacing of the FFT grid. Unit is A"
doc_ewald_beta = f"The splitting parameter of Ewald sum. Unit is A^{-1}"

return [
Expand Down
4 changes: 2 additions & 2 deletions tests/test_checker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List
from .context import dargs
import unittest
from typing import List

from dargs import Argument, Variant
from dargs.dargs import ArgumentKeyError, ArgumentTypeError, ArgumentValueError

Expand Down
2 changes: 1 addition & 1 deletion tests/test_creation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .context import dargs
import unittest

from dargs import Argument, Variant


Expand Down
Loading