diff --git a/flair/datasets/document_classification.py b/flair/datasets/document_classification.py index c00b7cffd..afebb7975 100644 --- a/flair/datasets/document_classification.py +++ b/flair/datasets/document_classification.py @@ -37,6 +37,7 @@ def __init__( memory_mode: str = "partial", label_name_map: Dict[str, str] = None, skip_labels: List[str] = None, + allow_examples_without_labels=False, encoding: str = 'utf-8', ): """ @@ -55,6 +56,7 @@ def __init__( if full corpus and all embeddings fits into memory for speedups during training. Otherwise use 'partial' and if even this is too much for your memory, use 'disk'. :param label_name_map: Optionally map label names to different schema. + :param allow_examples_without_labels: set to True to allow Sentences without label in the corpus. :param encoding: Default is 'uft-8' but some datasets are in 'latin-1 :return: a Corpus with annotated train, dev and test data """ @@ -73,6 +75,7 @@ def __init__( memory_mode=memory_mode, label_name_map=label_name_map, skip_labels=skip_labels, + allow_examples_without_labels=allow_examples_without_labels, encoding=encoding, ) @@ -87,6 +90,7 @@ def __init__( memory_mode=memory_mode, label_name_map=label_name_map, skip_labels=skip_labels, + allow_examples_without_labels=allow_examples_without_labels, encoding=encoding, ) if test_file is not None else None @@ -101,6 +105,7 @@ def __init__( memory_mode=memory_mode, label_name_map=label_name_map, skip_labels=skip_labels, + allow_examples_without_labels=allow_examples_without_labels, encoding=encoding, ) if dev_file is not None else None @@ -125,6 +130,7 @@ def __init__( memory_mode: str = "partial", label_name_map: Dict[str, str] = None, skip_labels: List[str] = None, + allow_examples_without_labels=False, encoding: str = 'utf-8', ): """ @@ -143,6 +149,7 @@ def __init__( if full corpus and all embeddings fits into memory for speedups during training. Otherwise use 'partial' and if even this is too much for your memory, use 'disk'. :param label_name_map: Optionally map label names to different schema. + :param allow_examples_without_labels: set to True to allow Sentences without label in the Dataset. :param encoding: Default is 'uft-8' but some datasets are in 'latin-1 :return: list of sentences """ @@ -169,6 +176,7 @@ def __init__( self.truncate_to_max_tokens = truncate_to_max_tokens self.filter_if_longer_than = filter_if_longer_than self.label_name_map = label_name_map + self.allow_examples_without_labels = allow_examples_without_labels self.path_to_file = path_to_file @@ -176,7 +184,7 @@ def __init__( line = f.readline() position = 0 while line: - if "__label__" not in line or (" " not in line and "\t" not in line): + if ("__label__" not in line and not allow_examples_without_labels) or (" " not in line and "\t" not in line): position = f.tell() line = f.readline() continue @@ -219,7 +227,7 @@ def __init__( text = line[l_len:].strip() # if so, add to indices - if text and label: + if text and (label or allow_examples_without_labels): if self.memory_mode == 'partial': self.lines.append(line) @@ -257,7 +265,7 @@ def _parse_line_to_sentence( if self.truncate_to_max_chars > 0: text = text[: self.truncate_to_max_chars] - if text and labels: + if text and (labels or self.allow_examples_without_label): sentence = Sentence(text, use_tokenizer=tokenizer) for label in labels: diff --git a/tests/resources/tasks/multi_class_negative_examples/dev.txt b/tests/resources/tasks/multi_class_negative_examples/dev.txt new file mode 100644 index 000000000..a0fb632f6 --- /dev/null +++ b/tests/resources/tasks/multi_class_negative_examples/dev.txt @@ -0,0 +1,5 @@ +__label__apple apple +__label__tv tv +__label__guitar guitar +__label__apple __label__tv apple tv + dev example without labels diff --git a/tests/resources/tasks/multi_class_negative_examples/test.txt b/tests/resources/tasks/multi_class_negative_examples/test.txt new file mode 100644 index 000000000..bd3a8b3f5 --- /dev/null +++ b/tests/resources/tasks/multi_class_negative_examples/test.txt @@ -0,0 +1,6 @@ +__label__guitar guitar +__label__apple apple +__label__tv tv +__label__apple __label__tv apple tv +__label__apple __label__guitar apple tv +test example without labels diff --git a/tests/resources/tasks/multi_class_negative_examples/train.txt b/tests/resources/tasks/multi_class_negative_examples/train.txt new file mode 100644 index 000000000..f3ba00135 --- /dev/null +++ b/tests/resources/tasks/multi_class_negative_examples/train.txt @@ -0,0 +1,8 @@ +__label__tv tv +__label__apple __label__tv apple tv +__label__apple apple +__label__tv tv +__label__apple __label__tv apple tv +__label__guitar guitar +__label__guitar guitar +train example without labels diff --git a/tests/test_data.py b/tests/test_data.py index 6a2d162e0..9c3e07721 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -704,6 +704,22 @@ def test_tagged_corpus_downsample(): assert 3 == len(corpus.train) +def test_classification_corpus_multi_labels_without_negative_examples(tasks_base_path): + corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "multi_class_negative_examples", + allow_examples_without_labels=False) + assert len(corpus.train) == 7 + assert len(corpus.dev) == 4 + assert len(corpus.test) == 5 + + +def test_classification_corpus_multi_labels_with_negative_examples(tasks_base_path): + corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "multi_class_negative_examples", + allow_examples_without_labels=True) + assert len(corpus.train) == 8 + assert len(corpus.dev) == 5 + assert len(corpus.test) == 6 + + def test_spans(): sentence = Sentence("Zalando Research is located in Berlin .")