Skip to content

Commit

Permalink
Merge pull request #696 from zalandoresearch/GH-673
Browse files Browse the repository at this point in the history
GH-673: add support for new German HDT Universal Dependencies dataset
  • Loading branch information
Alan Akbik committed May 2, 2019
2 parents 81cf1b5 + 602b31f commit 5cbc518
Showing 1 changed file with 44 additions and 4 deletions.
48 changes: 44 additions & 4 deletions flair/data_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ class NLPTask(Enum):
# Language isolates
UD_BASQUE = "ud_basque"

# recent Universal Dependencies
UD_GERMAN_HDT = "ud_german_hdt"

# other datasets
ONTONER = "ontoner"
FASHION = "fashion"
Expand Down Expand Up @@ -379,7 +382,7 @@ def load_classification_corpus(
test_file=None,
dev_file=None,
use_tokenizer: bool = True,
max_tokens_per_doc=-1
max_tokens_per_doc=-1,
) -> TaggedCorpus:
"""
Helper function to get a TaggedCorpus from text classification-formatted task data
Expand Down Expand Up @@ -424,19 +427,25 @@ def load_classification_corpus(
sentences_train: List[
Sentence
] = NLPTaskDataFetcher.read_text_classification_file(
train_file, use_tokenizer=use_tokenizer, max_tokens_per_doc=max_tokens_per_doc
train_file,
use_tokenizer=use_tokenizer,
max_tokens_per_doc=max_tokens_per_doc,
)
sentences_test: List[
Sentence
] = NLPTaskDataFetcher.read_text_classification_file(
test_file, use_tokenizer=use_tokenizer, max_tokens_per_doc=max_tokens_per_doc
test_file,
use_tokenizer=use_tokenizer,
max_tokens_per_doc=max_tokens_per_doc,
)

if dev_file is not None:
sentences_dev: List[
Sentence
] = NLPTaskDataFetcher.read_text_classification_file(
dev_file, use_tokenizer=use_tokenizer, max_tokens_per_doc=max_tokens_per_doc
dev_file,
use_tokenizer=use_tokenizer,
max_tokens_per_doc=max_tokens_per_doc,
)
else:
sentences_dev: List[Sentence] = [
Expand Down Expand Up @@ -1282,3 +1291,34 @@ def download_dataset(task: NLPTask):
f"{ud_path}UD_Basque-BDT/master/eu_bdt-ud-train.conllu",
Path("datasets") / task.value,
)

if task == NLPTask.UD_GERMAN_HDT:
cached_path(
f"{ud_path}UD_German-HDT/dev/de_hdt-ud-dev.conllu",
Path("datasets") / task.value,
)
cached_path(
f"{ud_path}UD_German-HDT/dev/de_hdt-ud-test.conllu",
Path("datasets") / task.value,
)
cached_path(
f"{ud_path}UD_German-HDT/dev/de_hdt-ud-train-a.conllu",
Path("datasets") / task.value / "original",
)
cached_path(
f"{ud_path}UD_German-HDT/dev/de_hdt-ud-train-b.conllu",
Path("datasets") / task.value / "original",
)
data_path = Path(flair.cache_root) / "datasets" / task.value

train_filenames = ["de_hdt-ud-train-a.conllu", "de_hdt-ud-train-b.conllu"]

new_train_file: Path = data_path / "de_hdt-ud-train-all.conllu"

if not new_train_file.is_file():
with open(new_train_file, "wt") as f_out:
for train_filename in train_filenames:
with open(
data_path / "original" / train_filename, "rt"
) as f_in:
f_out.write(f_in.read())

0 comments on commit 5cbc518

Please sign in to comment.