In [None]:
%matplotlib inline


PyTorch: nn
-----------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.
PyTorch autograd makes it easy to define computational graphs and take gradients,
but raw autograd can be a bit too low-level for defining complex neural networks;
this is where the nn package can help. The nn package defines a set of Modules,
which you can think of as a neural network layer that has produces output from
input and may have some trainable weights.



In [1]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model as a sequence of layers. nn.Sequential
# is a Module which contains other Modules, and applies them in sequence to
# produce its output. Each Linear Module computes output from input using a
# linear function, and holds internal Tensors for its weight and bias.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss(size_average=False)

learning_rate = 1e-4
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model. Module objects
    # override the __call__ operator so you can call them like functions. When
    # doing so you pass a Tensor of input data to the Module and it produces
    # a Tensor of output data.
    y_pred = model(x)

    # Compute and print loss. We pass Tensors containing the predicted and true
    # values of y, and the loss function returns a Tensor containing the
    # loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Zero the gradients before running the backward pass.
    model.zero_grad()

    # Backward pass: compute gradient of the loss with respect to all the learnable
    # parameters of the model. Internally, the parameters of each Module are stored
    # in Tensors with requires_grad=True, so this call will compute gradients for
    # all learnable parameters in the model.
    loss.backward()

    # Update the weights using gradient descent. Each parameter is a Tensor, so
    # we can access and gradients like we did before.
    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

0 624.5420532226562
1 575.4607543945312
2 533.3724365234375
3 496.6181640625
4 464.06512451171875
5 434.9016418457031
6 408.6252746582031
7 384.82550048828125
8 363.0043029785156
9 342.8042907714844
10 323.99005126953125
11 306.28839111328125
12 289.5701599121094
13 273.80474853515625
14 258.9190979003906
15 244.8496551513672
16 231.4934539794922
17 218.7944793701172
18 206.7321014404297
19 195.296630859375
20 184.4940185546875
21 174.24571228027344
22 164.484375
23 155.18238830566406
24 146.3510284423828
25 137.985595703125
26 130.05169677734375
27 122.52167510986328
28 115.41114807128906
29 108.67439270019531
30 102.31623840332031
31 96.31143188476562
32 90.64936065673828
33 85.31594848632812
34 80.28216552734375
35 75.53770446777344
36 71.06278228759766
37 66.85110473632812
38 62.89144515991211
39 59.16945266723633
40 55.66108322143555
41 52.36299514770508
42 49.24822235107422
43 46.31705856323242
44 43.561988830566406
45 40.976131439208984
46 38.55004119873047
47 36.2721061706543
4

429 1.7277012375416234e-05
430 1.6778238205006346e-05
431 1.629449070605915e-05
432 1.58248221850954e-05
433 1.536920535727404e-05
434 1.492761839472223e-05
435 1.4496726180368569e-05
436 1.4079943866818212e-05
437 1.3676625712832902e-05
438 1.3280951861815993e-05
439 1.290038380830083e-05
440 1.252920992556028e-05
441 1.2169395631644875e-05
442 1.182022606371902e-05
443 1.1479996828711592e-05
444 1.1150746104249265e-05
445 1.0829498023667838e-05
446 1.051957588060759e-05
447 1.0217261660727672e-05
448 9.92474815575406e-06
449 9.639484233048279e-06
450 9.364402103528846e-06
451 9.095394489122555e-06
452 8.834705113258678e-06
453 8.581064321333542e-06
454 8.335506208823062e-06
455 8.09655921329977e-06
456 7.865864063205663e-06
457 7.639318937435746e-06
458 7.421425380016444e-06
459 7.209064733615378e-06
460 7.001757239777362e-06
461 6.801632935093949e-06
462 6.606752322113607e-06
463 6.418799785024021e-06
464 6.234369720914401e-06
465 6.056281563360244e-06
466 5.883396170247579e-06
467 