Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reorg logical flow in train #37

Merged
merged 38 commits into from
May 22, 2019
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
674526d
Add new config about knowledge distillation for query binary classifier
Apr 25, 2019
59d6318
remove inferenced result in knowledge distillation for query binary c…
chengfx Apr 26, 2019
b4c110e
Add AUC.py in tools folder
chengfx Apr 26, 2019
891f43a
Add test_data_path into conf_kdqbc_bilstmattn_cnn.json
chengfx Apr 28, 2019
8b1d100
Modify AUC.py
chengfx Apr 28, 2019
333bd98
Rename AUC.py into calculate_AUC.py
chengfx Apr 28, 2019
b6523a7
Merge branch 'master' into dev/fecheng
chengfx Apr 28, 2019
74976c2
Modify test&calculate AUC commands for Knowledge Distillation for Que…
chengfx Apr 28, 2019
936d9fe
Merge branch 'master' into dev/fecheng
chengfx Apr 28, 2019
8c6e61b
Add cpu_thread_num parameter in conf.training_params
chengfx Apr 29, 2019
69c0bca
Rename cpu_thread_num into cpu_num_workers
chengfx Apr 29, 2019
fb11aba
update comments in ModelConf.py
chengfx Apr 29, 2019
bbfcde2
Add cup_num_workers in model_zoo/advanced/conf.json
chengfx Apr 29, 2019
153acd3
Add the description of cpu_num_workers in Tutorial.md
chengfx Apr 29, 2019
4c9380c
fix conflict
chengfx Apr 29, 2019
2ae9d4a
Merge branch 'master' into dev/fecheng
chengfx May 6, 2019
cff4cd3
Update inference speed of compressed model
chengfx May 6, 2019
cf534ce
Add ProcessorsScheduler Class
chengfx May 7, 2019
37d09d5
Merge branch 'master' into dev/fecheng
chengfx May 7, 2019
17b8447
Add license in ProcessorScheduler.py
chengfx May 7, 2019
e087427
use lazy loading instead of one-off loading
chengfx May 8, 2019
1fb0440
merge master
chengfx May 8, 2019
05ddcf8
Remove Debug Info in problem.py
chengfx May 8, 2019
af6ea60
use open instead of codecs.open
chengfx May 9, 2019
535649e
Merge branch 'master' into dev/fecheng
chengfx May 9, 2019
fb4e47b
update the inference of build dictionary for classification
chengfx May 9, 2019
a3a0c25
add md5 function in common_utils.py
chengfx May 9, 2019
889aa91
add merge_encode_* function
chengfx May 10, 2019
bab7f54
update typo
chengfx May 10, 2019
576b88d
update typo
chengfx May 10, 2019
91440a5
reorg the logical flow in train.py
chengfx May 11, 2019
5a747e6
Merge branch 'add_encoding_cache' into dev/fecheng
chengfx May 11, 2019
229622b
merge master
chengfx May 11, 2019
49a32fe
remove dummy comments in problem.py
chengfx May 11, 2019
627a80f
enumerate problem types in problem.py
chengfx May 15, 2019
fbf780d
remove data_encoding.py
chengfx May 16, 2019
c735b45
Modify comment and remove debug code
chengfx May 18, 2019
d6566d4
merge master
chengfx May 22, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
*.pyc
*.cache*
dataset/GloVe/
dataset/20_newsgroups/
models/
9 changes: 7 additions & 2 deletions ModelConf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import copy
import torch
import logging
import shutil

from losses.BaseLossConf import BaseLossConf
#import traceback
Expand Down Expand Up @@ -60,8 +61,8 @@ def load_from_file(self, conf_path):
self.tool_version = self.get_item(['tool_version'])
self.language = self.get_item(['language'], default='english').lower()
self.problem_type = self.get_item(['inputs', 'dataset_type']).lower()
if ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
self.tagging_scheme = self.get_item(['inputs', 'tagging_scheme'], default=None, use_default=True)
#if ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
self.tagging_scheme = self.get_item(['inputs', 'tagging_scheme'], default=None, use_default=True)

if self.mode == 'normal':
self.use_cache = self.get_item(['inputs', 'use_cache'], True)
Expand Down Expand Up @@ -519,3 +520,7 @@ def check_version_compat(self, nb_version, conf_version):
if not (nb_version_split[0] == conf_version_split[0] and nb_version_split[1] == conf_version_split[1]):
raise ConfigurationError('The NeuronBlocks version is %s, but the configuration version is %s, please update your configuration to %s.%s.X' % (nb_version, conf_version, nb_version_split[0], nb_version_split[1]))

def back_up(self, params):
shutil.copy(params.conf_path, self.save_base_dir)
logging.info('Configuration file is backed up to %s' % (self.save_base_dir))

25 changes: 4 additions & 21 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,10 @@

def main(params):
conf = ModelConf('predict', params.conf_path, version, params, mode=params.mode)

