To showcase the power of PyTorch dynamic graphs, we will implement a very strange model: a fully-connected ReLU network that on each forward pass randomly chooses a number between 1 and 4 and has that many hidden layers, reusing the same weights multiple times to compute the innermost hidden layers.

In [1]:
import random
import torch

In [4]:
class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)
        
    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same Module many
        times when defining a computational graph. This is a big improvement from Lua
        Torch, where each Module could be used only once.
        """
        
        # torch.clamp(input, min, max)将输入input张量每个元素夹紧到区间[min, max]，并返回结果到一个新的张量。操作如下：
        #       \ min, if x_i < min
        # y_i = \ x_i, if min <= x_i <= max
        #       \ max, if x_i > max
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred

In [5]:
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 758.2650146484375
1 695.9061889648438
2 598.3555297851562
3 693.276611328125
4 684.016845703125
5 674.0927124023438
6 676.2586669921875
7 671.2991333007812
8 638.9191284179688
9 622.5925903320312
10 292.0539245605469
11 647.0540771484375
12 241.882080078125
13 671.4503784179688
14 169.3603973388672
15 134.16317749023438
16 104.68383026123047
17 654.9441528320312
18 79.7642593383789
19 575.865234375
20 559.1033325195312
21 534.1851196289062
22 89.49536895751953
23 572.39697265625
24 346.7315979003906
25 85.11773681640625
26 81.42572784423828
27 407.7875061035156
28 61.00379943847656
29 45.2071533203125
30 262.22796630859375
31 22.13348388671875
32 323.6727294921875
33 414.5010986328125
34 37.43296432495117
35 38.86087417602539
36 452.8277282714844
37 25.702402114868164
38 190.16111755371094
39 365.7681579589844
40 241.8499298095703
41 196.21559143066406
42 93.80511474609375
43 322.7909240722656
44 256.92578125
45 50.94523620605469
46 34.35666275024414
47 168.03700256347656
48 43.46895

395 15.373851776123047
396 1.321668028831482
397 5.129409313201904
398 13.80821704864502
399 8.840275764465332
400 3.8235158920288086
401 1.8559694290161133
402 10.840008735656738
403 0.5184151530265808
404 34.48851013183594
405 11.289118766784668
406 3.086751937866211
407 32.7764778137207
408 4.381278038024902
409 15.30043888092041
410 4.711716175079346
411 2.0136280059814453
412 5.85380220413208
413 7.872400283813477
414 15.625263214111328
415 6.362409591674805
416 2.6256206035614014
417 3.1110317707061768
418 5.753883361816406
419 10.0475492477417
420 4.173224449157715
421 3.4579968452453613
422 6.340757369995117
423 7.252662658691406
424 0.9920048713684082
425 9.579207420349121
426 2.619746208190918
427 2.473119020462036
428 7.991682529449463
429 4.940432548522949
430 3.9363062381744385
431 4.350082874298096
432 5.487261772155762
433 5.166940212249756
434 7.376772880554199
435 10.858755111694336
436 1.6926426887512207
437 2.700580358505249
438 1.1736829280853271
439 9.7647428512573