In [1]:
%matplotlib inline


PyTorch: Control Flow + Weight Sharing
--------------------------------------

To showcase the power of PyTorch dynamic graphs, we will implement a very strange
model: a fully-connected ReLU network that on each forward pass randomly chooses
a number between 1 and 4 and has that many hidden layers, reusing the same
weights multiple times to compute the innermost hidden layers.



In [2]:
import random
import torch
from torch.autograd import Variable


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we construct three nn.Linear instances that we will use
        in the forward pass.
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same Module many
        times when defining a computational graph. This is a big improvement from Lua
        Torch, where each Module could be used only once.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs, and wrap them in Variables
x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.data[0])

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 682.3436889648438
1 638.6103515625
2 638.0169067382812
3 644.0669555664062
4 632.6254272460938
5 626.1472778320312
6 621.534423828125
7 459.9142761230469
8 612.0087890625
9 600.3358764648438
10 603.5641479492188
11 355.0858154296875
12 573.718505859375
13 591.21533203125
14 275.8436279296875
15 616.2935180664062
16 574.5927734375
17 508.6520080566406
18 606.9266357421875
19 161.63377380371094
20 534.8299560546875
21 586.9065551757812
22 114.10981750488281
23 564.3153686523438
24 469.3837585449219
25 86.1523208618164
26 336.72705078125
27 311.8858947753906
28 75.80751037597656
29 255.9938507080078
30 70.91107940673828
31 418.7335510253906
32 309.9340515136719
33 172.48699951171875
34 154.83633422851562
35 73.3981704711914
36 67.63911437988281
37 106.04136657714844
38 294.93475341796875
39 89.09300231933594
40 83.44585418701172
41 47.617244720458984
42 42.5150146484375
43 60.630455017089844
44 33.72962951660156
45 28.238750457763672
46 40.467002868652344
47 287.151123046875
48 307.2386

391 1.636651635169983
392 4.771994590759277
393 3.5596563816070557
394 5.888912200927734
395 0.688564121723175
396 1.3596407175064087
397 8.49829387664795
398 1.4878398180007935
399 0.9919367432594299
400 2.4244227409362793
401 0.6474348306655884
402 0.6501108407974243
403 2.761887788772583
404 1.8283811807632446
405 1.0153234004974365
406 1.2449439764022827
407 1.0114943981170654
408 2.080554246902466
409 0.5980648994445801
410 0.9889600276947021
411 2.4442126750946045
412 1.4858626127243042
413 1.868059515953064
414 2.3188750743865967
415 0.8402381539344788
416 1.586879014968872
417 1.7843743562698364
418 0.927699625492096
419 0.9928481578826904
420 1.3261089324951172
421 2.6886277198791504
422 0.7150315642356873
423 1.8665865659713745
424 3.0115714073181152
425 0.8044310808181763
426 1.0024043321609497
427 0.8530642986297607
428 5.0351972579956055
429 0.21429581940174103
430 3.80885648727417
431 1.0814388990402222
432 2.936898708343506
433 3.610461473464966
434 3.0358517169952393
43