In [1]:
%matplotlib inline


PyTorch: Control Flow + Weight Sharing
--------------------------------------

To showcase the power of PyTorch dynamic graphs, we will implement a very strange
model: a fully-connected ReLU network that on each forward pass randomly chooses
a number between 1 and 4 and has that many hidden layers, reusing the same
weights multiple times to compute the innermost hidden layers.



In [2]:
import random
import torch


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we construct three nn.Linear instances that we will use
        in the forward pass.
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same Module many
        times when defining a computational graph. This is a big improvement from Lua
        Torch, where each Module could be used only once.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 708.1929931640625
1 704.5569458007812
2 705.789306640625
3 751.10986328125
4 698.5592041015625
5 671.241455078125
6 553.136474609375
7 700.7708740234375
8 700.2088623046875
9 646.3265380859375
10 635.2501831054688
11 336.34881591796875
12 601.8843383789062
13 696.9946899414062
14 257.27532958984375
15 226.91696166992188
16 692.5521850585938
17 677.6451416015625
18 487.1050109863281
19 129.30433654785156
20 113.45774841308594
21 633.6678466796875
22 661.2832641601562
23 77.07271575927734
24 572.71875
25 613.4192504882812
26 585.5654296875
27 317.561279296875
28 293.01025390625
29 259.6499328613281
30 385.6541748046875
31 350.7640380859375
32 181.02394104003906
33 275.06884765625
34 241.32525634765625
35 282.4497985839844
36 188.66250610351562
37 363.694091796875
38 191.25485229492188
39 248.79104614257812
40 164.8061981201172
41 148.4585723876953
42 122.61903381347656
43 107.18975830078125
44 154.9362335205078
45 169.9942626953125
46 99.12210845947266
47 64.58987426757812
48 119.76276

473 0.4922123849391937
474 2.62910795211792
475 1.049784541130066
476 1.331726312637329
477 2.648322820663452
478 2.49495005607605
479 1.3115017414093018
480 0.6890081763267517
481 1.9521290063858032
482 3.130441665649414
483 0.4474015235900879
484 0.3502076268196106
485 1.6983177661895752
486 4.962226390838623
487 0.7290500998497009
488 1.9024465084075928
489 2.539017677307129
490 1.253010869026184
491 0.8676019906997681
492 5.016350746154785
493 0.554508626461029
494 1.6464886665344238
495 0.3892876207828522
496 1.5063093900680542
497 1.5162404775619507
498 2.551485061645508
499 2.1192126274108887
