diff --git a/buffalo/data/base.py b/buffalo/data/base.py index 4ab5c78..f15428a 100644 --- a/buffalo/data/base.py +++ b/buffalo/data/base.py @@ -26,6 +26,7 @@ def __init__(self, opt, *args, **kwargs): self.prepro = getattr(prepro, self.opt.data.value_prepro.name)(self.opt.data.value_prepro) self.value_prepro = self.prepro self.data_type = None + self.temp_file_list = [] @abc.abstractmethod def create_database(self, filename, **kwargs): @@ -166,6 +167,12 @@ def close(self): self.handle = None self.header = None + def temp_file_clear(self): + for path in self.temp_file_list: + if os.path.isfile(path): + os.remove(path) + self.temp_file_list = [] + def _create_database(self, path, **kwargs): # Create database structure if os.path.exists(path): @@ -469,6 +476,7 @@ def is_valid_option(self, opt) -> bool: class DataReader(object): def __init__(self, opt): self.opt = opt + self.temp_file_list = [] def get_main_path(self): return self.opt.input.main @@ -484,6 +492,7 @@ def _get_temporary_id_list_path(self, obj, name): if hasattr(self, field_name): return getattr(self, field_name) tmp_path = aux.get_temporary_file(self.opt.data.tmp_dir) + self.temp_file_list.append(tmp_path) with open(tmp_path, "w") as fout: if isinstance(obj, np.ndarray,) and obj.ndim == 1: fout.write("\n".join(map(str, obj.tolist()))) @@ -493,3 +502,9 @@ def _get_temporary_id_list_path(self, obj, name): raise RuntimeError(f"Unexpected data type for id list: {type(obj)}") setattr(self, field_name, tmp_path) return tmp_path + + def temp_file_clear(self): + for path in self.temp_file_list: + if os.path.isfile(path): + os.remove(path) + self.temp_file_list = [] diff --git a/buffalo/data/mm.py b/buffalo/data/mm.py index ee8e77b..fc64345 100644 --- a/buffalo/data/mm.py +++ b/buffalo/data/mm.py @@ -69,6 +69,7 @@ def get_main_path(self): log.get_logger("MatrixMarketDataReader").debug("creating temporary matrix-market data from numpy-kind array") tmp_path = aux.get_temporary_file(self.opt.data.tmp_dir) + self.temp_file_list.append(tmp_path) with open(tmp_path, "wb") as fout: if isinstance(main, (np.ndarray,)) and main.ndim == 2: main = scipy.sparse.csr_matrix(main) @@ -172,6 +173,7 @@ def _create_working_data(self, db, source_path, ignore_lines): vali_indexes = [] if "vali" not in db else db["vali"]["indexes"] vali_lines = [] file_path = aux.get_temporary_file(self.opt.data.tmp_dir) + self.temp_file_list.append(file_path) with open(file_path, "w") as w: fin = open(source_path, mode="r") file_size = fin.seek(0, 2) @@ -272,4 +274,6 @@ def create(self) -> h5py.File: if os.path.isfile(self.path): os.remove(self.path) raise + self.reader.temp_file_clear() + self.temp_file_clear() self.logger.info("DB built on %s" % data_path) diff --git a/buffalo/data/stream.py b/buffalo/data/stream.py index eb67af0..49e3130 100644 --- a/buffalo/data/stream.py +++ b/buffalo/data/stream.py @@ -171,6 +171,7 @@ def _build_sppmi(self, db, working_data_path, sppmi_total_lines, k): self.logger.debug("sort working_data") aux.psort(working_data_path, key=1) w_path = aux.get_temporary_file(root=self.opt.data.tmp_dir) + self.temp_file_list.append(w_path) self.logger.debug(f"build sppmi in_parallel. w: {w_path}") num_workers = psutil.cpu_count() nnz = parallel_build_sppmi(working_data_path, w_path, sppmi_total_lines, sz, k, num_workers) @@ -207,7 +208,9 @@ def _create_working_data(self, db, stream_main_path, itemids, warnings.simplefilter("ignore", ResourceWarning) if with_sppmi: w_sppmi = open(aux.get_temporary_file(root=self.opt.data.tmp_dir), "w") + self.temp_file_list.append(w_sppmi) file_path = aux.get_temporary_file(root=self.opt.data.tmp_dir) + self.temp_file_list.append(file_path) with open(stream_main_path) as fin, open(file_path, "w") as w: total_index = 0 internal_data_type = self.opt.data.internal_data_type @@ -308,4 +311,5 @@ def create(self) -> h5py.File: if os.path.isfile(self.path): os.remove(self.path) raise + self.temp_file_clear() self.logger.info("DB built on %s" % data_path)