forked from jfsantos/seq2seq
-
Notifications
You must be signed in to change notification settings - Fork 0
/
GRU.lua
51 lines (39 loc) · 1.38 KB
/
GRU.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
require 'nn';
require 'nngraph';
require 'Recurrent';
require 'LinearZeroBias';
local GRU, parent = torch.class('nn.GRU','nn.Recurrent')
function GRU:__init(diminput,dimoutput)
assert(diminput ~= nil, "diminput must be specified")
assert(dimoutput ~= nil, "dimoutput must be specified")
self.diminput = diminput
self.dimoutput = dimoutput
local x = nn.Identity()()
local prev_h = nn.Identity()()
self.x = x
self.prev_h = prev_h
local hx = nn.JoinTable(1,1)({prev_h,x})
local z = nn.Sigmoid()(nn.LinearZeroBias(diminput+dimoutput,dimoutput)(hx))
local r = nn.Sigmoid()(nn.LinearZeroBias(diminput+dimoutput,dimoutput)(hx))
local rh_x = nn.JoinTable(1,1)({nn.CMulTable()({r,prev_h}),x})
local h_ = nn.Tanh()(nn.LinearZeroBias(diminput+dimoutput,dimoutput)(rh_x))
local zh_ = nn.CMulTable()({z,h_})
local v1 = nn.AddConstant(1)(nn.MulConstant(-1)(z))
local v2 = nn.CMulTable()({v1,prev_h})
local h = nn.CAddTable()({v2,zh_})
self.z = z
self.zh_ = zh_
self.v1 = v1
self.v2 = v2
self.h = h
local gru = nn.gModule({x,prev_h},{h})
self.gru = gru
parent.__init(self,gru,dimoutput)
end
function GRU:updateGradInput(input,gradOutput)
if type(gradOutput) == 'table' then
return parent.updateGradInput(self,input,gradOutput[1])
else
return parent.updateGradInput(self,input,gradOutput)
end
end