Skip to content

Commit

Permalink
Reorg logical flow in train (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
chengfx authored and ljshou committed May 22, 2019
1 parent 8c82e17 commit dc013c3
Show file tree
Hide file tree
Showing 8 changed files with 317 additions and 319 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -2,5 +2,7 @@
*~
*.pyc
*.cache*
*.vs*
dataset/GloVe/
dataset/20_newsgroups/
models/
9 changes: 7 additions & 2 deletions ModelConf.py
Expand Up @@ -10,6 +10,7 @@
import copy
import torch
import logging
import shutil

from losses.BaseLossConf import BaseLossConf
#import traceback
Expand Down Expand Up @@ -60,8 +61,8 @@ def load_from_file(self, conf_path):
self.tool_version = self.get_item(['tool_version'])
self.language = self.get_item(['language'], default='english').lower()
self.problem_type = self.get_item(['inputs', 'dataset_type']).lower()
if ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
self.tagging_scheme = self.get_item(['inputs', 'tagging_scheme'], default=None, use_default=True)
#if ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
self.tagging_scheme = self.get_item(['inputs', 'tagging_scheme'], default=None, use_default=True)

if self.mode == 'normal':
self.use_cache = self.get_item(['inputs', 'use_cache'], True)
Expand Down Expand Up @@ -519,3 +520,7 @@ def check_version_compat(self, nb_version, conf_version):
if not (nb_version_split[0] == conf_version_split[0] and nb_version_split[1] == conf_version_split[1]):
raise ConfigurationError('The NeuronBlocks version is %s, but the configuration version is %s, please update your configuration to %s.%s.X' % (nb_version, conf_version, nb_version_split[0], nb_version_split[1]))

def back_up(self, params):
shutil.copy(params.conf_path, self.save_base_dir)
logging.info('Configuration file is backed up to %s' % (self.save_base_dir))

61 changes: 0 additions & 61 deletions data_encoding.py

This file was deleted.

25 changes: 4 additions & 21 deletions predict.py
Expand Up @@ -14,27 +14,10 @@

def main(params):
conf = ModelConf('predict', params.conf_path, version, params, mode=params.mode)

if ProblemTypes[conf.problem_type] == ProblemTypes.sequence_tagging:
problem = Problem(conf.problem_type, conf.input_types, None,
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
target_with_start=True, target_with_end=True, target_with_unk=True, target_with_pad=True, same_length=True,
with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer,
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
elif ProblemTypes[conf.problem_type] == ProblemTypes.classification \
or ProblemTypes[conf.problem_type] == ProblemTypes.regression:
problem = Problem(conf.problem_type, conf.input_types, None,
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False, same_length=True,
with_bos_eos=conf.add_start_end_for_seq, tokenizer=conf.tokenizer, remove_stopwords=conf.remove_stopwords,
DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
elif ProblemTypes[conf.problem_type] == ProblemTypes.mrc:
problem = Problem(conf.problem_type, conf.input_types,
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False,
same_length=False, with_bos_eos=False, tokenizer=conf.tokenizer,
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)

problem = Problem('predict', conf.problem_type, conf.input_types, None,
with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer,
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)

if os.path.isfile(conf.saved_problem_path):
problem.load_problem(conf.saved_problem_path)
logging.info("Problem loaded!")
Expand Down
122 changes: 82 additions & 40 deletions problem.py
Expand Up @@ -26,10 +26,8 @@
import torch.nn as nn

class Problem():
def __init__(self, problem_type, input_types, answer_column_name=None, lowercase=False,
source_with_start=True, source_with_end=True, source_with_unk=True,
source_with_pad=True, target_with_start=False, target_with_end=False,
target_with_unk=True, target_with_pad=True, same_length=True, with_bos_eos=True,

def __init__(self, phase, problem_type, input_types, answer_column_name=None, lowercase=False, with_bos_eos=True,
tagging_scheme=None, tokenizer="nltk", remove_stopwords=False, DBC2SBC=True, unicode_fix=True):
"""
Expand All @@ -50,9 +48,24 @@ def __init__(self, problem_type, input_types, answer_column_name=None, lowercase
same_length:
with_bos_eos: whether to add bos and eos when encoding
"""
self.lowercase = lowercase

self.input_dicts = dict()
# init
source_with_start, source_with_end, source_with_unk, source_with_pad, \
target_with_start, target_with_end, target_with_unk, target_with_pad, \
same_length = (True, ) * 9
if ProblemTypes[problem_type] == ProblemTypes.sequence_tagging:
pass
elif \
ProblemTypes[problem_type] == ProblemTypes.classification or \
ProblemTypes[problem_type] == ProblemTypes.regression:
target_with_start, target_with_end, target_with_unk, target_with_pad, same_length = (False, ) * 5
if phase != 'train':
same_length = True
elif ProblemTypes[problem_type] == ProblemTypes.mrc:
target_with_start, target_with_end, target_with_unk, target_with_pad, same_length = (False, ) * 5
with_bos_eos = False

self.lowercase = lowercase
self.problem_type = problem_type
self.tagging_scheme = tagging_scheme
self.with_bos_eos = with_bos_eos
Expand All @@ -65,6 +78,7 @@ def __init__(self, problem_type, input_types, answer_column_name=None, lowercase
self.target_with_unk = target_with_unk
self.target_with_pad = target_with_pad

self.input_dicts = dict()
for input_type in input_types:
self.input_dicts[input_type] = CellDict(with_unk=source_with_unk, with_pad=source_with_pad,
with_start=source_with_start, with_end=source_with_end)
Expand Down Expand Up @@ -245,6 +259,10 @@ def build(self, data_path_list, file_columns, input_types, file_with_col_header,
Returns:
"""
# parameter check
if not word2vec_path:
word_emb_dim, format, file_type, involve_all_words = None, None, None, None

if 'bpe' in input_types:
try:
bpe_encoder = BPEEncoder(input_types['bpe']['bpe_path'])
Expand Down Expand Up @@ -324,51 +342,65 @@ def build(self, data_path_list, file_columns, input_types, file_with_col_header,

return word_emb_matrix

def encode_data_multi_processor(self, data_generator, cpu_num_workers, file_columns, input_types, object_inputs,
answer_column_name, min_sentence_len, extra_feature, max_lengths=None, fixed_lengths=None, file_format="tsv", bpe_encoder=None):
@staticmethod
def _merge_encode_data(dest_dict, src_dict):
if len(dest_dict) == 0:
dest_dict = src_dict
else:
for branch in src_dict:
for input_type in dest_dict[branch]:
dest_dict[branch][input_type].extend(src_dict[branch][input_type])
return dest_dict

@staticmethod
def _merge_encode_lengths(dest_dict, src_dict):
def judge_dict(obj):
return True if isinstance(obj, dict) else False
cnt_legal, cnt_illegal = 0, 0
output_data = dict()
lengths = dict()
target = dict()
for data in tqdm(data_generator):

if len(dest_dict) == 0:
dest_dict = src_dict
else:
for branch in src_dict:
if judge_dict(src_dict[branch]):
for type_branch in src_dict[branch]:
dest_dict[branch][type_branch].extend(src_dict[branch][type_branch])
else:
dest_dict[branch].extend(src_dict[branch])
return dest_dict

@staticmethod
def _merge_target(dest_dict, src_dict):
if not src_dict:
return src_dict

if len(dest_dict) == 0:
dest_dict = src_dict
else:
for single_type in src_dict:
dest_dict[single_type].extend(src_dict[single_type])
return dest_dict

def encode_data_multi_processor(self, data_generator, cpu_num_workers, file_columns, input_types, object_inputs,
answer_column_name, min_sentence_len, extra_feature, max_lengths=None, fixed_lengths=None, file_format="tsv", bpe_encoder=None):


for data in data_generator:
scheduler = ProcessorsScheduler(cpu_num_workers)
func_args = (data, file_columns, input_types, object_inputs,
answer_column_name, min_sentence_len, extra_feature, max_lengths, fixed_lengths, file_format, bpe_encoder)
res = scheduler.run_data_parallel(self.encode_data_list, func_args)

output_data, lengths, target = dict(), dict(), dict()
cnt_legal, cnt_illegal = 0, 0
for (index, j) in res:
# logging.info("collect proccesor %d result"%index)
tmp_data, tmp_lengths, tmp_target, tmp_cnt_legal, tmp_cnt_illegal = j.get()

if len(output_data) == 0:
output_data = tmp_data
else:
for branch in tmp_data:
for input_type in output_data[branch]:
output_data[branch][input_type].extend(tmp_data[branch][input_type])
if len(lengths) == 0:
lengths = tmp_lengths
else:
for branch in tmp_lengths:
if judge_dict(tmp_lengths[branch]):
for type_branch in tmp_lengths[branch]:
lengths[branch][type_branch].extend(tmp_lengths[branch][type_branch])
else:
lengths[branch].extend(tmp_lengths[branch])
if not tmp_target:
target = None
else:
if len(target) == 0:
target = tmp_target
else:
for single_type in tmp_target:
target[single_type].extend(tmp_target[single_type])
output_data = self._merge_encode_data(output_data, tmp_data)
lengths = self._merge_encode_lengths(lengths, tmp_lengths)
target = self._merge_target(target, tmp_target)
cnt_legal += tmp_cnt_legal
cnt_illegal += tmp_cnt_illegal

return output_data, lengths, target, cnt_legal, cnt_illegal
yield output_data, lengths, target, cnt_legal, cnt_illegal

def encode_data_list(self, data_list, file_columns, input_types, object_inputs, answer_column_name, min_sentence_len,
extra_feature, max_lengths=None, fixed_lengths=None, file_format="tsv", bpe_encoder=None):
Expand Down Expand Up @@ -678,9 +710,19 @@ def encode(self, data_path, file_columns, input_types, file_with_col_header, obj
bpe_encoder = None

progress = self.get_data_generator_from_file([data_path], file_with_col_header)
data, lengths, target, cnt_legal, cnt_illegal = self.encode_data_multi_processor(progress, cpu_num_workers,
encoder_generator = self.encode_data_multi_processor(progress, cpu_num_workers,
file_columns, input_types, object_inputs, answer_column_name, min_sentence_len, extra_feature, max_lengths,
fixed_lengths, file_format, bpe_encoder=bpe_encoder)

data, lengths, target = dict(), dict(), dict()
cnt_legal, cnt_illegal = 0, 0
for temp_data, temp_lengths, temp_target, temp_cnt_legal, temp_cnt_illegal in tqdm(encoder_generator):
data = self._merge_encode_data(data, temp_data)
lengths = self._merge_encode_lengths(lengths, temp_lengths)
target = self._merge_target(target, temp_target)
cnt_legal += temp_cnt_legal
cnt_illegal += temp_cnt_illegal

logging.info("%s: %d legal samples, %d illegal samples" % (data_path, cnt_legal, cnt_illegal))
return data, lengths, target

Expand Down
25 changes: 4 additions & 21 deletions test.py
Expand Up @@ -16,27 +16,10 @@

def main(params):
conf = ModelConf("test", params.conf_path, version, params, mode=params.mode)

if ProblemTypes[conf.problem_type] == ProblemTypes.sequence_tagging:
problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name,
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
target_with_start=True, target_with_end=True, target_with_unk=True, target_with_pad=True, same_length=True,
with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer,
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
elif ProblemTypes[conf.problem_type] == ProblemTypes.classification \
or ProblemTypes[conf.problem_type] == ProblemTypes.regression:
problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name,
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False, same_length=True,
with_bos_eos=conf.add_start_end_for_seq, tokenizer=conf.tokenizer, remove_stopwords=conf.remove_stopwords,
DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
elif ProblemTypes[conf.problem_type] == ProblemTypes.mrc:
problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name,
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False,
same_length=False, with_bos_eos=False, tokenizer=conf.tokenizer,
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)

problem = Problem("test", conf.problem_type, conf.input_types, conf.answer_column_name,
with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer,
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)

if os.path.isfile(conf.saved_problem_path):
problem.load_problem(conf.saved_problem_path)
logging.info("Problem loaded!")
Expand Down

0 comments on commit dc013c3

Please sign in to comment.