In [1]:
%matplotlib inline


PyTorch: nn
-----------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.
PyTorch autograd makes it easy to define computational graphs and take gradients,
but raw autograd can be a bit too low-level for defining complex neural networks;
this is where the nn package can help. The nn package defines a set of Modules,
which you can think of as a neural network layer that has produces output from
input and may have some trainable weights.



In [3]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model as a sequence of layers. nn.Sequential
# is a Module which contains other Modules, and applies them in sequence to
# produce its output. Each Linear Module computes output from input using a
# linear function, and holds internal Tensors for its weight and bias.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 1e-4
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model. Module objects
    # override the __call__ operator so you can call them like functions. When
    # doing so you pass a Tensor of input data to the Module and it produces
    # a Tensor of output data.
    y_pred = model(x)

    # Compute and print loss. We pass Tensors containing the predicted and true
    # values of y, and the loss function returns a Tensor containing the
    # loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Zero the gradients before running the backward pass.
    model.zero_grad()

    # Backward pass: compute gradient of the loss with respect to all the learnable
    # parameters of the model. Internally, the parameters of each Module are stored
    # in Tensors with requires_grad=True, so this call will compute gradients for
    # all learnable parameters in the model.
    loss.backward()

    # Update the weights using gradient descent. Each parameter is a Tensor, so
    # we can access its gradients like we did before.
    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

0 675.6571655273438
1 626.1762084960938
2 582.8814697265625
3 544.5223999023438
4 510.0816650390625
5 479.0755310058594
6 450.67352294921875
7 424.5599670410156
8 400.3929443359375
9 377.9338684082031
10 356.85772705078125
11 337.0431823730469
12 318.3936767578125
13 300.80218505859375
14 284.1739196777344
15 268.319091796875
16 253.26461791992188
17 239.0445556640625
18 225.5977020263672
19 212.82461547851562
20 200.67532348632812
21 189.16921997070312
22 178.23272705078125
23 167.82298278808594
24 157.95590209960938
25 148.55653381347656
26 139.68324279785156
27 131.3024444580078
28 123.38134002685547
29 115.89131164550781
30 108.81800079345703
31 102.15617370605469
32 95.86717987060547
33 89.94039916992188
34 84.37422180175781
35 79.11396026611328
36 74.16951751708984
37 69.53067016601562
38 65.18250274658203
39 61.098915100097656
40 57.276912689208984
41 53.68800735473633
42 50.33268356323242
43 47.18873596191406
44 44.246116638183594
45 41.4908447265625
46 38.91278076171875
47 36.

371 4.4308566430117935e-05
372 4.2899988329736516e-05
373 4.153487680014223e-05
374 4.0217582863988355e-05
375 3.893955363309942e-05
376 3.7703997804783285e-05
377 3.650517828646116e-05
378 3.534924690029584e-05
379 3.422874942771159e-05
380 3.314372224849649e-05
381 3.2092251785798e-05
382 3.107808515778743e-05
383 3.0093480745563284e-05
384 2.9140646802261472e-05
385 2.82179971691221e-05
386 2.7327059797244146e-05
387 2.6464687834959477e-05
388 2.5628165531088598e-05
389 2.4820219550747424e-05
390 2.403554753982462e-05
391 2.327842048543971e-05
392 2.2543181330547668e-05
393 2.1830892364960164e-05
394 2.1143603589734994e-05
395 2.0478251826716587e-05
396 1.98340458155144e-05
397 1.9208928279113024e-05
398 1.8605061995913275e-05
399 1.801879989216104e-05
400 1.7452461179345846e-05
401 1.6902946299524046e-05
402 1.6370702724088915e-05
403 1.585843892826233e-05
404 1.5360748875536956e-05
405 1.4879036825732328e-05
406 1.4410356925509404e-05
407 1.39588719321182e-05
408 1.351994251308497