Skip to content

Commit

Permalink
Flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
lvapeab committed Mar 24, 2021
1 parent c118daa commit 8e58c1d
Show file tree
Hide file tree
Showing 8 changed files with 11 additions and 16 deletions.
4 changes: 1 addition & 3 deletions active_learning_nmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,9 +674,7 @@ def parse_args():
'state_below_index': -1,
'output_text_index': 0,
'apply_tokenization': params.get('APPLY_TOKENIZATION', False),
'tokenize_f': eval('dataset.' +
params.get('TOKENIZATION_METHOD', 'tokenize_none')),

'tokenize_f': eval('dataset.' + params.get('TOKENIZATION_METHOD', 'tokenize_none')),
'apply_detokenization': params.get('APPLY_DETOKENIZATION', True),
'detokenize_f': eval('dataset.' + params.get('DETOKENIZATION_METHOD',
'detokenize_none')),
Expand Down
3 changes: 1 addition & 2 deletions interactive_char_nmt_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,6 @@ def interactive_simulation():
# Get word2index and index2word dictionaries
index2word_y = dataset.vocabulary[params['OUTPUTS_IDS_DATASET'][0]]['idx2words']
word2index_y = dataset.vocabulary[params['OUTPUTS_IDS_DATASET'][0]]['words2idx']
index2word_x = dataset.vocabulary[params['INPUTS_IDS_DATASET'][0]]['idx2words']
word2index_x = dataset.vocabulary[params['INPUTS_IDS_DATASET'][0]]['words2idx']
unk_id = dataset.extra_words['<unk>']

# Initialize counters
Expand Down Expand Up @@ -579,5 +577,6 @@ def interactive_simulation():
fsrc.close()
ftrans.close()


if __name__ == "__main__":
interactive_simulation()
5 changes: 1 addition & 4 deletions interactive_nmt_simulation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
from __future__ import print_function
import time
import argparse
import ast
import codecs
Expand Down Expand Up @@ -271,8 +270,6 @@ def interactive_simulation():
# Get word2index and index2word dictionaries
index2word_y = dataset.vocabulary[params['OUTPUTS_IDS_DATASET'][0]]['idx2words']
word2index_y = dataset.vocabulary[params['OUTPUTS_IDS_DATASET'][0]]['words2idx']
index2word_x = dataset.vocabulary[params['INPUTS_IDS_DATASET'][0]]['idx2words']
word2index_x = dataset.vocabulary[params['INPUTS_IDS_DATASET'][0]]['words2idx']
unk_id = dataset.extra_words['<unk>']

# Initialize counters
Expand Down Expand Up @@ -501,7 +498,6 @@ def interactive_simulation():
mouse_actions_sentence += 1
if checked_index_h - last_checked_index > 1:
mouse_actions_sentence += 1
last_correct_pos = checked_index_h
keystrokes_sentence += new_word_len
# Substitution
new_word_indices = [word2index_y.get(word, unk_id) for word in new_words]
Expand Down Expand Up @@ -702,5 +698,6 @@ def interactive_simulation():
fsrc.close()
ftrans.close()


if __name__ == "__main__":
interactive_simulation()
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def parse_args():

return parser.parse_args()


if __name__ == "__main__":
args = parse_args()
parameters = load_parameters()
Expand Down
4 changes: 2 additions & 2 deletions nmt_keras/online_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ def build_online_models(models, params):
y_true = Input(name="y_true", batch_shape=tuple([None, None, None]))
y_pred = nmt_model.model.outputs[0]
inputs = [y_true, y_pred, hyp1, preds_h1, weight1, weight2]
losses = [Lambda(eval(loss), output_shape=(None,),
name=loss, supports_masking=False)(inputs) for loss in params['LOSS']]
_ = [Lambda(eval(loss), output_shape=(None,),
name=loss, supports_masking=False)(inputs) for loss in params['LOSS']]

trainer_model = Model(inputs=nmt_model.model.inputs + [state_below_h1] + [y_true, weight1, weight2],
outputs=loss_out)
Expand Down
7 changes: 3 additions & 4 deletions nmt_keras/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def train_model(params, load_dataset=None):
else:
logger.info('Updating dataset.')
dataset = loadDataset(
params['DATASET_STORE_PATH'] + '/Dataset_' + params['DATASET_NAME'] + '_' + params['SRC_LAN'] +
params['TRG_LAN'] + '.pkl')
params['DATASET_STORE_PATH'] + '/Dataset_' + params['DATASET_NAME'] + '_' + params['SRC_LAN'] + params['TRG_LAN'] + '.pkl')

epoch_offset = 0 if dataset.len_train == 0 else int(params['RELOAD'] * params['BATCH_SIZE'] / dataset.len_train)
params['EPOCH_OFFSET'] = params['RELOAD'] if params['RELOAD_EPOCH'] else epoch_offset
Expand Down Expand Up @@ -110,8 +109,8 @@ def train_model(params, load_dataset=None):
nmt_model.setParams(params)
nmt_model.setOptimizer()
if params.get('EPOCH_OFFSET') is None:
params['EPOCH_OFFSET'] = params['RELOAD'] if params['RELOAD_EPOCH'] else \
int(params['RELOAD'] * params['BATCH_SIZE'] / dataset.len_train)
params['EPOCH_OFFSET'] = \
params['RELOAD'] if params['RELOAD_EPOCH'] else int(params['RELOAD'] * params['BATCH_SIZE'] / dataset.len_train)

# Store configuration as pkl
dict2pkl(params, params['STORE_PATH'] + '/config')
Expand Down
2 changes: 1 addition & 1 deletion utils/compute_cvr_centroid_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np
import os
from keras_wrapper.extra.read_write import file2list, numpy2file
from nltk import ngrams

logging.basicConfig(level=logging.DEBUG, format='[%(asctime)s] %(message)s', datefmt='%d/%m/%Y %H:%M:%S')
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -54,6 +53,7 @@ def compute_centroid(args):
print ("Finished. Storing n-gram counts in %s " % args.dest)
numpy2file(args.dest, centroid)


if __name__ == "__main__":
args = parse_args()
assert args.sentence_mode in ['average', 'sum'], 'Unknown sentence-mode: "%s"' % args.sentence_mode
Expand Down
1 change: 1 addition & 0 deletions utils/compute_ngrams_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def extract_n_grams(args):
print ("Finished. Storing n-gram counts in %s " % args.dest)
dict2pkl(n_gram_counts, args.dest)


if __name__ == "__main__":
args = parse_args()
extract_n_grams(args)

0 comments on commit 8e58c1d

Please sign in to comment.