Skip to content
Permalink
Browse files
Improve command suggestions (new algorithm)
Signed-off-by: Chris Warrick <kwpolska@gmail.com>
  • Loading branch information
Kwpolska committed Jul 8, 2015
1 parent c3952b1 commit 2466123d33faa37dbcea2961728f99b090392586
Showing with 34 additions and 9 deletions.
  1. +34 −9 nikola/__main__.py
@@ -317,10 +317,15 @@ def run(self, cmd_args):
if args[0] not in sub_cmds.keys():
LOGGER.error("Unknown command {0}".format(args[0]))
sugg = defaultdict(list)
for c in sub_cmds.keys():
d = lev(c, args[0])
sub_filtered = (i for i in sub_cmds.keys() if i != 'run')
for c in sub_filtered:
d = levenshtein(c, args[0])
sugg[d].append(c)
LOGGER.info('Did you mean "{}"?', '" or "'.join(sugg[min(sugg.keys())]))
best_sugg = sugg[min(sugg.keys())]
if len(best_sugg) == 1:
LOGGER.info('Did you mean "{}"?'.format(best_sugg[0]))
else:
LOGGER.info('Did you mean "{}" or "{}"?'.format('", "'.join(best_sugg[:-1]), best_sugg[-1]))
return 3
if sub_cmds[args[0]] is not Help and not isinstance(sub_cmds[args[0]], Command): # Is a doit command
if not self.nikola.configured:
@@ -334,12 +339,32 @@ def print_version():
print("Nikola v" + __version__)


# Stolen from http://stackoverflow.com/questions/4173579/implementing-levenshtein-distance-in-python
def lev(a, b):
if not a or not b:
return max(len(a), len(b))
return min(lev(a[1:], b[1:]) + (a[0] != b[0]), lev(a[1:], b) + 1, lev(a, b[1:]) + 1)

def levenshtein(s1, s2):
"""Calculate the Levenshtein distance of two strings.
Implementation from Wikibooks:
https://en.wikibooks.org/w/index.php?title=Algorithm_Implementation/Strings/Levenshtein_distance&oldid=2974448#Python
Copyright © The Wikibooks contributors (CC BY-SA/fair use citation); edited to match coding style and add an exception.
"""
if len(s1) < len(s2):
return levenshtein(s2, s1)

# len(s1) >= len(s2)
if len(s2) == 0:
return len(s1)

previous_row = range(len(s2) + 1)
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
# j+1 instead of j since previous_row and current_row are one character longer than s2
insertions = previous_row[j + 1] + 1
deletions = current_row[j] + 1
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row

return previous_row[-1]

if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))

0 comments on commit 2466123

Please sign in to comment.