Skip to content

Commit

Permalink
Fix bug: all variables should be passed to the function.
Browse files Browse the repository at this point in the history
  • Loading branch information
taineleau committed Sep 19, 2016
1 parent eb38963 commit 223dc19
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions CIFAR10/cifar10_L2.lua
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ local function L2_norm_create(params, initHyper)
end


local function L2_norm(params)
local function L2_norm(params, params_l2)
-- print(params_l2, params[1])
local penalty = torch.sum(torch.cmul(params[1], params_l2[1]))
-- local penalty = torch.sum(params[1])
Expand Down Expand Up @@ -114,9 +114,10 @@ local function init(iter)
local Lossf = grad.nn.MSECriterion()

local function fTrain(params, x, y)
-- print(params)
print(params.p1)
print(params.p2)
local prediction = modelf(params.p2, x)
local penalty = L2_norm(params.p2)
local penalty = L2_norm(params.p2, params.p1)
return Lossf(prediction, y) + penalty
end

Expand Down Expand Up @@ -176,7 +177,12 @@ local function train_meta(iter)
return inputs:cuda(), t_:cuda()
end

local grads, loss = dfTrain(params, makesample(inputs, targets))
local p = {
p1 = params_l2,
p2 = params
}

local grads, loss = dfTrain(p, makesample(inputs, targets))

for i = 1, #grads do
params[i] = params[i] + opt.learningRate * grads[i]
Expand Down

0 comments on commit 223dc19

Please sign in to comment.