Skip to content

Commit d934085

Browse files
feat: allow repeat with dtype=dict (#67)
This will be helpful to normalize `model_dict`, `loss_dict`, etc, in DeePMD-kit. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced handling of data types based on the `repeat` flag, supporting both lists of dicts and dicts of dicts. - Improved documentation generation reflecting updated data type handling. - **Bug Fixes** - Corrected traversal logic for values to ensure proper handling of dictionaries and lists based on the `repeat` flag. - **Tests** - Updated and added new test cases to cover changes in data type handling and traversal logic. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent a02640c commit d934085

File tree

5 files changed

+127
-12
lines changed

5 files changed

+127
-12
lines changed

dargs/dargs.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class Argument:
109109
If given, `dtype` is assumed to be dict, and its items are determined
110110
by the `Variant`s in the given list and the value of their flag keys.
111111
repeat: bool, optional
112-
If true, `dtype` is assume to be list of dict and each dict consists
112+
If true, `dtype` is assume to be list of dict or dict of dict, and each dict consists
113113
of sub fields and sub variants described above. Defaults to false.
114114
optional: bool, optional
115115
If true, consider the current argument to be optional in checking.
@@ -235,7 +235,17 @@ def _reorg_dtype(
235235
}
236236
# check conner cases
237237
if self.sub_fields or self.sub_variants:
238-
dtype.add(list if self.repeat else dict)
238+
if not self.repeat:
239+
dtype.add(dict)
240+
else:
241+
# convert dtypes to unsubscripted types
242+
unsubscripted_dtype = {
243+
get_origin(dt) if get_origin(dt) is not None else dt for dt in dtype
244+
}
245+
if dict not in unsubscripted_dtype:
246+
# only add list (compatible with old behaviors) if no dict in dtype
247+
dtype.add(list)
248+
239249
if (
240250
self.optional
241251
and self.default is not _Flags.NONE
@@ -347,11 +357,11 @@ def traverse_value(
347357
# in the condition where there is no leading key
348358
if path is None:
349359
path = []
350-
if isinstance(value, dict):
360+
if not self.repeat and isinstance(value, dict):
351361
self._traverse_sub(
352362
value, key_hook, value_hook, sub_hook, variant_hook, path
353363
)
354-
if isinstance(value, list) and self.repeat:
364+
elif self.repeat and isinstance(value, list):
355365
for idx, item in enumerate(value):
356366
self._traverse_sub(
357367
item,
@@ -361,6 +371,16 @@ def traverse_value(
361371
variant_hook,
362372
[*path, str(idx)],
363373
)
374+
elif self.repeat and isinstance(value, dict):
375+
for kk, item in value.items():
376+
self._traverse_sub(
377+
item,
378+
key_hook,
379+
value_hook,
380+
sub_hook,
381+
variant_hook,
382+
[*path, kk],
383+
)
364384

365385
def _traverse_sub(
366386
self,
@@ -653,9 +673,22 @@ def gen_doc_body(self, path: list[str] | None = None, **kwargs) -> str:
653673
body_list.append(self.doc + "\n")
654674
if not self.fold_subdoc:
655675
if self.repeat:
676+
unsubscripted_dtype = {
677+
get_origin(dt) if get_origin(dt) is not None else dt
678+
for dt in self.dtype
679+
}
680+
allowed_types = []
681+
allowed_element = []
682+
if list in unsubscripted_dtype or dict in unsubscripted_dtype:
683+
if list in unsubscripted_dtype:
684+
allowed_types.append("list")
685+
allowed_element.append("element")
686+
if dict in unsubscripted_dtype:
687+
allowed_types.append("dict")
688+
allowed_element.append("key-value pair")
656689
body_list.append(
657-
"This argument takes a list with "
658-
"each element containing the following: \n"
690+
f"This argument takes a {' or '.join(allowed_types)} with "
691+
f"each {' or '.join(allowed_element)} containing the following: \n"
659692
)
660693
if self.sub_fields:
661694
# body_list.append("") # genetate a blank line

dargs/sphinx.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def _test_argument() -> Argument:
190190
doc=doc_test,
191191
sub_fields=[
192192
Argument(
193-
"test_repeat",
193+
"test_repeat_list",
194194
dtype=list,
195195
repeat=True,
196196
doc=doc_test,
@@ -199,7 +199,18 @@ def _test_argument() -> Argument:
199199
"test_repeat_item", dtype=bool, doc=doc_test
200200
),
201201
],
202-
)
202+
),
203+
Argument(
204+
"test_repeat_dict",
205+
dtype=dict,
206+
repeat=True,
207+
doc=doc_test,
208+
sub_fields=[
209+
Argument(
210+
"test_repeat_item", dtype=bool, doc=doc_test
211+
),
212+
],
213+
),
203214
],
204215
),
205216
],

tests/test_checker.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ def test_sub_fields(self):
100100
with self.assertRaises(ValueError):
101101
Argument("base", dict, [Argument("sub1", int), Argument("sub1", int)])
102102

