Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dargs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .dargs import Argument, Variant
from .dargs import Argument, Variant, ArgumentEncoder

__all__ = ["Argument", "Variant"]
__all__ = ["Argument", "Variant", "ArgumentEncoder"]
50 changes: 49 additions & 1 deletion dargs/dargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from copy import deepcopy
from enum import Enum
import fnmatch, re
import json


INDENT = " " # doc is indented by four spaces
Expand Down Expand Up @@ -632,4 +633,51 @@ def trim_by_pattern(argdict: dict, pattern: str,
f"following reserved names: {', '.join(conflict)}")
unrequired = list(filter(rem.match, argdict.keys()))
for key in unrequired:
argdict.pop(key)
argdict.pop(key)


class ArgumentEncoder(json.JSONEncoder):
"""Extended JSON Encoder to encode Argument object:

Examples
--------
>>> json.dumps(some_arg, cls=ArgumentEncoder)
"""
def default(self, obj) -> Dict[str, Union[str, bool, List]]:
"""Generate a dict containing argument information, making it ready to be encoded
to JSON string.

Note
----
All object in the dict should be JSON serializable.

Returns
-------
dict: Dict
a dict containing argument information
"""
if isinstance(obj, Argument):
return {
"object": "Argument",
"name": obj.name,
"type": obj.dtype,
"optional": obj.optional,
"alias": obj.alias,
"doc": obj.doc,
"repeat": obj.repeat,
"sub_fields": obj.sub_fields,
"sub_variants": obj.sub_variants,
}
elif isinstance(obj, Variant):
return {
"object": "Variant",
"flag_name": obj.flag_name,
"optional": obj.optional,
"default_tag": obj.default_tag,
"choice_dict": obj.choice_dict,
"choice_alias": obj.choice_alias,
"doc": obj.doc,
}
elif isinstance(obj, type):
return obj.__name__
return json.JSONEncoder.default(self, obj)
7 changes: 6 additions & 1 deletion tests/test_docgen.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from context import dargs
import unittest
from dargs import Argument, Variant
import json
from dargs import Argument, Variant, ArgumentEncoder


class TestDocgen(unittest.TestCase):
Expand All @@ -16,6 +17,7 @@ def test_sub_fields(self):
], doc="sub doc." * 5)
], doc="Base doc. " * 10)
docstr = ca.gen_doc()
jsonstr = json.dumps(ca, cls=ArgumentEncoder)
# print("\n\n"+docstr)

def test_sub_repeat(self):
Expand All @@ -29,6 +31,7 @@ def test_sub_repeat(self):
], doc="sub doc." * 5)
], doc="Base doc. " * 10, repeat=True)
docstr = ca.gen_doc()
jsonstr = json.dumps(ca, cls=ArgumentEncoder)
# print("\n\n"+docstr)

def test_sub_variants(self):
Expand Down Expand Up @@ -66,6 +69,7 @@ def test_sub_variants(self):
], optional=True, default_tag="type1", doc="another vnt")
])
docstr = ca.gen_doc(make_anchor=True)
jsonstr = json.dumps(ca, cls=ArgumentEncoder)
# print("\n\n"+docstr)

def test_multi_variants(self):
Expand Down Expand Up @@ -110,6 +114,7 @@ def test_multi_variants(self):
])
])
docstr = ca.gen_doc()
jsonstr = json.dumps(ca, cls=ArgumentEncoder)
# print("\n\n"+docstr)

def test_dpmd(self):
Expand Down