Skip to content

Commit

Permalink
Listify names and remove sort argument
Browse files Browse the repository at this point in the history
  • Loading branch information
bgyori committed Mar 17, 2020
1 parent 3aab21e commit 677620f
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions gilda/grounder.py
Expand Up @@ -87,7 +87,7 @@ def _generate_lookups(self, raw_str):
', '.join(lookups))
return lookups

def ground(self, raw_str, context=None, sort=False):
def ground(self, raw_str, context=None):
"""Return scored groundings for a given raw string.
Parameters
Expand All @@ -99,13 +99,12 @@ def ground(self, raw_str, context=None, sort=False):
Any additional text that serves as context for disambiguating the
given entity text, used if a model exists for disambiguating the
given text.
sort: bool
Should the matches be sorted by score before being returned?
Returns
-------
list[gilda.grounder.ScoredMatch]
A list of ScoredMatch objects representing the groundings.
A list of ScoredMatch objects representing the groundings sorted
by decreasing score.
"""
entries = self.lookup(raw_str)
logger.info('Comparing %s with %d entries' %
Expand All @@ -121,15 +120,22 @@ def ground(self, raw_str, context=None, sort=False):
sc = score(match, term)
scored_match = ScoredMatch(term, sc, match)
scored_matches.append(scored_match)

# Return early if we don't have anything to avoid calling other
# functions with no matches
if not scored_matches:
return scored_matches

# Merge equivalent matches
unique_scores = self._merge_equivalent_matches(scored_matches)

# If there's context available, disambiguate based on that
if context:
unique_scores = self.disambiguate(raw_str, unique_scores, context)

if sort:
unique_scores = sorted(unique_scores,
key=lambda x: x.score,
reverse=True)
# Then sort by decreasing score
unique_scores = sorted(unique_scores, key=lambda x: x.score,
reverse=True)

return unique_scores

Expand Down Expand Up @@ -258,7 +264,7 @@ def get_names(self, db, id, status=None, source=None):
(not status or entry.status == status) and \
(not source or entry.source == source):
names.add(entry.text)
return sorted(names)
return sorted(list(names))


class ScoredMatch(object):
Expand Down

0 comments on commit 677620f

Please sign in to comment.