Skip to content

Commit

Permalink
improve incident association on comma delimited data
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanathan committed Nov 16, 2018
1 parent 785575e commit 722b5a7
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 8 deletions.
52 changes: 51 additions & 1 deletion epitator/annotier.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,61 @@ def __iter__(self):
def __getitem__(self, idx):
return self.spans[idx]

def subtract_overlaps(self, other_tier):
"""
:param other_tier: The spans to be removed from the territory of this tier
:type other_tier: AnnoTier
:return: A copy of this tier with spans truncated and split so that
none of the new spans overlap a span in other_tier
:rtype: AnnoTier
>>> from .annospan import AnnoSpan
>>> from .annodoc import AnnoDoc
>>> doc = AnnoDoc('one two three four')
>>> tier_a = AnnoTier([AnnoSpan(0, 18, doc)])
>>> tier_b = AnnoTier([AnnoSpan(3, 8, doc), AnnoSpan(13, 18, doc)])
>>> tier_a.subtract_overlaps(tier_b)
AnnoTier([AnnoSpan(0-3, one), AnnoSpan(8-13, three)])
"""
result_spans = []
for span, overlapping_spans in self.group_spans_by_containing_span(other_tier, allow_partial_containment=True):
new_start = span.start
for overlapping_span in overlapping_spans:
if overlapping_span.start <= new_start:
new_start = max(overlapping_span.end, new_start)
else:
result_spans.append(AnnoSpan(
new_start,
overlapping_span.start,
span.doc,
span.label,
span.metadata
))
new_start = overlapping_span.end
if new_start >= span.end:
break
if new_start < span.end:
result_spans.append(AnnoSpan(
new_start,
span.end,
span.doc,
span.label,
span.metadata
))
return AnnoTier(result_spans)

def group_spans_by_containing_span(self,
other_tier,
allow_partial_containment=False):
"""
Group spans in the other tier by the spans that contain them.
Group spans in other_tier by the spans that contain them in this one.
:param other_tier: The spans to be grouped together
:type other_tier: AnnoTier
:param allow_partial_containment: Include spans in groups for spans that partially overlap them.
:return: An iterator that returns pairs of values, the first of which is
the containing span from this tier, the second is an array of
spans from other_tier that the span from this tier contans.
>>> from .annospan import AnnoSpan
>>> from .annodoc import AnnoDoc
Expand Down
27 changes: 21 additions & 6 deletions epitator/incident_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,15 @@ def format_geoname(geoname):
return result


def get_territories(spans, sent_spans):
def get_territories(spans, sent_spans, phrase_spans):
"""
A annotation's territory is the sentence containing it,
and all the following sentences until the next annotation.
Annotations in the same sentence are grouped.
If sub-sentence phrase spans are provided, only the spans within each phase
are assigned to its territory, while outside of it the usual rules apply.
This is intended to improve associations when multiple counts for multiple
locations appear in the same sentence.
"""
doc = sent_spans[0].doc
territories = []
Expand Down Expand Up @@ -79,7 +83,14 @@ def get_territories(spans, sent_spans):
territories.append(AnnoSpan(
sent_span.start, sent_span.end, doc,
metadata=last_doc_scope_spans))
return AnnoTier(territories)
phrase_territories = []
for phrase_span, span_group in phrase_spans.group_spans_by_containing_span(spans, allow_partial_containment=True):
if len(span_group) > 0:
phrase_territories.append(AnnoSpan(
phrase_span.start, phrase_span.end, doc,
metadata=span_group))
phrase_territories = AnnoTier(phrase_territories, presorted=True)
return AnnoTier(territories).subtract_overlaps(phrase_territories) + phrase_territories


class IncidentAnnotator(Annotator):
Expand Down Expand Up @@ -123,14 +134,18 @@ def annotate(self, doc, case_counts=None):
continue
if datetime_range[1].date() > publish_date.date():
# Truncate ranges that extend into the future
datetime_range[1] = datetime.datetime(publish_date.year, publish_date.month, publish_date.day)
datetime_range[1] = datetime.datetime(publish_date.year, publish_date.month, publish_date.day + 1)
dates_out.append(AnnoSpan(span.start, span.end, span.doc, metadata={
'datetime_range': datetime_range
}))
date_tier = AnnoTier(dates_out, presorted=True)
date_territories = get_territories(date_tier, sent_spans)
geoname_territories = get_territories(geonames, sent_spans)
disease_territories = get_territories(disease_tier, sent_spans)
phrase_spans = []
for sent_span, comma_group in sent_spans.group_spans_by_containing_span(doc.create_regex_tier(",")):
phrase_spans += AnnoTier([sent_span]).subtract_overlaps(comma_group).spans
phrase_spans = AnnoTier(phrase_spans)
date_territories = get_territories(date_tier, sent_spans, phrase_spans)
geoname_territories = get_territories(geonames, sent_spans, phrase_spans)
disease_territories = get_territories(disease_tier, sent_spans, phrase_spans)
# Only include the sentence the word appears in for species territories since
# the species is implicitly human in most of the articles we're analyzing.
species_territories = []
Expand Down
2 changes: 1 addition & 1 deletion epitator/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.2.3'
__version__ = '1.2.4'
26 changes: 26 additions & 0 deletions tests/annotator/test_incident_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@
import datetime


def find(li, func):
if len(li) > 0:
if func(li[0]):
return li[0]
else:
return find(li[1:], func)


class TestIncidentAnnotator(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -177,3 +185,21 @@ def test_sentence_segmentation(self):
doc.add_tier(self.annotator)
self.assertEqual(len(doc.tiers['incidents'].spans[0].metadata['locations']), 2)
self.assertEqual(len(doc.tiers['incidents'].spans[-1].metadata['locations']), 1)

def test_location_grouping(self):
doc = AnnoDoc("""
In an update yesterday, officials reported 6 more cases,
which include 1 case in Paris, 1 case in London, and 4 cases in Australia,
which is in a security "red zone" area.
Today, the health ministry reported 3 more Ebola deaths.
""")
doc.add_tier(self.annotator)
self.assertEqual(
len(find(doc.tiers['incidents'], lambda span: span.metadata['value'] == 6).metadata['locations']),
3)
self.assertEqual(
len(find(doc.tiers['incidents'], lambda span: span.metadata['value'] == 1).metadata['locations']),
1)
self.assertEqual(
len(find(doc.tiers['incidents'], lambda span: span.metadata['value'] == 3).metadata['locations']),
3)

0 comments on commit 722b5a7

Please sign in to comment.