Skip to content

Commit

Permalink
Finish replacing initialisations
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaixhin authored and gcr committed Jan 15, 2016
1 parent f16f517 commit b738d66
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 16 deletions.
10 changes: 7 additions & 3 deletions residual-layers.lua
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,16 @@ function addResidualLayer2(input, nChannels, nOutChannels, stride)
-- The first layer does the downsampling and the striding
local net = cudnn.SpatialConvolution(nChannels, nOutChannels,
3,3, stride,stride, 1,1)
:init('weight', nninit.kaiming, {gain = 'relu'})(input)
net = cudnn.SpatialBatchNormalization(nOutChannels)(net)
:init('weight', nninit.kaiming, {gain = 'relu'})
:init('bias', nninit.constant, 0)(input)
net = cudnn.SpatialBatchNormalization(nOutChannels)
:init('weight', nninit.normal, 1.0, 0.002)
:init('bias', nninit.constant, 0)(net)
net = cudnn.ReLU(true)(net)
net = cudnn.SpatialConvolution(nOutChannels, nOutChannels,
3,3, 1,1, 1,1)
:init('weight', nninit.kaiming, {gain = 'relu'})(net)
:init('weight', nninit.kaiming, {gain = 'relu'})
:init('bias', nninit.constant, 0)(net)
-- Should we put Batch Normalization here? I think not, because
-- BN would force the output to have unit variance, which breaks the residual
-- property of the network.
Expand Down
13 changes: 0 additions & 13 deletions train-cifar.lua
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,6 @@ if opt.loadFrom == "" then
model = nn.gModule({input}, {model})
model:cuda()
--print(#model:forward(torch.randn(100, 3, 32,32):cuda()))

model:apply(function(m)
-- Initialize weights
local name = torch.type(m)
if name:find('Convolution') then
m.weight:normal(0.0, math.sqrt(2/(m.nInputPlane*m.kW*m.kH)))
m.bias:fill(0)
elseif name:find('BatchNormalization') then
if m.weight then m.weight:normal(1.0, 0.002) end
if m.bias then m.bias:fill(0) end
end
end)

else
print("Loading model from "..opt.loadFrom)
cutorch.setDevice(1)
Expand Down

0 comments on commit b738d66

Please sign in to comment.