In [1]:
%matplotlib inline


PyTorch: nn
-----------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.
PyTorch autograd makes it easy to define computational graphs and take gradients,
but raw autograd can be a bit too low-level for defining complex neural networks;
this is where the nn package can help. The nn package defines a set of Modules,
which you can think of as a neural network layer that has produces output from
input and may have some trainable weights.



In [2]:
import torch
from torch.autograd import Variable

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs, and wrap them in Variables.
x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

# Use the nn package to define our model as a sequence of layers. nn.Sequential
# is a Module which contains other Modules, and applies them in sequence to
# produce its output. Each Linear Module computes output from input using a
# linear function, and holds internal Variables for its weight and bias.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss(size_average=False)

learning_rate = 1e-4
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model. Module objects
    # override the __call__ operator so you can call them like functions. When
    # doing so you pass a Variable of input data to the Module and it produces
    # a Variable of output data.
    y_pred = model(x)

    # Compute and print loss. We pass Variables containing the predicted and true
    # values of y, and the loss function returns a Variable containing the
    # loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.data[0])

    # Zero the gradients before running the backward pass.
    model.zero_grad()

    # Backward pass: compute gradient of the loss with respect to all the learnable
    # parameters of the model. Internally, the parameters of each Module are stored
    # in Variables with requires_grad=True, so this call will compute gradients for
    # all learnable parameters in the model.
    loss.backward()

    # Update the weights using gradient descent. Each parameter is a Variable, so
    # we can access its data and gradients like we did before.
    for param in model.parameters():
        param.data -= learning_rate * param.grad.data

0 672.3374633789062
1 622.8316040039062
2 579.5445556640625
3 541.4376831054688
4 507.4725646972656
5 477.0326232910156
6 449.6767883300781
7 424.53094482421875
8 401.20355224609375
9 379.6190490722656
10 359.6780700683594
11 340.9647521972656
12 323.36737060546875
13 306.7540283203125
14 290.9610900878906
15 275.8394775390625
16 261.40625
17 247.704345703125
18 234.6118621826172
19 222.13772583007812
20 210.2431640625
21 198.9040985107422
22 188.09007263183594
23 177.8012237548828
24 167.99868774414062
25 158.65406799316406
26 149.75830078125
27 141.28176879882812
28 133.23516845703125
29 125.59796142578125
30 118.3651351928711
31 111.5326156616211
32 105.07860565185547
33 99.00236511230469
34 93.26808166503906
35 87.85501861572266
36 82.76030731201172
37 77.95792388916016
38 73.43431854248047
39 69.17705535888672
40 65.17559051513672
41 61.42422103881836
42 57.9016227722168
43 54.58761978149414
44 51.47725296020508
45 48.551536560058594
46 45.80130386352539
47 43.2195930480957
48 40.

404 0.00019783942843787372
405 0.0001927333214553073
406 0.00018776686920318753
407 0.00018293235916644335
408 0.0001782204199116677
409 0.0001736285339575261
410 0.00016915662854444236
411 0.00016480610065627843
412 0.00016056756430771202
413 0.00015644283848814666
414 0.0001524184044683352
415 0.0001485020329710096
416 0.0001446915848646313
417 0.00014097595703788102
418 0.00013735622633248568
419 0.0001338325091637671
420 0.00013040473277214915
421 0.0001270573993679136
422 0.00012380363477859646
423 0.00012063657777616754
424 0.00011754750448744744
425 0.00011454063496785238
426 0.00011161617294419557
427 0.00010876158194150776
428 0.00010597630171105266
429 0.00010327166819479316
430 0.00010063287481898442
431 9.80604236247018e-05
432 9.555993165122345e-05
433 9.311562462244183e-05
434 9.074102126760408e-05
435 8.842781971907243e-05
436 8.617312414571643e-05
437 8.39749991428107e-05
438 8.18328044260852e-05
439 7.974783511599526e-05
440 7.771500531816855e-05
441 7.5736636063084e-0