From d0a8bd5f6526b583949a3c90e62f763873aeaef9 Mon Sep 17 00:00:00 2001 From: Kishaloy Halder Date: Mon, 22 Mar 2021 13:50:38 +0100 Subject: [PATCH] GH-2174: fixed IMDB data splitting logic --- flair/datasets/document_classification.py | 24 +++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/flair/datasets/document_classification.py b/flair/datasets/document_classification.py index 310495745..b4f006f38 100644 --- a/flair/datasets/document_classification.py +++ b/flair/datasets/document_classification.py @@ -745,21 +745,25 @@ def __init__(self, base_path: Path = Path(base_path) # this dataset name - dataset_name = self.__class__.__name__.lower() + '_v3' - - if rebalance_corpus: - dataset_name = dataset_name + '-rebalanced' + dataset_name = self.__class__.__name__.lower() + '_v4' # default dataset folder is the cache root if not base_path: base_path = Path(flair.cache_root) / "datasets" - data_folder = base_path / dataset_name # download data if necessary imdb_acl_path = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz" + + if rebalance_corpus: + dataset_name = dataset_name + '-rebalanced' + data_folder = base_path / dataset_name data_path = Path(flair.cache_root) / "datasets" / dataset_name - data_file = data_path / "train.txt" - if not data_file.is_file(): + train_data_file = data_path / "train.txt" + test_data_file = data_path / "test.txt" + + if train_data_file.is_file()==False or (rebalance_corpus==False and test_data_file.is_file()==False): + [os.remove(file_path) for file_path in [train_data_file, test_data_file] if file_path.is_file()] + cached_path(imdb_acl_path, Path("datasets") / dataset_name) import tarfile @@ -783,7 +787,11 @@ def __init__(self, if f"{dataset}/{label}" in m.name ], ) - with open(f"{data_path}/train.txt", "at") as f_p: + data_file = train_data_file + if rebalance_corpus==False and dataset=="test": + data_file = test_data_file + + with open(data_file, "at") as f_p: current_path = data_path / "aclImdb" / dataset / label for file_name in current_path.iterdir(): if file_name.is_file() and file_name.name.endswith(