Skip to content

Commit

Permalink
Match MKW's search output
Browse files Browse the repository at this point in the history
  • Loading branch information
lvapeab committed Apr 14, 2020
1 parent 588621b commit c268aa7
Showing 1 changed file with 26 additions and 25 deletions.
51 changes: 26 additions & 25 deletions nmt_keras/apply_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,39 +99,37 @@ def sample_ensemble(args, params):
for s in args.splits:
# Apply model predictions
params_prediction['predict_on_sets'] = [s]
beam_searcher = BeamSearchEnsemble(models, dataset, params_prediction,
model_weights=model_weights, n_best=args.n_best, verbose=args.verbose)
if args.n_best:
predictions, n_best = beam_searcher.predictBeamSearchNet()[s]
else:
predictions = beam_searcher.predictBeamSearchNet()[s]
n_best = None
beam_searcher = BeamSearchEnsemble(models,
dataset,
params_prediction,
model_weights=model_weights,
n_best=args.n_best,
verbose=args.verbose)
predictions = beam_searcher.predictBeamSearchNet()[s]
samples = predictions['samples']
alphas = predictions['alphas'] if params_prediction['pos_unk'] else None

if params_prediction['pos_unk']:
samples = predictions[0]
alphas = predictions[1]
sources = [x.strip() for x in open(args.text, 'r').read().split('\n')]
sources = sources[:-1] if len(sources[-1]) == 0 else sources
else:
samples = predictions
alphas = None
heuristic = None
sources = None

predictions = decode_predictions_beam_search(samples,
index2word_y,
glossary=glossary,
alphas=alphas,
x_text=sources,
heuristic=heuristic,
mapping=mapping,
verbose=args.verbose)
decoded_predictions = decode_predictions_beam_search(samples,
index2word_y,
glossary=glossary,
alphas=alphas,
x_text=sources,
heuristic=heuristic,
mapping=mapping,
verbose=args.verbose)
# Apply detokenization function if needed
if params.get('APPLY_DETOKENIZATION', False):
predictions = list(map(detokenize_function, predictions))
decoded_predictions = list(map(detokenize_function, decoded_predictions))

if args.n_best:
n_best_predictions = []
for i, (n_best_preds, n_best_scores, n_best_alphas) in enumerate(n_best):
for i, (n_best_preds, n_best_scores, n_best_alphas) in enumerate(predictions['n_best']):
n_best_sample_score = []
for n_best_pred, n_best_score, n_best_alpha in zip(n_best_preds, n_best_scores, n_best_alphas):
pred = decode_predictions_beam_search([n_best_pred],
Expand All @@ -153,13 +151,13 @@ def sample_ensemble(args, params):
if args.dest is not None:
filepath = args.dest # results file
if params.get('SAMPLING_SAVE_MODE', 'list'):
list2file(filepath, predictions)
list2file(filepath, decoded_predictions)
if args.n_best:
nbest2file(filepath + '.nbest', n_best_predictions)
else:
raise Exception('Only "list" is allowed in "SAMPLING_SAVE_MODE"')
else:
list2stdout(predictions)
list2stdout(decoded_predictions)
if args.n_best:
logger.info('Storing n-best sentences in ./' + s + '.nbest')
nbest2file('./' + s + '.nbest', n_best_predictions)
Expand Down Expand Up @@ -239,7 +237,10 @@ def score_corpus(args, params):
params_prediction['output_min_length_depending_on_x_factor'] = params.get('MINLEN_GIVEN_X_FACTOR', 2)
params_prediction['attend_on_output'] = params.get('ATTEND_ON_OUTPUT',
'transformer' in params['MODEL_TYPE'].lower())
beam_searcher = BeamSearchEnsemble(models, dataset, params_prediction, model_weights=model_weights,
beam_searcher = BeamSearchEnsemble(models,
dataset,
params_prediction,
model_weights=model_weights,
verbose=args.verbose)
scores = beam_searcher.scoreNet()[s]

Expand Down

0 comments on commit c268aa7

Please sign in to comment.