From df7e762b10727eae46a737a00a4b9a3dfeead598 Mon Sep 17 00:00:00 2001 From: Stefan Schweter Date: Mon, 11 Mar 2024 14:17:48 +0100 Subject: [PATCH 1/2] datasets: update MasakhaPOS dataset --- flair/datasets/sequence_labeling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index 0d0c78491..965207052 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -4887,11 +4887,13 @@ def __init__( "ibo", "kin", "lug", + "luo", "mos", "pcm", "nya", "sna", "swa", + "tsn", "twi", "wol", "xho", @@ -4936,5 +4938,5 @@ def __init__( corpora.append(corp) super().__init__( corpora, - name="africa-pos-" + "-".join(languages), + name="masakha-pos-" + "-".join(languages), ) From 0e96656c4f807672c75a2bb009499e2d06f2b1d6 Mon Sep 17 00:00:00 2001 From: Stefan Schweter Date: Mon, 11 Mar 2024 14:18:06 +0100 Subject: [PATCH 2/2] tests: update test cases for MasakhaPOS dataset (adding two new languages) --- tests/test_datasets.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 102206c6f..52fec1c5e 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -822,11 +822,13 @@ def test_masakha_pos_corpus(tasks_base_path): "ibo", "kin", "lug", + "luo", "mos", "pcm", "nya", "sna", "swa", + "tsn", "twi", "wol", "xho", @@ -835,7 +837,7 @@ def test_masakha_pos_corpus(tasks_base_path): ], } - africa_pos_stats = { + masakha_pos_stats = { "v1": { "bam": {"train": 775, "dev": 154, "test": 619}, "bbj": {"train": 750, "dev": 149, "test": 599}, @@ -845,11 +847,13 @@ def test_masakha_pos_corpus(tasks_base_path): "ibo": {"train": 803, "dev": 160, "test": 642}, "kin": {"train": 757, "dev": 151, "test": 604}, "lug": {"train": 733, "dev": 146, "test": 586}, + "luo": {"train": 758, "dev": 151, "test": 606}, "mos": {"train": 757, "dev": 151, "test": 604}, "pcm": {"train": 752, "dev": 150, "test": 600}, "nya": {"train": 728, "dev": 145, "test": 582}, "sna": {"train": 747, "dev": 149, "test": 596}, "swa": {"train": 693, "dev": 138, "test": 553}, + "tsn": {"train": 754, "dev": 150, "test": 602}, "twi": {"train": 785, "dev": 157, "test": 628}, "wol": {"train": 782, "dev": 156, "test": 625}, "xho": {"train": 752, "dev": 150, "test": 601}, @@ -865,7 +869,7 @@ def check_number_sentences(reference: int, actual: int, split_name: str, languag for language in supported_languages[version]: corpus = flair.datasets.MASAKHA_POS(languages=language, version=version) - gold_stats = africa_pos_stats[version][language] + gold_stats = masakha_pos_stats[version][language] check_number_sentences(len(corpus.train), gold_stats["train"], "train", language, version) check_number_sentences(len(corpus.dev), gold_stats["dev"], "dev", language, version)