Skip to content

Commit

Permalink
refactor: add found synonyms to output
Browse files Browse the repository at this point in the history
  • Loading branch information
iwpnd committed Sep 13, 2021
1 parent e57150f commit fb06232
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 37 deletions.
52 changes: 27 additions & 25 deletions flashgeotext/geotext.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,30 @@ class GeoText(LookupDataPool):
to cut tariffs on $75 billion worth of goods that the country
imports from the US. Washington welcomes the decision.'''
geotext.extract(input_text=input_text, span_info=True)
geotext.extract(input_text=input_text)
>> {
'cities': {
'Shanghai': {
'count': 2,
'span_info': [(0, 8), (45, 53)]
'found_as': ['Shanghai', 'Shanghai']
},
'Washington, D.C.': {
'count': 1,
'span_info': [(175, 185)]
'found_as': ['Washington']
}
},
'countries': {
'China': {
'count': 1,
'span_info': [(64, 69)]
'found_as': ['China']
},
'United States': {
'count': 1,
'span_info': [(171, 173)]
'found_as': ['US']
}
}
}
Expand Down Expand Up @@ -100,45 +104,43 @@ def extract(self, input_text: str, span_info: bool = True) -> dict:
extract = self.pool[lookup].extract_keywords(
input_text, span_info=span_info
)
output[lookup] = self._parse_extract(extract, span_info=span_info)
output[lookup] = self._parse_extract(extract, input_text)

return output

def _parse_extract(self, extract_data: list, span_info: bool = True) -> dict:
def _parse_extract(self, extract_data: list, input_text: str) -> dict:
"""Parse flashtext.KeywordProcessor.extract_keywords() output to count occurances
Parse flashtext.KeywordProcessor.extract_keywords() output to count occurances,
and optionally span_info.
Args:
extract_data (list): flashtext.KeywordProcessor.extract_keywords() return value
span_info (bool): optionally, parse span_info
input_text (str): input text
Returns:
parsed_extract (dict): parsed extract_data to include count, optionally span_info
"""
parsed_extract: dict = {}

if span_info:
for entry in extract_data:
if entry[0] not in parsed_extract:
parsed_extract[entry[0]] = {
"count": 1,
"span_info": [(entry[1], entry[2])],
}
else:
parsed_extract[entry[0]]["count"] = (
parsed_extract[entry[0]]["count"] + 1
)
parsed_extract[entry[0]]["span_info"] = parsed_extract[entry[0]][
"span_info"
] + [(entry[1], entry[2])]

else:
for entry in extract_data:
if entry not in parsed_extract:
parsed_extract[entry] = {"count": 1}
else:
parsed_extract[entry]["count"] = parsed_extract[entry]["count"] + 1
for entry in extract_data:
keyword = entry[0]
span_start = entry[1]
span_end = entry[2]

if keyword not in parsed_extract:
parsed_extract[keyword] = {
"count": 1,
"span_info": [(span_start, span_end)],
"found_as": [input_text[span_start:span_end]],
}
else:
parsed_extract[keyword]["count"] = parsed_extract[keyword]["count"] + 1
parsed_extract[keyword]["span_info"] = parsed_extract[keyword][
"span_info"
] + [(span_start, span_end)]
parsed_extract[keyword]["found_as"] = parsed_extract[keyword][
"found_as"
] + [input_text[span_start:span_end]]

return parsed_extract
19 changes: 7 additions & 12 deletions tests/integration/test_geotext_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,20 @@ def test_geotext_raises_on_empty_pool():


def test_geotext_extract_with_count_span_info_true(geotext):
output = geotext.extract(input_text=text, span_info=True)
output = geotext.extract(input_text=text)
assert output["cities"]["Berlin"]["count"] == 2
assert output["cities"]["Berlin"]["span_info"] == [(0, 6), (43, 49)]


def test_geotext_extract_with_count_span_info_false(geotext):
output = geotext.extract(input_text=text, span_info=False)

with pytest.raises(KeyError):
assert output["cities"]["Berlin"]["span_info"] == [(0, 6), (43, 49)]
assert output["cities"]["Berlin"]["found_as"] == ["Berlin", "Berlin"]


def test_geotext_case_sensitive_demo_data():
config = GeoTextConfiguration(**{"use_demo_data": True, "case_sensitive": False})
geotext = GeoText(config)
text = "berlin ist ne tolle stadt"
output = geotext.extract(input_text=text, span_info=True)
output = geotext.extract(input_text=text)

assert output["cities"]["Berlin"]["span_info"] == [(0, 6)]
assert output["cities"]["Berlin"]["found_as"] == ["berlin"]


# tests used in geotext (https://github.com/elyase/geotext)
Expand Down Expand Up @@ -124,7 +119,7 @@ def test_geotext_case_sensitive_demo_data():
],
)
def test_geotext_extract_cities(nr, text, expected_cities, geotext):
output = geotext.extract(input_text=text, span_info=False)
output = geotext.extract(input_text=text)

assert all([city in output["cities"] for city in expected_cities])

Expand Down Expand Up @@ -161,7 +156,7 @@ def test_geotext_extract_cities(nr, text, expected_cities, geotext):
],
)
def test_geotext_extract_countries(nr, text, expected_countries, geotext):
output = geotext.extract(input_text=text, span_info=False)
output = geotext.extract(input_text=text)

assert all([country in output["countries"] for country in expected_countries])

Expand All @@ -186,5 +181,5 @@ def test_geotext_with_script_added_to_non_word_boundaries():
что традиционно в середине апреля закрываются для движения автотранспорта все ледовые переправы.
"""

result = geotext.extract(text, span_info=False)
result = geotext.extract(text)
result["test_1"]["Нижневартовск"]["count"] == 1

0 comments on commit fb06232

Please sign in to comment.