-
Notifications
You must be signed in to change notification settings - Fork 1
/
fix_ckpoints.py
28 lines (24 loc) · 950 Bytes
/
fix_ckpoints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
OLD_CHECKPOINT_FILE = "model/train/model.ckpt-2000000"
NEW_CHECKPOINT_FILE = "model/train/fixed_model.ckpt-2000000"
import tensorflow as tf
# vars_to_rename = {
# "lstm/BasicLSTMCell/Linear/Matrix": "lstm/basic_lstm_cell/weights",
# "lstm/BasicLSTMCell/Linear/Bias": "lstm/basic_lstm_cell/biases",
# }
vars_to_rename = {
"lstm/basic_lstm_cell/weights": "lstm/basic_lstm_cell/kernel",
"lstm/basic_lstm_cell/biases": "lstm/basic_lstm_cell/bias",
}
new_checkpoint_vars = {}
reader = tf.train.NewCheckpointReader(OLD_CHECKPOINT_FILE)
for old_name in reader.get_variable_to_shape_map():
if old_name in vars_to_rename:
new_name = vars_to_rename[old_name]
else:
new_name = old_name
new_checkpoint_vars[new_name] = tf.Variable(reader.get_tensor(old_name))
init = tf.global_variables_initializer()
saver = tf.train.Saver(new_checkpoint_vars)
with tf.Session() as sess:
sess.run(init)
saver.save(sess, NEW_CHECKPOINT_FILE)