Skip to content

Commit

Permalink
feat: did you know for variant chocie (#47)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] committed Mar 27, 2024
1 parent c7f79a3 commit d557ca0
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
27 changes: 26 additions & 1 deletion dargs/dargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
We also need to pay special attention to flat the keys of its choices.
"""

import difflib
import fnmatch
import json
import re
Expand Down Expand Up @@ -800,7 +801,12 @@ def get_choice(self, argdict: dict, path=None) -> "Argument":
return self.choice_dict[self.choice_alias[tag]]
else:
raise ArgumentValueError(
path, f"get invalid choice `{tag}` for flag key `{self.flag_name}`."
path,
f"get invalid choice `{tag}` for flag key `{self.flag_name}`."
+ did_you_mean(
tag,
list(self.choice_dict.keys()) + list(self.choice_alias.keys()),
),
)
elif self.optional:
return self.choice_dict[self.default_tag]
Expand Down Expand Up @@ -1042,3 +1048,22 @@ def default(self, obj) -> Dict[str, Union[str, bool, List]]:
elif isinstance(obj, type):
return obj.__name__
return json.JSONEncoder.default(self, obj)


def did_you_mean(choice: str, choices: List[str]) -> str:
"""Get did you mean message.
Parameters
----------
choice : str
the user's wrong choice
choices : list[str]
all the choices
Returns
-------
str
did you mean error message
"""
matches = difflib.get_close_matches(choice, choices)
return f"Did you mean: {matches[0]}?" if matches else ""
3 changes: 2 additions & 1 deletion tests/test_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,9 @@ def test_sub_variants(self):
"vnt2_1": 21,
}
}
with self.assertRaises(ArgumentValueError):
with self.assertRaises(ArgumentValueError) as cm:
ca.check(err_dict2)
self.assertIn("Did you mean: type3?", str(cm.exception))
# test optional choice
test_dict1["base"].pop("vnt_flag")
with self.assertRaises(ArgumentKeyError):
Expand Down

0 comments on commit d557ca0

Please sign in to comment.