if ProblemTypes[conf.problem_type] == ProblemTypes.sequence_tagging:
problem = Problem(conf.problem_type, conf.input_types, None,
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
target_with_start=True, target_with_end=True, target_with_unk=True, target_with_pad=True, same_length=True,
with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer,
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
elif ProblemTypes[conf.problem_type] == ProblemTypes.classification \
or ProblemTypes[conf.problem_type] == ProblemTypes.regression:
problem = Problem(conf.problem_type, conf.input_types, None,
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False, same_length=True,
with_bos_eos=conf.add_start_end_for_seq, tokenizer=conf.tokenizer, remove_stopwords=conf.remove_stopwords,
DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
elif ProblemTypes[conf.problem_type] == ProblemTypes.mrc:
problem = Problem(conf.problem_type, conf.input_types,
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False,
same_length=False, with_bos_eos=False, tokenizer=conf.tokenizer,
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)

problem = Problem('predict', conf.problem_type, conf.input_types, None,
with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer,
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)

if os.path.isfile(conf.saved_problem_path):
problem.load_problem(conf.saved_problem_path)
logging.info("Problem loaded!")
Expand Down
117 changes: 77 additions & 40 deletions problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@
import torch.nn as nn

class Problem():
def __init__(self, problem_type, input_types, answer_column_name=None, lowercase=False,
source_with_start=True, source_with_end=True, source_with_unk=True,
source_with_pad=True, target_with_start=False, target_with_end=False,
target_with_unk=True, target_with_pad=True, same_length=True, with_bos_eos=True,

def __init__(self, phase, problem_type, input_types, answer_column_name=None, lowercase=False, with_bos_eos=True,
tagging_scheme=None, tokenizer="nltk", remove_stopwords=False, DBC2SBC=True, unicode_fix=True):
"""

Expand All @@ -50,9 +48,19 @@ def __init__(self, problem_type, input_types, answer_column_name=None, lowercase
same_length:
with_bos_eos: whether to add bos and eos when encoding
"""
self.lowercase = lowercase

self.input_dicts = dict()
# init
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

readability is not good, I think. Strongly suggest enumerate every task.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @woailaosang Readability is just special for train logical flow here. I made less works on other modules. Yeah, enumerate every task is a good idea . But I think we need to optimize the code to reduce the repeated codes and logic. Otherwise you have to modify every place if you want to make some changes

source_with_start, source_with_end, source_with_unk, source_with_pad, \
target_with_start, target_with_end, target_with_unk, target_with_pad, \
same_length = (True, ) * 9
if ProblemTypes[problem_type] != ProblemTypes.sequence_tagging:
target_with_start, target_with_end, target_with_unk, target_with_pad, same_length = (False, ) * 5
if phase != 'train':
same_length = True
if ProblemTypes[problem_type] == ProblemTypes.mrc:
with_bos_eos = False

self.lowercase = lowercase
self.problem_type = problem_type
self.tagging_scheme = tagging_scheme
self.with_bos_eos = with_bos_eos
Expand All @@ -65,6 +73,7 @@ def __init__(self, problem_type, input_types, answer_column_name=None, lowercase
self.target_with_unk = target_with_unk
self.target_with_pad = target_with_pad

self.input_dicts = dict()
for input_type in input_types:
self.input_dicts[input_type] = CellDict(with_unk=source_with_unk, with_pad=source_with_pad,
with_start=source_with_start, with_end=source_with_end)
Expand Down Expand Up @@ -242,6 +251,10 @@ def build(self, training_data_path, file_columns, input_types, file_with_col_hea
Returns:

"""
# parameter check
if not word2vec_path:
word_emb_dim, format, file_type, involve_all_words = None, None, None, None

if 'bpe' in input_types:
try:
bpe_encoder = BPEEncoder(input_types['bpe']['bpe_path'])
Expand Down Expand Up @@ -321,51 +334,65 @@ def build(self, training_data_path, file_columns, input_types, file_with_col_hea

return word_emb_matrix

def encode_data_multi_processor(self, data_generator, cpu_num_workers, file_columns, input_types, object_inputs,
answer_column_name, min_sentence_len, extra_feature, max_lengths=None, fixed_lengths=None, file_format="tsv", bpe_encoder=None):
@staticmethod
def _merge_encode_data(dest_dict, src_dict):
if len(dest_dict) == 0:
dest_dict = src_dict
else:
for branch in src_dict:
for input_type in dest_dict[branch]:
dest_dict[branch][input_type].extend(src_dict[branch][input_type])
return dest_dict

@staticmethod
def _merge_encode_lengths(dest_dict, src_dict):
def judge_dict(obj):
return True if isinstance(obj, dict) else False
cnt_legal, cnt_illegal = 0, 0
output_data = dict()
lengths = dict()
target = dict()
for data in tqdm(data_generator):

if len(dest_dict) == 0:
dest_dict = src_dict
else:
for branch in src_dict:
if judge_dict(src_dict[branch]):
for type_branch in src_dict[branch]:
dest_dict[branch][type_branch].extend(src_dict[branch][type_branch])
else:
dest_dict[branch].extend(src_dict[branch])
return dest_dict

@staticmethod
def _merge_target(dest_dict, src_dict):
if not src_dict:
return src_dict

if len(dest_dict) == 0:
dest_dict = src_dict
else:
for single_type in src_dict:
dest_dict[single_type].extend(src_dict[single_type])
return dest_dict

def encode_data_multi_processor(self, data_generator, cpu_num_workers, file_columns, input_types, object_inputs,
answer_column_name, min_sentence_len, extra_feature, max_lengths=None, fixed_lengths=None, file_format="tsv", bpe_encoder=None):


for data in data_generator:
scheduler = ProcessorsScheduler(cpu_num_workers)
func_args = (data, file_columns, input_types, object_inputs,
answer_column_name, min_sentence_len, extra_feature, max_lengths, fixed_lengths, file_format, bpe_encoder)
res = scheduler.run_data_parallel(self.encode_data_list, func_args)

output_data, lengths, target = dict(), dict(), dict()
cnt_legal, cnt_illegal = 0, 0
for (index, j) in res:
# logging.info("collect proccesor %d result"%index)
tmp_data, tmp_lengths, tmp_target, tmp_cnt_legal, tmp_cnt_illegal = j.get()

if len(output_data) == 0:
output_data = tmp_data
else:
for branch in tmp_data:
for input_type in output_data[branch]:
output_data[branch][input_type].extend(tmp_data[branch][input_type])
if len(lengths) == 0:
lengths = tmp_lengths
else:
for branch in tmp_lengths:
if judge_dict(tmp_lengths[branch]):
for type_branch in tmp_lengths[branch]:
lengths[branch][type_branch].extend(tmp_lengths[branch][type_branch])
else:
lengths[branch].extend(tmp_lengths[branch])
if not tmp_target:
target = None
else:
if len(target) == 0:
target = tmp_target
else:
for single_type in tmp_target:
target[single_type].extend(tmp_target[single_type])
output_data = self._merge_encode_data(output_data, tmp_data)
lengths = self._merge_encode_lengths(lengths, tmp_lengths)
target = self._merge_target(target, tmp_target)
cnt_legal += tmp_cnt_legal
cnt_illegal += tmp_cnt_illegal

return output_data, lengths, target, cnt_legal, cnt_illegal
yield output_data, lengths, target, cnt_legal, cnt_illegal

def encode_data_list(self, data_list, file_columns, input_types, object_inputs, answer_column_name, min_sentence_len,
extra_feature, max_lengths=None, fixed_lengths=None, file_format="tsv", bpe_encoder=None):
Expand Down Expand Up @@ -675,9 +702,19 @@ def encode(self, data_path, file_columns, input_types, file_with_col_header, obj
bpe_encoder = None

progress = self.get_data_generator_from_file(data_path, file_with_col_header)
data, lengths, target, cnt_legal, cnt_illegal = self.encode_data_multi_processor(progress, cpu_num_workers,
encoder_generator = self.encode_data_multi_processor(progress, cpu_num_workers,
file_columns, input_types, object_inputs, answer_column_name, min_sentence_len, extra_feature, max_lengths,
fixed_lengths, file_format, bpe_encoder=bpe_encoder)

data, lengths, target = dict(), dict(), dict()
cnt_legal, cnt_illegal = 0, 0
for temp_data, temp_lengths, temp_target, temp_cnt_legal, temp_cnt_illegal in tqdm(encoder_generator):
data = self._merge_encode_data(data, temp_data)
lengths = self._merge_encode_lengths(lengths, temp_lengths)
target = self._merge_target(target, temp_target)
cnt_legal += temp_cnt_legal
cnt_illegal += temp_cnt_illegal

logging.info("%s: %d legal samples, %d illegal samples" % (data_path, cnt_legal, cnt_illegal))
return data, lengths, target

Expand Down
25 changes: 4 additions & 21 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,10 @@

def main(params):
conf = ModelConf("test", params.conf_path, version, params, mode=params.mode)

if ProblemTypes[conf.problem_type] == ProblemTypes.sequence_tagging:
problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name,
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
target_with_start=True, target_with_end=True, target_with_unk=True, target_with_pad=True, same_length=True,
with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer,
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
elif ProblemTypes[conf.problem_type] == ProblemTypes.classification \
or ProblemTypes[conf.problem_type] == ProblemTypes.regression:
problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name,
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False, same_length=True,
with_bos_eos=conf.add_start_end_for_seq, tokenizer=conf.tokenizer, remove_stopwords=conf.remove_stopwords,
DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
elif ProblemTypes[conf.problem_type] == ProblemTypes.mrc:
problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name,
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False,
same_length=False, with_bos_eos=False, tokenizer=conf.tokenizer,
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)

problem = Problem("test", conf.problem_type, conf.input_types, conf.answer_column_name,
with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer,
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)

if os.path.isfile(conf.saved_problem_path):
problem.load_problem(conf.saved_problem_path)
logging.info("Problem loaded!")
Expand Down
Loading