diff --git a/flair/datasets/__init__.py b/flair/datasets/__init__.py index 98eeb46d2..c298b1cd1 100755 --- a/flair/datasets/__init__.py +++ b/flair/datasets/__init__.py @@ -91,6 +91,7 @@ from .treebanks import UniversalDependenciesCorpus from .treebanks import UniversalDependenciesDataset from .treebanks import UD_ENGLISH +from .treebanks import UD_ESTONIAN from .treebanks import UD_GERMAN from .treebanks import UD_GERMAN_HDT from .treebanks import UD_DUTCH diff --git a/flair/datasets/treebanks.py b/flair/datasets/treebanks.py index 1707cdc71..5607f8333 100755 --- a/flair/datasets/treebanks.py +++ b/flair/datasets/treebanks.py @@ -252,6 +252,33 @@ def __init__(self, base_path: Union[str, Path] = None, in_memory: bool = True, s super(UD_ENGLISH, self).__init__(data_folder, in_memory=in_memory, split_multiwords=split_multiwords) +class UD_ESTONIAN(UniversalDependenciesCorpus): + def __init__(self, base_path: Union[str, Path] = None, in_memory: bool = True, split_multiwords: bool = True): + + if type(base_path) == str: + base_path: Path = Path(base_path) + + # this dataset name + dataset_name = self.__class__.__name__.lower() + + # default dataset folder is the cache root + if not base_path: + base_path = Path(flair.cache_root) / "datasets" + data_folder = base_path / dataset_name + + # download data if necessary + web_path = "https://raw.githubusercontent.com/UniversalDependencies/UD_Estonian-EDT/master" + cached_path(f"{web_path}/et_edt-ud-dev.conllu", Path("datasets") / dataset_name) + cached_path( + f"{web_path}/et_edt-ud-test.conllu", Path("datasets") / dataset_name + ) + cached_path( + f"{web_path}/et_edt-ud-train.conllu", Path("datasets") / dataset_name + ) + + super(UD_ESTONIAN, self).__init__(data_folder, in_memory=in_memory, split_multiwords=split_multiwords) + + class UD_GERMAN(UniversalDependenciesCorpus): def __init__(self, base_path: Union[str, Path] = None, in_memory: bool = True, split_multiwords: bool = True):