In [24]:
import torch
import torch.nn as nn
from torch.autograd import Variable

torch.manual_seed(7)

<torch._C.Generator at 0x113459450>

In [25]:
class ForwardNet(nn.Module):
    def __init__(self):
        super(ForwardNet, self).__init__()
        self.fc1 = nn.Linear(10, 10)
        self.fc2 = nn.Linear(10, 10)
        self.fc3 = nn.Linear(10, 10)
        self.fc4 = nn.Linear(10, 10)
        self.layers = nn.ModuleList([self.fc1, self.fc2, self.fc3, self.fc4])

    def forward(self, x):
        self.output = []
        self.input = []
        for layer in self.layers:
            # detach from previous history
            x = Variable(x.data, requires_grad=True)
            self.input.append(x)
            # compute output
            x = layer(x)
            # add to list of outputs
            self.output.append(x)
        return x

    def backward(self, g):
        for i, output in reversed(list(enumerate(self.output))):
            if i == (len(self.output) - 1):
                # for last node, use g
                output.backward(g)
            else:
                output.backward(self.input[i+1].grad.data)
            print(i, self.input[i].grad.data.sum())

In [26]:
model.eval()

ForwardNet(
  (fc1): Linear(in_features=10, out_features=10, bias=True)
  (fc2): Linear(in_features=10, out_features=10, bias=True)
  (fc3): Linear(in_features=10, out_features=10, bias=True)
  (fc4): Linear(in_features=10, out_features=10, bias=True)
  (layers): ModuleList(
    (0): Linear(in_features=10, out_features=10, bias=True)
    (1): Linear(in_features=10, out_features=10, bias=True)
    (2): Linear(in_features=10, out_features=10, bias=True)
    (3): Linear(in_features=10, out_features=10, bias=True)
  )
)

In [27]:
model = ForwardNet()
inp = Variable(torch.randn(4, 10))
output = model(inp)
gradients = torch.randn(*output.size())
model.backward(gradients)

3 tensor(0.6882)
2 tensor(-1.4815)
1 tensor(0.3181)
0 tensor(-0.3078)


In [28]:
class FeedbackNet(nn.Module):
    def __init__(self, batch_size):
        super(FeedbackNet, self).__init__()
        self.layers = nn.ModuleList([
            nn.Linear(10, 10),
            nn.Linear(10, 10),
            nn.Linear(10, 10),
            nn.Linear(10, 10),
        ])
        self.z = []
        for i, layer in list(enumerate(self.layers)):
            z = torch.ones(batch_size, 10)
            self.z.append(z)

    def forward(self, x):
        self.output = []
        self.input = []
        for i, layer in list(enumerate(self.layers)):
            # detach from previous history
            x = Variable(x.data, requires_grad=True)
            self.input.append(x)
            # compute output
            x = layer(x)
            # multiply by the hidden gate
            x = x * self.z[i]
            # add to list of outputs
            self.output.append(x)
        return x

    def backward(self, g):
        for i, output in reversed(list(enumerate(self.output))):
            if i == (len(self.output) - 1):
                # for last node, use g
                output.backward(g)
            else:
                output.backward(self.input[i+1].grad.data)
            alpha = self.input[i].grad
            self.z[i] = (alpha > 0).float()
            self.input[i].grad = self.z[i] * alpha
            print(i, self.input[i].grad.data.sum())

In [29]:
model.eval()

ForwardNet(
  (fc1): Linear(in_features=10, out_features=10, bias=True)
  (fc2): Linear(in_features=10, out_features=10, bias=True)
  (fc3): Linear(in_features=10, out_features=10, bias=True)
  (fc4): Linear(in_features=10, out_features=10, bias=True)
  (layers): ModuleList(
    (0): Linear(in_features=10, out_features=10, bias=True)
    (1): Linear(in_features=10, out_features=10, bias=True)
    (2): Linear(in_features=10, out_features=10, bias=True)
    (3): Linear(in_features=10, out_features=10, bias=True)
  )
)

In [31]:
model = FeedbackNet(4)
inp = Variable(torch.randn(4, 10))
gradients = torch.randn(*output.size())
for iter in range(5):
    output = model(inp)
    model.backward(gradients)

3 tensor(7.6773)
2 tensor(2.1005)
1 tensor(0.9308)
0 tensor(0.3725)
3 tensor(5.5707)
2 tensor(0.7880)
1 tensor(0.1020)
0 tensor(1.00000e-02 *
       4.6935)
3 tensor(6.3577)
2 tensor(1.3364)
1 tensor(0.2458)
0 tensor(1.00000e-02 *
       7.1084)
3 tensor(6.5020)
2 tensor(1.3364)
1 tensor(0.3185)
0 tensor(1.00000e-02 *
       6.4023)
3 tensor(6.5934)
2 tensor(1.0495)
1 tensor(0.2015)
0 tensor(1.00000e-02 *
       6.6785)
