Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use lazy loading instead of one-off loading #28

Merged
merged 28 commits into from May 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
674526d
Add new config about knowledge distillation for query binary classifier
Apr 25, 2019
59d6318
remove inferenced result in knowledge distillation for query binary c…
chengfx Apr 26, 2019
b4c110e
Add AUC.py in tools folder
chengfx Apr 26, 2019
891f43a
Add test_data_path into conf_kdqbc_bilstmattn_cnn.json
chengfx Apr 28, 2019
8b1d100
Modify AUC.py
chengfx Apr 28, 2019
333bd98
Rename AUC.py into calculate_AUC.py
chengfx Apr 28, 2019
b6523a7
Merge branch 'master' into dev/fecheng
chengfx Apr 28, 2019
74976c2
Modify test&calculate AUC commands for Knowledge Distillation for Que…
chengfx Apr 28, 2019
936d9fe
Merge branch 'master' into dev/fecheng
chengfx Apr 28, 2019
8c6e61b
Add cpu_thread_num parameter in conf.training_params
chengfx Apr 29, 2019
69c0bca
Rename cpu_thread_num into cpu_num_workers
chengfx Apr 29, 2019
fb11aba
update comments in ModelConf.py
chengfx Apr 29, 2019
bbfcde2
Add cup_num_workers in model_zoo/advanced/conf.json
chengfx Apr 29, 2019
153acd3
Add the description of cpu_num_workers in Tutorial.md
chengfx Apr 29, 2019
4c9380c
fix conflict
chengfx Apr 29, 2019
2ae9d4a
Merge branch 'master' into dev/fecheng
chengfx May 6, 2019
cff4cd3
Update inference speed of compressed model
chengfx May 6, 2019
cf534ce
Add ProcessorsScheduler Class
chengfx May 7, 2019
37d09d5
Merge branch 'master' into dev/fecheng
chengfx May 7, 2019
17b8447
Add license in ProcessorScheduler.py
chengfx May 7, 2019
e087427
use lazy loading instead of one-off loading
chengfx May 8, 2019
1fb0440
merge master
chengfx May 8, 2019
05ddcf8
Remove Debug Info in problem.py
chengfx May 8, 2019
af6ea60
use open instead of codecs.open
chengfx May 9, 2019
535649e
Merge branch 'master' into dev/fecheng
chengfx May 9, 2019
fb4e47b
update the inference of build dictionary for classification
chengfx May 9, 2019
bab7f54
update typo
chengfx May 10, 2019
576b88d
update typo
chengfx May 10, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions core/CellDict.py
Expand Up @@ -66,23 +66,23 @@ def has_cell(self, cell):

return cell in self.cell_id_map

def build(self, docs, threshold, max_vocabulary_num=800000):

def update(self, docs):
for doc in docs:
# add type judge
if isinstance(doc, list):
self.cell_doc_count.update(set(doc))
else:
self.cell_doc_count.update([doc])


def build(self, threshold, max_vocabulary_num=800000):
# restrict the vocabulary size to prevent embedding dict weight size
self.cell_doc_count = Counter(dict(self.cell_doc_count.most_common(max_vocabulary_num)))
for cell in self.cell_doc_count:

if cell not in self.cell_id_map and self.cell_doc_count[cell] >= threshold:
id = len(self.cell_id_map)
self.cell_id_map[cell] = id
self.id_cell_map[id] = cell
self.cell_doc_count = Counter()

