Skip to content

Commit

Permalink
Updates from master
Browse files Browse the repository at this point in the history
  • Loading branch information
lvapeab committed Mar 24, 2021
1 parent 8e58c1d commit f72ca0a
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 128 deletions.
137 changes: 67 additions & 70 deletions data_engine/prepare_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from keras_wrapper.dataset import Dataset, saveDataset, loadDataset

logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(message)s', datefmt='%d/%m/%Y %H:%M:%S')
Expand Down Expand Up @@ -28,9 +29,6 @@ def update_dataset_from_file(ds,
:return: Dataset object with the processed data
"""

logging.info("<<< Updating Dataset instance " + ds.name + " ... >>>")

if splits is None:
splits = ['val']

Expand Down Expand Up @@ -80,6 +78,7 @@ def update_dataset_from_file(ds,
min_occ=params.get('MIN_OCCURRENCES_INPUT_VOCAB', 0),
bpe_codes=params.get('BPE_CODES_PATH', None),
overwrite_split=True)

if compute_state_below and output_text_filename is not None:
# INPUT DATA
ds.setInput(output_text_filename,
Expand Down Expand Up @@ -113,7 +112,7 @@ def update_dataset_from_file(ds,

# If we had multiple references per sentence
if recompute_references:
keep_n_captions(ds, repeat=1, n=1, set_names=params['EVAL_ON_SETS'])
prepare_references(ds, repeat=1, n=1, set_names=params['EVAL_ON_SETS'])

return ds

Expand All @@ -138,9 +137,10 @@ def build_dataset(params):

# OUTPUT DATA
# Load the train, val and test splits of the target language sentences (outputs). The files include a sentence per line.
ds.setOutput(base_path + '/' + params['TEXT_FILES']['train'] + params['TRG_LAN'],
ds.setOutput(os.path.join(base_path, params['TEXT_FILES']['train'] + params['TRG_LAN']),
'train',
type=params.get('OUTPUTS_TYPES_DATASET', ['dense-text'] if 'sparse' in params['LOSS'] else ['text'])[0],
type=params.get('OUTPUTS_TYPES_DATASET',
['dense-text'] if 'sparse' in params['LOSS'] else ['text'])[0],
id=params['OUTPUTS_IDS_DATASET'][0],
tokenization=params.get('TOKENIZATION_METHOD', 'tokenize_none'),
build_vocabulary=True,
Expand All @@ -152,15 +152,10 @@ def build_dataset(params):
min_occ=params.get('MIN_OCCURRENCES_OUTPUT_VOCAB', 0),
bpe_codes=params.get('BPE_CODES_PATH', None),
label_smoothing=params.get('LABEL_SMOOTHING', 0.))
if params.get('ALIGN_FROM_RAW', True) and not params.get('HOMOGENEOUS_BATCHES', False):
ds.setRawOutput(base_path + '/' + params['TEXT_FILES']['train'] + params['TRG_LAN'],
'train',
type='file-name',
id='raw_' + params['OUTPUTS_IDS_DATASET'][0])

for split in ['val', 'test']:
if params['TEXT_FILES'].get(split) is not None:
ds.setOutput(base_path + '/' + params['TEXT_FILES'][split] + params['TRG_LAN'],
ds.setOutput(os.path.join(base_path, params['TEXT_FILES'][split] + params['TRG_LAN']),
split,
type='text', # The type of the references should be always 'text'
id=params['OUTPUTS_IDS_DATASET'][0],
Expand All @@ -171,82 +166,80 @@ def build_dataset(params):
max_words=params.get('OUTPUT_VOCABULARY_SIZE', 0),
bpe_codes=params.get('BPE_CODES_PATH', None),
label_smoothing=0.)
if params.get('ALIGN_FROM_RAW', True) and not params.get('HOMOGENEOUS_BATCHES', False):
ds.setRawOutput(base_path + '/' + params['TEXT_FILES'][split] + params['TRG_LAN'],
split,
type='file-name',
id='raw_' + params['OUTPUTS_IDS_DATASET'][0])

# INPUT DATA
# We must ensure that the 'train' split is the first (for building the vocabulary)
for split in ['train', 'val', 'test']:
if params['TEXT_FILES'].get(split) is not None:
if split == 'train':
build_vocabulary = True
for split in params['TEXT_FILES']:
build_vocabulary = split == 'train'
ds.setInput(os.path.join(base_path, params['TEXT_FILES'][split] + params['SRC_LAN']),
split,
type=params.get('INPUTS_TYPES_DATASET', ['text', 'text'])[0],
id=params['INPUTS_IDS_DATASET'][0],
pad_on_batch=params.get('PAD_ON_BATCH', True),
tokenization=params.get('TOKENIZATION_METHOD', 'tokenize_none'),
build_vocabulary=build_vocabulary,
fill=params.get('FILL', 'end'),
max_text_len=params.get('MAX_INPUT_TEXT_LEN', 70),
max_words=params.get('INPUT_VOCABULARY_SIZE', 0),
min_occ=params.get('MIN_OCCURRENCES_INPUT_VOCAB', 0),
bpe_codes=params.get('BPE_CODES_PATH', None))

if len(params['INPUTS_IDS_DATASET']) > 1:
if 'train' in split:
ds.setInput(os.path.join(base_path, params['TEXT_FILES'][split] + params['TRG_LAN']),
split,
type=params.get('INPUTS_TYPES_DATASET', ['text', 'text'])[1],
id=params['INPUTS_IDS_DATASET'][1],
required=False,
tokenization=params.get('TOKENIZATION_METHOD', 'tokenize_none'),
pad_on_batch=params.get('PAD_ON_BATCH', True),
build_vocabulary=params['OUTPUTS_IDS_DATASET'][0],
offset=1,
fill=params.get('FILL', 'end'),
max_text_len=params.get('MAX_OUTPUT_TEXT_LEN', 70),
max_words=params.get('OUTPUT_VOCABULARY_SIZE', 0),
bpe_codes=params.get('BPE_CODES_PATH', None))
if params.get('TIE_EMBEDDINGS', False):
ds.merge_vocabularies([params['INPUTS_IDS_DATASET'][1], params['INPUTS_IDS_DATASET'][0]])
else:
build_vocabulary = False
ds.setInput(base_path + '/' + params['TEXT_FILES'][split] + params['SRC_LAN'],
split,
type=params.get('INPUTS_TYPES_DATASET', ['text', 'text'])[0],
id=params['INPUTS_IDS_DATASET'][0],
pad_on_batch=params.get('PAD_ON_BATCH', True),
tokenization=params.get('TOKENIZATION_METHOD', 'tokenize_none'),
build_vocabulary=build_vocabulary,
fill=params.get('FILL', 'end'),
max_text_len=params.get('MAX_INPUT_TEXT_LEN', 70),
max_words=params.get('INPUT_VOCABULARY_SIZE', 0),
min_occ=params.get('MIN_OCCURRENCES_INPUT_VOCAB', 0),
bpe_codes=params.get('BPE_CODES_PATH', None))

if len(params['INPUTS_IDS_DATASET']) > 1:
if 'train' in split:
ds.setInput(base_path + '/' + params['TEXT_FILES'][split] + params['TRG_LAN'],
split,
type=params.get('INPUTS_TYPES_DATASET', ['text', 'text'])[1],
id=params['INPUTS_IDS_DATASET'][1],
required=False,
tokenization=params.get('TOKENIZATION_METHOD', 'tokenize_none'),
pad_on_batch=params.get('PAD_ON_BATCH', True),
build_vocabulary=params['OUTPUTS_IDS_DATASET'][0],
offset=1,
fill=params.get('FILL', 'end'),
max_text_len=params.get('MAX_OUTPUT_TEXT_LEN', 70),
max_words=params.get('OUTPUT_VOCABULARY_SIZE', 0),
bpe_codes=params.get('BPE_CODES_PATH', None))
if params.get('TIE_EMBEDDINGS', False):
ds.merge_vocabularies([params['INPUTS_IDS_DATASET'][1], params['INPUTS_IDS_DATASET'][0]])
else:
ds.setInput(None,
split,
type='ghost',
id=params['INPUTS_IDS_DATASET'][-1],
required=False)
if params.get('ALIGN_FROM_RAW', True) and not params.get('HOMOGENEOUS_BATCHES', False):
ds.setRawInput(base_path + '/' + params['TEXT_FILES'][split] + params['SRC_LAN'],
split,
type='file-name',
id='raw_' + params['INPUTS_IDS_DATASET'][0])
ds.setInput(None,
split,
type='ghost',
id=params['INPUTS_IDS_DATASET'][-1],
required=False)
if params.get('ALIGN_FROM_RAW', True) and not params.get('HOMOGENEOUS_BATCHES', False):
ds.setRawInput(os.path.join(base_path, params['TEXT_FILES'][split] + params['SRC_LAN']),
split,
type='file-name',
id='raw_' + params['INPUTS_IDS_DATASET'][0])
if params.get('POS_UNK', False):
if params.get('HEURISTIC', 0) > 0:
ds.loadMapping(params['MAPPING'])

# If we had multiple references per sentence
keep_n_captions(ds, repeat=1, n=1, set_names=params['EVAL_ON_SETS'])
# Prepare references
prepare_references(ds,
repeat=1,
n=1,
set_names=params['EVAL_ON_SETS'])

# We have finished loading the dataset, now we can store it for using it in the future
saveDataset(ds, params['DATASET_STORE_PATH'])

else:
# We can easily recover it with a single line
ds = loadDataset(params['DATASET_STORE_PATH'] + '/Dataset_' + params['DATASET_NAME'] + '_' + params['SRC_LAN'] + params['TRG_LAN'] + '.pkl')
ds = loadDataset(
os.path.join(params['DATASET_STORE_PATH'],
'Dataset_' + params['DATASET_NAME'] + '_' + params['SRC_LAN'] + params['TRG_LAN'] + '.pkl'))

# If we had multiple references per sentence
keep_n_captions(ds, repeat=1, n=1, set_names=params['EVAL_ON_SETS'])
# Prepare references
prepare_references(ds,
repeat=1,
n=1,
set_names=params['EVAL_ON_SETS'])

return ds


def keep_n_captions(ds, repeat, n=1, set_names=None):
def prepare_references(ds, repeat, n=1, set_names=None):
"""
Keeps only n captions per image and stores the rest in dictionaries for a later evaluation
:param ds: Dataset object
Expand Down Expand Up @@ -310,3 +303,7 @@ def keep_n_captions(ds, repeat, n=1, set_names=None):
setattr(ds, 'len_' + s, new_len)

logger.info('Samples reduced to ' + str(new_len) + ' in ' + s + ' set.')


# Backwards compatibility:
keep_n_captions = prepare_references
4 changes: 2 additions & 2 deletions data_engine/rebuild_dataset_from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import argparse
import logging
import ast
from prepare_data import build_dataset
from keras_wrapper.extra.read_write import pkl2dict

logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(message)s', datefmt='%d/%m/%Y %H:%M:%S')
logger = logging.getLogger(__name__)
Expand All @@ -27,7 +29,6 @@ def parse_args():
params = load_parameters()
else:
logger.info("Loading parameters from %s" % str(args.config))
from keras_wrapper.extra.read_write import pkl2dict
params = pkl2dict(args.config)
try:
for arg in args.changes:
Expand All @@ -44,5 +45,4 @@ def parse_args():
print ('Error processing arguments: (', k, ",", v, ")")
exit(2)
params['REBUILD_DATASET'] = True
from prepare_data import build_dataset
dataset = build_dataset(params)
55 changes: 29 additions & 26 deletions nmt_keras/apply_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,41 +99,37 @@ def sample_ensemble(args, params):
for s in args.splits:
# Apply model predictions
params_prediction['predict_on_sets'] = [s]
beam_searcher = BeamSearchEnsemble(models, dataset, params_prediction,
beam_searcher = BeamSearchEnsemble(models,
dataset,
params_prediction,
model_weights=model_weights,
n_best=args.n_best,
verbose=args.verbose)
if args.n_best:
predictions, n_best = beam_searcher.predictBeamSearchNet()[s]
else:
predictions = beam_searcher.predictBeamSearchNet()[s]
n_best = None
predictions = beam_searcher.predictBeamSearchNet()[s]
samples = predictions['samples']
alphas = predictions['alphas'] if params_prediction['pos_unk'] else None

if params_prediction['pos_unk']:
samples = predictions[0]
alphas = predictions[1]
sources = [x.strip() for x in open(args.text, 'r').read().split('\n')]
sources = sources[:-1] if len(sources[-1]) == 0 else sources
else:
samples = predictions
alphas = None
heuristic = None
sources = None

predictions = decode_predictions_beam_search(samples,
index2word_y,
glossary=glossary,
alphas=alphas,
x_text=sources,
heuristic=heuristic,
mapping=mapping,
verbose=args.verbose)
decoded_predictions = decode_predictions_beam_search(samples,
index2word_y,
glossary=glossary,
alphas=alphas,
x_text=sources,
heuristic=heuristic,
mapping=mapping,
verbose=args.verbose)
# Apply detokenization function if needed
if params.get('APPLY_DETOKENIZATION', False):
predictions = list(map(detokenize_function, predictions))
decoded_predictions = list(map(detokenize_function, decoded_predictions))

if args.n_best:
n_best_predictions = []
for i, (n_best_preds, n_best_scores, n_best_alphas) in enumerate(n_best):
for i, (n_best_preds, n_best_scores, n_best_alphas) in enumerate(predictions['n_best']):
n_best_sample_score = []
for n_best_pred, n_best_score, n_best_alpha in zip(n_best_preds, n_best_scores, n_best_alphas):
pred = decode_predictions_beam_search([n_best_pred],
Expand All @@ -155,13 +151,13 @@ def sample_ensemble(args, params):
if args.dest is not None:
filepath = args.dest # results file
if params.get('SAMPLING_SAVE_MODE', 'list'):
list2file(filepath, predictions)
list2file(filepath, decoded_predictions)
if args.n_best:
nbest2file(filepath + '.nbest', n_best_predictions)
else:
raise Exception('Only "list" is allowed in "SAMPLING_SAVE_MODE"')
else:
list2stdout(predictions)
list2stdout(decoded_predictions)
if args.n_best:
logger.info('Storing n-best sentences in ./' + s + '.nbest')
nbest2file('./' + s + '.nbest', n_best_predictions)
Expand Down Expand Up @@ -194,8 +190,12 @@ def score_corpus(args, params):
logger.info("Using an ensemble of %d models" % len(args.models))
models = [loadModel(m, -1, full_path=True) for m in args.models]
dataset = loadDataset(args.dataset)
dataset = update_dataset_from_file(dataset, args.source, params, splits=args.splits,
output_text_filename=args.target, compute_state_below=True)
dataset = update_dataset_from_file(dataset,
args.source,
params,
splits=args.splits,
output_text_filename=args.target,
compute_state_below=True)

params['INPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[params['INPUTS_IDS_DATASET'][0]]
params['OUTPUT_VOCABULARY_SIZE'] = dataset.vocabulary_len[params['OUTPUTS_IDS_DATASET'][0]]
Expand Down Expand Up @@ -241,7 +241,10 @@ def score_corpus(args, params):
params_prediction['output_min_length_depending_on_x_factor'] = params.get('MINLEN_GIVEN_X_FACTOR', 2)
params_prediction['attend_on_output'] = params.get('ATTEND_ON_OUTPUT',
'transformer' in params['MODEL_TYPE'].lower())
beam_searcher = BeamSearchEnsemble(models, dataset, params_prediction, model_weights=model_weights,
beam_searcher = BeamSearchEnsemble(models,
dataset,
params_prediction,
model_weights=model_weights,
verbose=args.verbose)
scores = beam_searcher.scoreNet()[s]

Expand Down

0 comments on commit f72ca0a

Please sign in to comment.