In [None]:
%matplotlib inline


PyTorch: optim
--------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.

Rather than manually updating the weights of the model as we have been doing,
we use the optim package to define an Optimizer that will update the weights
for us. The optim package defines many optimization algorithms that are commonly
used for deep learning, including SGD+momentum, RMSProp, Adam, etc.



In [1]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
loss_fn = torch.nn.MSELoss(size_average=False)

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

0 669.8269653320312
1 652.672607421875
2 635.9494018554688
3 619.7540893554688
4 604.0421142578125
5 588.7352905273438
6 573.87353515625
7 559.48779296875
8 545.523681640625
9 531.9373779296875
10 518.7428588867188
11 505.9844665527344
12 493.67633056640625
13 481.72967529296875
14 470.09124755859375
15 458.8467102050781
16 447.99493408203125
17 437.3759460449219
18 427.002197265625
19 416.93243408203125
20 407.11968994140625
21 397.5495300292969
22 388.2081298828125
23 379.12774658203125
24 370.2745056152344
25 361.6503601074219
26 353.2135314941406
27 344.9685363769531
28 336.8570251464844
29 328.912841796875
30 321.18804931640625
31 313.6419982910156
32 306.25830078125
33 299.0367126464844
34 291.9704284667969
35 285.0777893066406
36 278.3537292480469
37 271.76153564453125
38 265.3178405761719
39 259.0311279296875
40 252.8808135986328
41 246.85047912597656
42 240.9495086669922
43 235.16036987304688
44 229.48973083496094
45 223.93301391601562
46 218.4736328125
47 213.12942504882812
4

438 1.7262058804590197e-07
439 1.589274489788295e-07
440 1.4629124223120016e-07
441 1.3454844349780615e-07
442 1.2393556403367256e-07
443 1.1402956090478256e-07
444 1.0482002466005724e-07
445 9.645424370319233e-08
446 8.879194979272143e-08
447 8.163316067566484e-08
448 7.50225339629651e-08
449 6.900004478893607e-08
450 6.351808679028181e-08
451 5.840417927061026e-08
452 5.370356603862092e-08
453 4.933685815444733e-08
454 4.535472797329021e-08
455 4.1696647912203844e-08
456 3.83229554756781e-08
457 3.514663404757812e-08
458 3.237902745922838e-08
459 2.976906543494806e-08
460 2.731848169901241e-08
461 2.5073298104416608e-08
462 2.3046496266942995e-08
463 2.1199756616852028e-08
464 1.9510483895146535e-08
465 1.79079435724816e-08
466 1.650080960757805e-08
467 1.511424940758843e-08
468 1.3851012781174177e-08
469 1.2732757070921252e-08
470 1.1719824222211628e-08
471 1.076122213561348e-08
472 9.914125520538164e-09
473 9.12664788188522e-09
474 8.380312443989624e-09
475 7.68752173030407e-09
476