Skip to content

Commit

Permalink
improved handling of multiple relations between same entity pair duri…
Browse files Browse the repository at this point in the history
…ng sampling
  • Loading branch information
markus-eberts committed Apr 28, 2021
1 parent c1dded7 commit 7b27b7d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 19 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Jinja2==2.10.3
Jinja2==2.11.3
numpy==1.17.4
tensorboardX==1.6
torch==1.4.0
Expand Down
41 changes: 23 additions & 18 deletions spert/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,26 @@ def create_train_sample(doc, neg_entity_count: int, neg_rel_count: int, max_span
pos_entity_sizes.append(len(e.tokens))

# positive relations
pos_rels, pos_rel_spans, pos_rel_types, pos_rel_masks = [], [], [], []

# collect relations between entity pairs
entity_pair_relations = dict()
for rel in doc.relations:
s1, s2 = rel.head_entity.span, rel.tail_entity.span
pair = (rel.head_entity, rel.tail_entity)
if pair not in entity_pair_relations:
entity_pair_relations[pair] = []
entity_pair_relations[pair].append(rel)

# build positive relation samples
pos_rels, pos_rel_spans, pos_rel_types, pos_rel_masks = [], [], [], []
for pair, rels in entity_pair_relations.items():
head_entity, tail_entity = pair
s1, s2 = head_entity.span, tail_entity.span
pos_rels.append((pos_entity_spans.index(s1), pos_entity_spans.index(s2)))
pos_rel_spans.append((s1, s2))
pos_rel_types.append(rel.relation_type)

pair_rel_types = [r.relation_type.index for r in rels]
pair_rel_types = [int(t in pair_rel_types) for t in range(1, rel_type_count)]
pos_rel_types.append(pair_rel_types)
pos_rel_masks.append(create_rel_mask(s1, s2, context_size))

# negative entities
Expand All @@ -50,30 +64,26 @@ def create_train_sample(doc, neg_entity_count: int, neg_rel_count: int, max_span

for i1, s1 in enumerate(pos_entity_spans):
for i2, s2 in enumerate(pos_entity_spans):
rev = (s2, s1)
rev_symmetric = rev in pos_rel_spans and pos_rel_types[pos_rel_spans.index(rev)].symmetric

# do not add as negative relation sample:
# neg. relations from an entity to itself
# entity pairs that are related according to gt
# entity pairs whose reverse exists as a symmetric relation in gt
if s1 != s2 and (s1, s2) not in pos_rel_spans and not rev_symmetric:
if s1 != s2 and (s1, s2) not in pos_rel_spans:
neg_rel_spans.append((s1, s2))

# sample negative relations
neg_rel_spans = random.sample(neg_rel_spans, min(len(neg_rel_spans), neg_rel_count))

neg_rels = [(pos_entity_spans.index(s1), pos_entity_spans.index(s2)) for s1, s2 in neg_rel_spans]
neg_rel_masks = [create_rel_mask(*spans, context_size) for spans in neg_rel_spans]
neg_rel_types = [0] * len(neg_rel_spans)
neg_rel_types = [(0,) * (rel_type_count-1)] * len(neg_rel_spans)

# merge
entity_types = pos_entity_types + neg_entity_types
entity_masks = pos_entity_masks + neg_entity_masks
entity_sizes = pos_entity_sizes + list(neg_entity_sizes)

rels = pos_rels + neg_rels
rel_types = [r.index for r in pos_rel_types] + neg_rel_types
rel_types = pos_rel_types + neg_rel_types
rel_masks = pos_rel_masks + neg_rel_masks

assert len(entity_masks) == len(entity_sizes) == len(entity_types)
Expand Down Expand Up @@ -105,23 +115,18 @@ def create_train_sample(doc, neg_entity_count: int, neg_rel_count: int, max_span
if rels:
rels = torch.tensor(rels, dtype=torch.long)
rel_masks = torch.stack(rel_masks)
rel_types = torch.tensor(rel_types, dtype=torch.long)
rel_types = torch.tensor(rel_types, dtype=torch.float32)
rel_sample_masks = torch.ones([rels.shape[0]], dtype=torch.bool)
else:
# corner case handling (no pos/neg relations)
rels = torch.zeros([1, 2], dtype=torch.long)
rel_types = torch.zeros([1], dtype=torch.long)
rel_types = torch.zeros([1, rel_type_count-1], dtype=torch.float32)
rel_masks = torch.zeros([1, context_size], dtype=torch.bool)
rel_sample_masks = torch.zeros([1], dtype=torch.bool)

# relation types to one-hot encoding
rel_types_onehot = torch.zeros([rel_types.shape[0], rel_type_count], dtype=torch.float32)
rel_types_onehot.scatter_(1, rel_types.unsqueeze(1), 1)
rel_types_onehot = rel_types_onehot[:, 1:] # all zeros for 'none' relation

return dict(encodings=encodings, context_masks=context_masks, entity_masks=entity_masks,
entity_sizes=entity_sizes, entity_types=entity_types,
rels=rels, rel_masks=rel_masks, rel_types=rel_types_onehot,
rels=rels, rel_masks=rel_masks, rel_types=rel_types,
entity_sample_masks=entity_sample_masks, rel_sample_masks=rel_sample_masks)


Expand Down

0 comments on commit 7b27b7d

Please sign in to comment.