Skip to content

Commit

Permalink
Merge pull request #14 from dribnet/model_dir
Browse files Browse the repository at this point in the history
add --model_dir option to scripts
  • Loading branch information
hardmaru committed Feb 23, 2017
2 parents 3d3114f + 3d8307e commit 140c69b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
6 changes: 4 additions & 2 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,19 @@
help='number of strokes to sample')
parser.add_argument('--scale_factor', type=int, default=10,
help='factor to scale down by for svg output. smaller means bigger output')
parser.add_argument('--model_dir', type=str, default='save',
help='directory to save model to')
sample_args = parser.parse_args()

with open(os.path.join('save', 'config.pkl'), 'rb') as f:
with open(os.path.join(sample_args.model_dir, 'config.pkl'), 'rb') as f:
saved_args = pickle.load(f)

model = Model(saved_args, True)
sess = tf.InteractiveSession()
#saver = tf.train.Saver(tf.all_variables())
saver = tf.train.Saver()

ckpt = tf.train.get_checkpoint_state('save')
ckpt = tf.train.get_checkpoint_state(sample_args.model_dir)
print("loading model: ", ckpt.model_checkpoint_path)

saver.restore(sess, ckpt.model_checkpoint_path)
Expand Down
9 changes: 7 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def main():
help='number of epochs')
parser.add_argument('--save_every', type=int, default=500,
help='save frequency')
parser.add_argument('--model_dir', type=str, default='save',
help='directory to save model to')
parser.add_argument('--grad_clip', type=float, default=10.,
help='clip gradients at this value')
parser.add_argument('--learning_rate', type=float, default=0.005,
Expand All @@ -43,7 +45,10 @@ def main():
def train(args):
data_loader = DataLoader(args.batch_size, args.seq_length, args.data_scale)

with open(os.path.join('save', 'config.pkl'), 'wb') as f:
if args.model_dir != '' and not os.path.exists(args.model_dir):
os.makedirs(args.model_dir)

with open(os.path.join(args.model_dir, 'config.pkl'), 'wb') as f:
pickle.dump(args, f)

model = Model(args)
Expand Down Expand Up @@ -72,7 +77,7 @@ def train(args):
e,
train_loss, valid_loss, end - start))
if (e * data_loader.num_batches + b) % args.save_every == 0 and ((e * data_loader.num_batches + b) > 0):
checkpoint_path = os.path.join('save', 'model.ckpt')
checkpoint_path = os.path.join(args.model_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step = e * data_loader.num_batches + b)
print("model saved to {}".format(checkpoint_path))

Expand Down

0 comments on commit 140c69b

Please sign in to comment.