In [1]:
%matplotlib inline


PyTorch: optim
--------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.

Rather than manually updating the weights of the model as we have been doing,
we use the optim package to define an Optimizer that will update the weights
for us. The optim package defines many optimization algorithms that are commonly
used for deep learning, including SGD+momentum, RMSProp, Adam, etc.



In [2]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
loss_fn = torch.nn.MSELoss(reduction='sum')

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

0 707.6091918945312
1 689.8565063476562
2 672.56494140625
3 655.7291870117188
4 639.2657470703125
5 623.2619018554688
6 607.6333618164062
7 592.470458984375
8 577.7386474609375
9 563.454345703125
10 549.6069946289062
11 536.0805053710938
12 523.0216674804688
13 510.43487548828125
14 498.33026123046875
15 486.5951232910156
16 475.1317138671875
17 463.95831298828125
18 453.0721740722656
19 442.46478271484375
20 432.1418151855469
21 422.0572814941406
22 412.2323303222656
23 402.6678771972656
24 393.4327392578125
25 384.4163513183594
26 375.6009826660156
27 366.96990966796875
28 358.52813720703125
29 350.26068115234375
30 342.2342834472656
31 334.3739929199219
32 326.7049865722656
33 319.2134704589844
34 311.88470458984375
35 304.7474365234375
36 297.7886657714844
37 290.9989013671875
38 284.3433532714844
39 277.82763671875
40 271.4323425292969
41 265.1755676269531
42 259.0511779785156
43 253.05239868164062
44 247.17807006835938
45 241.42262268066406
46 235.78857421875
47 230.2621307373047

365 0.0007857330492697656
366 0.0007496534381061792
367 0.0007151767495088279
368 0.000682228768710047
369 0.0006507717771455646
370 0.0006207136902958155
371 0.0005920098628848791
372 0.0005645871860906482
373 0.0005384132964536548
374 0.0005134022212587297
375 0.0004895354504697025
376 0.0004667250905185938
377 0.0004449661064427346
378 0.0004241768328938633
379 0.0004043379449285567
380 0.0003854058450087905
381 0.0003673039609566331
382 0.00035005767131224275
383 0.00033358432119712234
384 0.0003178729675710201
385 0.00030286883702501655
386 0.0002885417779907584
387 0.00027489085914567113
388 0.00026185810565948486
389 0.0002494225336704403
390 0.00023755909933242947
391 0.00022624089615419507
392 0.0002154769899789244
393 0.0002051514748018235
394 0.00019533470913302153
395 0.00018597710004542023
396 0.00017705025675240904
397 0.00016853683337103575
398 0.000160425144713372
399 0.00015268338029272854
400 0.00014530683984048665
401 0.00013827667862642556
402 0.00013156638306099921