From dc013c335516e97ef3df3966f6f38c821dc762b9 Mon Sep 17 00:00:00 2001 From: Flyer Cheng Date: Thu, 23 May 2019 07:55:17 +0800 Subject: [PATCH] Reorg logical flow in train (#37) --- .gitignore | 2 + ModelConf.py | 9 +- data_encoding.py | 61 ------- predict.py | 25 +-- problem.py | 122 +++++++++----- test.py | 25 +-- train.py | 371 ++++++++++++++++++++++-------------------- utils/common_utils.py | 21 +++ 8 files changed, 317 insertions(+), 319 deletions(-) delete mode 100644 data_encoding.py diff --git a/.gitignore b/.gitignore index d9acd66..d7627e4 100644 --- a/.gitignore +++ b/.gitignore @@ -2,5 +2,7 @@ *~ *.pyc *.cache* +*.vs* dataset/GloVe/ +dataset/20_newsgroups/ models/ diff --git a/ModelConf.py b/ModelConf.py index ad67f74..7b8b429 100644 --- a/ModelConf.py +++ b/ModelConf.py @@ -10,6 +10,7 @@ import copy import torch import logging +import shutil from losses.BaseLossConf import BaseLossConf #import traceback @@ -60,8 +61,8 @@ def load_from_file(self, conf_path): self.tool_version = self.get_item(['tool_version']) self.language = self.get_item(['language'], default='english').lower() self.problem_type = self.get_item(['inputs', 'dataset_type']).lower() - if ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging: - self.tagging_scheme = self.get_item(['inputs', 'tagging_scheme'], default=None, use_default=True) + #if ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging: + self.tagging_scheme = self.get_item(['inputs', 'tagging_scheme'], default=None, use_default=True) if self.mode == 'normal': self.use_cache = self.get_item(['inputs', 'use_cache'], True) @@ -519,3 +520,7 @@ def check_version_compat(self, nb_version, conf_version): if not (nb_version_split[0] == conf_version_split[0] and nb_version_split[1] == conf_version_split[1]): raise ConfigurationError('The NeuronBlocks version is %s, but the configuration version is %s, please update your configuration to %s.%s.X' % (nb_version, conf_version, nb_version_split[0], nb_version_split[1])) + def back_up(self, params): + shutil.copy(params.conf_path, self.save_base_dir) + logging.info('Configuration file is backed up to %s' % (self.save_base_dir)) + diff --git a/data_encoding.py b/data_encoding.py deleted file mode 100644 index 57a3623..0000000 --- a/data_encoding.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT license. - -# add the project root to python path -import os -from settings import ProblemTypes, version - -import argparse -import logging - -from ModelConf import ModelConf -from problem import Problem -from utils.common_utils import log_set, dump_to_pkl, load_from_pkl - -def main(params, data_path, save_path): - conf = ModelConf("cache", params.conf_path, version, params) - - if ProblemTypes[conf.problem_type] == ProblemTypes.sequence_tagging: - problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name, - source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True, - target_with_start=True, target_with_end=True, target_with_unk=True, target_with_pad=True, same_length=True, - with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer, - remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix) - elif ProblemTypes[conf.problem_type] == ProblemTypes.classification \ - or ProblemTypes[conf.problem_type] == ProblemTypes.regression: - problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name, - source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True, - target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False, same_length=True, - with_bos_eos=conf.add_start_end_for_seq, tokenizer=conf.tokenizer, remove_stopwords=conf.remove_stopwords, - DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix) - - if os.path.isfile(conf.problem_path): - problem.load_problem(conf.problem_path) - logging.info("Cache loaded!") - logging.debug("Cache loaded from %s" % conf.problem_path) - else: - raise Exception("Cache does not exist!") - - data, length, target = problem.encode(data_path, conf.file_columns, conf.input_types, conf.file_with_col_header, - conf.object_inputs, conf.answer_column_name, conf.min_sentence_len, - extra_feature=conf.extra_feature,max_lengths=conf.max_lengths, file_format='tsv', - cpu_num_workers=conf.cpu_num_workers) - if not os.path.isdir(os.path.dirname(save_path)): - os.makedirs(os.path.dirname(save_path)) - dump_to_pkl({'data': data, 'length': length, 'target': target}, save_path) - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Data encoding') - parser.add_argument("data_path", type=str) - parser.add_argument("save_path", type=str) - parser.add_argument("--conf_path", type=str, default='conf.json', help="configuration path") - parser.add_argument("--debug", type=bool, default=False) - parser.add_argument("--force", type=bool, default=False) - - log_set('encoding_data.log') - - params, _ = parser.parse_known_args() - - if params.debug is True: - import debugger - main(params, params.data_path, params.save_path) \ No newline at end of file diff --git a/predict.py b/predict.py index 4df23df..3e96f11 100644 --- a/predict.py +++ b/predict.py @@ -14,27 +14,10 @@ def main(params): conf = ModelConf('predict', params.conf_path, version, params, mode=params.mode) - - if ProblemTypes[conf.problem_type] == ProblemTypes.sequence_tagging: - problem = Problem(conf.problem_type, conf.input_types, None, - source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True, - target_with_start=True, target_with_end=True, target_with_unk=True, target_with_pad=True, same_length=True, - with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer, - remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix) - elif ProblemTypes[conf.problem_type] == ProblemTypes.classification \ - or ProblemTypes[conf.problem_type] == ProblemTypes.regression: - problem = Problem(conf.problem_type, conf.input_types, None, - source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True, - target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False, same_length=True, - with_bos_eos=conf.add_start_end_for_seq, tokenizer=conf.tokenizer, remove_stopwords=conf.remove_stopwords, - DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix) - elif ProblemTypes[conf.problem_type] == ProblemTypes.mrc: - problem = Problem(conf.problem_type, conf.input_types, - source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True, - target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False, - same_length=False, with_bos_eos=False, tokenizer=conf.tokenizer, - remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix) - + problem = Problem('predict', conf.problem_type, conf.input_types, None, + with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer, + remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix) + if os.path.isfile(conf.saved_problem_path): problem.load_problem(conf.saved_problem_path) logging.info("Problem loaded!") diff --git a/problem.py b/problem.py index 21c633b..2610da8 100644 --- a/problem.py +++ b/problem.py @@ -26,10 +26,8 @@ import torch.nn as nn class Problem(): - def __init__(self, problem_type, input_types, answer_column_name=None, lowercase=False, - source_with_start=True, source_with_end=True, source_with_unk=True, - source_with_pad=True, target_with_start=False, target_with_end=False, - target_with_unk=True, target_with_pad=True, same_length=True, with_bos_eos=True, + + def __init__(self, phase, problem_type, input_types, answer_column_name=None, lowercase=False, with_bos_eos=True, tagging_scheme=None, tokenizer="nltk", remove_stopwords=False, DBC2SBC=True, unicode_fix=True): """ @@ -50,9 +48,24 @@ def __init__(self, problem_type, input_types, answer_column_name=None, lowercase same_length: with_bos_eos: whether to add bos and eos when encoding """ - self.lowercase = lowercase - self.input_dicts = dict() + # init + source_with_start, source_with_end, source_with_unk, source_with_pad, \ + target_with_start, target_with_end, target_with_unk, target_with_pad, \ + same_length = (True, ) * 9 + if ProblemTypes[problem_type] == ProblemTypes.sequence_tagging: + pass + elif \ + ProblemTypes[problem_type] == ProblemTypes.classification or \ + ProblemTypes[problem_type] == ProblemTypes.regression: + target_with_start, target_with_end, target_with_unk, target_with_pad, same_length = (False, ) * 5 + if phase != 'train': + same_length = True + elif ProblemTypes[problem_type] == ProblemTypes.mrc: + target_with_start, target_with_end, target_with_unk, target_with_pad, same_length = (False, ) * 5 + with_bos_eos = False + + self.lowercase = lowercase self.problem_type = problem_type self.tagging_scheme = tagging_scheme self.with_bos_eos = with_bos_eos @@ -65,6 +78,7 @@ def __init__(self, problem_type, input_types, answer_column_name=None, lowercase self.target_with_unk = target_with_unk self.target_with_pad = target_with_pad + self.input_dicts = dict() for input_type in input_types: self.input_dicts[input_type] = CellDict(with_unk=source_with_unk, with_pad=source_with_pad, with_start=source_with_start, with_end=source_with_end) @@ -245,6 +259,10 @@ def build(self, data_path_list, file_columns, input_types, file_with_col_header, Returns: """ + # parameter check + if not word2vec_path: + word_emb_dim, format, file_type, involve_all_words = None, None, None, None + if 'bpe' in input_types: try: bpe_encoder = BPEEncoder(input_types['bpe']['bpe_path']) @@ -324,51 +342,65 @@ def build(self, data_path_list, file_columns, input_types, file_with_col_header, return word_emb_matrix - def encode_data_multi_processor(self, data_generator, cpu_num_workers, file_columns, input_types, object_inputs, - answer_column_name, min_sentence_len, extra_feature, max_lengths=None, fixed_lengths=None, file_format="tsv", bpe_encoder=None): + @staticmethod + def _merge_encode_data(dest_dict, src_dict): + if len(dest_dict) == 0: + dest_dict = src_dict + else: + for branch in src_dict: + for input_type in dest_dict[branch]: + dest_dict[branch][input_type].extend(src_dict[branch][input_type]) + return dest_dict + + @staticmethod + def _merge_encode_lengths(dest_dict, src_dict): def judge_dict(obj): return True if isinstance(obj, dict) else False - cnt_legal, cnt_illegal = 0, 0 - output_data = dict() - lengths = dict() - target = dict() - for data in tqdm(data_generator): + + if len(dest_dict) == 0: + dest_dict = src_dict + else: + for branch in src_dict: + if judge_dict(src_dict[branch]): + for type_branch in src_dict[branch]: + dest_dict[branch][type_branch].extend(src_dict[branch][type_branch]) + else: + dest_dict[branch].extend(src_dict[branch]) + return dest_dict + + @staticmethod + def _merge_target(dest_dict, src_dict): + if not src_dict: + return src_dict + + if len(dest_dict) == 0: + dest_dict = src_dict + else: + for single_type in src_dict: + dest_dict[single_type].extend(src_dict[single_type]) + return dest_dict + + def encode_data_multi_processor(self, data_generator, cpu_num_workers, file_columns, input_types, object_inputs, + answer_column_name, min_sentence_len, extra_feature, max_lengths=None, fixed_lengths=None, file_format="tsv", bpe_encoder=None): + + + for data in data_generator: scheduler = ProcessorsScheduler(cpu_num_workers) func_args = (data, file_columns, input_types, object_inputs, answer_column_name, min_sentence_len, extra_feature, max_lengths, fixed_lengths, file_format, bpe_encoder) res = scheduler.run_data_parallel(self.encode_data_list, func_args) + output_data, lengths, target = dict(), dict(), dict() + cnt_legal, cnt_illegal = 0, 0 for (index, j) in res: # logging.info("collect proccesor %d result"%index) tmp_data, tmp_lengths, tmp_target, tmp_cnt_legal, tmp_cnt_illegal = j.get() - - if len(output_data) == 0: - output_data = tmp_data - else: - for branch in tmp_data: - for input_type in output_data[branch]: - output_data[branch][input_type].extend(tmp_data[branch][input_type]) - if len(lengths) == 0: - lengths = tmp_lengths - else: - for branch in tmp_lengths: - if judge_dict(tmp_lengths[branch]): - for type_branch in tmp_lengths[branch]: - lengths[branch][type_branch].extend(tmp_lengths[branch][type_branch]) - else: - lengths[branch].extend(tmp_lengths[branch]) - if not tmp_target: - target = None - else: - if len(target) == 0: - target = tmp_target - else: - for single_type in tmp_target: - target[single_type].extend(tmp_target[single_type]) + output_data = self._merge_encode_data(output_data, tmp_data) + lengths = self._merge_encode_lengths(lengths, tmp_lengths) + target = self._merge_target(target, tmp_target) cnt_legal += tmp_cnt_legal cnt_illegal += tmp_cnt_illegal - - return output_data, lengths, target, cnt_legal, cnt_illegal + yield output_data, lengths, target, cnt_legal, cnt_illegal def encode_data_list(self, data_list, file_columns, input_types, object_inputs, answer_column_name, min_sentence_len, extra_feature, max_lengths=None, fixed_lengths=None, file_format="tsv", bpe_encoder=None): @@ -678,9 +710,19 @@ def encode(self, data_path, file_columns, input_types, file_with_col_header, obj bpe_encoder = None progress = self.get_data_generator_from_file([data_path], file_with_col_header) - data, lengths, target, cnt_legal, cnt_illegal = self.encode_data_multi_processor(progress, cpu_num_workers, + encoder_generator = self.encode_data_multi_processor(progress, cpu_num_workers, file_columns, input_types, object_inputs, answer_column_name, min_sentence_len, extra_feature, max_lengths, fixed_lengths, file_format, bpe_encoder=bpe_encoder) + + data, lengths, target = dict(), dict(), dict() + cnt_legal, cnt_illegal = 0, 0 + for temp_data, temp_lengths, temp_target, temp_cnt_legal, temp_cnt_illegal in tqdm(encoder_generator): + data = self._merge_encode_data(data, temp_data) + lengths = self._merge_encode_lengths(lengths, temp_lengths) + target = self._merge_target(target, temp_target) + cnt_legal += temp_cnt_legal + cnt_illegal += temp_cnt_illegal + logging.info("%s: %d legal samples, %d illegal samples" % (data_path, cnt_legal, cnt_illegal)) return data, lengths, target diff --git a/test.py b/test.py index b4527f6..a1ef8c8 100644 --- a/test.py +++ b/test.py @@ -16,27 +16,10 @@ def main(params): conf = ModelConf("test", params.conf_path, version, params, mode=params.mode) - - if ProblemTypes[conf.problem_type] == ProblemTypes.sequence_tagging: - problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name, - source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True, - target_with_start=True, target_with_end=True, target_with_unk=True, target_with_pad=True, same_length=True, - with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer, - remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix) - elif ProblemTypes[conf.problem_type] == ProblemTypes.classification \ - or ProblemTypes[conf.problem_type] == ProblemTypes.regression: - problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name, - source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True, - target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False, same_length=True, - with_bos_eos=conf.add_start_end_for_seq, tokenizer=conf.tokenizer, remove_stopwords=conf.remove_stopwords, - DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix) - elif ProblemTypes[conf.problem_type] == ProblemTypes.mrc: - problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name, - source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True, - target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False, - same_length=False, with_bos_eos=False, tokenizer=conf.tokenizer, - remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix) - + problem = Problem("test", conf.problem_type, conf.input_types, conf.answer_column_name, + with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer, + remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix) + if os.path.isfile(conf.saved_problem_path): problem.load_problem(conf.saved_problem_path) logging.info("Problem loaded!") diff --git a/train.py b/train.py index 4d169ab..fedbf5d 100644 --- a/train.py +++ b/train.py @@ -21,202 +21,207 @@ from LearningMachine import LearningMachine +class Cache: + def __init__(self): + self.dictionary_invalid = True + self.embedding_invalid = True + self.encoding_invalid = True + + def _check_dictionary(self, conf, params): + # init status + self.dictionary_invalid = True + self.embedding_invalid = True + + # cache_conf + cache_conf = None + cache_conf_path = os.path.join(conf.cache_dir, 'conf_cache.json') + if os.path.isfile(cache_conf_path): + params_cache = copy.deepcopy(params) + try: + cache_conf = ModelConf('cache', cache_conf_path, version, params_cache) + except Exception as e: + cache_conf = None + if cache_conf is None or not self._verify_conf(cache_conf, conf): + return False + + # problem + if not os.path.isfile(conf.problem_path): + return False + + # embedding + if conf.emb_pkl_path: + if not os.path.isfile(conf.emb_pkl_path): + return False + self.embedding_invalid = False + + self.dictionary_invalid = False + return True + + def _check_encoding(self, conf): + self.encoding_invalid = False + return True + + def check(self, conf, params): + # dictionary + if not self._check_dictionary(conf, params): + self._renew_cache(params, conf.cache_dir) + return + # encoding + if not self._check_encoding(conf): + self._renew_cache(params, conf.cache_dir) -def verify_cache(cache_conf, cur_conf): - """ To verify if the cache is appliable to current configuration - - Args: - cache_conf (ModelConf): - cur_conf (ModelConf): - - Returns: - - """ - if cache_conf.tool_version != cur_conf.tool_version: - return False - - attribute_to_cmp = ['file_columns', 'object_inputs', 'answer_column_name', 'input_types'] - - flag = True - for attr in attribute_to_cmp: - if not (hasattr(cache_conf, attr) and hasattr(cur_conf, attr) and getattr(cache_conf, attr) == getattr(cur_conf, attr)): - logging.error('configuration %s is inconsistent with the old cache' % attr) - flag = False - return flag - - -def main(params): - conf = ModelConf("train", params.conf_path, version, params, mode=params.mode) - - shutil.copy(params.conf_path, conf.save_base_dir) - logging.info('Configuration file is backed up to %s' % (conf.save_base_dir)) - - if ProblemTypes[conf.problem_type] == ProblemTypes.sequence_tagging: - problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name, - source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True, - target_with_start=True, target_with_end=True, target_with_unk=True, target_with_pad=True, same_length=True, - with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer, - remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix) - elif ProblemTypes[conf.problem_type] == ProblemTypes.classification \ - or ProblemTypes[conf.problem_type] == ProblemTypes.regression: - problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name, - source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True, - target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False, - same_length=False, with_bos_eos=conf.add_start_end_for_seq, tokenizer=conf.tokenizer, - remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix) - elif ProblemTypes[conf.problem_type] == ProblemTypes.mrc: - problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name, - source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True, - target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False, - same_length=False, with_bos_eos=False, tokenizer=conf.tokenizer, - remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix) - - cache_load_flag = False - if not conf.pretrained_model_path: - # first time training, load cache if appliable - if conf.use_cache: - cache_conf_path = os.path.join(conf.cache_dir, 'conf_cache.json') - if os.path.isfile(cache_conf_path): - params_cache = copy.deepcopy(params) - ''' - for key in vars(params_cache): - setattr(params_cache, key, None) - params_cache.mode = params.mode - ''' - try: - cache_conf = ModelConf('cache', cache_conf_path, version, params_cache) - except Exception as e: - cache_conf = None - if cache_conf is None or verify_cache(cache_conf, conf) is not True: - logging.info('Found cache that is ineffective') - if params.mode == 'philly' or params.force is True: - renew_option = 'yes' - else: - renew_option = input('There exists ineffective cache %s for old models. Input "yes" to renew cache and "no" to exit. (default:no): ' % os.path.abspath(conf.cache_dir)) - if renew_option.lower() != 'yes': - exit(0) - else: - shutil.rmtree(conf.cache_dir) - time.sleep(2) # sleep 2 seconds since the deleting is asynchronous - logging.info('Old cache is deleted') - else: - logging.info('Found cache that is appliable to current configuration...') - - elif os.path.isdir(conf.cache_dir): - renew_option = input('There exists ineffective cache %s for old models. Input "yes" to renew cache and "no" to exit. (default:no): ' % os.path.abspath(conf.cache_dir)) - if renew_option.lower() != 'yes': - exit(0) - else: - shutil.rmtree(conf.cache_dir) - time.sleep(2) # Sleep 2 seconds since the deleting is asynchronous - logging.info('Old cache is deleted') - - if not os.path.exists(conf.cache_dir): - os.makedirs(conf.cache_dir) - shutil.copy(params.conf_path, os.path.join(conf.cache_dir, 'conf_cache.json')) - - # first time training, load problem from cache, and then backup the cache to model_save_dir/.necessary_cache/ - if conf.use_cache and os.path.isfile(conf.problem_path): + def load(self, conf, problem, emb_matrix): + # load dictionary when (not finetune) and (cache valid) + if not conf.pretrained_model_path and not self.dictionary_invalid: problem.load_problem(conf.problem_path) - if conf.emb_pkl_path is not None: - if os.path.isfile(conf.emb_pkl_path): - emb_matrix = np.array(load_from_pkl(conf.emb_pkl_path)) - cache_load_flag = True - else: - if params.mode == 'normal': - renew_option = input('The cache is invalid because the embedding matrix does not exist in the cache directory. Input "yes" to renew cache and "no" to exit. (default:no): ') - if renew_option.lower() != 'yes': - exit(0) - else: - # by default, renew cache - renew_option = 'yes' - else: - emb_matrix = None - cache_load_flag = True - if cache_load_flag: - logging.info("Cache loaded!") - - if cache_load_flag is False: - logging.info("Preprocessing... Depending on your corpus size, this step may take a while.") - # modify train_data_path to [train_data_path, valid_data_path, test_data_path] - # remember the test_data may be None - data_path_list = [conf.train_data_path, conf.valid_data_path, conf.test_data_path] - if conf.pretrained_emb_path: - emb_matrix = problem.build(data_path_list, conf.file_columns, conf.input_types, conf.file_with_col_header, - conf.answer_column_name, word2vec_path=conf.pretrained_emb_path, - word_emb_dim=conf.pretrained_emb_dim, format=conf.pretrained_emb_type, - file_type=conf.pretrained_emb_binary_or_text, involve_all_words=conf.involve_all_words_in_pretrained_emb, - show_progress=True if params.mode == 'normal' else False, cpu_num_workers = conf.cpu_num_workers, - max_vocabulary=conf.max_vocabulary, word_frequency=conf.min_word_frequency) - else: - emb_matrix = problem.build(data_path_list, conf.file_columns, conf.input_types, conf.file_with_col_header, - conf.answer_column_name, word2vec_path=None, word_emb_dim=None, format=None, - file_type=None, involve_all_words=conf.involve_all_words_in_pretrained_emb, - show_progress=True if params.mode == 'normal' else False, cpu_num_workers = conf.cpu_num_workers, - max_vocabulary=conf.max_vocabulary, word_frequency=conf.min_word_frequency) - + if not self.embedding_invalid: + emb_matrix = np.array(load_from_pkl(conf.emb_pkl_path)) + logging.info('[Cache] loading dictionary successfully') + + if not self.encoding_invalid: + pass + return problem, emb_matrix + + def save(self, conf, params, problem, emb_matrix): + if not os.path.exists(conf.cache_dir): + os.makedirs(conf.cache_dir) + shutil.copy(params.conf_path, os.path.join(conf.cache_dir, 'conf_cache.json')) + if self.dictionary_invalid: if conf.mode == 'philly' and conf.emb_pkl_path.startswith('/hdfs/'): with HDFSDirectTransferer(conf.problem_path, with_hdfs_command=True) as transferer: transferer.pkl_dump(problem.export_problem(conf.problem_path, ret_without_save=True)) else: problem.export_problem(conf.problem_path) - if conf.use_cache: - logging.info("Cache saved to %s" % conf.problem_path) - if emb_matrix is not None and conf.emb_pkl_path is not None: - if conf.mode == 'philly' and conf.emb_pkl_path.startswith('/hdfs/'): - with HDFSDirectTransferer(conf.emb_pkl_path, with_hdfs_command=True) as transferer: - transferer.pkl_dump(emb_matrix) - else: - dump_to_pkl(emb_matrix, conf.emb_pkl_path) - logging.info("Embedding matrix saved to %s" % conf.emb_pkl_path) - else: - logging.debug("Cache saved to %s" % conf.problem_path) + logging.info("[Cache] problem is saved to %s" % conf.problem_path) + if emb_matrix is not None and conf.emb_pkl_path is not None: + if conf.mode == 'philly' and conf.emb_pkl_path.startswith('/hdfs/'): + with HDFSDirectTransferer(conf.emb_pkl_path, with_hdfs_command=True) as transferer: + transferer.pkl_dump(emb_matrix) + else: + dump_to_pkl(emb_matrix, conf.emb_pkl_path) + logging.info("Embedding matrix saved to %s" % conf.emb_pkl_path) + + if self.encoding_invalid: + pass - # Back up the problem.pkl to save_base_dir/.necessary_cache. During test phase, we would load cache from save_base_dir/.necessary_cache/problem.pkl + def back_up(self, conf, problem): cache_bakup_path = os.path.join(conf.save_base_dir, 'necessary_cache/') logging.debug('Prepare dir: %s' % cache_bakup_path) prepare_dir(cache_bakup_path, True, allow_overwrite=True, clear_dir_if_exist=True) - shutil.copy(conf.problem_path, cache_bakup_path) + problem.export_problem(cache_bakup_path+'problem.pkl') logging.debug("Problem %s is backed up to %s" % (conf.problem_path, cache_bakup_path)) - if problem.output_dict: - logging.debug("Problem target cell dict: %s" % (problem.output_dict.cell_id_map)) - if params.make_cache_only: - logging.info("Finish building cache!") + def _renew_cache(self, params, cache_path): + if not os.path.exists(cache_path): return + logging.info('Found cache that is ineffective') + renew_option = 'yes' + if params.mode != 'philly' and params.force is not True: + renew_option = input('There exists ineffective cache %s for old models. Input "yes" to renew cache and "no" to exit. (default:no): ' % os.path.abspath(cache_path)) + if renew_option.lower() != 'yes': + exit(0) + else: + shutil.rmtree(cache_path) + time.sleep(2) # sleep 2 seconds since the deleting is asynchronous + logging.info('Old cache is deleted') + + def _verify_conf(self, cache_conf, cur_conf): + """ To verify if the cache is appliable to current configuration + + Args: + cache_conf (ModelConf): + cur_conf (ModelConf): + + Returns: + + """ + if cache_conf.tool_version != cur_conf.tool_version: + return False + + attribute_to_cmp = ['file_columns', 'object_inputs', 'answer_column_name', 'input_types'] + + flag = True + for attr in attribute_to_cmp: + if not (hasattr(cache_conf, attr) and hasattr(cur_conf, attr) and getattr(cache_conf, attr) == getattr(cur_conf, attr)): + logging.error('configuration %s is inconsistent with the old cache' % attr) + flag = False + return flag - vocab_info = dict() # include input_type's vocab_size & init_emd_matrix - vocab_sizes = problem.get_vocab_sizes() - for input_cluster in vocab_sizes: - vocab_info[input_cluster] = dict() - vocab_info[input_cluster]['vocab_size'] = vocab_sizes[input_cluster] - # add extra info for char_emb - if input_cluster.lower() == 'char': - for key, value in conf.input_types[input_cluster].items(): - if key != 'cols': - vocab_info[input_cluster][key] = value - if input_cluster == 'word' and emb_matrix is not None: - vocab_info[input_cluster]['init_weights'] = emb_matrix - else: - vocab_info[input_cluster]['init_weights'] = None - - lm = LearningMachine('train', conf, problem, vocab_info=vocab_info, initialize=True, use_gpu=conf.use_gpu) - else: - # when finetuning, load previous saved problem +def main(params): + # init + conf = ModelConf("train", params.conf_path, version, params, mode=params.mode) + problem = Problem("train", conf.problem_type, conf.input_types, conf.answer_column_name, + with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer, + remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix) + if conf.pretrained_model_path: + ### when finetuning, load previous saved problem problem.load_problem(conf.saved_problem_path) - lm = LearningMachine('train', conf, problem, vocab_info=None, initialize=False, use_gpu=conf.use_gpu) + + # cache verification + emb_matrix = None + cache = Cache() + if conf.use_cache: + ## check + cache.check(conf, params) + ## load + problem, emb_matrix = cache.load(conf, problem, emb_matrix) + + # data preprocessing + ## build dictionary when (not in finetune model) and (not use cache or cache invalid) + if (not conf.pretrained_model_path) and ((conf.use_cache == False) or cache.dictionary_invalid): + logging.info("Preprocessing... Depending on your corpus size, this step may take a while.") + # modify train_data_path to [train_data_path, valid_data_path, test_data_path] + # remember the test_data may be None + data_path_list = [conf.train_data_path, conf.valid_data_path, conf.test_data_path] + emb_matrix = problem.build(data_path_list, conf.file_columns, conf.input_types, conf.file_with_col_header, + conf.answer_column_name, word2vec_path=conf.pretrained_emb_path, + word_emb_dim=conf.pretrained_emb_dim, format=conf.pretrained_emb_type, + file_type=conf.pretrained_emb_binary_or_text, involve_all_words=conf.involve_all_words_in_pretrained_emb, + show_progress=True if params.mode == 'normal' else False, cpu_num_workers = conf.cpu_num_workers, + max_vocabulary=conf.max_vocabulary, word_frequency=conf.min_word_frequency) + + ## encode rawdata when do not use cache + if conf.use_cache == False: + pass + + # environment preparing + ## cache save + if conf.use_cache: + cache.save(conf, params, problem, emb_matrix) + + if params.make_cache_only: + if conf.use_cache: + logging.info("Finish building cache!") + else: + logging.info('Please set parameters "use_cache" is true') + return + + ## back up the problem.pkl to save_base_dir/.necessary_cache. + ## During test phase, we would load cache from save_base_dir/.necessary_cache/problem.pkl + conf.back_up(params) + cache.back_up(conf, problem) + if problem.output_dict: + logging.debug("Problem target cell dict: %s" % (problem.output_dict.cell_id_map)) + + # train phase + ## init + ### model + vocab_info, initialize = None, False + if not conf.pretrained_model_path: + vocab_info, initialize = get_vocab_info(problem, emb_matrix), True + + lm = LearningMachine('train', conf, problem, vocab_info=vocab_info, initialize=initialize, use_gpu=conf.use_gpu) + if conf.pretrained_model_path: + logging.info('Loading the pretrained model: %s...' % conf.pretrained_model_path) + lm.load_model(conf.pretrained_model_path) + ### loss if len(conf.metrics_post_check) > 0: for metric_to_chk in conf.metrics_post_check: metric, target = metric_to_chk.split('@') if not problem.output_dict.has_cell(target): raise Exception("The target %s of %s does not exist in the training data." % (target, metric_to_chk)) - - if conf.pretrained_model_path: - logging.info('Loading the pretrained model: %s...' % conf.pretrained_model_path) - lm.load_model(conf.pretrained_model_path) - loss_conf = conf.loss loss_conf['output_layer_id'] = conf.output_layer_id loss_conf['answer_column_name'] = conf.answer_column_name @@ -225,11 +230,13 @@ def main(params): if conf.use_gpu is True: loss_fn.cuda() + ### optimizer optimizer = eval(conf.optimizer_name)(lm.model.parameters(), **conf.optimizer_params) + ## train lm.train(optimizer, loss_fn) - # test the best model with the best model saved + ## test the best model with the best model saved lm.load_model(conf.model_save_path) if conf.test_data_path is not None: test_path = conf.test_data_path @@ -241,6 +248,22 @@ def main(params): else: lm.test(loss_fn, test_path) +def get_vocab_info(problem, emb_matrix): + vocab_info = dict() # include input_type's vocab_size & init_emd_matrix + vocab_sizes = problem.get_vocab_sizes() + for input_cluster in vocab_sizes: + vocab_info[input_cluster] = dict() + vocab_info[input_cluster]['vocab_size'] = vocab_sizes[input_cluster] + # add extra info for char_emb + if input_cluster.lower() == 'char': + for key, value in conf.input_types[input_cluster].items(): + if key != 'cols': + vocab_info[input_cluster][key] = value + if input_cluster == 'word' and emb_matrix is not None: + vocab_info[input_cluster]['init_weights'] = emb_matrix + else: + vocab_info[input_cluster]['init_weights'] = None + return vocab_info if __name__ == "__main__": parser = argparse.ArgumentParser(description='Training') diff --git a/utils/common_utils.py b/utils/common_utils.py index cc41834..704989e 100644 --- a/utils/common_utils.py +++ b/utils/common_utils.py @@ -9,6 +9,7 @@ import time import tempfile import subprocess +import hashlib def log_set(log_path, console_level='INFO', console_detailed=False, disable_log_file=False): """ @@ -216,3 +217,23 @@ def prepare_dir(path, is_dir, allow_overwrite=False, clear_dir_if_exist=False, e overwrite_option = input('The file %s already exists, input "yes" to allow us to overwrite it or "no" to exit. (default:no): ' % path) if overwrite_option.lower() != 'yes': exit(0) + +def md5(file_paths, chunk_size=1024*1024*1024): + """ Calculate a md5 of lists of files. + + Args: + file_paths: an iterable object contains files. Files will be concatenated orderly if there are more than one file + chunk_size: unit is byte, default value is 1GB + Returns: + md5 + + """ + md5 = hashlib.md5() + for path in file_paths: + with open(path, 'rb') as fin: + while True: + data = fin.read(chunk_size) + if not data: + break + md5.update(data) + return md5.hexdigest() \ No newline at end of file