-
Notifications
You must be signed in to change notification settings - Fork 0
/
annotations.py
264 lines (220 loc) · 11.7 KB
/
annotations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import logging
import os
import re
from collections import Counter, namedtuple
from math import ceil
EntTuple = namedtuple('EntTuple', ['tag', 'start', 'end', 'text'])
class Annotations:
"""
An Annotations object stores all relevant information needed to manage Annotations over a document.
The Annotation object is utilized by medaCy to structure input to models and output from models.
This object wraps a list of tuples representing the entities in a document.
:ivar ann_path: the path to the .ann file
:ivar source_text_path: path to the related .txt file
:ivar annotations: a list of annotation tuples
"""
brat_pattern = re.compile(r'T(\d+)\t(\S+) ((\d+ \d+;)*\d+ \d+)\t(.+)')
def __init__(self, annotation_data, source_text_path=None):
"""
:param annotation_data: a file path to an annotation file, or a list of annotation tuples.
Construction from a list of tuples is intended for internal use.
:param source_text_path: optional; path to the text file from which the annotations were derived.
"""
if isinstance(annotation_data, list) and all(isinstance(e, tuple) for e in annotation_data):
self.annotations = annotation_data
self.source_text_path = source_text_path
return
elif not os.path.isfile(annotation_data):
raise FileNotFoundError("annotation_data must be a list of tuples or a valid file path, but is %s" % repr(annotation_data))
self.ann_path = annotation_data
self.source_text_path = source_text_path
self.annotations = self._init_from_file(annotation_data)
@staticmethod
def _init_from_file(file_path):
"""
Creates a list of annotation tuples from a file path
:param file_path: the path to an ann file
:return: a list of annotation tuples
"""
annotations = []
with open(file_path, 'r', encoding='utf-8') as f:
data = f.read()
for match in re.finditer(Annotations.brat_pattern, data):
tag = match.group(2)
spans = match.group(3)
mention = match.group(5)
spans = re.findall(r'\d+', spans)
start, end = int(spans[0]), int(spans[-1])
new_ent = EntTuple(tag, start, end, mention)
annotations.append(new_ent)
return annotations
@property
def annotations(self):
return self._annotations
@annotations.setter
def annotations(self, value):
"""Ensures that annotations are always sorted"""
self._annotations = sorted([EntTuple(*e) for e in value], key=lambda x: (x.start, x.end))
def get_labels(self, as_list=False):
"""
Get the set of labels from this collection of annotations.
:param as_list: bool for if to return the results as a list; defaults to False
:return: The set of labels.
"""
labels = {e[0] for e in self.annotations}
if as_list:
return list(labels)
return labels
def add_entity(self, label, start, end, text=""):
"""
Adds an entity to the Annotations
:param label: the label of the annotation you are appending
:param start: the start index in the document of the annotation you are appending.
:param end: the end index in the document of the annotation you are appending
:param text: the raw text of the annotation you are appending
"""
self.annotations.append(EntTuple(label, start, end, text))
def to_ann(self, write_location=None):
"""
Formats the Annotations object into a string representing a valid ANN file. Optionally writes the formatted
string to a destination.
:param write_location: path of location to write ann file to
:return: returns string formatted as an ann file, if write_location is valid path also writes to that path.
"""
ann_string = ""
for num, tup in enumerate(self.annotations, 1):
mention = tup.text.replace('\n', ' ')
ann_string += f"T{num}\t{tup.tag} {tup.start} {tup.end}\t{mention}\n"
if write_location is not None:
if os.path.isfile(write_location):
logging.warning("Overwriting file at: %s", write_location)
with open(write_location, 'w') as f:
f.write(ann_string)
return ann_string
def difference(self, other, leniency=0):
"""
Identifies the difference between two Annotations objects. Useful for checking if an unverified annotation
matches an annotation known to be accurate. This is done returning a list of all annotations in the operated on
Annotation object that do not exist in the passed in annotation object. This is a set difference.
:param other: Another Annotations object.
:param leniency: a floating point value between [0,1] defining the leniency of the character spans to count
as different. A value of zero considers only exact character matches while a positive value considers entities
that differ by up to :code:`ceil(leniency * len(span)/2)` on either side.
:return: A set of tuples of non-matching annotations.
"""
if not isinstance(other, Annotations):
raise ValueError("Annotations.diff() can only accept another Annotations object as an argument.")
if leniency == 0:
return set(self.annotations) - set(other.annotations)
if not 0 <= leniency <= 1:
raise ValueError("Leniency must be a floating point between [0,1]")
matches = set()
for ann in self.annotations:
label, start, end, text = ann
window = ceil(leniency * (end - start))
for c_label, c_start, c_end, c_text in other.annotations:
if label == c_label:
if start - window <= c_start and end + window >= c_end:
matches.add(ann)
break
return set(self.annotations) - matches
def intersection(self, other, leniency=0):
"""
Computes the intersection of the operated annotation object with the operand annotation object.
:param other: Another Annotations object.
:param leniency: a floating point value between [0,1] defining the leniency of the character spans to count as
different. A value of zero considers only exact character matches while a positive value considers entities that
differ by up to :code:`ceil(leniency * len(span)/2)` on either side.
:return A set of annotations that appear in both Annotation objects
"""
if not isinstance(other, Annotations):
raise ValueError("An Annotations object is requried as an argument.")
if leniency == 0:
return set(self.annotations) & set(other.annotations)
if not 0 <= leniency <= 1:
raise ValueError("Leniency must be a floating point between [0,1]")
matches = set()
for ann in self.annotations:
label, start, end, text = ann
window = ceil(leniency * (end - start))
for c_label, c_start, c_end, c_text in other:
if label == c_label and start - window <= c_start and end + window >= c_end:
matches.add(ann)
break
return matches
def compute_ambiguity(self, other):
"""
Finds occurrences of spans from 'annotations' that intersect with a span from this annotation but do not have this spans label.
label. If 'annotation' comprises a models predictions, this method provides a strong indicators
of a model's in-ability to dis-ambiguate between entities. For a full analysis, compute a confusion matrix.
:param other: Another Annotations object.
:return: a dictionary containing incorrect label predictions for given spans
"""
if not isinstance(other, Annotations):
raise ValueError("An Annotations object is required as an argument.")
ambiguity_dict = {}
for label, start, end, text in self.annotations:
for c_label, c_start, c_end, c_text in other.annotations:
if label == c_label:
continue
overlap = max(0, min(end, c_end) - max(c_start, start))
if overlap != 0:
ambiguity_dict[(label, start, end, text)] = [(c_label, c_start, c_end, c_text)]
return ambiguity_dict
def compute_confusion_matrix(self, other, entities, leniency=0):
"""
Computes a confusion matrix representing span level ambiguity between this annotation and the argument annotation.
An annotation in 'annotations' is ambiguous is it overlaps with a span in this Annotation but does not have the
same entity label. The main diagonal of this matrix corresponds to entities in this Annotation that match spans
in 'annotations' and have equivalent class label.
:param other: Another Annotations object.
:param entities: a list of entities to use in computing matrix ambiguity.
:param leniency: leniency to utilize when computing overlapping entities. This is the same definition of leniency as in intersection.
:return: a square matrix with dimension len(entities) where matrix[i][j] indicates that entities[i] in this annotation was predicted as entities[j] in 'annotation' matrix[i][j] times.
"""
if not isinstance(other, Annotations):
raise ValueError("An Annotations object is required as an argument.")
if not isinstance(entities, list):
raise ValueError("entities must be a list of entities, but is %s" % repr(entities))
entity_encoding = {entity: i for i, entity in enumerate(entities)}
# Create 2-d array of len(entities) ** 2
confusion_matrix = [[0 for _ in range(len(entities))] for _ in range(len(entities))]
ambiguity_dict = self.compute_ambiguity(other)
intersection = self.intersection(other, leniency=leniency)
# Compute all off diagonal scores
for gold_annotation in ambiguity_dict:
gold_label, start, end, text = gold_annotation
for ambiguous_annotation in ambiguity_dict[gold_annotation]:
ambiguous_label = ambiguous_annotation[0]
confusion_matrix[entity_encoding[gold_label]][entity_encoding[ambiguous_label]] += 1
# Compute diagonal scores (correctly predicted entities with correct spans)
for matching_annotation in intersection:
matching_label, start, end, text = matching_annotation
confusion_matrix[entity_encoding[matching_label]][entity_encoding[matching_label]] += 1
return confusion_matrix
def compute_counts(self):
"""
Computes counts of each entity type in this annotation.
:return: a Counter of the entity counts
"""
return Counter(e[0] for e in self.annotations)
def __str__(self):
return str(self.annotations)
def __len__(self):
return len(self.annotations)
def __iter__(self):
return iter(self.annotations)
def __or__(self, other):
"""
Creates an Annotations object containing annotations from two instances
:param other: the other Annotations instance
:return: a new Annotations object containing entities from both
"""
new_entities = list(set(self.annotations) | set(other.annotations))
new_annotations = Annotations(new_entities, source_text_path=self.source_text_path or other.source_text_path)
new_annotations.ann_path = 'None'
return new_annotations
def __ior__(self, other):
self.annotations = list(set(self.annotations) | set(other.annotations))
self.ann_path = 'None'
return self