In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable

In [3]:
class RNN(nn.Module):

    def __init__(self, size):
        super(RNN, self).__init__()
        self.W1 = nn.Linear(size, size)
        self.W2 = nn.Linear(size, size, bias=False)
        self.h0 = nn.Parameter(torch.Tensor(size).fill_(0), 
                               requires_grad=True)

    def forward(self, xs):
        hs = [self.h0]
        for i in range(xs.size(0)):
            h = nn.Tanh()(self.W2(hs[i]) + self.W1(xs[i]))
            hs.append(h)
        return hs

In [6]:
rnn = RNN(5)
T = 100
x = Variable(torch.FloatTensor(T, 5).fill_(1), 
            requires_grad=True) 
hs = rnn(x)

loss = hs[T].sum()
loss.backward()
print(x.grad[99])
print(x.grad[0])

Variable containing:
 0.6130
 0.5908
-0.1896
-0.5053
-0.0959
[torch.FloatTensor of size 5]

Variable containing:
1.00000e-44 *
  1.5414
  1.6816
 -0.5605
 -1.9618
  0.1401
[torch.FloatTensor of size 5]



In [96]:
class Gated(nn.Module):
    def __init__(self, size):
        super(Gated, self).__init__()
        self.W1 = nn.Linear(size, size)
        self.W2 = nn.Linear(size, size, bias=False)
        self.activation = nn.Tanh()

        self.gate_W1 = nn.Linear(size, size)
        self.gate_W2 = nn.Linear(size, size, bias=False)
        self.gate_activation = nn.Sigmoid()
        self.h0 = nn.Parameter(torch.Tensor(size))

    def forward(self, xs):
        hs = [self.h0]
        for i in range(xs.size(0)):
            h = self.activation(self.W1(hs[i]) + self.W2(xs[i]))
            t = 0.1 # self.gate_activation(self.gate_W1(hs[i]) + self.gate_W2(xs[i]))
            hs.append( (1-t) * hs[i] + t * h)
        return hs

In [103]:
rnn = Gated(5)
a = Variable(torch.FloatTensor(100, 5).fill_(1), 
            requires_grad=True) 
hs = rnn(a)
loss = hs[-1].sum()
loss.backward()
print(a.grad[99])
print(a.grad[0])

Variable containing:
1.00000e-02 *
  1.8279
 -3.2543
  1.1446
 -4.7897
  4.0270
[torch.FloatTensor of size 5]

Variable containing:
1.00000e-05 *
 -0.4706
 -0.3314
 -0.6020
 -1.5595
  0.0797
[torch.FloatTensor of size 5]

