In [None]:
%matplotlib inline


PyTorch: nn
-----------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.
PyTorch autograd makes it easy to define computational graphs and take gradients,
but raw autograd can be a bit too low-level for defining complex neural networks;
this is where the nn package can help. The nn package defines a set of Modules,
which you can think of as a neural network layer that has produces output from
input and may have some trainable weights.



In [1]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model as a sequence of layers. nn.Sequential
# is a Module which contains other Modules, and applies them in sequence to
# produce its output. Each Linear Module computes output from input using a
# linear function, and holds internal Tensors for its weight and bias.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 1e-4
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model. Module objects
    # override the __call__ operator so you can call them like functions. When
    # doing so you pass a Tensor of input data to the Module and it produces
    # a Tensor of output data.
    y_pred = model(x)

    # Compute and print loss. We pass Tensors containing the predicted and true
    # values of y, and the loss function returns a Tensor containing the
    # loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Zero the gradients before running the backward pass.
    model.zero_grad()

    # Backward pass: compute gradient of the loss with respect to all the learnable
    # parameters of the model. Internally, the parameters of each Module are stored
    # in Tensors with requires_grad=True, so this call will compute gradients for
    # all learnable parameters in the model.
    loss.backward()

    # Update the weights using gradient descent. Each parameter is a Tensor, so
    # we can access its gradients like we did before.
    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

0 769.4638061523438
1 715.1837768554688
2 668.64111328125
3 627.4976806640625
4 591.0120849609375
5 558.0015869140625
6 527.9564208984375
7 500.0458679199219
8 474.05169677734375
9 449.5151672363281
10 426.47735595703125
11 404.79351806640625
12 384.1741027832031
13 364.58099365234375
14 346.0767822265625
15 328.3545227050781
16 311.4040222167969
17 295.15606689453125
18 279.59228515625
19 264.6990661621094
20 250.37939453125
21 236.66104125976562
22 223.5629119873047
23 211.08689880371094
24 199.20252990722656
25 187.8653564453125
26 177.0853271484375
27 166.824462890625
28 157.04901123046875
29 147.779541015625
30 138.98916625976562
31 130.65875244140625
32 122.78919982910156
33 115.3527603149414
34 108.3271255493164
35 101.69461822509766
36 95.45890808105469
37 89.59132385253906
38 84.0953140258789
39 78.92584228515625
40 74.07184600830078
41 69.52830505371094
42 65.26624298095703
43 61.26431655883789
44 57.508689880371094
45 53.99211120605469
46 50.6967658996582
47 47.6112136840820

348 0.00022786676709074527
349 0.0002204205229645595
350 0.00021322345128282905
351 0.000206267082830891
352 0.0001995312049984932
353 0.0001930213184095919
354 0.00018673283921089023
355 0.00018065038602799177
356 0.00017476383072789758
357 0.0001690750359557569
358 0.00016356077685486525
359 0.00015823946159798652
360 0.00015309025184251368
361 0.0001481085055274889
362 0.00014329116675071418
363 0.00013863116328138858
364 0.0001341276802122593
365 0.00012976920697838068
366 0.00012555168359540403
367 0.00012147249071858823
368 0.0001175263460027054
369 0.00011371258733561262
370 0.00011002095561707392
371 0.00010644979920471087
372 0.00010299916175426915
373 9.966258221538737e-05
374 9.643190423958004e-05
375 9.330597094958648e-05
376 9.028323256643489e-05
377 8.735925075598061e-05
378 8.452982001472265e-05
379 8.179531869245693e-05
380 7.914852176327258e-05
381 7.658584945602342e-05
382 7.410722901113331e-05
383 7.171348261181265e-05
384 6.939373270142823e-05
385 6.715585186611861e