103-
def test_sub_repeat(self):
103+
def test_sub_repeat_list(self):
104104
ca = Argument(
105-
"base", dict, [Argument("sub1", int), Argument("sub2", str)], repeat=True
105+
"base", list, [Argument("sub1", int), Argument("sub2", str)], repeat=True
106106
)
107107
test_dict1 = {
108108
"base": [{"sub1": 10, "sub2": "hello"}, {"sub1": 11, "sub2": "world"}]
@@ -124,6 +124,39 @@ def test_sub_repeat(self):
124124
with self.assertRaises(ArgumentTypeError):
125125
ca.check(err_dict2)
126126

127+
def test_sub_repeat_dict(self):
128+
ca = Argument(
129+
"base", dict, [Argument("sub1", int), Argument("sub2", str)], repeat=True
130+
)
131+
test_dict1 = {
132+
"base": {
133+
"item1": {"sub1": 10, "sub2": "hello"},
134+
"item2": {"sub1": 11, "sub2": "world"},
135+
}
136+
}
137+
ca.check(test_dict1)
138+
ca.check_value(test_dict1["base"])
139+
err_dict1 = {
140+
"base": {
141+
"item1": {"sub1": 10, "sub2": "hello"},
142+
"item2": {"sub1": 11, "sub3": "world"},
143+
}
144+
}
145+
with self.assertRaises(ArgumentKeyError):
146+
ca.check(err_dict1)
147+
err_dict1["base"]["item2"]["sub2"] = "world too"
148+
ca.check(err_dict1) # now should pass
149+
with self.assertRaises(ArgumentKeyError):
150+
ca.check(err_dict1, strict=True) # but should fail when strict
151+
err_dict2 = {
152+
"base": {
153+
"item1": {"sub1": 10, "sub2": "hello"},
154+
"item2": {"sub1": 11, "sub2": None},
155+
}
156+
}
157+
with self.assertRaises(ArgumentTypeError):
158+
ca.check(err_dict2)
159+
127160
def test_sub_variants(self):
128161
ca = Argument(
129162
"base",

tests/test_docgen.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_sub_fields(self):
4141
jsonstr = json.dumps(ca, cls=ArgumentEncoder)
4242
# print("\n\n"+docstr)
4343

44-
def test_sub_repeat(self):
44+
def test_sub_repeat_list(self):
4545
ca = Argument(
4646
"base",
4747
list,
@@ -70,6 +70,34 @@ def test_sub_repeat(self):
7070
jsonstr = json.dumps(ca, cls=ArgumentEncoder)
7171
# print("\n\n"+docstr)
7272

73+
def test_sub_repeat_dict(self):
74+
ca = Argument(
75+
"base",
76+
dict,
77+
[
78+
Argument("sub1", int, doc="sub doc." * 5),
79+
Argument(
80+
"sub2",
81+
[None, str, dict],
82+
[
83+
Argument("subsub1", int, doc="subsub doc." * 5, optional=True),
84+
Argument(
85+
"subsub2",
86+
dict,
87+
[Argument("subsubsub1", int, doc="subsubsub doc." * 5)],
88+
doc="subsub doc." * 5,
89+
repeat=True,
90+
),
91+
],
92+
doc="sub doc." * 5,
93+
),
94+
],
95+
doc="Base doc. " * 10,
96+
repeat=True,
97+
)
98+
docstr = ca.gen_doc()
99+
jsonstr = json.dumps(ca, cls=ArgumentEncoder)
100+
73101
def test_sub_variants(self):
74102
ca = Argument(
75103
"base",

tests/test_normalizer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,13 @@ def test_complicated(self):
9393
repeat=True,
9494
alias=["sub2a"],
9595
),
96+
Argument(
97+
"sub2_dict",
98+
dict,
99+
[Argument("ss1", int, optional=True, default=21, alias=["ss1a"])],
100+
repeat=True,
101+
alias=["sub2a_dict"],
102+
),
96103
],
97104
[
98105
Variant(
@@ -145,11 +152,12 @@ def test_complicated(self):
145152
)
146153
],
147154
)
148-
beg1 = {"base": {"sub2": [{}, {}]}}
155+
beg1 = {"base": {"sub2": [{}, {}], "sub2_dict": {"item1": {}, "item2": {}}}}
149156
ref1 = {
150157
"base": {
151158
"sub1": 1,
152159
"sub2": [{"ss1": 21}, {"ss1": 21}],
160+
"sub2_dict": {"item1": {"ss1": 21}, "item2": {"ss1": 21}},
153161
"vnt_flag": "type1",
154162
"shared": -1,
155163
"vnt1": 111,
@@ -161,6 +169,7 @@ def test_complicated(self):
161169
"base": {
162170
"sub1a": 2,
163171
"sub2a": [{"ss1a": 22}, {"_comment1": None}],
172+
"sub2a_dict": {"item1": {"ss1a": 22}, "item2": {"_comment1": None}},
164173
"vnt_flag": "type3",
165174
"sharedb": -3,
166175
"vnt2a": 223,
@@ -172,6 +181,7 @@ def test_complicated(self):
172181
"base": {
173182
"sub1": 2,
174183
"sub2": [{"ss1": 22}, {"ss1": 21}],
184+
"sub2_dict": {"item1": {"ss1": 22}, "item2": {"ss1": 21}},
175185
"vnt_flag": "type2",
176186
"shared": -3,
177187
"vnt2": 223,

0 commit comments

Comments
 (0)