Skip to content

Commit

Permalink
add option to normalize gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjohnson committed Sep 9, 2015
1 parent 7ca5aa4 commit 0c5d5d5
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions neural_style.lua
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ cmd:option('-content_weight', 5e0)
cmd:option('-style_weight', 1e2)
cmd:option('-tv_weight', 1e-3)
cmd:option('-num_iterations', 1000)
cmd:option('-normalize_gradients', false)
cmd:option('-init', 'random', 'random|image')

-- Output options
Expand Down Expand Up @@ -104,7 +105,8 @@ local function main(params)
end
if i == content_layers[next_content_idx] then
local target = net:forward(content_image_caffe):clone()
local loss_module = nn.ContentLoss(params.content_weight, target):float()
local norm = params.normalize_gradients
local loss_module = nn.ContentLoss(params.content_weight, target, norm):float()
if params.gpu >= 0 then
loss_module:cuda()
end
Expand All @@ -121,7 +123,8 @@ local function main(params)
local target = gram:forward(target_features)
target:div(target_features:nElement())
local weight = params.style_weight * style_layer_weights[next_style_idx]
local loss_module = nn.StyleLoss(weight, target):float()
local norm = params.normalize_gradients
local loss_module = nn.StyleLoss(weight, target, norm):float()
if params.gpu >= 0 then
loss_module:cuda()
end
Expand Down Expand Up @@ -254,10 +257,11 @@ end
-- Define an nn Module to compute content loss in-place
local ContentLoss, parent = torch.class('nn.ContentLoss', 'nn.Module')

function ContentLoss:__init(strength, target)
function ContentLoss:__init(strength, target, normalize)
parent.__init(self)
self.strength = strength
self.target = target
self.normalize = normalize or false
self.loss = 0
self.crit = nn.MSECriterion()
end
Expand All @@ -276,6 +280,9 @@ function ContentLoss:updateGradInput(input, gradOutput)
if input:nElement() == self.target:nElement() then
self.gradInput = self.crit:backward(input, self.target)
end
if self.normalize then
self.gradInput:div(torch.norm(self.gradInput, 1) + 1e-8)
end
self.gradInput:mul(self.strength)
self.gradInput:add(gradOutput)
return self.gradInput
Expand All @@ -298,8 +305,9 @@ end
-- Define an nn Module to compute style loss in-place
local StyleLoss, parent = torch.class('nn.StyleLoss', 'nn.Module')

function StyleLoss:__init(strength, target)
function StyleLoss:__init(strength, target, normalize)
parent.__init(self)
self.normalize = normalize or false
self.strength = strength
self.target = target
self.loss = 0
Expand All @@ -322,6 +330,9 @@ function StyleLoss:updateGradInput(input, gradOutput)
local dG = self.crit:backward(self.G, self.target)
dG:div(input:nElement())
self.gradInput = self.gram:backward(input, dG)
if self.normalize then
self.gradInput:div(torch.norm(self.gradInput, 1) + 1e-8)
end
self.gradInput:mul(self.strength)
self.gradInput:add(gradOutput)
return self.gradInput
Expand Down

0 comments on commit 0c5d5d5

Please sign in to comment.