Skip to content

Commit

Permalink
Fix map error
Browse files Browse the repository at this point in the history
  • Loading branch information
lvapeab committed Jan 3, 2020
1 parent 951ff77 commit c1c3eee
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions nmt_keras/apply_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,12 @@ def sample_ensemble(args, params):
else:
glossary = None

if model_weights is not None and model_weights != []:
if model_weights:
assert len(model_weights) == len(models), 'You should give a weight to each model. You gave %d models and %d weights.' % (len(models), len(model_weights))
model_weights = map(float, model_weights)
model_weights = list(map(float, model_weights))
if len(model_weights) > 1:
logger.info('Giving the following weights to each model: %s' % str(model_weights))

for s in args.splits:
# Apply model predictions
params_prediction['predict_on_sets'] = [s]
Expand Down Expand Up @@ -122,7 +123,7 @@ def sample_ensemble(args, params):
verbose=args.verbose)
# Apply detokenization function if needed
if params.get('APPLY_DETOKENIZATION', False):
predictions = map(detokenize_function, predictions)
predictions = list(map(detokenize_function, predictions))

if args.n_best:
n_best_predictions = []
Expand All @@ -139,7 +140,7 @@ def sample_ensemble(args, params):
verbose=args.verbose)
# Apply detokenization function if needed
if params.get('APPLY_DETOKENIZATION', False):
pred = map(detokenize_function, pred)
pred = list(map(detokenize_function, pred))

n_best_sample_score.append([i, pred, n_best_score])
n_best_predictions.append(n_best_sample_score)
Expand Down Expand Up @@ -196,9 +197,9 @@ def score_corpus(args, params):
extra_vars['tokenize_f'] = eval('dataset.' + params['TOKENIZATION_METHOD'])

model_weights = args.weights
if model_weights is not None and model_weights != []:
if model_weights:
assert len(model_weights) == len(models), 'You should give a weight to each model. You gave %d models and %d weights.' % (len(models), len(model_weights))
model_weights = map(float, model_weights)
model_weights = list(map(float, model_weights))
if len(model_weights) > 1:
logger.info('Giving the following weights to each model: %s' % str(model_weights))

Expand Down

0 comments on commit c1c3eee

Please sign in to comment.