In [1]:
%matplotlib inline


PyTorch: Control Flow + Weight Sharing
--------------------------------------

To showcase the power of PyTorch dynamic graphs, we will implement a very strange
model: a fully-connected ReLU network that on each forward pass randomly chooses
a number between 1 and 4 and has that many hidden layers, reusing the same
weights multiple times to compute the innermost hidden layers.



In [2]:
import random
import torch


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we construct three nn.Linear instances that we will use
        in the forward pass.
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same Module many
        times when defining a computational graph. This is a big improvement from Lua
        Torch, where each Module could be used only once.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 694.7438354492188
1 640.669677734375
2 642.233642578125
3 564.689697265625
4 639.2024536132812
5 456.2197265625
6 636.8248291015625
7 354.3739013671875
8 305.3263244628906
9 636.8751831054688
10 636.2811279296875
11 629.2474975585938
12 627.80029296875
13 633.27587890625
14 631.6547241210938
15 590.4172973632812
16 627.1930541992188
17 607.0570678710938
18 599.1259765625
19 123.79695892333984
20 110.33484649658203
21 567.9360961914062
22 457.57025146484375
23 596.3828125
24 72.873046875
25 508.3536071777344
26 486.9886779785156
27 545.8179931640625
28 520.1388549804688
29 486.2393493652344
30 78.65131378173828
31 326.7291259765625
32 298.5826110839844
33 70.68470001220703
34 320.8783264160156
35 301.7637634277344
36 217.37481689453125
37 197.7802734375
38 229.95204162597656
39 194.2117462158203
40 235.8970489501953
41 221.7454833984375
42 87.56327056884766
43 130.15658569335938
44 130.7443084716797
45 216.62374877929688
46 113.77384948730469
47 166.87298583984375
48 170.1840057373047

418 1.3045860528945923
419 0.49829035997390747
420 0.7573560476303101
421 0.201497420668602
422 1.0189800262451172
423 0.15743224322795868
424 1.685482382774353
425 0.6143415570259094
426 0.4768216013908386
427 0.4995113015174866
428 1.089442491531372
429 1.4085839986801147
430 0.4712297022342682
431 0.6377395391464233
432 1.7197670936584473
433 1.3959394693374634
434 0.3086584508419037
435 0.4414238929748535
436 0.1919858306646347
437 0.5208927392959595
438 0.45591598749160767
439 1.41886305809021
440 0.6342697739601135
441 0.42917391657829285
442 0.5206900238990784
443 0.822360634803772
444 0.2855653762817383
445 1.9233498573303223
446 0.47684141993522644
447 0.7856736779212952
448 0.7662190198898315
449 1.0224645137786865
450 0.4116176664829254
451 0.34033897519111633
452 0.6427425742149353
453 1.2764562368392944
454 0.46870389580726624
455 0.8088362812995911
456 0.1690785139799118
457 1.1954450607299805
458 0.551815927028656
459 0.5066419839859009
460 0.9917063117027283
461 0.70450