In [1]:
# -*- coding: utf-8 -*-
import random
import torch
from torch.autograd import Variable


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we construct three nn.Linear instances that we will use
        in the forward pass.
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same Module many
        times when defining a computational graph. This is a big improvement from Lua
        Torch, where each Module could be used only once.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs, and wrap them in Variables
x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.data[0])

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 617.4964599609375
1 586.7207641601562
2 594.462890625
3 584.2737426757812
4 582.7604370117188
5 448.11956787109375
6 581.6450805664062
7 560.4078369140625
8 550.8650512695312
9 326.7133483886719
10 526.0846557617188
11 271.10186767578125
12 497.96246337890625
13 480.2035827636719
14 574.2371826171875
15 171.4656982421875
16 562.2449340820312
17 129.00836181640625
18 567.35595703125
19 91.0340347290039
20 560.6006469726562
21 64.0407485961914
22 520.9205322265625
23 545.0962524414062
24 494.73541259765625
25 474.9066467285156
26 449.0616760253906
27 258.177490234375
28 471.19976806640625
29 351.294189453125
30 410.5562744140625
31 369.94610595703125
32 247.9542694091797
33 232.91111755371094
34 220.1400146484375
35 180.5821533203125
36 195.04095458984375
37 170.70648193359375
38 147.42210388183594
39 130.0604705810547
40 129.25
41 175.88394165039062
42 159.11351013183594
43 110.16151428222656
44 90.27572631835938
45 85.23445129394531
46 140.77098083496094
47 80.3070068359375
48 299.40

375 1.3810803890228271
376 0.38162529468536377
377 1.035822868347168
378 1.039865255355835
379 1.013710379600525
380 0.24925124645233154
381 0.9570943713188171
382 0.2069171965122223
383 1.1371022462844849
384 0.781284749507904
385 0.5542237162590027
386 0.22487430274486542
387 0.7474405765533447
388 0.6208946108818054
389 0.19531965255737305
390 0.5539401769638062
391 0.226301908493042
392 0.40292954444885254
393 1.694028615951538
394 1.4330745935440063
395 0.14397381246089935
396 0.5758671760559082
397 0.14175888895988464
398 0.12704035639762878
399 0.10257995873689651
400 2.051517963409424
401 1.0105328559875488
402 1.0537744760513306
403 1.2365745306015015
404 1.4436901807785034
405 0.42905786633491516
406 0.8277648091316223
407 0.1363527923822403
408 0.9801706671714783
409 0.9845675230026245
410 0.17707951366901398
411 0.6263932585716248
412 0.10676196217536926
413 1.6918052434921265
414 0.08687221258878708
415 0.9948474168777466
416 0.8633851408958435
417 0.8391584753990173
418 0