## Dynamic Graph and Weight Sharing in Pytorch
A fully connected ReLU that on each forward pass chooses a random number betweeen 1 and 4 and uses that many hidden layers, reusing the same weight multiple times to compute the innermost hiddenlayers.

In [22]:
import random
import torch
from torch.autograd import Variable
from torch import nn

In [45]:
class DynamicNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        """
        We construct three nn.Linear instances that we will use in the forward pass
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(input_dim, hidden_dim)
        self.middle_linear = torch.nn.Linear(hidden_dim, hidden_dim)
        self.output_linear = torch.nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        """
        Randomly choose either 0, 1, 2 or 3 and reuse the middle_linear Module
        that many times to compute hidden lyaer representations.
        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when defining
        the forward pass of the model.
        Here we also see that it is perfectly safe to reuse the same Module many times
        when defining a computational graph.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred

In [46]:
batch_size = 64
input_dim = 1000
hidden_dim = 100
output_dim = 10

In [47]:
# Create tensors wrapped with Variables to hold inputs and outputs
x = Variable(torch.randn(batch_size, input_dim))
y = Variable(torch.randn(batch_size, output_dim), requires_grad=False)

In [48]:
# Initialize our model
model = DynamicNet(input_dim, hidden_dim, output_dim)

In [49]:
# Define our loss function and optimizer
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)

In [50]:
for t in range(500):
    y_pred = model(x)
    loss = criterion(y_pred, y)
    print(t, loss.data[0])
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 753.000732421875
1 781.9813842773438
2 753.4469604492188
3 745.1839599609375
4 740.6723022460938
5 602.9732666015625
6 717.0496826171875
7 495.111328125
8 434.4676818847656
9 741.2548217773438
10 718.1215209960938
11 284.787841796875
12 244.49330139160156
13 657.5676879882812
14 169.89157104492188
15 137.93592834472656
16 688.9822387695312
17 94.38672637939453
18 83.29523468017578
19 73.72959899902344
20 716.45947265625
21 638.3557739257812
22 619.5986938476562
23 71.51226806640625
24 69.9638442993164
25 59.480987548828125
26 547.053955078125
27 490.8528137207031
28 450.5050048828125
29 39.074092864990234
30 590.2769165039062
31 429.1327819824219
32 520.2168579101562
33 470.17388916015625
34 338.3823547363281
35 51.85633850097656
36 320.9876403808594
37 51.585113525390625
38 374.6116943359375
39 48.96507263183594
40 493.4587097167969
41 41.06834030151367
42 69.48420715332031
43 78.45011138916016
44 608.5751953125
45 52.18977355957031
46 123.8389663696289
47 345.6250305175781
48 322.7

In [52]:
print(model)
params = list(model.parameters())
print(len(params))
print(params[0].size())
print(params[1].size())
print(params[2].size())
print(params[3].size())
print(params[4].size())
print(params[5].size())

DynamicNet (
  (input_linear): Linear (1000 -> 100)
  (middle_linear): Linear (100 -> 100)
  (output_linear): Linear (100 -> 10)
)
6
torch.Size([100, 1000])
torch.Size([100])
torch.Size([100, 100])
torch.Size([100])
torch.Size([10, 100])
torch.Size([10])
