Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change learning rate and logging criterion #408

Merged
merged 1 commit into from Nov 1, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions chapter_computer-vision/neural-style.md
Expand Up @@ -198,7 +198,7 @@ def train(x, content_y, style_y, ctx, lr, max_epochs, lr_decay_epoch):
% (i, nd.add_n(*content_L).asscalar(),
nd.add_n(*style_L).asscalar(), tv_L.asscalar(),
time.time() - tic))
if i % lr_decay_epoch == 0:
if i % lr_decay_epoch == 0 and i != 0:
trainer.set_learning_rate(trainer.learning_rate * 0.1)
print('change lr to %.1e' % trainer.learning_rate)
return net()
Expand All @@ -213,7 +213,7 @@ content_x, content_y = get_contents(image_shape, ctx)
style_x, style_y = get_styles(image_shape, ctx)

x = content_x
y = train(x, content_y, style_y, ctx, 0.1, 500, 200)
y = train(x, content_y, style_y, ctx, 0.01, 500, 200)
```

因为使用了内容图像作为初始值,所以一开始内容误差远小于样式误差。随着迭代的进行样式误差迅速减少,最终它们值在相近的范围。下面我们将训练好的合成图像保存下来。
Expand All @@ -233,7 +233,7 @@ content_x, content_y = get_contents(image_shape, ctx)
style_x, style_y = get_styles(image_shape, ctx)

x = preprocess(postprocess(y) * 255, image_shape)
z = train(x, content_y, style_y, ctx, 0.1, 300, 100)
z = train(x, content_y, style_y, ctx, 0.01, 300, 100)

gb.plt.imsave('../img/neural-style-2.png', postprocess(z).asnumpy())
```
Expand Down