From 13ad496fba4bb9015ef8b59b71bc0029b6a566f6 Mon Sep 17 00:00:00 2001 From: Tanay Soni Date: Thu, 19 Mar 2020 12:11:17 +0100 Subject: [PATCH] Skip a dict chunk if less than 2 documents are present --- farm/data_handler/data_silo.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/farm/data_handler/data_silo.py b/farm/data_handler/data_silo.py index 53ab183ea..11a5d6966 100644 --- a/farm/data_handler/data_silo.py +++ b/farm/data_handler/data_silo.py @@ -16,7 +16,7 @@ from tqdm import tqdm from farm.data_handler.dataloader import NamedDataLoader -from farm.data_handler.processor import Processor +from farm.data_handler.processor import Processor, BertStyleLMProcessor from farm.data_handler.utils import grouper, stream_grouper from farm.utils import MLFlowLogger as MlLogger from farm.utils import log_ascii_workers, calc_chunksize @@ -608,6 +608,8 @@ def __iter__(self): batch = [] for datasets, tensor_names in results: + if not datasets: + continue self.tensor_names = tensor_names for ds in datasets: batch.append(ds) @@ -626,6 +628,10 @@ def _dataset_from_chunk(self, chunk): :return: PyTorch Dataset """ dicts = [d[1] for d in chunk] + # need at least 2 documents to sample random sentences from + if len(dicts) < 2 and type(self.processor) == BertStyleLMProcessor: + logger.info("Skipping a dict chunk as it contains less than 2 documents ...") + return None, None indices = [x[0] for x in chunk] datasets, tensor_names = self.processor.dataset_from_dicts(dicts=dicts, indices=indices) return datasets, tensor_names