In [1]:
%matplotlib inline


PyTorch: Control Flow + Weight Sharing
--------------------------------------

To showcase the power of `PyTorch dynamic graphs`, 

we will implement a very strange model: 

a fully-connected ReLU network that on each forward pass randomly chooses
a number between 1 and 4 

and has that many hidden layers, 

reusing the same weights multiple times to compute the innermost hidden layers.



In [3]:
import random
import torch


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we construct three nn.Linear instances that we will use
        in the forward pass.
        """
        super(DynamicNet, self).__init__()
        self.input_linear  = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        For the forward pass of the model, 
        we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same Module many
        times when defining a computational graph. This is a big improvement from Lua
        Torch, where each Module could be used only once.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        # Multiple Middle_Linear 0 times, or 1, or 3 times
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred

In [4]:
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)



In [5]:

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum

criterion = torch.nn.MSELoss() # reduction='sum'
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 1.0531384944915771
1 1.0986714363098145
2 1.053126335144043
3 1.0520378351211548
4 1.0520329475402832
5 1.0618231296539307
6 1.053080439567566
7 1.0530635118484497
8 1.0979418754577637
9 1.0977392196655273
10 1.051985263824463
11 1.0616248846054077
12 1.0615708827972412
13 1.0614978075027466
14 1.052924394607544
15 1.0529009103775024
16 1.0528767108917236
17 1.0958902835845947
18 1.051896333694458
19 1.0518838167190552
20 1.0951356887817383
21 1.0518563985824585
22 1.0608599185943604
23 1.094226360321045
24 1.0938503742218018
25 1.0933845043182373
26 1.0517877340316772
27 1.0923449993133545
28 1.0517594814300537
29 1.0912578105926514
30 1.0603508949279785
31 1.0525166988372803
32 1.0517010688781738
33 1.0892071723937988
34 1.0524446964263916
35 1.0600279569625854
36 1.0877869129180908
37 1.051628828048706
38 1.0523440837860107
39 1.0523184537887573
40 1.051585078239441
41 1.0515708923339844
42 1.051555871963501
43 1.0522117614746094
44 1.0848308801651
45 1.0515111684799194
46 1.08415

382 1.0343787670135498
383 1.0343071222305298
384 1.034220576286316
385 1.0444053411483765
386 1.0443860292434692
387 1.033950686454773
388 1.044343113899231
389 1.0443191528320312
390 1.0336850881576538
391 1.044269323348999
392 1.0442438125610352
393 1.0334243774414062
394 0.9847644567489624
395 1.044164776802063
396 0.9843646883964539
397 1.0441110134124756
398 1.0469725131988525
399 1.0440572500228882
400 1.0440301895141602
401 1.0440013408660889
402 1.0328171253204346
403 0.982917308807373
404 1.0439143180847168
405 1.0326379537582397
406 1.0325672626495361
407 1.0324804782867432
408 1.0323821306228638
409 1.0468568801879883
410 1.04376220703125
411 1.0437383651733398
412 0.9814281463623047
413 0.9812291264533997
414 1.0318472385406494
415 1.0436410903930664
416 0.980462908744812
417 0.9801542162895203
418 0.9797749519348145
419 1.0314886569976807
420 1.04674232006073
421 1.0435093641281128
422 0.9782465696334839
423 1.0434656143188477
424 1.031169056892395
425 0.9771747589111328