def id_unk(self, cell):
"""
Expand Down
232 changes: 122 additions & 110 deletions problem.py
Expand Up @@ -10,7 +10,6 @@
nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True)
from utils.BPEEncoder import BPEEncoder
import codecs
import os
import pickle as pkl
from utils.common_utils import load_from_pkl, dump_to_pkl
Expand Down Expand Up @@ -94,16 +93,21 @@ def output_target_num(self):
else:
return None

def get_data_list_from_file(self, fin, file_with_col_header):
data_list = list()
for index, line in enumerate(fin):
if file_with_col_header and index == 0:
continue
line = line.rstrip()
if not line:
break
data_list.append(line)
return data_list
def get_data_generator_from_file(self, file_path, file_with_col_header, chunk_size=1000000):
with open(file_path, "r", encoding='utf-8') as f:
if file_with_col_header:
f.readline()
data_list = list()
for index, line in enumerate(f):
line = line.rstrip()
if not line:
break
data_list.append(line)
if (index + 1) % chunk_size == 0:
yield data_list
data_list = list()
if len(data_list) > 0:
yield data_list

def build_training_data_list(self, training_data_list, file_columns, input_types, answer_column_name, bpe_encoder=None):
docs = dict() # docs of each type of input
Expand Down Expand Up @@ -165,33 +169,35 @@ def build_training_data_list(self, training_data_list, file_columns, input_types
pass
return docs, target_docs, cnt_legal, cnt_illegal

def build_training_multi_processor(self, training_data_list, cpu_num_workers, file_columns, input_types, answer_column_name, bpe_encoder=None):
scheduler = ProcessorsScheduler(cpu_num_workers)
func_args = (training_data_list, file_columns, input_types, answer_column_name, bpe_encoder)
res = scheduler.run_data_parallel(self.build_training_data_list, func_args)

docs = dict() # docs of each type of input
target_docs = []
cnt_legal = 0
cnt_illegal = 0
for (index, j) in res:
#logging.info("collect proccesor %d result" % index)
tmp_docs, tmp_target_docs, tmp_cnt_legal, tmp_cnt_illegal = j.get()
if len(docs) == 0:
docs = tmp_docs
else:
for key, value in tmp_docs.items():
docs[key].extend(value)
if len(target_docs) == 0:
target_docs = tmp_target_docs
else:
for single_type in tmp_target_docs:
target_docs[single_type].extend(tmp_target_docs[single_type])
# target_docs.extend(tmp_target_docs)
cnt_legal += tmp_cnt_legal
cnt_illegal += tmp_cnt_illegal
def build_training_multi_processor(self, training_data_generator, cpu_num_workers, file_columns, input_types, answer_column_name, bpe_encoder=None):
for data in training_data_generator:
# multi-Processing
scheduler = ProcessorsScheduler(cpu_num_workers)
func_args = (data, file_columns, input_types, answer_column_name, bpe_encoder)
res = scheduler.run_data_parallel(self.build_training_data_list, func_args)
# aggregate
docs = dict() # docs of each type of input
target_docs = []
cnt_legal = 0
cnt_illegal = 0
for (index, j) in res:
#logging.info("collect proccesor %d result" % index)
tmp_docs, tmp_target_docs, tmp_cnt_legal, tmp_cnt_illegal = j.get()
if len(docs) == 0:
docs = tmp_docs
else:
for key, value in tmp_docs.items():
docs[key].extend(value)
if len(target_docs) == 0:
target_docs = tmp_target_docs
else:
for single_type in tmp_target_docs:
target_docs[single_type].extend(tmp_target_docs[single_type])
# target_docs.extend(tmp_target_docs)
cnt_legal += tmp_cnt_legal
cnt_illegal += tmp_cnt_illegal

return docs, target_docs, cnt_legal, cnt_illegal
yield docs, target_docs, cnt_legal, cnt_illegal

def build(self, training_data_path, file_columns, input_types, file_with_col_header, answer_column_name, word2vec_path=None, word_emb_dim=None,
format=None, file_type=None, involve_all_words=None, file_format="tsv", show_progress=True,
Expand Down Expand Up @@ -245,38 +251,47 @@ def build(self, training_data_path, file_columns, input_types, file_with_col_hea
bpe_encoder = None

self.file_column_num = len(file_columns)
with open(training_data_path, "r", encoding='utf-8') as f:
progress = self.get_data_list_from_file(f, file_with_col_header)
docs, target_docs, cnt_legal, cnt_illegal = self.build_training_multi_processor(progress, cpu_num_workers, file_columns, input_types, answer_column_name, bpe_encoder=bpe_encoder)

logging.info("Corpus imported: %d legal lines, %d illegal lines." % (cnt_legal, cnt_illegal))

if word2vec_path and involve_all_words is True:
logging.info("Getting pre-trained embeddings...")
word_emb_dict = load_embedding(word2vec_path, word_emb_dim, format, file_type, with_head=False, word_set=None)
self.input_dicts['word'].build([list(word_emb_dict.keys())], max_vocabulary_num=len(word_emb_dict), threshold=0)
progress = self.get_data_generator_from_file(training_data_path, file_with_col_header)
preprocessed_data_generator= self.build_training_multi_processor(progress, cpu_num_workers, file_columns, input_types, answer_column_name, bpe_encoder=bpe_encoder)

# update symbol universe
total_cnt_legal, total_cnt_illegal = 0, 0
for docs, target_docs, cnt_legal, cnt_illegal in tqdm(preprocessed_data_generator):
total_cnt_legal += cnt_legal
total_cnt_illegal += cnt_illegal

# input_type
for input_type in input_types:
self.input_dicts[input_type].update(docs[input_type])

# problem_type
if ProblemTypes[self.problem_type] == ProblemTypes.classification or \
ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
self.output_dict.update(list(target_docs.values())[0])
elif ProblemTypes[self.problem_type] == ProblemTypes.regression or \
ProblemTypes[self.problem_type] == ProblemTypes.mrc:
pass
logging.info("Corpus imported: %d legal lines, %d illegal lines." % (total_cnt_legal, total_cnt_illegal))

# build dictionary
for input_type in input_types:
if input_type != 'word':
self.input_dicts[input_type].build(docs[input_type], max_vocabulary_num=max_vocabulary, threshold=word_frequency)
else:
self.input_dicts[input_type].build(docs[input_type], max_vocabulary_num=max_vocabulary, threshold=word_frequency)
logging.info("%d types in %s" % (self.input_dicts[input_type].cell_num(), input_type))
if ProblemTypes[self.problem_type] == ProblemTypes.classification:
self.output_dict.build(list(target_docs.values())[0], threshold=0)
elif ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
self.output_dict.build(list(target_docs.values())[0], threshold=0)
elif ProblemTypes[self.problem_type] == ProblemTypes.regression or \
ProblemTypes[self.problem_type] == ProblemTypes.mrc:
pass

self.input_dicts[input_type].build(threshold=word_frequency, max_vocabulary_num=max_vocabulary)
logging.info("%d types in %s column" % (self.input_dicts[input_type].cell_num(), input_type))
if self.output_dict:
logging.info("%d types in target" % (self.output_dict.cell_num()))

logging.debug("Cell dict built")
self.output_dict.build(threshold=0)
logging.info("%d types in target column" % (self.output_dict.cell_num()))
logging.debug("training data dict built")

# embedding
word_emb_matrix = None
if word2vec_path:
if not involve_all_words:
logging.info("Getting pre-trained embeddings...")
logging.info("Getting pre-trained embeddings...")
word_emb_dict = None
if involve_all_words is True:
word_emb_dict = load_embedding(word2vec_path, word_emb_dim, format, file_type, with_head=False, word_set=None)
self.input_dicts['word'].update([list(word_emb_dict.keys())])
self.input_dicts['word'].build(threshold=0, max_vocabulary_num=len(word_emb_dict))
else:
word_emb_dict = load_embedding(word2vec_path, word_emb_dim, format, file_type, with_head=False, word_set=self.input_dicts['word'].cell_id_map.keys())

for word in word_emb_dict:
Expand All @@ -285,6 +300,7 @@ def build(self, training_data_path, file_columns, input_types, file_with_col_hea

assert loaded_emb_dim == word_emb_dim, "The dimension of defined word embedding is inconsistent with the pretrained embedding provided!"

logging.info("constructing embedding table")
if self.input_dicts['word'].with_unk:
word_emb_dict['<unk>'] = np.random.random(size=word_emb_dim)
if self.input_dicts['word'].with_pad:
Expand All @@ -302,57 +318,54 @@ def build(self, training_data_path, file_columns, input_types, file_with_col_hea
logging.info("word embedding matrix shape:(%d, %d); unknown word count: %d;" %
(len(word_emb_matrix), len(word_emb_matrix[0]), unknown_word_count))
logging.info("Word embedding loaded")
else:
word_emb_matrix = None

return word_emb_matrix

def encode_data_multi_processor(self, data_list, cpu_num_workers, file_columns, input_types, object_inputs,
def encode_data_multi_processor(self, data_generator, cpu_num_workers, file_columns, input_types, object_inputs,
answer_column_name, min_sentence_len, extra_feature, max_lengths=None, fixed_lengths=None, file_format="tsv", bpe_encoder=None):
def judge_dict(obj):
return True if isinstance(obj, dict) else False

scheduler = ProcessorsScheduler(cpu_num_workers)
func_args = (data_list, file_columns, input_types, object_inputs,
answer_column_name, min_sentence_len, extra_feature, max_lengths, fixed_lengths, file_format, bpe_encoder)
res = scheduler.run_data_parallel(self.encode_data_list, func_args)

data = dict()
cnt_legal, cnt_illegal = 0, 0
output_data = dict()
lengths = dict()
target = dict()
cnt_legal = 0
cnt_illegal = 0

for (index, j) in res:
# logging.info("collect proccesor %d result"%index)
tmp_data, tmp_lengths, tmp_target, tmp_cnt_legal, tmp_cnt_illegal = j.get()

if len(data) == 0:
data = tmp_data
else:
for branch in tmp_data:
for input_type in data[branch]:
data[branch][input_type].extend(tmp_data[branch][input_type])
if len(lengths) == 0:
lengths = tmp_lengths
else:
for branch in tmp_lengths:
if judge_dict(tmp_lengths[branch]):
for type_branch in tmp_lengths[branch]:
lengths[branch][type_branch].extend(tmp_lengths[branch][type_branch])
else:
lengths[branch].extend(tmp_lengths[branch])
if not tmp_target:
target = None
else:
if len(target) == 0:
target = tmp_target
for data in tqdm(data_generator):
scheduler = ProcessorsScheduler(cpu_num_workers)
func_args = (data, file_columns, input_types, object_inputs,
answer_column_name, min_sentence_len, extra_feature, max_lengths, fixed_lengths, file_format, bpe_encoder)
res = scheduler.run_data_parallel(self.encode_data_list, func_args)

for (index, j) in res:
# logging.info("collect proccesor %d result"%index)
tmp_data, tmp_lengths, tmp_target, tmp_cnt_legal, tmp_cnt_illegal = j.get()

if len(output_data) == 0:
output_data = tmp_data
else:
for single_type in tmp_target:
target[single_type].extend(tmp_target[single_type])
cnt_legal += tmp_cnt_legal
cnt_illegal += tmp_cnt_illegal
for branch in tmp_data:
for input_type in output_data[branch]:
output_data[branch][input_type].extend(tmp_data[branch][input_type])
if len(lengths) == 0:
lengths = tmp_lengths
else:
for branch in tmp_lengths:
if judge_dict(tmp_lengths[branch]):
for type_branch in tmp_lengths[branch]:
lengths[branch][type_branch].extend(tmp_lengths[branch][type_branch])
else:
lengths[branch].extend(tmp_lengths[branch])
if not tmp_target:
target = None
else:
if len(target) == 0:
target = tmp_target
else:
for single_type in tmp_target:
target[single_type].extend(tmp_target[single_type])
cnt_legal += tmp_cnt_legal
cnt_illegal += tmp_cnt_illegal

return data, lengths, target, cnt_legal, cnt_illegal
return output_data, lengths, target, cnt_legal, cnt_illegal

def encode_data_list(self, data_list, file_columns, input_types, object_inputs, answer_column_name, min_sentence_len,
extra_feature, max_lengths=None, fixed_lengths=None, file_format="tsv", bpe_encoder=None):
Expand Down Expand Up @@ -661,9 +674,8 @@ def encode(self, data_path, file_columns, input_types, file_with_col_header, obj
else:
bpe_encoder = None

with open(data_path, 'r', encoding='utf-8') as fin:
progress = self.get_data_list_from_file(fin, file_with_col_header)
data, lengths, target, cnt_legal, cnt_illegal = self.encode_data_multi_processor(progress, cpu_num_workers,
progress = self.get_data_generator_from_file(data_path, file_with_col_header)
data, lengths, target, cnt_legal, cnt_illegal = self.encode_data_multi_processor(progress, cpu_num_workers,
file_columns, input_types, object_inputs, answer_column_name, min_sentence_len, extra_feature, max_lengths,
fixed_lengths, file_format, bpe_encoder=bpe_encoder)
logging.info("%s: %d legal samples, %d illegal samples" % (data_path, cnt_legal, cnt_illegal))
Expand Down
3 changes: 1 addition & 2 deletions tools/calculate_AUC.py
Expand Up @@ -2,14 +2,13 @@
# Licensed under the MIT license.

import argparse
import codecs
from sklearn.metrics import roc_auc_score

def read_tsv(params):
prediction, label = [], []
predict_index, label_index = int(params.predict_index), int(params.label_index)
min_column_num = max(predict_index, label_index) + 1
with codecs.open(params.input_file, mode='r', encoding='utf-8') as f:
with open(params.input_file, mode='r', encoding='utf-8') as f:
for index, line in enumerate(f):
if params.header and index == 0:
continue
Expand Down