Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relation extraction for DrugProt [WIP] #2340

Merged
merged 2 commits into from
Jul 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 6 additions & 1 deletion flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,7 +1403,12 @@ def make_label_dictionary(self, label_type: str) -> Dictionary:

from flair.datasets import DataLoader

data = ConcatDataset([self.train, self.test])
datasets = [self.train]
if self.test is not None:
datasets.append(self.test)

data = ConcatDataset(datasets)

loader = DataLoader(data, batch_size=1)

log.info("Computing label dictionary. Progress:")
Expand Down
3 changes: 3 additions & 0 deletions flair/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,6 @@

# Expose all relation extraction datasets
from .relation_extraction import SEMEVAL_2010_TASK_8
from .relation_extraction import TACRED
from .relation_extraction import CoNLL04
from .relation_extraction import DrugProt
21 changes: 19 additions & 2 deletions flair/datasets/conllu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path
from typing import List, Union, Optional, Sequence, Dict, Tuple

from flair.data import Sentence, Corpus, Token, FlairDataset, Span, RelationLabel
from flair.data import Sentence, Corpus, Token, FlairDataset, Span, RelationLabel, SpanLabel
from flair.datasets.base import find_train_dev_test_files
import conllu

Expand Down Expand Up @@ -53,6 +53,7 @@ def __init__(
fields: Optional[Sequence[str]] = None,
field_parsers: Optional[Dict[str, conllu._FieldParserType]] = None,
metadata_parsers: Optional[Dict[str, conllu._MetadataParserType]] = None,
sample_missing_splits: bool = True,
):
"""
Instantiates a Corpus from CoNLL-U (Plus) column-formatted task data
Expand Down Expand Up @@ -103,7 +104,8 @@ def __init__(
else None
)

super(CoNLLUCorpus, self).__init__(train, dev, test, name=str(data_folder))
super(CoNLLUCorpus, self).__init__(train, dev, test, name=str(data_folder),
sample_missing_splits=sample_missing_splits)


class CoNLLUDataset(FlairDataset):
Expand Down Expand Up @@ -203,6 +205,9 @@ def token_list_to_sentence(self, token_list: conllu.TokenList) -> Sentence:
if "ner" in conllu_token:
token.add_label("ner", conllu_token["ner"])

if "ner-2" in conllu_token:
token.add_label("ner-2", conllu_token["ner-2"])

if "lemma" in conllu_token:
token.add_label("lemma", conllu_token["lemma"])

Expand All @@ -226,4 +231,16 @@ def token_list_to_sentence(self, token_list: conllu.TokenList) -> Sentence:

sentence.add_complex_label("relation", RelationLabel(value=label, head=head, tail=tail))

# determine all NER label types in sentence and add all NER spans as sentence-level labels
ner_label_types = []
for token in sentence.tokens:
for annotation in token.annotation_layers.keys():
if annotation.startswith("ner") and annotation not in ner_label_types:
ner_label_types.append(annotation)

for label_type in ner_label_types:
spans = sentence.get_spans(label_type)
for span in spans:
sentence.add_complex_label("entity", label=SpanLabel(span=span, value=span.tag, score=span.score))

return sentence
212 changes: 211 additions & 1 deletion flair/datasets/relation_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,22 @@
import re
import io
import os
import bisect
from pathlib import Path
from typing import List, Union, Optional, Sequence, Dict, Any, Tuple
from typing import List, Union, Sequence, Dict, Any, Tuple, Set
from collections import defaultdict

import flair
import json
import gdown
import conllu
from flair.file_utils import cached_path
from flair.datasets.conllu import CoNLLUCorpus
from flair.tokenization import (
SentenceSplitter,
SciSpacySentenceSplitter,
)
from flair.data import Sentence

log = logging.getLogger("flair")

Expand Down Expand Up @@ -453,3 +460,206 @@ def _src_token_list_to_token_list(self, src_token_list):
}

return conllu.TokenList(tokens=token_dicts, metadata=metadata)


class DrugProt(CoNLLUCorpus):
def __init__(
self,
base_path: Union[str, Path] = None,
in_memory: bool = True,
sentence_splitter: SentenceSplitter = None,
):
if type(base_path) == str:
base_path: Path = Path(base_path)

self.sentence_splitter = (
sentence_splitter if sentence_splitter else SciSpacySentenceSplitter()
)

# this dataset name
dataset_name = self.__class__.__name__.lower()

# default dataset folder is the cache root
if not base_path:
base_path = flair.cache_root / "datasets"
data_folder = base_path / dataset_name

drugprot_url = (
"https://zenodo.org/record/5042151/files/drugprot-gs-training-development.zip"
)
data_file = data_folder / "drugprot-train.conllu"

if not data_file.is_file():
source_data_folder = data_folder / "original"
cached_path(drugprot_url, source_data_folder)
self.extract_and_convert_to_conllu(
data_file=source_data_folder / "drugprot-gs-training-development.zip",
data_folder=data_folder,
)

super(DrugProt, self).__init__(
data_folder,
in_memory=in_memory,
sample_missing_splits=False,
)

def extract_and_convert_to_conllu(self, data_file, data_folder):
import zipfile

splits = ["training", "development"]
target_filenames = ["drugprot-train.conllu", "drugprot-dev.conllu"]

with zipfile.ZipFile(data_file) as zip_file:
for split, target_filename in zip(splits, target_filenames):
pmid_to_entities = defaultdict(dict)
pmid_to_relations = defaultdict(set)

with zip_file.open(f"drugprot-gs-training-development/{split}/drugprot_{split}_entities.tsv") as entites_file:
for line in io.TextIOWrapper(entites_file, encoding="utf-8"):
fields = line.strip().split("\t")
pmid, ent_id, ent_type, start, end, mention = fields
pmid_to_entities[pmid][ent_id] = (
ent_type, int(start), int(end), mention)

with zip_file.open(f"drugprot-gs-training-development/{split}/drugprot_{split}_relations.tsv") as relations_file:
for line in io.TextIOWrapper(relations_file, encoding="utf-8"):
fields = line.strip().split("\t")
pmid, rel_type, arg1, arg2 = fields
ent1 = arg1.split(":")[1]
ent2 = arg2.split(":")[1]
pmid_to_relations[pmid].add((rel_type, ent1, ent2))

tokenlists: List[conllu.TokenList] = []
with zip_file.open(f"drugprot-gs-training-development/{split}/drugprot_{split}_abstracs.tsv") as abstracts_file:
for line in io.TextIOWrapper(abstracts_file, encoding="utf-8"):
fields = line.strip().split("\t")
pmid, title, abstract = fields
title_sentences = self.sentence_splitter.split(title)
abstract_sentences = self.sentence_splitter.split(abstract)

tokenlists.extend(self.drugprot_document_to_tokenlists(pmid=pmid,
title_sentences=title_sentences,
abstract_sentences=abstract_sentences,
abstract_offset=len(title) + 1,
entities=pmid_to_entities[pmid],
relations=pmid_to_relations[pmid]))

target_file_path = Path(data_folder) / target_filename
with open(target_file_path, mode="w", encoding="utf-8") as target_file:
# write CoNLL-U Plus header
target_file.write("# global.columns = id form ner ner-2\n")

for tokenlist in tokenlists:
target_file.write(tokenlist.serialize())

# for source_file_path, target_filename in zip(source_file_paths, target_filenames):
# with zip_file.open(source_file_path, mode="r") as source_file:

# target_file_path = Path(data_folder) / target_filename
# with open(target_file_path, mode="w", encoding="utf-8") as target_file:
# # write CoNLL-U Plus header
# target_file.write("# global.columns = id form ner\n")

# for example in json.load(source_file):
# token_list = self._tacred_example_to_token_list(example)
# target_file.write(token_list.serialize())
def char_spans_to_token_spans(self, char_spans, token_offsets):
token_starts = [s[0] for s in token_offsets]
token_ends = [s[1] for s in token_offsets]

token_spans = []
for char_start, char_end in char_spans:
token_start = bisect.bisect_right(token_ends, char_start)
token_end = bisect.bisect_left(token_starts, char_end)
token_spans.append((token_start, token_end))

return token_spans

def has_overlap(self, a, b):
if a is None or b is None:
return False

return max(0, min(a[1], b[1]) - max(a[0], b[0])) > 0

def drugprot_document_to_tokenlists(self,
pmid: str,
title_sentences: List[Sentence],
abstract_sentences: List[Sentence],
abstract_offset: int,
entities: Dict[str, Tuple[str, int, int, str]],
relations: Set[Tuple[str, str, str]]
) -> List[conllu.TokenList]:
tokenlists: List[conllu.TokenList] = []
sentence_id = 1
for offset, sents in [(0, title_sentences), (abstract_offset, abstract_sentences)]:
for sent in sents:
sent_char_start = sent.start_pos + offset
sent_char_end = sent.end_pos + offset

entities_in_sent = set()
for entity_id, (_, char_start, char_end, _) in entities.items():
if sent_char_start <= char_start and char_end <= sent_char_end:
entities_in_sent.add(entity_id)

entity_char_spans = [(entities[entity_id][1], entities[entity_id][2]) for entity_id in entities_in_sent]

token_offsets = [(sent.start_pos + token.start_pos + offset, sent.start_pos + token.end_pos + offset) for token in sent.tokens]
entity_token_spans = self.char_spans_to_token_spans(entity_char_spans, token_offsets)

tags_1 = ["O"] * len(sent)
tags_2 = ["O"] * len(sent)
entity_id_to_token_idx = {}
prev_entity_span = None
for entity_id, entity_span in sorted(zip(entities_in_sent, entity_token_spans), key=lambda x: x[1][0]):
entity_id_to_token_idx[entity_id] = entity_span

overlap = self.has_overlap(prev_entity_span, entity_span)

tags = tags_2 if overlap else tags_1

tag = entities[entity_id][0]
token_start, token_end = entity_span
for i in range(token_start, token_end):
if i == token_start:
prefix = "B-"
else:
prefix = "I-"

tags[i] = prefix + tag

prev_entity_span = entity_span

token_dicts = []
for i, (token, tag_1, tag_2) in enumerate(zip(sent, tags_1, tags_2)):
token_dicts.append({
"id": str(i + 1),
"form": token.text,
"ner": tag_1,
"ner-2": tag_2
})

relations_in_sent = []
for relation, ent1, ent2 in [r for r in relations if {r[1], r[2]} <= entities_in_sent]:
subj_start = entity_id_to_token_idx[ent1][0]
subj_end = entity_id_to_token_idx[ent1][1]
obj_start = entity_id_to_token_idx[ent2][0]
obj_end = entity_id_to_token_idx[ent2][1]
relations_in_sent.append((subj_start, subj_end, obj_start, obj_end, relation))

metadata = {
"text": sent.to_original_text(),
"doc_id": pmid,
"sentence_id": str(sentence_id),
"relations": "|".join(
[
";".join([str(subj_start + 1), str(subj_end), str(obj_start + 1), str(obj_end), relation])
for subj_start, subj_end, obj_start, obj_end, relation in relations_in_sent
]
),
}

tokenlists.append(conllu.TokenList(tokens=token_dicts, metadata=metadata))

sentence_id += 1

return tokenlists
32 changes: 25 additions & 7 deletions flair/models/relation_extractor_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import List, Union
from typing import List, Union, Dict, Optional, Tuple

import torch
import torch.nn as nn
Expand All @@ -19,6 +19,7 @@ def __init__(
label_type: str = None,
span_label_type: str = None,
use_gold_spans: bool = False,
use_entity_pairs: List[Tuple[str, str]] = None,
pooling_operation: str = "first_last",
dropout_value: float = 0.0,
**classifierargs,
Expand All @@ -43,6 +44,11 @@ def __init__(
self.dropout_value = dropout_value
self.dropout = torch.nn.Dropout(dropout_value)

if use_entity_pairs is not None:
self.use_entity_pairs = set(use_entity_pairs)
else:
self.use_entity_pairs = None

relation_representation_length = 2 * token_embeddings.embedding_length
if self.pooling_operation == 'first_last':
relation_representation_length *= 2
Expand Down Expand Up @@ -74,21 +80,31 @@ def forward_pass(self,
relation_label: RelationLabel = relation_label
relation_dict[create_position_string(relation_label.head, relation_label.tail)] = relation_label

# get all entities
spans = sentence.get_spans(self.span_label_type)
# get all entity spans
span_labels = sentence.get_labels(self.span_label_type)

# get embedding for each entity
span_embeddings = []
for span in spans:
for span_label in span_labels:
span: Span = span_label.span
if self.pooling_operation == "first":
span_embeddings.append(span.tokens[0].get_embedding())
if self.pooling_operation == "first_last":
span_embeddings.append(torch.cat([span.tokens[0].get_embedding(), span.tokens[-1].get_embedding()]))

# go through cross product of entities, for each pair concat embeddings
for span, embedding in zip(spans, span_embeddings):
for span_2, embedding_2 in zip(spans, span_embeddings):
if span == span_2: continue
for span_label, embedding in zip(span_labels, span_embeddings):
span = span_label.span

for span_label_2, embedding_2 in zip(span_labels, span_embeddings):
span_2 = span_label_2.span

if span == span_2:
continue

if (self.use_entity_pairs is not None
and (span_label.value, span_label_2.value) not in self.use_entity_pairs):
continue

position_string = create_position_string(span, span_2)

Expand Down Expand Up @@ -137,6 +153,7 @@ def _get_state_dict(self):
"loss_weights": self.loss_weights,
"pooling_operation": self.pooling_operation,
"dropout_value": self.dropout_value,
"use_entity_pairs": self.use_entity_pairs,
}
return model_state

Expand All @@ -150,6 +167,7 @@ def _init_model_with_state_dict(state):
loss_weights=state["loss_weights"],
pooling_operation=state["pooling_operation"],
dropout_value=state["dropout_value"],
use_entity_pairs=state["use_entity_pairs"],
)
model.load_state_dict(state["state_dict"])
return model
Expand Down
2 changes: 1 addition & 1 deletion flair/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def combined_rule_prefixes() -> List[str]:
infix_re = spacy.util.compile_infix_regex(infixes)

self.model = spacy.load(
"en_core_sci_sm", disable=["tagger", "ner", "parser", "textcat"]
"en_core_sci_sm", disable=["tagger", "ner", "parser", "textcat", "lemmatizer"]
)
self.model.tokenizer.prefix_search = prefix_re.search
self.model.tokenizer.infix_finditer = infix_re.finditer
Expand Down