diff --git a/requirements.txt b/requirements.txt index f16aacc2..8e407161 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -Jinja2==2.10.3 +Jinja2==2.11.3 numpy==1.17.4 tensorboardX==1.6 torch==1.4.0 diff --git a/spert/sampling.py b/spert/sampling.py index 26386b2c..0256c6a0 100644 --- a/spert/sampling.py +++ b/spert/sampling.py @@ -19,12 +19,26 @@ def create_train_sample(doc, neg_entity_count: int, neg_rel_count: int, max_span pos_entity_sizes.append(len(e.tokens)) # positive relations - pos_rels, pos_rel_spans, pos_rel_types, pos_rel_masks = [], [], [], [] + + # collect relations between entity pairs + entity_pair_relations = dict() for rel in doc.relations: - s1, s2 = rel.head_entity.span, rel.tail_entity.span + pair = (rel.head_entity, rel.tail_entity) + if pair not in entity_pair_relations: + entity_pair_relations[pair] = [] + entity_pair_relations[pair].append(rel) + + # build positive relation samples + pos_rels, pos_rel_spans, pos_rel_types, pos_rel_masks = [], [], [], [] + for pair, rels in entity_pair_relations.items(): + head_entity, tail_entity = pair + s1, s2 = head_entity.span, tail_entity.span pos_rels.append((pos_entity_spans.index(s1), pos_entity_spans.index(s2))) pos_rel_spans.append((s1, s2)) - pos_rel_types.append(rel.relation_type) + + pair_rel_types = [r.relation_type.index for r in rels] + pair_rel_types = [int(t in pair_rel_types) for t in range(1, rel_type_count)] + pos_rel_types.append(pair_rel_types) pos_rel_masks.append(create_rel_mask(s1, s2, context_size)) # negative entities @@ -50,14 +64,10 @@ def create_train_sample(doc, neg_entity_count: int, neg_rel_count: int, max_span for i1, s1 in enumerate(pos_entity_spans): for i2, s2 in enumerate(pos_entity_spans): - rev = (s2, s1) - rev_symmetric = rev in pos_rel_spans and pos_rel_types[pos_rel_spans.index(rev)].symmetric - # do not add as negative relation sample: # neg. relations from an entity to itself # entity pairs that are related according to gt - # entity pairs whose reverse exists as a symmetric relation in gt - if s1 != s2 and (s1, s2) not in pos_rel_spans and not rev_symmetric: + if s1 != s2 and (s1, s2) not in pos_rel_spans: neg_rel_spans.append((s1, s2)) # sample negative relations @@ -65,7 +75,7 @@ def create_train_sample(doc, neg_entity_count: int, neg_rel_count: int, max_span neg_rels = [(pos_entity_spans.index(s1), pos_entity_spans.index(s2)) for s1, s2 in neg_rel_spans] neg_rel_masks = [create_rel_mask(*spans, context_size) for spans in neg_rel_spans] - neg_rel_types = [0] * len(neg_rel_spans) + neg_rel_types = [(0,) * (rel_type_count-1)] * len(neg_rel_spans) # merge entity_types = pos_entity_types + neg_entity_types @@ -73,7 +83,7 @@ def create_train_sample(doc, neg_entity_count: int, neg_rel_count: int, max_span entity_sizes = pos_entity_sizes + list(neg_entity_sizes) rels = pos_rels + neg_rels - rel_types = [r.index for r in pos_rel_types] + neg_rel_types + rel_types = pos_rel_types + neg_rel_types rel_masks = pos_rel_masks + neg_rel_masks assert len(entity_masks) == len(entity_sizes) == len(entity_types) @@ -105,23 +115,18 @@ def create_train_sample(doc, neg_entity_count: int, neg_rel_count: int, max_span if rels: rels = torch.tensor(rels, dtype=torch.long) rel_masks = torch.stack(rel_masks) - rel_types = torch.tensor(rel_types, dtype=torch.long) + rel_types = torch.tensor(rel_types, dtype=torch.float32) rel_sample_masks = torch.ones([rels.shape[0]], dtype=torch.bool) else: # corner case handling (no pos/neg relations) rels = torch.zeros([1, 2], dtype=torch.long) - rel_types = torch.zeros([1], dtype=torch.long) + rel_types = torch.zeros([1, rel_type_count-1], dtype=torch.float32) rel_masks = torch.zeros([1, context_size], dtype=torch.bool) rel_sample_masks = torch.zeros([1], dtype=torch.bool) - # relation types to one-hot encoding - rel_types_onehot = torch.zeros([rel_types.shape[0], rel_type_count], dtype=torch.float32) - rel_types_onehot.scatter_(1, rel_types.unsqueeze(1), 1) - rel_types_onehot = rel_types_onehot[:, 1:] # all zeros for 'none' relation - return dict(encodings=encodings, context_masks=context_masks, entity_masks=entity_masks, entity_sizes=entity_sizes, entity_types=entity_types, - rels=rels, rel_masks=rel_masks, rel_types=rel_types_onehot, + rels=rels, rel_masks=rel_masks, rel_types=rel_types, entity_sample_masks=entity_sample_masks, rel_sample_masks=rel_sample_masks)