Skip to content

Commit

Permalink
Added finit arg to RNN to initialize forget bias, default = ones.
Browse files Browse the repository at this point in the history
  • Loading branch information
denizyuret committed Aug 18, 2019
1 parent 3d9d262 commit 9e87f4d
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/rnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ function RNN(inputSize, hiddenSize;
seed=0, # seed=0 for random init, positive integer for replicability
winit=xavier,
binit=zeros,
finit=ones, # forget bias for lstm
usegpu=(gpu()>=0),
)
w = dx = dhx = dcx = nothing
Expand All @@ -158,7 +159,7 @@ function RNN(inputSize, hiddenSize;
dropoutDesc = usegpu ? DD(handle=handle,dropout=dropout,seed=seed) : nothing # Need to keep dropoutDesc in RNN so it does not get gc'ed.
rnnDesc = usegpu ? RD(hiddenSize,numLayers,dropoutDesc,inputMode,direction,mode,algo,dataType) : nothing
r = RNN(w,h,c,inputSize,hiddenSize,numLayers,dropout,seed,inputMode,direction,mode,algo,dataType,rnnDesc,dropoutDesc,dx,dhx,dcx)
r.w = Param(Array{dataType}(undef,1,1,getRNNParamsSize(r)))
r.w = Array{dataType}(undef,1,1,getRNNParamsSize(r))
for a in rnnparams(r; handle=handle, useview=true)
if a == nothing
continue
Expand All @@ -170,8 +171,14 @@ function RNN(inputSize, hiddenSize;
error("Invalid RNN param $(summary(a))")
end
end
if rnnType == :lstm # separate initialization for lstm forget biases
for layer in 1:(numLayers*(bidirectional ? 2 : 1)), id in (2,6), param in (2,)
a = rnnparam(r, layer, id, param, useview=true, handle=handle)
copyto!(a, finit(dataType, size(a)))
end
end
# many copyto! ops to gpu is expensive (~20s), so we init on cpu and copy it over once
if usegpu; r.w = Param(KnetArray(value(r.w))); end
r.w = Param(usegpu ? KnetArray(r.w) : r.w)
return r
end

Expand Down

0 comments on commit 9e87f4d

Please sign in to comment.