In [None]:
%matplotlib inline


PyTorch: Control Flow + Weight Sharing
--------------------------------------

To showcase the power of PyTorch dynamic graphs, we will implement a very strange
model: a fully-connected ReLU network that on each forward pass randomly chooses
a number between 1 and 4 and has that many hidden layers, reusing the same
weights multiple times to compute the innermost hidden layers.



In [1]:
import random
import torch


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we construct three nn.Linear instances that we will use
        in the forward pass.
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same Module many
        times when defining a computational graph. This is a big improvement from Lua
        Torch, where each Module could be used only once.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 601.462158203125
1 599.4974365234375
2 599.0643920898438
3 599.2340087890625
4 591.6246337890625
5 623.9635009765625
6 596.884521484375
7 546.104248046875
8 553.7937622070312
9 591.5114135742188
10 403.1564636230469
11 588.6131591796875
12 525.2078857421875
13 290.3096008300781
14 253.80894470214844
15 213.43417358398438
16 489.55194091796875
17 473.6482238769531
18 118.17691040039062
19 428.1708068847656
20 588.7614135742188
21 371.7599182128906
22 535.6546630859375
23 306.4186096191406
24 270.7079772949219
25 549.8812255859375
26 440.6129150390625
27 180.75926208496094
28 477.8786315917969
29 154.19459533691406
30 408.09002685546875
31 366.59735107421875
32 321.2332763671875
33 222.98228454589844
34 216.57080078125
35 550.6864624023438
36 164.7527313232422
37 277.43017578125
38 95.72271728515625
39 95.51760864257812
40 243.2257080078125
41 171.6375732421875
42 237.1130828857422
43 213.39817810058594
44 104.47303009033203
45 133.98678588867188
46 130.688232421875
47 101.033309936523

378 1.5393837690353394
379 1.1100828647613525
380 0.42510223388671875
381 0.31866565346717834
382 0.9195457696914673
383 0.8368210792541504
384 0.7337167859077454
385 1.7325667142868042
386 1.450669765472412
387 0.8411374688148499
388 1.581304907798767
389 0.8536913990974426
390 0.4170280992984772
391 0.7576215863227844
392 0.33012837171554565
393 0.8963790535926819
394 0.27647364139556885
395 3.16617488861084
396 0.7681730389595032
397 0.9552274346351624
398 0.4090581238269806
399 4.112884998321533
400 1.6990383863449097
401 0.27506792545318604
402 0.4065823554992676
403 11.44956111907959
404 1.0284479856491089
405 0.242122620344162
406 3.8965373039245605
407 0.4545843303203583
408 5.350459575653076
409 0.2139841765165329
410 0.2713792026042938
411 2.1561696529388428
412 1.8721356391906738
413 2.49051570892334
414 1.9384654760360718
415 0.7226355671882629
416 1.0024816989898682
417 4.382483959197998
418 1.0362496376037598
419 1.2084519863128662
420 1.6997195482254028
421 1.21067631244