In [1]:
%matplotlib inline


PyTorch: Control Flow + Weight Sharing
--------------------------------------

To showcase the power of PyTorch dynamic graphs, we will implement a very strange
model: a fully-connected ReLU network that on each forward pass randomly chooses
a number between 1 and 4 and has that many hidden layers, reusing the same
weights multiple times to compute the innermost hidden layers.



In [2]:
import random
import torch


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we construct three nn.Linear instances that we will use
        in the forward pass.
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same Module many
        times when defining a computational graph. This is a big improvement from Lua
        Torch, where each Module could be used only once.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 600.1221923828125
1 601.9895629882812
2 599.015380859375
3 598.0311279296875
4 647.0963745117188
5 600.8802490234375
6 602.1266479492188
7 593.64892578125
8 592.5841674804688
9 591.44482421875
10 366.234619140625
11 589.305908203125
12 301.593017578125
13 571.1676025390625
14 237.90194702148438
15 585.5845336914062
16 591.4974975585938
17 589.412841796875
18 139.79920959472656
19 119.44674682617188
20 577.551025390625
21 577.5701293945312
22 564.8087768554688
23 63.78001403808594
24 569.930419921875
25 493.82391357421875
26 528.2279663085938
27 453.348388671875
28 548.6288452148438
29 538.7023315429688
30 362.1465759277344
31 510.0733947753906
32 421.2705993652344
33 125.29044342041016
34 368.3860168457031
35 338.86920166015625
36 123.6346664428711
37 362.0947265625
38 94.48916625976562
39 71.5792007446289
40 163.116943359375
41 32.9329719543457
42 204.5526123046875
43 22.57729721069336
44 23.000648498535156
45 225.56671142578125
46 111.49029541015625
47 29.153820037841797
48 30.9058

379 0.6808908581733704
380 1.297163724899292
381 0.7125924229621887
382 0.27360621094703674
383 0.6888591647148132
384 0.6604183912277222
385 0.6137325167655945
386 0.17890872061252594
387 0.15163713693618774
388 0.5889298319816589
389 0.7424818277359009
390 0.09165921807289124
391 0.9930313229560852
392 0.9496101140975952
393 0.520941436290741
394 0.10068042576313019
395 0.7033364176750183
396 0.6237083077430725
397 0.5395421385765076
398 0.5678409337997437
399 0.40789464116096497
400 0.3451990783214569
401 0.32091763615608215
402 0.28625354170799255
403 1.5556929111480713
404 0.6959434747695923
405 0.37532246112823486
406 0.3114350140094757
407 0.2841632664203644
408 1.1926965713500977
409 0.35058867931365967
410 0.14432823657989502
411 0.9482675790786743
412 0.12392844259738922
413 0.8012407422065735
414 0.11921138316392899
415 0.6766843199729919
416 0.6195272207260132
417 0.7797030210494995
418 0.8561292886734009
419 0.7854558229446411
420 0.5690597891807556
421 0.686340868473053
4