In [None]:
%matplotlib inline


PyTorch: optim
--------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.

Rather than manually updating the weights of the model as we have been doing,
we use the optim package to define an Optimizer that will update the weights
for us. The optim package defines many optimization algorithms that are commonly
used for deep learning, including SGD+momentum, RMSProp, Adam, etc.



In [1]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
loss_fn = torch.nn.MSELoss(reduction='sum')

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

0 658.5504150390625
1 641.2799072265625
2 624.4940185546875
3 608.194091796875
4 592.3651123046875
5 577.0455932617188
6 562.234130859375
7 547.8652954101562
8 534.07568359375
9 520.7279663085938
10 507.6936950683594
11 494.9717712402344
12 482.62823486328125
13 470.6606140136719
14 459.0511169433594
15 447.7821960449219
16 436.84576416015625
17 426.2340087890625
18 415.9100646972656
19 405.9193115234375
20 396.1959228515625
21 386.7435302734375
22 377.5851135253906
23 368.6575012207031
24 359.96728515625
25 351.5531921386719
26 343.35833740234375
27 335.38580322265625
28 327.6744689941406
29 320.17071533203125
30 312.856689453125
31 305.7014465332031
32 298.6844482421875
33 291.82159423828125
34 285.1192932128906
35 278.5886535644531
36 272.1812438964844
37 265.9128112792969
38 259.77734375
39 253.7689666748047
40 247.88961791992188
41 242.12828063964844
42 236.47244262695312
43 230.92088317871094
44 225.4736328125
45 220.14918518066406
46 214.92869567871094
47 209.8229522705078
48 20

371 0.0005396429332904518
372 0.0005099581321701407
373 0.0004818574816454202
374 0.00045522963046096265
375 0.0004300199216231704
376 0.0004061473300680518
377 0.00038353531272150576
378 0.0003621435898821801
379 0.00034188610152341425
380 0.00032271453528665006
381 0.0003045838384423405
382 0.0002874249767046422
383 0.00027119467267766595
384 0.0002558462438173592
385 0.00024132401449605823
386 0.000227599564823322
387 0.00021462318545673043
388 0.0002023654233198613
389 0.00019077658362220973
390 0.00017982487042900175
391 0.00016947892436292022
392 0.00015970951062627137
393 0.00015047499618958682
394 0.00014176640252117068
395 0.00013353489339351654
396 0.0001257620460819453
397 0.00011842397361760959
398 0.00011150423233630136
399 0.0001049732236424461
400 9.881155710900202e-05
401 9.300142846768722e-05
402 8.752199209993705e-05
403 8.234706183429807e-05
404 7.747014751657844e-05
405 7.287124753929675e-05
406 6.85396371409297e-05
407 6.445442704716697e-05
408 6.0605965700233355e-